├── Kohya Trainer XL Runpod.ipynb ├── Kohya_SDXL_Inference.ipynb ├── LICENSE.md ├── README.md ├── XTI_hijack.py ├── bitsandbytes_windows ├── cextension.py ├── libbitsandbytes_cpu.dll ├── libbitsandbytes_cuda116.dll └── main.py ├── fast-kohya-trainer.ipynb ├── fine_tune.py ├── finetune ├── blip │ ├── blip.py │ ├── med.py │ ├── med_config.json │ └── vit.py ├── clean_captions_and_tags.py ├── hypernetwork_nai.py ├── make_captions.py ├── make_captions_by_git.py ├── merge_all_to_metadata.py ├── merge_captions_to_metadata.py ├── merge_dd_tags_to_metadata.py ├── prepare_buckets_latents.py └── tag_images_by_wd14_tagger.py ├── gen_img_diffusers.py ├── kohya-LoRA-dreambooth.ipynb ├── kohya-LoRA-finetuner.ipynb ├── kohya-LoRA-trainer-XL.ipynb ├── kohya-dreambooth.ipynb ├── kohya-trainer-XL.ipynb ├── kohya-trainer.ipynb ├── library ├── __init__.py ├── config_util.py ├── custom_train_functions.py ├── huggingface_util.py ├── lpw_stable_diffusion.py ├── model_util.py ├── train_util.py └── utils.py ├── networks ├── check_lora_weights.py ├── extract_lora_from_models.py ├── lora.py ├── lora_interrogator.py ├── merge_lora.py ├── merge_lora_old.py ├── resize_lora.py └── svd_merge_lora.py ├── requirements.txt ├── setup.py ├── tools ├── canny.py ├── code-snippet.ipynb ├── convert_diffusers20_original_sd.py ├── detect_face_rotate.py ├── merge_block_weighted.py ├── merge_vae.py ├── original_control_net.py └── resize_images_to_resolution.py ├── train_db.py ├── train_network.py ├── train_textual_inversion.py └── train_textual_inversion_XTI.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /XTI_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, List, Optional, Dict, Any, Tuple 3 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 4 | 5 | def unet_forward_XTI(self, 6 | sample: torch.FloatTensor, 7 | timestep: Union[torch.Tensor, float, int], 8 | encoder_hidden_states: torch.Tensor, 9 | class_labels: Optional[torch.Tensor] = None, 10 | return_dict: bool = True, 11 | ) -> Union[UNet2DConditionOutput, Tuple]: 12 | r""" 13 | Args: 14 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 15 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 16 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 17 | return_dict (`bool`, *optional*, defaults to `True`): 18 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 19 | 20 | Returns: 21 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 22 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 23 | returning a tuple, the first element is the sample tensor. 24 | """ 25 | # By default samples have to be AT least a multiple of the overall upsampling factor. 26 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 27 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 28 | # on the fly if necessary. 29 | default_overall_up_factor = 2**self.num_upsamplers 30 | 31 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 32 | forward_upsample_size = False 33 | upsample_size = None 34 | 35 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 36 | logger.info("Forward upsample size to force interpolation output size.") 37 | forward_upsample_size = True 38 | 39 | # 0. center input if necessary 40 | if self.config.center_input_sample: 41 | sample = 2 * sample - 1.0 42 | 43 | # 1. time 44 | timesteps = timestep 45 | if not torch.is_tensor(timesteps): 46 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 47 | # This would be a good case for the `match` statement (Python 3.10+) 48 | is_mps = sample.device.type == "mps" 49 | if isinstance(timestep, float): 50 | dtype = torch.float32 if is_mps else torch.float64 51 | else: 52 | dtype = torch.int32 if is_mps else torch.int64 53 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 54 | elif len(timesteps.shape) == 0: 55 | timesteps = timesteps[None].to(sample.device) 56 | 57 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 58 | timesteps = timesteps.expand(sample.shape[0]) 59 | 60 | t_emb = self.time_proj(timesteps) 61 | 62 | # timesteps does not contain any weights and will always return f32 tensors 63 | # but time_embedding might actually be running in fp16. so we need to cast here. 64 | # there might be better ways to encapsulate this. 65 | t_emb = t_emb.to(dtype=self.dtype) 66 | emb = self.time_embedding(t_emb) 67 | 68 | if self.config.num_class_embeds is not None: 69 | if class_labels is None: 70 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 71 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 72 | emb = emb + class_emb 73 | 74 | # 2. pre-process 75 | sample = self.conv_in(sample) 76 | 77 | # 3. down 78 | down_block_res_samples = (sample,) 79 | down_i = 0 80 | for downsample_block in self.down_blocks: 81 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 82 | sample, res_samples = downsample_block( 83 | hidden_states=sample, 84 | temb=emb, 85 | encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], 86 | ) 87 | down_i += 2 88 | else: 89 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 90 | 91 | down_block_res_samples += res_samples 92 | 93 | # 4. mid 94 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) 95 | 96 | # 5. up 97 | up_i = 7 98 | for i, upsample_block in enumerate(self.up_blocks): 99 | is_final_block = i == len(self.up_blocks) - 1 100 | 101 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 102 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 103 | 104 | # if we have not reached the final block and need to forward the 105 | # upsample size, we do it here 106 | if not is_final_block and forward_upsample_size: 107 | upsample_size = down_block_res_samples[-1].shape[2:] 108 | 109 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 110 | sample = upsample_block( 111 | hidden_states=sample, 112 | temb=emb, 113 | res_hidden_states_tuple=res_samples, 114 | encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], 115 | upsample_size=upsample_size, 116 | ) 117 | up_i += 3 118 | else: 119 | sample = upsample_block( 120 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 121 | ) 122 | # 6. post-process 123 | sample = self.conv_norm_out(sample) 124 | sample = self.conv_act(sample) 125 | sample = self.conv_out(sample) 126 | 127 | if not return_dict: 128 | return (sample,) 129 | 130 | return UNet2DConditionOutput(sample=sample) 131 | 132 | def downblock_forward_XTI( 133 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None 134 | ): 135 | output_states = () 136 | i = 0 137 | 138 | for resnet, attn in zip(self.resnets, self.attentions): 139 | if self.training and self.gradient_checkpointing: 140 | 141 | def create_custom_forward(module, return_dict=None): 142 | def custom_forward(*inputs): 143 | if return_dict is not None: 144 | return module(*inputs, return_dict=return_dict) 145 | else: 146 | return module(*inputs) 147 | 148 | return custom_forward 149 | 150 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 151 | hidden_states = torch.utils.checkpoint.checkpoint( 152 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 153 | )[0] 154 | else: 155 | hidden_states = resnet(hidden_states, temb) 156 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 157 | 158 | output_states += (hidden_states,) 159 | i += 1 160 | 161 | if self.downsamplers is not None: 162 | for downsampler in self.downsamplers: 163 | hidden_states = downsampler(hidden_states) 164 | 165 | output_states += (hidden_states,) 166 | 167 | return hidden_states, output_states 168 | 169 | def upblock_forward_XTI( 170 | self, 171 | hidden_states, 172 | res_hidden_states_tuple, 173 | temb=None, 174 | encoder_hidden_states=None, 175 | upsample_size=None, 176 | ): 177 | i = 0 178 | for resnet, attn in zip(self.resnets, self.attentions): 179 | # pop res hidden states 180 | res_hidden_states = res_hidden_states_tuple[-1] 181 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 182 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 183 | 184 | if self.training and self.gradient_checkpointing: 185 | 186 | def create_custom_forward(module, return_dict=None): 187 | def custom_forward(*inputs): 188 | if return_dict is not None: 189 | return module(*inputs, return_dict=return_dict) 190 | else: 191 | return module(*inputs) 192 | 193 | return custom_forward 194 | 195 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 196 | hidden_states = torch.utils.checkpoint.checkpoint( 197 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 198 | )[0] 199 | else: 200 | hidden_states = resnet(hidden_states, temb) 201 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 202 | 203 | i += 1 204 | 205 | if self.upsamplers is not None: 206 | for upsampler in self.upsamplers: 207 | hidden_states = upsampler(hidden_states, upsample_size) 208 | 209 | return hidden_states -------------------------------------------------------------------------------- /bitsandbytes_windows/cextension.py: -------------------------------------------------------------------------------- 1 | import ctypes as ct 2 | from pathlib import Path 3 | from warnings import warn 4 | 5 | from .cuda_setup.main import evaluate_cuda_setup 6 | 7 | 8 | class CUDALibrary_Singleton(object): 9 | _instance = None 10 | 11 | def __init__(self): 12 | raise RuntimeError("Call get_instance() instead") 13 | 14 | def initialize(self): 15 | binary_name = evaluate_cuda_setup() 16 | package_dir = Path(__file__).parent 17 | binary_path = package_dir / binary_name 18 | 19 | if not binary_path.exists(): 20 | print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}") 21 | legacy_binary_name = "libbitsandbytes.so" 22 | print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") 23 | binary_path = package_dir / legacy_binary_name 24 | if not binary_path.exists(): 25 | print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!') 26 | print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') 27 | raise Exception('CUDA SETUP: Setup Failed!') 28 | # self.lib = ct.cdll.LoadLibrary(binary_path) 29 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ 30 | else: 31 | print(f"CUDA SETUP: Loading binary {binary_path}...") 32 | # self.lib = ct.cdll.LoadLibrary(binary_path) 33 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ 34 | 35 | @classmethod 36 | def get_instance(cls): 37 | if cls._instance is None: 38 | cls._instance = cls.__new__(cls) 39 | cls._instance.initialize() 40 | return cls._instance 41 | 42 | 43 | lib = CUDALibrary_Singleton.get_instance().lib 44 | try: 45 | lib.cadam32bit_g32 46 | lib.get_context.restype = ct.c_void_p 47 | lib.get_cusparse.restype = ct.c_void_p 48 | COMPILED_WITH_CUDA = True 49 | except AttributeError: 50 | warn( 51 | "The installed version of bitsandbytes was compiled without GPU support. " 52 | "8-bit optimizers and GPU quantization are unavailable." 53 | ) 54 | COMPILED_WITH_CUDA = False 55 | -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cpu.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linaqruf/kohya-trainer/c2a9dc897ac0634450efb257a89577efa2d2487a/bitsandbytes_windows/libbitsandbytes_cpu.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cuda116.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linaqruf/kohya-trainer/c2a9dc897ac0634450efb257a89577efa2d2487a/bitsandbytes_windows/libbitsandbytes_cuda116.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | extract factors the build is dependent on: 3 | [X] compute capability 4 | [ ] TODO: Q - What if we have multiple GPUs of different makes? 5 | - CUDA version 6 | - Software: 7 | - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) 8 | - CuBLAS-LT: full-build 8-bit optimizer 9 | - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) 10 | 11 | evaluation: 12 | - if paths faulty, return meaningful error 13 | - else: 14 | - determine CUDA version 15 | - determine capabilities 16 | - based on that set the default path 17 | """ 18 | 19 | import ctypes 20 | 21 | from .paths import determine_cuda_runtime_lib_path 22 | 23 | 24 | def check_cuda_result(cuda, result_val): 25 | # 3. Check for CUDA errors 26 | if result_val != 0: 27 | error_str = ctypes.c_char_p() 28 | cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) 29 | print(f"CUDA exception! Error code: {error_str.value.decode()}") 30 | 31 | def get_cuda_version(cuda, cudart_path): 32 | # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION 33 | try: 34 | cudart = ctypes.CDLL(cudart_path) 35 | except OSError: 36 | # TODO: shouldn't we error or at least warn here? 37 | print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') 38 | return None 39 | 40 | version = ctypes.c_int() 41 | check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) 42 | version = int(version.value) 43 | major = version//1000 44 | minor = (version-(major*1000))//10 45 | 46 | if major < 11: 47 | print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') 48 | 49 | return f'{major}{minor}' 50 | 51 | 52 | def get_cuda_lib_handle(): 53 | # 1. find libcuda.so library (GPU driver) (/usr/lib) 54 | try: 55 | cuda = ctypes.CDLL("libcuda.so") 56 | except OSError: 57 | # TODO: shouldn't we error or at least warn here? 58 | print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') 59 | return None 60 | check_cuda_result(cuda, cuda.cuInit(0)) 61 | 62 | return cuda 63 | 64 | 65 | def get_compute_capabilities(cuda): 66 | """ 67 | 1. find libcuda.so library (GPU driver) (/usr/lib) 68 | init_device -> init variables -> call function by reference 69 | 2. call extern C function to determine CC 70 | (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) 71 | 3. Check for CUDA errors 72 | https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 73 | # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 74 | """ 75 | 76 | 77 | nGpus = ctypes.c_int() 78 | cc_major = ctypes.c_int() 79 | cc_minor = ctypes.c_int() 80 | 81 | device = ctypes.c_int() 82 | 83 | check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) 84 | ccs = [] 85 | for i in range(nGpus.value): 86 | check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) 87 | ref_major = ctypes.byref(cc_major) 88 | ref_minor = ctypes.byref(cc_minor) 89 | # 2. call extern C function to determine CC 90 | check_cuda_result( 91 | cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) 92 | ) 93 | ccs.append(f"{cc_major.value}.{cc_minor.value}") 94 | 95 | return ccs 96 | 97 | 98 | # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error 99 | def get_compute_capability(cuda): 100 | """ 101 | Extracts the highest compute capbility from all available GPUs, as compute 102 | capabilities are downwards compatible. If no GPUs are detected, it returns 103 | None. 104 | """ 105 | ccs = get_compute_capabilities(cuda) 106 | if ccs is not None: 107 | # TODO: handle different compute capabilities; for now, take the max 108 | return ccs[-1] 109 | return None 110 | 111 | 112 | def evaluate_cuda_setup(): 113 | print('') 114 | print('='*35 + 'BUG REPORT' + '='*35) 115 | print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') 116 | print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') 117 | print('='*80) 118 | return "libbitsandbytes_cuda116.dll" # $$$ 119 | 120 | binary_name = "libbitsandbytes_cpu.so" 121 | #if not torch.cuda.is_available(): 122 | #print('No GPU detected. Loading CPU library...') 123 | #return binary_name 124 | 125 | cudart_path = determine_cuda_runtime_lib_path() 126 | if cudart_path is None: 127 | print( 128 | "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" 129 | ) 130 | return binary_name 131 | 132 | print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") 133 | cuda = get_cuda_lib_handle() 134 | cc = get_compute_capability(cuda) 135 | print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") 136 | cuda_version_string = get_cuda_version(cuda, cudart_path) 137 | 138 | 139 | if cc == '': 140 | print( 141 | "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." 142 | ) 143 | return binary_name 144 | 145 | # 7.5 is the minimum CC vor cublaslt 146 | has_cublaslt = cc in ["7.5", "8.0", "8.6"] 147 | 148 | # TODO: 149 | # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) 150 | # (2) Multiple CUDA versions installed 151 | 152 | # we use ls -l instead of nvcc to determine the cuda version 153 | # since most installations will have the libcudart.so installed, but not the compiler 154 | print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') 155 | 156 | def get_binary_name(): 157 | "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" 158 | bin_base_name = "libbitsandbytes_cuda" 159 | if has_cublaslt: 160 | return f"{bin_base_name}{cuda_version_string}.so" 161 | else: 162 | return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" 163 | 164 | binary_name = get_binary_name() 165 | 166 | return binary_name 167 | -------------------------------------------------------------------------------- /finetune/blip/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | # from models.vit import VisionTransformer, interpolate_pos_embed 12 | # from models.med import BertConfig, BertModel, BertLMHeadModel 13 | from blip.vit import VisionTransformer, interpolate_pos_embed 14 | from blip.med import BertConfig, BertModel, BertLMHeadModel 15 | from transformers import BertTokenizer 16 | 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | 21 | import os 22 | from urllib.parse import urlparse 23 | from timm.models.hub import download_cached_file 24 | 25 | class BLIP_Base(nn.Module): 26 | def __init__(self, 27 | med_config = 'configs/med_config.json', 28 | image_size = 224, 29 | vit = 'base', 30 | vit_grad_ckpt = False, 31 | vit_ckpt_layer = 0, 32 | ): 33 | """ 34 | Args: 35 | med_config (str): path for the mixture of encoder-decoder model's configuration file 36 | image_size (int): input image size 37 | vit (str): model size of vision transformer 38 | """ 39 | super().__init__() 40 | 41 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 42 | self.tokenizer = init_tokenizer() 43 | med_config = BertConfig.from_json_file(med_config) 44 | med_config.encoder_width = vision_width 45 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 46 | 47 | 48 | def forward(self, image, caption, mode): 49 | 50 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 51 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 52 | 53 | if mode=='image': 54 | # return image features 55 | image_embeds = self.visual_encoder(image) 56 | return image_embeds 57 | 58 | elif mode=='text': 59 | # return text features 60 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 61 | return_dict = True, mode = 'text') 62 | return text_output.last_hidden_state 63 | 64 | elif mode=='multimodal': 65 | # return multimodel features 66 | image_embeds = self.visual_encoder(image) 67 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 68 | 69 | text.input_ids[:,0] = self.tokenizer.enc_token_id 70 | output = self.text_encoder(text.input_ids, 71 | attention_mask = text.attention_mask, 72 | encoder_hidden_states = image_embeds, 73 | encoder_attention_mask = image_atts, 74 | return_dict = True, 75 | ) 76 | return output.last_hidden_state 77 | 78 | 79 | 80 | class BLIP_Decoder(nn.Module): 81 | def __init__(self, 82 | med_config = 'configs/med_config.json', 83 | image_size = 384, 84 | vit = 'base', 85 | vit_grad_ckpt = False, 86 | vit_ckpt_layer = 0, 87 | prompt = 'a picture of ', 88 | ): 89 | """ 90 | Args: 91 | med_config (str): path for the mixture of encoder-decoder model's configuration file 92 | image_size (int): input image size 93 | vit (str): model size of vision transformer 94 | """ 95 | super().__init__() 96 | 97 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 98 | self.tokenizer = init_tokenizer() 99 | med_config = BertConfig.from_json_file(med_config) 100 | med_config.encoder_width = vision_width 101 | self.text_decoder = BertLMHeadModel(config=med_config) 102 | 103 | self.prompt = prompt 104 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 105 | 106 | 107 | def forward(self, image, caption): 108 | 109 | image_embeds = self.visual_encoder(image) 110 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 111 | 112 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 113 | 114 | text.input_ids[:,0] = self.tokenizer.bos_token_id 115 | 116 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 117 | decoder_targets[:,:self.prompt_length] = -100 118 | 119 | decoder_output = self.text_decoder(text.input_ids, 120 | attention_mask = text.attention_mask, 121 | encoder_hidden_states = image_embeds, 122 | encoder_attention_mask = image_atts, 123 | labels = decoder_targets, 124 | return_dict = True, 125 | ) 126 | loss_lm = decoder_output.loss 127 | 128 | return loss_lm 129 | 130 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 131 | image_embeds = self.visual_encoder(image) 132 | 133 | if not sample: 134 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 135 | 136 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 137 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 138 | 139 | prompt = [self.prompt] * image.size(0) 140 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 141 | input_ids[:,0] = self.tokenizer.bos_token_id 142 | input_ids = input_ids[:, :-1] 143 | 144 | if sample: 145 | #nucleus sampling 146 | outputs = self.text_decoder.generate(input_ids=input_ids, 147 | max_length=max_length, 148 | min_length=min_length, 149 | do_sample=True, 150 | top_p=top_p, 151 | num_return_sequences=1, 152 | eos_token_id=self.tokenizer.sep_token_id, 153 | pad_token_id=self.tokenizer.pad_token_id, 154 | repetition_penalty=1.1, 155 | **model_kwargs) 156 | else: 157 | #beam search 158 | outputs = self.text_decoder.generate(input_ids=input_ids, 159 | max_length=max_length, 160 | min_length=min_length, 161 | num_beams=num_beams, 162 | eos_token_id=self.tokenizer.sep_token_id, 163 | pad_token_id=self.tokenizer.pad_token_id, 164 | repetition_penalty=repetition_penalty, 165 | **model_kwargs) 166 | 167 | captions = [] 168 | for output in outputs: 169 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 170 | captions.append(caption[len(self.prompt):]) 171 | return captions 172 | 173 | 174 | def blip_decoder(pretrained='',**kwargs): 175 | model = BLIP_Decoder(**kwargs) 176 | if pretrained: 177 | model,msg = load_checkpoint(model,pretrained) 178 | assert(len(msg.missing_keys)==0) 179 | return model 180 | 181 | def blip_feature_extractor(pretrained='',**kwargs): 182 | model = BLIP_Base(**kwargs) 183 | if pretrained: 184 | model,msg = load_checkpoint(model,pretrained) 185 | assert(len(msg.missing_keys)==0) 186 | return model 187 | 188 | def init_tokenizer(): 189 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 190 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 191 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 192 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 193 | return tokenizer 194 | 195 | 196 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 197 | 198 | assert vit in ['base', 'large'], "vit parameter must be base or large" 199 | if vit=='base': 200 | vision_width = 768 201 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 202 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 203 | drop_path_rate=0 or drop_path_rate 204 | ) 205 | elif vit=='large': 206 | vision_width = 1024 207 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 208 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 209 | drop_path_rate=0.1 or drop_path_rate 210 | ) 211 | return visual_encoder, vision_width 212 | 213 | def is_url(url_or_filename): 214 | parsed = urlparse(url_or_filename) 215 | return parsed.scheme in ("http", "https") 216 | 217 | def load_checkpoint(model,url_or_filename): 218 | if is_url(url_or_filename): 219 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 220 | checkpoint = torch.load(cached_file, map_location='cpu') 221 | elif os.path.isfile(url_or_filename): 222 | checkpoint = torch.load(url_or_filename, map_location='cpu') 223 | else: 224 | raise RuntimeError('checkpoint url or path is invalid') 225 | 226 | state_dict = checkpoint['model'] 227 | 228 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 229 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 230 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 231 | model.visual_encoder_m) 232 | for key in model.state_dict().keys(): 233 | if key in state_dict.keys(): 234 | if state_dict[key].shape!=model.state_dict()[key].shape: 235 | del state_dict[key] 236 | 237 | msg = model.load_state_dict(state_dict,strict=False) 238 | print('load checkpoint from %s'%url_or_filename) 239 | return model,msg 240 | 241 | -------------------------------------------------------------------------------- /finetune/blip/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /finetune/blip/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /finetune/clean_captions_and_tags.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | import re 9 | 10 | from tqdm import tqdm 11 | 12 | PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') 13 | PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') 14 | PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') 15 | PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') 16 | 17 | # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する 18 | PATTERNS_REMOVE_IN_MULTI = [ 19 | PATTERN_HAIR_LENGTH, 20 | PATTERN_HAIR_CUT, 21 | re.compile(r', [\w\-]+ eyes, '), 22 | re.compile(r', ([\w\-]+ sleeves|sleeveless), '), 23 | # 複数の髪型定義がある場合は削除する 24 | re.compile( 25 | r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), 26 | ] 27 | 28 | 29 | def clean_tags(image_key, tags): 30 | # replace '_' to ' ' 31 | tags = tags.replace('^_^', '^@@@^') 32 | tags = tags.replace('_', ' ') 33 | tags = tags.replace('^@@@^', '^_^') 34 | 35 | # remove rating: deepdanbooruのみ 36 | tokens = tags.split(", rating") 37 | if len(tokens) == 1: 38 | # WD14 taggerのときはこちらになるのでメッセージは出さない 39 | # print("no rating:") 40 | # print(f"{image_key} {tags}") 41 | pass 42 | else: 43 | if len(tokens) > 2: 44 | print("multiple ratings:") 45 | print(f"{image_key} {tags}") 46 | tags = tokens[0] 47 | 48 | tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 49 | 50 | # 複数の人物がいる場合は髪色等のタグを削除する 51 | if 'girls' in tags or 'boys' in tags: 52 | for pat in PATTERNS_REMOVE_IN_MULTI: 53 | found = pat.findall(tags) 54 | if len(found) > 1: # 二つ以上、タグがある 55 | tags = pat.sub("", tags) 56 | 57 | # 髪の特殊対応 58 | srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) 59 | if srch_hair_len: 60 | org = srch_hair_len.group() 61 | tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) 62 | 63 | found = PATTERN_HAIR.findall(tags) 64 | if len(found) > 1: 65 | tags = PATTERN_HAIR.sub("", tags) 66 | 67 | if srch_hair_len: 68 | tags = tags.replace(", @@@, ", org) # 戻す 69 | 70 | # white shirtとshirtみたいな重複タグの削除 71 | found = PATTERN_WORD.findall(tags) 72 | for word in found: 73 | if re.search(f", ((\w+) )+{word}, ", tags): 74 | tags = tags.replace(f", {word}, ", "") 75 | 76 | tags = tags.replace(", , ", ", ") 77 | assert tags.startswith(", ") and tags.endswith(", ") 78 | tags = tags[2:-2] 79 | return tags 80 | 81 | 82 | # 上から順に検索、置換される 83 | # ('置換元文字列', '置換後文字列') 84 | CAPTION_REPLACEMENTS = [ 85 | ('anime anime', 'anime'), 86 | ('young ', ''), 87 | ('anime girl', 'girl'), 88 | ('cartoon female', 'girl'), 89 | ('cartoon lady', 'girl'), 90 | ('cartoon character', 'girl'), # a or ~s 91 | ('cartoon woman', 'girl'), 92 | ('cartoon women', 'girls'), 93 | ('cartoon girl', 'girl'), 94 | ('anime female', 'girl'), 95 | ('anime lady', 'girl'), 96 | ('anime character', 'girl'), # a or ~s 97 | ('anime woman', 'girl'), 98 | ('anime women', 'girls'), 99 | ('lady', 'girl'), 100 | ('female', 'girl'), 101 | ('woman', 'girl'), 102 | ('women', 'girls'), 103 | ('people', 'girls'), 104 | ('person', 'girl'), 105 | ('a cartoon figure', 'a figure'), 106 | ('a cartoon image', 'an image'), 107 | ('a cartoon picture', 'a picture'), 108 | ('an anime cartoon image', 'an image'), 109 | ('a cartoon anime drawing', 'a drawing'), 110 | ('a cartoon drawing', 'a drawing'), 111 | ('girl girl', 'girl'), 112 | ] 113 | 114 | 115 | def clean_caption(caption): 116 | for rf, rt in CAPTION_REPLACEMENTS: 117 | replaced = True 118 | while replaced: 119 | bef = caption 120 | caption = caption.replace(rf, rt) 121 | replaced = bef != caption 122 | return caption 123 | 124 | 125 | def main(args): 126 | if os.path.exists(args.in_json): 127 | print(f"loading existing metadata: {args.in_json}") 128 | with open(args.in_json, "rt", encoding='utf-8') as f: 129 | metadata = json.load(f) 130 | else: 131 | print("no metadata / メタデータファイルがありません") 132 | return 133 | 134 | print("cleaning captions and tags.") 135 | image_keys = list(metadata.keys()) 136 | for image_key in tqdm(image_keys): 137 | tags = metadata[image_key].get('tags') 138 | if tags is None: 139 | print(f"image does not have tags / メタデータにタグがありません: {image_key}") 140 | else: 141 | org = tags 142 | tags = clean_tags(image_key, tags) 143 | metadata[image_key]['tags'] = tags 144 | if args.debug and org != tags: 145 | print("FROM: " + org) 146 | print("TO: " + tags) 147 | 148 | caption = metadata[image_key].get('caption') 149 | if caption is None: 150 | print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") 151 | else: 152 | org = caption 153 | caption = clean_caption(caption) 154 | metadata[image_key]['caption'] = caption 155 | if args.debug and org != caption: 156 | print("FROM: " + org) 157 | print("TO: " + caption) 158 | 159 | # metadataを書き出して終わり 160 | print(f"writing metadata: {args.out_json}") 161 | with open(args.out_json, "wt", encoding='utf-8') as f: 162 | json.dump(metadata, f, indent=2) 163 | print("done!") 164 | 165 | 166 | def setup_parser() -> argparse.ArgumentParser: 167 | parser = argparse.ArgumentParser() 168 | # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 169 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 170 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 171 | parser.add_argument("--debug", action="store_true", help="debug mode") 172 | 173 | return parser 174 | 175 | 176 | if __name__ == '__main__': 177 | parser = setup_parser() 178 | 179 | args, unknown = parser.parse_known_args() 180 | if len(unknown) == 1: 181 | print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") 182 | print("All captions and tags in the metadata are processed.") 183 | print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") 184 | print("メタデータ内のすべてのキャプションとタグが処理されます。") 185 | args.in_json = args.out_json 186 | args.out_json = unknown[0] 187 | elif len(unknown) > 0: 188 | raise ValueError(f"error: unrecognized arguments: {unknown}") 189 | 190 | main(args) 191 | -------------------------------------------------------------------------------- /finetune/hypernetwork_nai.py: -------------------------------------------------------------------------------- 1 | # NAI compatible 2 | 3 | import torch 4 | 5 | 6 | class HypernetworkModule(torch.nn.Module): 7 | def __init__(self, dim, multiplier=1.0): 8 | super().__init__() 9 | 10 | linear1 = torch.nn.Linear(dim, dim * 2) 11 | linear2 = torch.nn.Linear(dim * 2, dim) 12 | linear1.weight.data.normal_(mean=0.0, std=0.01) 13 | linear1.bias.data.zero_() 14 | linear2.weight.data.normal_(mean=0.0, std=0.01) 15 | linear2.bias.data.zero_() 16 | linears = [linear1, linear2] 17 | 18 | self.linear = torch.nn.Sequential(*linears) 19 | self.multiplier = multiplier 20 | 21 | def forward(self, x): 22 | return x + self.linear(x) * self.multiplier 23 | 24 | 25 | class Hypernetwork(torch.nn.Module): 26 | enable_sizes = [320, 640, 768, 1280] 27 | # return self.modules[Hypernetwork.enable_sizes.index(size)] 28 | 29 | def __init__(self, multiplier=1.0) -> None: 30 | super().__init__() 31 | self.modules = [] 32 | for size in Hypernetwork.enable_sizes: 33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) 34 | self.register_module(f"{size}_0", self.modules[-1][0]) 35 | self.register_module(f"{size}_1", self.modules[-1][1]) 36 | 37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet): 38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks 39 | for block in blocks: 40 | for subblk in block: 41 | if 'SpatialTransformer' in str(type(subblk)): 42 | for tf_block in subblk.transformer_blocks: 43 | for attn in [tf_block.attn1, tf_block.attn2]: 44 | size = attn.context_dim 45 | if size in Hypernetwork.enable_sizes: 46 | attn.hypernetwork = self 47 | else: 48 | attn.hypernetwork = None 49 | 50 | def apply_to_diffusers(self, text_encoder, vae, unet): 51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks 52 | for block in blocks: 53 | if hasattr(block, 'attentions'): 54 | for subblk in block.attentions: 55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ 56 | for tf_block in subblk.transformer_blocks: 57 | for attn in [tf_block.attn1, tf_block.attn2]: 58 | size = attn.to_k.in_features 59 | if size in Hypernetwork.enable_sizes: 60 | attn.hypernetwork = self 61 | else: 62 | attn.hypernetwork = None 63 | return True # TODO error checking 64 | 65 | def forward(self, x, context): 66 | size = context.shape[-1] 67 | assert size in Hypernetwork.enable_sizes 68 | module = self.modules[Hypernetwork.enable_sizes.index(size)] 69 | return module[0].forward(context), module[1].forward(context) 70 | 71 | def load_from_state_dict(self, state_dict): 72 | # old ver to new ver 73 | changes = { 74 | 'linear1.bias': 'linear.0.bias', 75 | 'linear1.weight': 'linear.0.weight', 76 | 'linear2.bias': 'linear.1.bias', 77 | 'linear2.weight': 'linear.1.weight', 78 | } 79 | for key_from, key_to in changes.items(): 80 | if key_from in state_dict: 81 | state_dict[key_to] = state_dict[key_from] 82 | del state_dict[key_from] 83 | 84 | for size, sd in state_dict.items(): 85 | if type(size) == int: 86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) 87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) 88 | return True 89 | 90 | def get_state_dict(self): 91 | state_dict = {} 92 | for i, size in enumerate(Hypernetwork.enable_sizes): 93 | sd0 = self.modules[i][0].state_dict() 94 | sd1 = self.modules[i][1].state_dict() 95 | state_dict[size] = [sd0, sd1] 96 | return state_dict 97 | -------------------------------------------------------------------------------- /finetune/make_captions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import random 6 | 7 | from pathlib import Path 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import numpy as np 11 | import torch 12 | from torchvision import transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | from blip.blip import blip_decoder 15 | import library.train_util as train_util 16 | 17 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | IMAGE_SIZE = 384 21 | 22 | # 正方形でいいのか? という気がするがソースがそうなので 23 | IMAGE_TRANSFORM = transforms.Compose([ 24 | transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 27 | ]) 28 | 29 | # 共通化したいが微妙に処理が異なる…… 30 | class ImageLoadingTransformDataset(torch.utils.data.Dataset): 31 | def __init__(self, image_paths): 32 | self.images = image_paths 33 | 34 | def __len__(self): 35 | return len(self.images) 36 | 37 | def __getitem__(self, idx): 38 | img_path = self.images[idx] 39 | 40 | try: 41 | image = Image.open(img_path).convert("RGB") 42 | # convert to tensor temporarily so dataloader will accept it 43 | tensor = IMAGE_TRANSFORM(image) 44 | except Exception as e: 45 | print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") 46 | return None 47 | 48 | return (tensor, img_path) 49 | 50 | 51 | def collate_fn_remove_corrupted(batch): 52 | """Collate function that allows to remove corrupted examples in the 53 | dataloader. It expects that the dataloader returns 'None' when that occurs. 54 | The 'None's in the batch are removed. 55 | """ 56 | # Filter out all the Nones (corrupted examples) 57 | batch = list(filter(lambda x: x is not None, batch)) 58 | return batch 59 | 60 | 61 | def main(args): 62 | # fix the seed for reproducibility 63 | seed = args.seed # + utils.get_rank() 64 | torch.manual_seed(seed) 65 | np.random.seed(seed) 66 | random.seed(seed) 67 | 68 | if not os.path.exists("blip"): 69 | args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path 70 | 71 | cwd = os.getcwd() 72 | print('Current Working Directory is: ', cwd) 73 | os.chdir('finetune') 74 | 75 | print(f"load images from {args.train_data_dir}") 76 | train_data_dir = Path(args.train_data_dir) 77 | image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive) 78 | print(f"found {len(image_paths)} images.") 79 | 80 | print(f"loading BLIP caption: {args.caption_weights}") 81 | model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") 82 | model.eval() 83 | model = model.to(DEVICE) 84 | print("BLIP loaded") 85 | 86 | # captioningする 87 | def run_batch(path_imgs): 88 | imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) 89 | 90 | with torch.no_grad(): 91 | if args.beam_search: 92 | captions = model.generate(imgs, sample=False, num_beams=args.num_beams, 93 | max_length=args.max_length, min_length=args.min_length) 94 | else: 95 | captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) 96 | 97 | for (image_path, _), caption in zip(path_imgs, captions): 98 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: 99 | f.write(caption + "\n") 100 | if args.debug: 101 | print(image_path, caption) 102 | 103 | # 読み込みの高速化のためにDataLoaderを使うオプション 104 | if args.max_data_loader_n_workers is not None: 105 | dataset = ImageLoadingTransformDataset(image_paths) 106 | data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 107 | num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) 108 | else: 109 | data = [[(None, ip)] for ip in image_paths] 110 | 111 | b_imgs = [] 112 | for data_entry in tqdm(data, smoothing=0.0): 113 | for data in data_entry: 114 | if data is None: 115 | continue 116 | 117 | img_tensor, image_path = data 118 | if img_tensor is None: 119 | try: 120 | raw_image = Image.open(image_path) 121 | if raw_image.mode != 'RGB': 122 | raw_image = raw_image.convert("RGB") 123 | img_tensor = IMAGE_TRANSFORM(raw_image) 124 | except Exception as e: 125 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 126 | continue 127 | 128 | b_imgs.append((image_path, img_tensor)) 129 | if len(b_imgs) >= args.batch_size: 130 | run_batch(b_imgs) 131 | b_imgs.clear() 132 | if len(b_imgs) > 0: 133 | run_batch(b_imgs) 134 | 135 | print("done!") 136 | 137 | 138 | def setup_parser() -> argparse.ArgumentParser: 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 141 | parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", 142 | help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") 143 | parser.add_argument("--caption_extention", type=str, default=None, 144 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 145 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 146 | parser.add_argument("--beam_search", action="store_true", 147 | help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") 148 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 149 | parser.add_argument("--max_data_loader_n_workers", type=int, default=None, 150 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") 151 | parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") 152 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") 153 | parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") 154 | parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") 155 | parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') 156 | parser.add_argument("--debug", action="store_true", help="debug mode") 157 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") 158 | 159 | return parser 160 | 161 | 162 | if __name__ == '__main__': 163 | parser = setup_parser() 164 | 165 | args = parser.parse_args() 166 | 167 | # スペルミスしていたオプションを復元する 168 | if args.caption_extention is not None: 169 | args.caption_extension = args.caption_extention 170 | 171 | main(args) 172 | -------------------------------------------------------------------------------- /finetune/make_captions_by_git.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | from pathlib import Path 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import torch 9 | from transformers import AutoProcessor, AutoModelForCausalLM 10 | from transformers.generation.utils import GenerationMixin 11 | 12 | import library.train_util as train_util 13 | 14 | 15 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | PATTERN_REPLACE = [ 18 | re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), 19 | re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), 20 | re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), 21 | re.compile(r'with the number \d+ on (it|\w+ \w+)'), 22 | re.compile(r'with the words "'), 23 | re.compile(r'word \w+ on it'), 24 | re.compile(r'that says the word \w+ on it'), 25 | re.compile('that says\'the word "( on it)?'), 26 | ] 27 | 28 | # 誤検知しまくりの with the word xxxx を消す 29 | 30 | 31 | def remove_words(captions, debug): 32 | removed_caps = [] 33 | for caption in captions: 34 | cap = caption 35 | for pat in PATTERN_REPLACE: 36 | cap = pat.sub("", cap) 37 | if debug and cap != caption: 38 | print(caption) 39 | print(cap) 40 | removed_caps.append(cap) 41 | return removed_caps 42 | 43 | 44 | def collate_fn_remove_corrupted(batch): 45 | """Collate function that allows to remove corrupted examples in the 46 | dataloader. It expects that the dataloader returns 'None' when that occurs. 47 | The 'None's in the batch are removed. 48 | """ 49 | # Filter out all the Nones (corrupted examples) 50 | batch = list(filter(lambda x: x is not None, batch)) 51 | return batch 52 | 53 | 54 | def main(args): 55 | # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 56 | org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation 57 | curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように 58 | 59 | # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す 60 | # ここより上で置き換えようとするとすごく大変 61 | def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): 62 | input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) 63 | if input_ids.size()[0] != curr_batch_size[0]: 64 | input_ids = input_ids.repeat(curr_batch_size[0], 1) 65 | return input_ids 66 | GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch 67 | 68 | print(f"load images from {args.train_data_dir}") 69 | train_data_dir = Path(args.train_data_dir) 70 | image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive) 71 | print(f"found {len(image_paths)} images.") 72 | 73 | # できればcacheに依存せず明示的にダウンロードしたい 74 | print(f"loading GIT: {args.model_id}") 75 | git_processor = AutoProcessor.from_pretrained(args.model_id) 76 | git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) 77 | print("GIT loaded") 78 | 79 | # captioningする 80 | def run_batch(path_imgs): 81 | imgs = [im for _, im in path_imgs] 82 | 83 | curr_batch_size[0] = len(path_imgs) 84 | inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 85 | generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) 86 | captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) 87 | 88 | if args.remove_words: 89 | captions = remove_words(captions, args.debug) 90 | 91 | for (image_path, _), caption in zip(path_imgs, captions): 92 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: 93 | f.write(caption + "\n") 94 | if args.debug: 95 | print(image_path, caption) 96 | 97 | # 読み込みの高速化のためにDataLoaderを使うオプション 98 | if args.max_data_loader_n_workers is not None: 99 | dataset = train_util.ImageLoadingDataset(image_paths) 100 | data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 101 | num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) 102 | else: 103 | data = [[(None, ip)] for ip in image_paths] 104 | 105 | b_imgs = [] 106 | for data_entry in tqdm(data, smoothing=0.0): 107 | for data in data_entry: 108 | if data is None: 109 | continue 110 | 111 | image, image_path = data 112 | if image is None: 113 | try: 114 | image = Image.open(image_path) 115 | if image.mode != 'RGB': 116 | image = image.convert("RGB") 117 | except Exception as e: 118 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 119 | continue 120 | 121 | b_imgs.append((image_path, image)) 122 | if len(b_imgs) >= args.batch_size: 123 | run_batch(b_imgs) 124 | b_imgs.clear() 125 | 126 | if len(b_imgs) > 0: 127 | run_batch(b_imgs) 128 | 129 | print("done!") 130 | 131 | 132 | def setup_parser() -> argparse.ArgumentParser: 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 135 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 136 | parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps", 137 | help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID") 138 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 139 | parser.add_argument("--max_data_loader_n_workers", type=int, default=None, 140 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") 141 | parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") 142 | parser.add_argument("--remove_words", action="store_true", 143 | help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") 144 | parser.add_argument("--debug", action="store_true", help="debug mode") 145 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") 146 | 147 | return parser 148 | 149 | 150 | if __name__ == '__main__': 151 | parser = setup_parser() 152 | 153 | args = parser.parse_args() 154 | main(args) 155 | -------------------------------------------------------------------------------- /finetune/merge_all_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | from pathlib import Path 6 | from typing import List 7 | from tqdm import tqdm 8 | from collections import Counter 9 | import library.train_util as train_util 10 | 11 | TAGS_EXT = ".txt" 12 | CAPTION_EXT = ".caption" 13 | 14 | PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') 15 | PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') 16 | PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') 17 | PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') 18 | 19 | PATTERNS_REMOVE_IN_MULTI = [ 20 | PATTERN_HAIR_LENGTH, 21 | PATTERN_HAIR_CUT, 22 | re.compile(r', [\w\-]+ eyes, '), 23 | re.compile(r', ([\w\-]+ sleeves|sleeveless), '), 24 | re.compile( 25 | r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), 26 | ] 27 | 28 | CAPTION_REPLACEMENTS = [ 29 | ('anime anime', 'anime'), 30 | ('young ', ''), 31 | ('anime girl', 'girl'), 32 | ('cartoon female', 'girl'), 33 | ('cartoon lady', 'girl'), 34 | ('cartoon character', 'girl'), 35 | ('cartoon woman', 'girl'), 36 | ('cartoon women', 'girls'), 37 | ('cartoon girl', 'girl'), 38 | ('anime female', 'girl'), 39 | ('anime lady', 'girl'), 40 | ('anime character', 'girl'), 41 | ('anime woman', 'girl'), 42 | ('anime women', 'girls'), 43 | ('lady', 'girl'), 44 | ('female', 'girl'), 45 | ('woman', 'girl'), 46 | ('women', 'girls'), 47 | ('people', 'girls'), 48 | ('person', 'girl'), 49 | ('a cartoon figure', 'a figure'), 50 | ('a cartoon image', 'an image'), 51 | ('a cartoon picture', 'a picture'), 52 | ('an anime cartoon image', 'an image'), 53 | ('a cartoon anime drawing', 'a drawing'), 54 | ('a cartoon drawing', 'a drawing'), 55 | ('girl girl', 'girl'), 56 | ] 57 | 58 | def clean_tags(image_key, tags): 59 | tags = tags.replace('^_^', '^@@@^') 60 | tags = tags.replace('_', ' ') 61 | tags = tags.replace('^@@@^', '^_^') 62 | 63 | tokens = tags.split(", rating") 64 | if len(tokens) == 1: 65 | pass 66 | else: 67 | if len(tokens) > 2: 68 | print("multiple ratings:") 69 | print(f"{image_key} {tags}") 70 | tags = tokens[0] 71 | 72 | tags = ", " + tags.replace(", ", ", , ") + ", " 73 | 74 | if 'girls' in tags or 'boys' in tags: 75 | for pat in PATTERNS_REMOVE_IN_MULTI: 76 | found = pat.findall(tags) 77 | if len(found) > 1: 78 | tags = pat.sub("", tags) 79 | 80 | srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) 81 | if srch_hair_len: 82 | org = srch_hair_len.group() 83 | tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) 84 | 85 | found = PATTERN_HAIR.findall(tags) 86 | if len(found) > 1: 87 | tags = PATTERN_HAIR.sub("", tags) 88 | 89 | if srch_hair_len: 90 | tags = tags.replace(", @@@, ", org) 91 | 92 | found = PATTERN_WORD.findall(tags) 93 | for word in found: 94 | if re.search(f", ((\w+) )+{word}, ", tags): 95 | tags = tags.replace(f", {word}, ", "") 96 | 97 | tags = tags.replace(", , ", ", ") 98 | assert tags.startswith(", ") and tags.endswith(", ") 99 | tags = tags[2:-2] 100 | return tags 101 | 102 | def clean_caption(caption): 103 | for rf, rt in CAPTION_REPLACEMENTS: 104 | replaced = True 105 | while replaced: 106 | bef = caption 107 | caption = caption.replace(rf, rt) 108 | replaced = bef != caption 109 | return caption 110 | 111 | def count_files(image_paths, metadata): 112 | counts = Counter({'_captions': 0, '_tags': 0}) 113 | 114 | for image_key in metadata: 115 | if 'tags' not in metadata[image_key]: 116 | counts['_tags'] += 1 117 | if 'caption' not in metadata[image_key]: 118 | counts['_captions'] += 1 119 | 120 | return counts 121 | 122 | def report_counts(counts, total_files): 123 | for key, value in counts.items(): 124 | if value == total_files: 125 | print(f"No {key.replace('_', '')} found for any of the {total_files} images") 126 | elif value == 0: 127 | print(f"All {total_files} images have {key.replace('_', '')}") 128 | else: 129 | print(f"{total_files - value}/{total_files} images have {key.replace('_', '')}") 130 | 131 | def merge_metadata(image_paths, metadata, full_path): 132 | for image_path in tqdm(image_paths): 133 | tags_path = image_path.with_suffix(TAGS_EXT) 134 | if not tags_path.exists(): 135 | tags_path = image_path.joinpath(TAGS_EXT) 136 | 137 | caption_path = image_path.with_suffix(CAPTION_EXT) 138 | if not caption_path.exists(): 139 | caption_path = image_path.joinpath(CAPTION_EXT) 140 | 141 | image_key = str(image_path) if full_path else image_path.stem 142 | if image_key not in metadata: 143 | metadata[image_key] = {} 144 | 145 | if tags_path.is_file(): 146 | tags = tags_path.read_text(encoding='utf-8').strip() 147 | metadata[image_key]['tags'] = tags 148 | 149 | if caption_path.is_file(): 150 | caption = caption_path.read_text(encoding='utf-8').strip() 151 | metadata[image_key]['caption'] = caption 152 | 153 | counts = count_files(image_paths, metadata) 154 | report_counts(counts, len(image_paths)) 155 | 156 | return metadata 157 | 158 | def clean_metadata(metadata): 159 | image_keys = list(metadata.keys()) 160 | for image_key in tqdm(image_keys): 161 | tags = metadata[image_key].get('tags') 162 | if tags is not None: 163 | org = tags 164 | tags = clean_tags(image_key, tags) 165 | metadata[image_key]['tags'] = tags 166 | 167 | caption = metadata[image_key].get('caption') 168 | if caption is not None: 169 | org = caption 170 | caption = clean_caption(caption) 171 | metadata[image_key]['caption'] = caption 172 | 173 | return metadata 174 | 175 | def main(args): 176 | assert not args.recursive or (args.recursive and args.full_path), "--recursive requires --full_path!" 177 | 178 | train_data_dir_path = Path(args.train_data_dir) 179 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 180 | print(f"Found {len(image_paths)} images.") 181 | 182 | if args.in_json is not None: 183 | print(f"Loading existing metadata: {args.in_json}") 184 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 185 | print("Metadata for existing images will be overwritten") 186 | else: 187 | print("Creating a new metadata file") 188 | metadata = {} 189 | 190 | print("Merging tags and captions into metadata json.") 191 | metadata = merge_metadata(image_paths, metadata, args.full_path) 192 | 193 | if args.clean_caption: 194 | print("Cleaning captions and tags.") 195 | metadata = clean_metadata(metadata) 196 | 197 | if args.debug: 198 | print("Debug: image_key, tags, caption") 199 | for image_key, data in metadata.items(): 200 | print(image_key, data['tags'], data['caption']) 201 | 202 | print(f"Writing metadata: {args.out_json}") 203 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 204 | print("Done!") 205 | 206 | def setup_parser() -> argparse.ArgumentParser: 207 | parser = argparse.ArgumentParser() 208 | parser.add_argument("train_data_dir", type=str, help="directory for train images") 209 | parser.add_argument("out_json", type=str, help="metadata file to output") 210 | parser.add_argument("--in_json", type=str, 211 | help="metadata file to input (if omitted and out_json exists, existing out_json is read)") 212 | parser.add_argument("--full_path", action="store_true", 213 | help="use full path as image-key in metadata (supports multiple directories)") 214 | parser.add_argument("--recursive", action="store_true", 215 | help="recursively search for training tags and captions in all child folders of train_data_dir") 216 | parser.add_argument("--debug", action="store_true", help="debug mode") 217 | parser.add_argument("--clean_caption", action="store_true", help="clean captions and tags in metadata") 218 | 219 | return parser 220 | 221 | if __name__ == '__main__': 222 | parser = setup_parser() 223 | 224 | args = parser.parse_args() 225 | 226 | main(args) 227 | -------------------------------------------------------------------------------- /finetune/merge_captions_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | 9 | def main(args): 10 | assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 11 | 12 | train_data_dir_path = Path(args.train_data_dir) 13 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and Path(args.out_json).is_file(): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 22 | print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") 23 | else: 24 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 25 | metadata = {} 26 | 27 | print("merge caption texts to metadata json.") 28 | for image_path in tqdm(image_paths): 29 | caption_path = image_path.with_suffix(args.caption_extension) 30 | caption = caption_path.read_text(encoding='utf-8').strip() 31 | 32 | if not os.path.exists(caption_path): 33 | caption_path = os.path.join(image_path, args.caption_extension) 34 | 35 | image_key = str(image_path) if args.full_path else image_path.stem 36 | if image_key not in metadata: 37 | metadata[image_key] = {} 38 | 39 | metadata[image_key]['caption'] = caption 40 | if args.debug: 41 | print(image_key, caption) 42 | 43 | # metadataを書き出して終わり 44 | print(f"writing metadata: {args.out_json}") 45 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 46 | print("done!") 47 | 48 | 49 | def setup_parser() -> argparse.ArgumentParser: 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 52 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 53 | parser.add_argument("--in_json", type=str, 54 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 55 | parser.add_argument("--caption_extention", type=str, default=None, 56 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 57 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") 58 | parser.add_argument("--full_path", action="store_true", 59 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 60 | parser.add_argument("--recursive", action="store_true", 61 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 62 | parser.add_argument("--debug", action="store_true", help="debug mode") 63 | 64 | return parser 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = setup_parser() 69 | 70 | args = parser.parse_args() 71 | 72 | # スペルミスしていたオプションを復元する 73 | if args.caption_extention is not None: 74 | args.caption_extension = args.caption_extention 75 | 76 | main(args) 77 | -------------------------------------------------------------------------------- /finetune/merge_dd_tags_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | 9 | def main(args): 10 | assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 11 | 12 | train_data_dir_path = Path(args.train_data_dir) 13 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and Path(args.out_json).is_file(): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 22 | print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") 23 | else: 24 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 25 | metadata = {} 26 | 27 | print("merge tags to metadata json.") 28 | for image_path in tqdm(image_paths): 29 | tags_path = image_path.with_suffix(args.caption_extension) 30 | tags = tags_path.read_text(encoding='utf-8').strip() 31 | 32 | if not os.path.exists(tags_path): 33 | tags_path = os.path.join(image_path, args.caption_extension) 34 | 35 | image_key = str(image_path) if args.full_path else image_path.stem 36 | if image_key not in metadata: 37 | metadata[image_key] = {} 38 | 39 | metadata[image_key]['tags'] = tags 40 | if args.debug: 41 | print(image_key, tags) 42 | 43 | # metadataを書き出して終わり 44 | print(f"writing metadata: {args.out_json}") 45 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 46 | 47 | print("done!") 48 | 49 | 50 | def setup_parser() -> argparse.ArgumentParser: 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 53 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 54 | parser.add_argument("--in_json", type=str, 55 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 56 | parser.add_argument("--full_path", action="store_true", 57 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 58 | parser.add_argument("--recursive", action="store_true", 59 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 60 | parser.add_argument("--caption_extension", type=str, default=".txt", 61 | help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") 62 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags") 63 | 64 | return parser 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = setup_parser() 69 | 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /finetune/prepare_buckets_latents.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | from pathlib import Path 6 | from typing import List 7 | from tqdm import tqdm 8 | import numpy as np 9 | from PIL import Image 10 | import cv2 11 | import torch 12 | from torchvision import transforms 13 | 14 | import library.model_util as model_util 15 | import library.train_util as train_util 16 | 17 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | IMAGE_TRANSFORMS = transforms.Compose( 20 | [ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.5], [0.5]), 23 | ] 24 | ) 25 | 26 | 27 | def collate_fn_remove_corrupted(batch): 28 | """Collate function that allows to remove corrupted examples in the 29 | dataloader. It expects that the dataloader returns 'None' when that occurs. 30 | The 'None's in the batch are removed. 31 | """ 32 | # Filter out all the Nones (corrupted examples) 33 | batch = list(filter(lambda x: x is not None, batch)) 34 | return batch 35 | 36 | 37 | def get_latents(vae, images, weight_dtype): 38 | img_tensors = [IMAGE_TRANSFORMS(image) for image in images] 39 | img_tensors = torch.stack(img_tensors) 40 | img_tensors = img_tensors.to(DEVICE, weight_dtype) 41 | with torch.no_grad(): 42 | latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() 43 | return latents 44 | 45 | 46 | def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): 47 | if is_full_path: 48 | base_name = os.path.splitext(os.path.basename(image_key))[0] 49 | relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) 50 | else: 51 | base_name = image_key 52 | relative_path = "" 53 | 54 | if flip: 55 | base_name += '_flip' 56 | 57 | if recursive and relative_path: 58 | return os.path.join(data_dir, relative_path, base_name) 59 | else: 60 | return os.path.join(data_dir, base_name) 61 | 62 | 63 | 64 | def main(args): 65 | # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" 66 | if args.bucket_reso_steps % 8 > 0: 67 | print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") 68 | 69 | train_data_dir_path = Path(args.train_data_dir) 70 | image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] 71 | print(f"found {len(image_paths)} images.") 72 | 73 | if os.path.exists(args.in_json): 74 | print(f"loading existing metadata: {args.in_json}") 75 | with open(args.in_json, "rt", encoding='utf-8') as f: 76 | metadata = json.load(f) 77 | else: 78 | print(f"no metadata / メタデータファイルがありません: {args.in_json}") 79 | return 80 | 81 | weight_dtype = torch.float32 82 | if args.mixed_precision == "fp16": 83 | weight_dtype = torch.float16 84 | elif args.mixed_precision == "bf16": 85 | weight_dtype = torch.bfloat16 86 | 87 | vae = model_util.load_vae(args.model_name_or_path, weight_dtype) 88 | vae.eval() 89 | vae.to(DEVICE, dtype=weight_dtype) 90 | 91 | # bucketのサイズを計算する 92 | max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) 93 | assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" 94 | 95 | bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso, 96 | args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps) 97 | if not args.bucket_no_upscale: 98 | bucket_manager.make_buckets() 99 | else: 100 | print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") 101 | 102 | # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する 103 | img_ar_errors = [] 104 | 105 | def process_batch(is_last): 106 | for bucket in bucket_manager.buckets: 107 | if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: 108 | latents = get_latents(vae, [img for _, img in bucket], weight_dtype) 109 | assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \ 110 | f"latent shape {latents.shape}, {bucket[0][1].shape}" 111 | 112 | for (image_key, _), latent in zip(bucket, latents): 113 | npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) 114 | np.savez(npz_file_name, latent) 115 | 116 | # flip 117 | if args.flip_aug: 118 | latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない 119 | 120 | for (image_key, _), latent in zip(bucket, latents): 121 | npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) 122 | np.savez(npz_file_name, latent) 123 | else: 124 | # remove existing flipped npz 125 | for image_key, _ in bucket: 126 | npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" 127 | if os.path.isfile(npz_file_name): 128 | print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") 129 | os.remove(npz_file_name) 130 | 131 | bucket.clear() 132 | 133 | # 読み込みの高速化のためにDataLoaderを使うオプション 134 | if args.max_data_loader_n_workers is not None: 135 | dataset = train_util.ImageLoadingDataset(image_paths) 136 | data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 137 | num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) 138 | else: 139 | data = [[(None, ip)] for ip in image_paths] 140 | 141 | bucket_counts = {} 142 | for data_entry in tqdm(data, smoothing=0.0): 143 | if data_entry[0] is None: 144 | continue 145 | 146 | img_tensor, image_path = data_entry[0] 147 | if img_tensor is not None: 148 | image = transforms.functional.to_pil_image(img_tensor) 149 | else: 150 | try: 151 | image = Image.open(image_path) 152 | if image.mode != 'RGB': 153 | image = image.convert("RGB") 154 | except Exception as e: 155 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 156 | continue 157 | 158 | image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] 159 | if image_key not in metadata: 160 | metadata[image_key] = {} 161 | 162 | # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 163 | 164 | reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) 165 | img_ar_errors.append(abs(ar_error)) 166 | bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 167 | 168 | # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て 169 | metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) 170 | 171 | if not args.bucket_no_upscale: 172 | # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する 173 | assert resized_size[0] == reso[0] or resized_size[1] == reso[ 174 | 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" 175 | assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ 176 | 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" 177 | 178 | assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ 179 | 1], f"internal error resized size is small: {resized_size}, {reso}" 180 | 181 | # 既に存在するファイルがあればshapeを確認して同じならskipする 182 | if args.skip_existing: 183 | npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] 184 | if args.flip_aug: 185 | npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz") 186 | 187 | found = True 188 | for npz_file in npz_files: 189 | if not os.path.exists(npz_file): 190 | found = False 191 | break 192 | 193 | dat = np.load(npz_file)['arr_0'] 194 | if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 195 | found = False 196 | break 197 | if found: 198 | continue 199 | 200 | # 画像をリサイズしてトリミングする 201 | # PILにinter_areaがないのでcv2で…… 202 | image = np.array(image) 203 | if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? 204 | image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) 205 | 206 | if resized_size[0] > reso[0]: 207 | trim_size = resized_size[0] - reso[0] 208 | image = image[:, trim_size//2:trim_size//2 + reso[0]] 209 | 210 | if resized_size[1] > reso[1]: 211 | trim_size = resized_size[1] - reso[1] 212 | image = image[trim_size//2:trim_size//2 + reso[1]] 213 | 214 | assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" 215 | 216 | # # debug 217 | # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) 218 | 219 | # バッチへ追加 220 | bucket_manager.add_image(reso, (image_key, image)) 221 | 222 | # バッチを推論するか判定して推論する 223 | process_batch(False) 224 | 225 | # 残りを処理する 226 | process_batch(True) 227 | 228 | bucket_manager.sort() 229 | for i, reso in enumerate(bucket_manager.resos): 230 | count = bucket_counts.get(reso, 0) 231 | if count > 0: 232 | print(f"bucket {i} {reso}: {count}") 233 | img_ar_errors = np.array(img_ar_errors) 234 | print(f"mean ar error: {np.mean(img_ar_errors)}") 235 | 236 | # metadataを書き出して終わり 237 | print(f"writing metadata: {args.out_json}") 238 | with open(args.out_json, "wt", encoding='utf-8') as f: 239 | json.dump(metadata, f, indent=2) 240 | print("done!") 241 | 242 | 243 | def setup_parser() -> argparse.ArgumentParser: 244 | parser = argparse.ArgumentParser() 245 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 246 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 247 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 248 | parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") 249 | parser.add_argument("--v2", action='store_true', 250 | help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)') 251 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 252 | parser.add_argument("--max_data_loader_n_workers", type=int, default=None, 253 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") 254 | parser.add_argument("--max_resolution", type=str, default="512,512", 255 | help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") 256 | parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") 257 | parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") 258 | parser.add_argument("--bucket_reso_steps", type=int, default=64, 259 | help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") 260 | parser.add_argument("--bucket_no_upscale", action="store_true", 261 | help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") 262 | parser.add_argument("--mixed_precision", type=str, default="no", 263 | choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") 264 | parser.add_argument("--full_path", action="store_true", 265 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 266 | parser.add_argument("--flip_aug", action="store_true", 267 | help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") 268 | parser.add_argument("--skip_existing", action="store_true", 269 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)") 270 | parser.add_argument("--recursive", action="store_true", 271 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 272 | 273 | return parser 274 | 275 | 276 | if __name__ == '__main__': 277 | parser = setup_parser() 278 | 279 | args = parser.parse_args() 280 | main(args) 281 | -------------------------------------------------------------------------------- /finetune/tag_images_by_wd14_tagger.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import glob 4 | import os 5 | 6 | from PIL import Image 7 | import cv2 8 | from tqdm import tqdm 9 | import numpy as np 10 | from tensorflow.keras.models import load_model 11 | from huggingface_hub import hf_hub_download 12 | import torch 13 | from pathlib import Path 14 | 15 | import library.train_util as train_util 16 | 17 | # from wd14 tagger 18 | IMAGE_SIZE = 448 19 | 20 | # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 21 | DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' 22 | FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] 23 | SUB_DIR = "variables" 24 | SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] 25 | CSV_FILE = FILES[-1] 26 | 27 | def preprocess_image(image): 28 | image = np.array(image) 29 | image = image[:, :, ::-1] # RGB->BGR 30 | 31 | # pad to square 32 | size = max(image.shape[0:2]) 33 | pad_x = size - image.shape[1] 34 | pad_y = size - image.shape[0] 35 | pad_l = pad_x // 2 36 | pad_t = pad_y // 2 37 | image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) 38 | 39 | interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 40 | image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) 41 | 42 | image = image.astype(np.float32) 43 | return image 44 | 45 | class ImageLoadingPrepDataset(torch.utils.data.Dataset): 46 | def __init__(self, image_paths): 47 | self.images = image_paths 48 | 49 | def __len__(self): 50 | return len(self.images) 51 | 52 | def __getitem__(self, idx): 53 | img_path = str(self.images[idx]) 54 | 55 | try: 56 | image = Image.open(img_path).convert("RGB") 57 | image = preprocess_image(image) 58 | tensor = torch.tensor(image) 59 | except Exception as e: 60 | print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") 61 | return None 62 | 63 | return (tensor, img_path) 64 | 65 | def collate_fn_remove_corrupted(batch): 66 | """Collate function that allows to remove corrupted examples in the 67 | dataloader. It expects that the dataloader returns 'None' when that occurs. 68 | The 'None's in the batch are removed. 69 | """ 70 | # Filter out all the Nones (corrupted examples) 71 | batch = list(filter(lambda x: x is not None, batch)) 72 | return batch 73 | 74 | def main(args): 75 | # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする 76 | # depreacatedの警告が出るけどなくなったらその時 77 | # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 78 | if not os.path.exists(args.model_dir) or args.force_download: 79 | print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") 80 | for file in FILES: 81 | hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) 82 | for file in SUB_DIR_FILES: 83 | hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( 84 | args.model_dir, SUB_DIR), force_download=True, force_filename=file) 85 | else: 86 | print("using existing wd14 tagger model") 87 | 88 | # 画像を読み込む 89 | model = load_model(args.model_dir) 90 | 91 | # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") 92 | # 依存ライブラリを増やしたくないので自力で読むよ 93 | 94 | with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: 95 | reader = csv.reader(f) 96 | l = [row for row in reader] 97 | header = l[0] # tag_id,name,category,count 98 | rows = l[1:] 99 | assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" 100 | 101 | general_tags = [row[1] for row in rows[1:] if row[2] == '0'] 102 | character_tags = [row[1] for row in rows[1:] if row[2] == '4'] 103 | 104 | # 画像を読み込む 105 | 106 | train_data_dir = Path(args.train_data_dir) 107 | image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive) 108 | print(f"found {len(image_paths)} images.") 109 | 110 | tag_freq = {} 111 | 112 | undesired_tags = set(args.undesired_tags.split(',')) 113 | 114 | def run_batch(path_imgs): 115 | imgs = np.array([im for _, im in path_imgs]) 116 | 117 | probs = model(imgs, training=False) 118 | probs = probs.numpy() 119 | 120 | for (image_path, _), prob in zip(path_imgs, probs): 121 | # 最初の4つはratingなので無視する 122 | # # First 4 labels are actually ratings: pick one with argmax 123 | # ratings_names = label_names[:4] 124 | # rating_index = ratings_names["probs"].argmax() 125 | # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] 126 | 127 | # それ以降はタグなのでconfidenceがthresholdより高いものを追加する 128 | # Everything else is tags: pick any where prediction confidence > threshold 129 | combined_tags = [] 130 | general_tag_text = "" 131 | character_tag_text = "" 132 | for i, p in enumerate(prob[4:]): 133 | if i < len(general_tags) and p >= args.general_threshold: 134 | tag_name = general_tags[i].replace('_', ' ') if args.remove_underscore else general_tags[i] 135 | if tag_name not in undesired_tags: 136 | tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 137 | general_tag_text += ", " + tag_name 138 | combined_tags.append(tag_name) 139 | elif i >= len(general_tags) and p >= args.character_threshold: 140 | tag_name = character_tags[i - len(general_tags)].replace('_', ' ') if args.remove_underscore else character_tags[i - len(general_tags)] 141 | if tag_name not in undesired_tags: 142 | tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 143 | character_tag_text += ", " + tag_name 144 | combined_tags.append(tag_name) 145 | 146 | if len(general_tag_text) > 0: 147 | general_tag_text = general_tag_text[2:] 148 | 149 | if len(character_tag_text) > 0: 150 | character_tag_text = character_tag_text[2:] 151 | 152 | tag_text = ', '.join(combined_tags) 153 | 154 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: 155 | f.write(tag_text + '\n') 156 | if args.debug: 157 | print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") 158 | 159 | 160 | # 読み込みの高速化のためにDataLoaderを使うオプション 161 | if args.max_data_loader_n_workers is not None: 162 | dataset = ImageLoadingPrepDataset(image_paths) 163 | data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 164 | num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) 165 | else: 166 | data = [[(None, ip)] for ip in image_paths] 167 | 168 | b_imgs = [] 169 | for data_entry in tqdm(data, smoothing=0.0): 170 | for data in data_entry: 171 | if data is None: 172 | continue 173 | 174 | image, image_path = data 175 | if image is not None: 176 | image = image.detach().numpy() 177 | else: 178 | try: 179 | image = Image.open(image_path) 180 | if image.mode != 'RGB': 181 | image = image.convert("RGB") 182 | image = preprocess_image(image) 183 | except Exception as e: 184 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 185 | continue 186 | b_imgs.append((image_path, image)) 187 | 188 | if len(b_imgs) >= args.batch_size: 189 | b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string 190 | run_batch(b_imgs) 191 | b_imgs.clear() 192 | 193 | if len(b_imgs) > 0: 194 | b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string 195 | run_batch(b_imgs) 196 | 197 | if args.frequency_tags: 198 | sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) 199 | print("\nTag frequencies:") 200 | for tag, freq in sorted_tags: 201 | print(f"{tag}: {freq}") 202 | 203 | print("done!") 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser() 208 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 209 | parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, 210 | help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") 211 | parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", 212 | help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") 213 | parser.add_argument("--force_download", action='store_true', 214 | help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") 215 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 216 | parser.add_argument("--max_data_loader_n_workers", type=int, default=None, 217 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") 218 | parser.add_argument("--caption_extention", type=str, default=None, 219 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 220 | parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") 221 | parser.add_argument("--general_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for general category") 222 | parser.add_argument("--character_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for character category") 223 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") 224 | parser.add_argument("--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags") 225 | parser.add_argument("--debug", action="store_true", help="debug mode") 226 | parser.add_argument("--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output") 227 | parser.add_argument('--frequency_tags', action='store_true', help='Show frequency of tags for images') 228 | 229 | args = parser.parse_args() 230 | 231 | # スペルミスしていたオプションを復元する 232 | if args.caption_extention is not None: 233 | args.caption_extension = args.caption_extention 234 | 235 | main(args) 236 | -------------------------------------------------------------------------------- /library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linaqruf/kohya-trainer/c2a9dc897ac0634450efb257a89577efa2d2487a/library/__init__.py -------------------------------------------------------------------------------- /library/custom_train_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): 5 | alphas_cumprod = noise_scheduler.alphas_cumprod 6 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 7 | sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) 8 | alpha = sqrt_alphas_cumprod 9 | sigma = sqrt_one_minus_alphas_cumprod 10 | all_snr = (alpha / sigma) ** 2 11 | snr = torch.stack([all_snr[t] for t in timesteps]) 12 | gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) 13 | snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper 14 | loss = loss * snr_weight 15 | return loss 16 | 17 | def add_custom_train_arguments(parser: argparse.ArgumentParser): 18 | parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨") 19 | -------------------------------------------------------------------------------- /library/huggingface_util.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from huggingface_hub import HfApi 3 | from pathlib import Path 4 | import argparse 5 | import os 6 | 7 | from library.utils import fire_in_thread 8 | 9 | 10 | def exists_repo( 11 | repo_id: str, repo_type: str, revision: str = "main", token: str = None 12 | ): 13 | api = HfApi( 14 | token=token, 15 | ) 16 | try: 17 | api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 18 | return True 19 | except: 20 | return False 21 | 22 | 23 | def upload( 24 | args: argparse.Namespace, 25 | src: Union[str, Path, bytes, BinaryIO], 26 | dest_suffix: str = "", 27 | force_sync_upload: bool = False, 28 | ): 29 | repo_id = args.huggingface_repo_id 30 | repo_type = args.huggingface_repo_type 31 | token = args.huggingface_token 32 | path_in_repo = args.huggingface_path_in_repo + dest_suffix 33 | private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" 34 | api = HfApi(token=token) 35 | if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): 36 | api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) 37 | 38 | is_folder = (type(src) == str and os.path.isdir(src)) or ( 39 | isinstance(src, Path) and src.is_dir() 40 | ) 41 | 42 | def uploader(): 43 | if is_folder: 44 | api.upload_folder( 45 | repo_id=repo_id, 46 | repo_type=repo_type, 47 | folder_path=src, 48 | path_in_repo=path_in_repo, 49 | ) 50 | else: 51 | api.upload_file( 52 | repo_id=repo_id, 53 | repo_type=repo_type, 54 | path_or_fileobj=src, 55 | path_in_repo=path_in_repo, 56 | ) 57 | 58 | if args.async_upload and not force_sync_upload: 59 | fire_in_thread(uploader) 60 | else: 61 | uploader() 62 | 63 | 64 | def list_dir( 65 | repo_id: str, 66 | subfolder: str, 67 | repo_type: str, 68 | revision: str = "main", 69 | token: str = None, 70 | ): 71 | api = HfApi( 72 | token=token, 73 | ) 74 | repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 75 | file_list = [ 76 | file for file in repo_info.siblings if file.rfilename.startswith(subfolder) 77 | ] 78 | return file_list 79 | -------------------------------------------------------------------------------- /library/utils.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import * 3 | 4 | 5 | def fire_in_thread(f, *args, **kwargs): 6 | threading.Thread(target=f, args=args, kwargs=kwargs).start() -------------------------------------------------------------------------------- /networks/check_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from safetensors.torch import load_file 5 | 6 | 7 | def main(file): 8 | print(f"loading: {file}") 9 | if os.path.splitext(file)[1] == '.safetensors': 10 | sd = load_file(file) 11 | else: 12 | sd = torch.load(file, map_location='cpu') 13 | 14 | values = [] 15 | 16 | keys = list(sd.keys()) 17 | for key in keys: 18 | if 'lora_up' in key or 'lora_down' in key: 19 | values.append((key, sd[key])) 20 | print(f"number of LoRA modules: {len(values)}") 21 | 22 | for key, value in values: 23 | value = value.to(torch.float32) 24 | print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") 25 | 26 | 27 | def setup_parser() -> argparse.ArgumentParser: 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") 30 | 31 | return parser 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = setup_parser() 36 | 37 | args = parser.parse_args() 38 | 39 | main(args.file) 40 | -------------------------------------------------------------------------------- /networks/extract_lora_from_models.py: -------------------------------------------------------------------------------- 1 | # extract approximating LoRA by svd from two SD models 2 | # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo! 4 | 5 | import argparse 6 | import os 7 | import torch 8 | from safetensors.torch import load_file, save_file 9 | from tqdm import tqdm 10 | import library.model_util as model_util 11 | import lora 12 | 13 | 14 | CLAMP_QUANTILE = 0.99 15 | MIN_DIFF = 1e-6 16 | 17 | 18 | def save_to_file(file_name, model, state_dict, dtype): 19 | if dtype is not None: 20 | for key in list(state_dict.keys()): 21 | if type(state_dict[key]) == torch.Tensor: 22 | state_dict[key] = state_dict[key].to(dtype) 23 | 24 | if os.path.splitext(file_name)[1] == '.safetensors': 25 | save_file(model, file_name) 26 | else: 27 | torch.save(model, file_name) 28 | 29 | 30 | def svd(args): 31 | def str_to_dtype(p): 32 | if p == 'float': 33 | return torch.float 34 | if p == 'fp16': 35 | return torch.float16 36 | if p == 'bf16': 37 | return torch.bfloat16 38 | return None 39 | 40 | save_dtype = str_to_dtype(args.save_precision) 41 | 42 | print(f"loading SD model : {args.model_org}") 43 | text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) 44 | print(f"loading SD model : {args.model_tuned}") 45 | text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) 46 | 47 | # create LoRA network to extract weights: Use dim (rank) as alpha 48 | if args.conv_dim is None: 49 | kwargs = {} 50 | else: 51 | kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} 52 | 53 | lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) 54 | lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) 55 | assert len(lora_network_o.text_encoder_loras) == len( 56 | lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " 57 | 58 | # get diffs 59 | diffs = {} 60 | text_encoder_different = False 61 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): 62 | lora_name = lora_o.lora_name 63 | module_o = lora_o.org_module 64 | module_t = lora_t.org_module 65 | diff = module_t.weight - module_o.weight 66 | 67 | # Text Encoder might be same 68 | if torch.max(torch.abs(diff)) > MIN_DIFF: 69 | text_encoder_different = True 70 | 71 | diff = diff.float() 72 | diffs[lora_name] = diff 73 | 74 | if not text_encoder_different: 75 | print("Text encoder is same. Extract U-Net only.") 76 | lora_network_o.text_encoder_loras = [] 77 | diffs = {} 78 | 79 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): 80 | lora_name = lora_o.lora_name 81 | module_o = lora_o.org_module 82 | module_t = lora_t.org_module 83 | diff = module_t.weight - module_o.weight 84 | diff = diff.float() 85 | 86 | if args.device: 87 | diff = diff.to(args.device) 88 | 89 | diffs[lora_name] = diff 90 | 91 | # make LoRA with svd 92 | print("calculating by svd") 93 | lora_weights = {} 94 | with torch.no_grad(): 95 | for lora_name, mat in tqdm(list(diffs.items())): 96 | # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 97 | conv2d = (len(mat.size()) == 4) 98 | kernel_size = None if not conv2d else mat.size()[2:4] 99 | conv2d_3x3 = conv2d and kernel_size != (1, 1) 100 | 101 | rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim 102 | out_dim, in_dim = mat.size()[0:2] 103 | 104 | if args.device: 105 | mat = mat.to(args.device) 106 | 107 | # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) 108 | rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim 109 | 110 | if conv2d: 111 | if conv2d_3x3: 112 | mat = mat.flatten(start_dim=1) 113 | else: 114 | mat = mat.squeeze() 115 | 116 | U, S, Vh = torch.linalg.svd(mat) 117 | 118 | U = U[:, :rank] 119 | S = S[:rank] 120 | U = U @ torch.diag(S) 121 | 122 | Vh = Vh[:rank, :] 123 | 124 | dist = torch.cat([U.flatten(), Vh.flatten()]) 125 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) 126 | low_val = -hi_val 127 | 128 | U = U.clamp(low_val, hi_val) 129 | Vh = Vh.clamp(low_val, hi_val) 130 | 131 | if conv2d: 132 | U = U.reshape(out_dim, rank, 1, 1) 133 | Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) 134 | 135 | U = U.to("cpu").contiguous() 136 | Vh = Vh.to("cpu").contiguous() 137 | 138 | lora_weights[lora_name] = (U, Vh) 139 | 140 | # make state dict for LoRA 141 | lora_sd = {} 142 | for lora_name, (up_weight, down_weight) in lora_weights.items(): 143 | lora_sd[lora_name + '.lora_up.weight'] = up_weight 144 | lora_sd[lora_name + '.lora_down.weight'] = down_weight 145 | lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) 146 | 147 | # load state dict to LoRA and save it 148 | lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) 149 | lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict 150 | 151 | info = lora_network_save.load_state_dict(lora_sd) 152 | print(f"Loading extracted LoRA weights: {info}") 153 | 154 | dir_name = os.path.dirname(args.save_to) 155 | if dir_name and not os.path.exists(dir_name): 156 | os.makedirs(dir_name, exist_ok=True) 157 | 158 | # minimum metadata 159 | metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} 160 | 161 | lora_network_save.save_weights(args.save_to, save_dtype, metadata) 162 | print(f"LoRA weights are saved to: {args.save_to}") 163 | 164 | 165 | def setup_parser() -> argparse.ArgumentParser: 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--v2", action='store_true', 168 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 169 | parser.add_argument("--save_precision", type=str, default=None, 170 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") 171 | parser.add_argument("--model_org", type=str, default=None, 172 | help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") 173 | parser.add_argument("--model_tuned", type=str, default=None, 174 | help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") 175 | parser.add_argument("--save_to", type=str, default=None, 176 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 177 | parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") 178 | parser.add_argument("--conv_dim", type=int, default=None, 179 | help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)") 180 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") 181 | 182 | return parser 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = setup_parser() 187 | 188 | args = parser.parse_args() 189 | svd(args) 190 | -------------------------------------------------------------------------------- /networks/lora_interrogator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tqdm import tqdm 4 | from library import model_util 5 | import argparse 6 | from transformers import CLIPTokenizer 7 | import torch 8 | 9 | import library.model_util as model_util 10 | import lora 11 | 12 | TOKENIZER_PATH = "openai/clip-vit-large-patch14" 13 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う 14 | 15 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def interrogate(args): 19 | # いろいろ準備する 20 | print(f"loading SD model: {args.sd_model}") 21 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 22 | 23 | print(f"loading LoRA: {args.model}") 24 | network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) 25 | 26 | # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい 27 | has_te_weight = False 28 | for key in network.weights_sd.keys(): 29 | if 'lora_te' in key: 30 | has_te_weight = True 31 | break 32 | if not has_te_weight: 33 | print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") 34 | return 35 | del vae 36 | 37 | print("loading tokenizer") 38 | if args.v2: 39 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") 40 | else: 41 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) 42 | 43 | text_encoder.to(DEVICE) 44 | text_encoder.eval() 45 | unet.to(DEVICE) 46 | unet.eval() # U-Netは呼び出さないので不要だけど 47 | 48 | # トークンをひとつひとつ当たっていく 49 | token_id_start = 0 50 | token_id_end = max(tokenizer.all_special_ids) 51 | print(f"interrogate tokens are: {token_id_start} to {token_id_end}") 52 | 53 | def get_all_embeddings(text_encoder): 54 | embs = [] 55 | with torch.no_grad(): 56 | for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): 57 | batch = [] 58 | for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): 59 | tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] 60 | # tokens = [tid] # こちらは結果がいまひとつ 61 | batch.append(tokens) 62 | 63 | # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] 64 | # clip skip対応 65 | batch = torch.tensor(batch).to(DEVICE) 66 | if args.clip_skip is None: 67 | encoder_hidden_states = text_encoder(batch)[0] 68 | else: 69 | enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) 70 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] 71 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) 72 | encoder_hidden_states = encoder_hidden_states.to("cpu") 73 | 74 | embs.extend(encoder_hidden_states) 75 | return torch.stack(embs) 76 | 77 | print("get original text encoder embeddings.") 78 | orig_embs = get_all_embeddings(text_encoder) 79 | 80 | network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) 81 | network.to(DEVICE) 82 | network.eval() 83 | 84 | print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") 85 | print("get text encoder embeddings with lora.") 86 | lora_embs = get_all_embeddings(text_encoder) 87 | 88 | # 比べる:とりあえず単純に差分の絶対値で 89 | print("comparing...") 90 | diffs = {} 91 | for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): 92 | diff = torch.mean(torch.abs(orig_emb - lora_emb)) 93 | # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない 94 | diff = float(diff.detach().to('cpu').numpy()) 95 | diffs[token_id_start + i] = diff 96 | 97 | diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) 98 | 99 | # 結果を表示する 100 | print("top 100:") 101 | for i, (token, diff) in enumerate(diffs_sorted[:100]): 102 | # if diff < 1e-6: 103 | # break 104 | string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) 105 | print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") 106 | 107 | 108 | def setup_parser() -> argparse.ArgumentParser: 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--v2", action='store_true', 111 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 112 | parser.add_argument("--sd_model", type=str, default=None, 113 | help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") 114 | parser.add_argument("--model", type=str, default=None, 115 | help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") 116 | parser.add_argument("--batch_size", type=int, default=16, 117 | help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ") 118 | parser.add_argument("--clip_skip", type=int, default=None, 119 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") 120 | 121 | return parser 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = setup_parser() 126 | 127 | args = parser.parse_args() 128 | interrogate(args) 129 | -------------------------------------------------------------------------------- /networks/merge_lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import torch 5 | from safetensors.torch import load_file, save_file 6 | import library.model_util as model_util 7 | import lora 8 | 9 | 10 | def load_state_dict(file_name, dtype): 11 | if os.path.splitext(file_name)[1] == ".safetensors": 12 | sd = load_file(file_name) 13 | else: 14 | sd = torch.load(file_name, map_location="cpu") 15 | for key in list(sd.keys()): 16 | if type(sd[key]) == torch.Tensor: 17 | sd[key] = sd[key].to(dtype) 18 | return sd 19 | 20 | 21 | def save_to_file(file_name, model, state_dict, dtype): 22 | if dtype is not None: 23 | for key in list(state_dict.keys()): 24 | if type(state_dict[key]) == torch.Tensor: 25 | state_dict[key] = state_dict[key].to(dtype) 26 | 27 | if os.path.splitext(file_name)[1] == ".safetensors": 28 | save_file(model, file_name) 29 | else: 30 | torch.save(model, file_name) 31 | 32 | 33 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 34 | text_encoder.to(merge_dtype) 35 | unet.to(merge_dtype) 36 | 37 | # create module map 38 | name_to_module = {} 39 | for i, root_module in enumerate([text_encoder, unet]): 40 | if i == 0: 41 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 42 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 43 | else: 44 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 45 | target_replace_modules = ( 46 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 47 | ) 48 | 49 | for name, module in root_module.named_modules(): 50 | if module.__class__.__name__ in target_replace_modules: 51 | for child_name, child_module in module.named_modules(): 52 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": 53 | lora_name = prefix + "." + name + "." + child_name 54 | lora_name = lora_name.replace(".", "_") 55 | name_to_module[lora_name] = child_module 56 | 57 | for model, ratio in zip(models, ratios): 58 | print(f"loading: {model}") 59 | lora_sd = load_state_dict(model, merge_dtype) 60 | 61 | print(f"merging...") 62 | for key in lora_sd.keys(): 63 | if "lora_down" in key: 64 | up_key = key.replace("lora_down", "lora_up") 65 | alpha_key = key[: key.index("lora_down")] + "alpha" 66 | 67 | # find original module for this lora 68 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" 69 | if module_name not in name_to_module: 70 | print(f"no module found for LoRA weight: {key}") 71 | continue 72 | module = name_to_module[module_name] 73 | # print(f"apply {key} to {module}") 74 | 75 | down_weight = lora_sd[key] 76 | up_weight = lora_sd[up_key] 77 | 78 | dim = down_weight.size()[0] 79 | alpha = lora_sd.get(alpha_key, dim) 80 | scale = alpha / dim 81 | 82 | # W <- W + U * D 83 | weight = module.weight 84 | # print(module_name, down_weight.size(), up_weight.size()) 85 | if len(weight.size()) == 2: 86 | # linear 87 | weight = weight + ratio * (up_weight @ down_weight) * scale 88 | elif down_weight.size()[2:4] == (1, 1): 89 | # conv2d 1x1 90 | weight = ( 91 | weight 92 | + ratio 93 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 94 | * scale 95 | ) 96 | else: 97 | # conv2d 3x3 98 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 99 | # print(conved.size(), weight.size(), module.stride, module.padding) 100 | weight = weight + ratio * conved * scale 101 | 102 | module.weight = torch.nn.Parameter(weight) 103 | 104 | 105 | def merge_lora_models(models, ratios, merge_dtype): 106 | base_alphas = {} # alpha for merged model 107 | base_dims = {} 108 | 109 | merged_sd = {} 110 | for model, ratio in zip(models, ratios): 111 | print(f"loading: {model}") 112 | lora_sd = load_state_dict(model, merge_dtype) 113 | 114 | # get alpha and dim 115 | alphas = {} # alpha for current model 116 | dims = {} # dims for current model 117 | for key in lora_sd.keys(): 118 | if "alpha" in key: 119 | lora_module_name = key[: key.rfind(".alpha")] 120 | alpha = float(lora_sd[key].detach().numpy()) 121 | alphas[lora_module_name] = alpha 122 | if lora_module_name not in base_alphas: 123 | base_alphas[lora_module_name] = alpha 124 | elif "lora_down" in key: 125 | lora_module_name = key[: key.rfind(".lora_down")] 126 | dim = lora_sd[key].size()[0] 127 | dims[lora_module_name] = dim 128 | if lora_module_name not in base_dims: 129 | base_dims[lora_module_name] = dim 130 | 131 | for lora_module_name in dims.keys(): 132 | if lora_module_name not in alphas: 133 | alpha = dims[lora_module_name] 134 | alphas[lora_module_name] = alpha 135 | if lora_module_name not in base_alphas: 136 | base_alphas[lora_module_name] = alpha 137 | 138 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") 139 | 140 | # merge 141 | print(f"merging...") 142 | for key in lora_sd.keys(): 143 | if "alpha" in key: 144 | continue 145 | 146 | lora_module_name = key[: key.rfind(".lora_")] 147 | 148 | base_alpha = base_alphas[lora_module_name] 149 | alpha = alphas[lora_module_name] 150 | 151 | scale = math.sqrt(alpha / base_alpha) * ratio 152 | 153 | if key in merged_sd: 154 | assert ( 155 | merged_sd[key].size() == lora_sd[key].size() 156 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 157 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale 158 | else: 159 | merged_sd[key] = lora_sd[key] * scale 160 | 161 | # set alpha to sd 162 | for lora_module_name, alpha in base_alphas.items(): 163 | key = lora_module_name + ".alpha" 164 | merged_sd[key] = torch.tensor(alpha) 165 | 166 | print("merged model") 167 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") 168 | 169 | return merged_sd 170 | 171 | 172 | def merge(args): 173 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 174 | 175 | def str_to_dtype(p): 176 | if p == "float": 177 | return torch.float 178 | if p == "fp16": 179 | return torch.float16 180 | if p == "bf16": 181 | return torch.bfloat16 182 | return None 183 | 184 | merge_dtype = str_to_dtype(args.precision) 185 | save_dtype = str_to_dtype(args.save_precision) 186 | if save_dtype is None: 187 | save_dtype = merge_dtype 188 | 189 | if args.sd_model is not None: 190 | print(f"loading SD model: {args.sd_model}") 191 | 192 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 193 | 194 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 195 | 196 | print(f"saving SD model to: {args.save_to}") 197 | model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) 198 | else: 199 | state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) 200 | 201 | print(f"saving model to: {args.save_to}") 202 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 203 | 204 | 205 | def setup_parser() -> argparse.ArgumentParser: 206 | parser = argparse.ArgumentParser() 207 | parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") 208 | parser.add_argument( 209 | "--save_precision", 210 | type=str, 211 | default=None, 212 | choices=[None, "float", "fp16", "bf16"], 213 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", 214 | ) 215 | parser.add_argument( 216 | "--precision", 217 | type=str, 218 | default="float", 219 | choices=["float", "fp16", "bf16"], 220 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", 221 | ) 222 | parser.add_argument( 223 | "--sd_model", 224 | type=str, 225 | default=None, 226 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", 227 | ) 228 | parser.add_argument( 229 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 230 | ) 231 | parser.add_argument( 232 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" 233 | ) 234 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") 235 | 236 | return parser 237 | 238 | 239 | if __name__ == "__main__": 240 | parser = setup_parser() 241 | 242 | args = parser.parse_args() 243 | merge(args) 244 | -------------------------------------------------------------------------------- /networks/merge_lora_old.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | import library.model_util as model_util 8 | import lora 9 | 10 | 11 | def load_state_dict(file_name, dtype): 12 | if os.path.splitext(file_name)[1] == '.safetensors': 13 | sd = load_file(file_name) 14 | else: 15 | sd = torch.load(file_name, map_location='cpu') 16 | for key in list(sd.keys()): 17 | if type(sd[key]) == torch.Tensor: 18 | sd[key] = sd[key].to(dtype) 19 | return sd 20 | 21 | 22 | def save_to_file(file_name, model, state_dict, dtype): 23 | if dtype is not None: 24 | for key in list(state_dict.keys()): 25 | if type(state_dict[key]) == torch.Tensor: 26 | state_dict[key] = state_dict[key].to(dtype) 27 | 28 | if os.path.splitext(file_name)[1] == '.safetensors': 29 | save_file(model, file_name) 30 | else: 31 | torch.save(model, file_name) 32 | 33 | 34 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 35 | text_encoder.to(merge_dtype) 36 | unet.to(merge_dtype) 37 | 38 | # create module map 39 | name_to_module = {} 40 | for i, root_module in enumerate([text_encoder, unet]): 41 | if i == 0: 42 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 43 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 44 | else: 45 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 46 | target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE 47 | 48 | for name, module in root_module.named_modules(): 49 | if module.__class__.__name__ in target_replace_modules: 50 | for child_name, child_module in module.named_modules(): 51 | if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): 52 | lora_name = prefix + '.' + name + '.' + child_name 53 | lora_name = lora_name.replace('.', '_') 54 | name_to_module[lora_name] = child_module 55 | 56 | for model, ratio in zip(models, ratios): 57 | print(f"loading: {model}") 58 | lora_sd = load_state_dict(model, merge_dtype) 59 | 60 | print(f"merging...") 61 | for key in lora_sd.keys(): 62 | if "lora_down" in key: 63 | up_key = key.replace("lora_down", "lora_up") 64 | alpha_key = key[:key.index("lora_down")] + 'alpha' 65 | 66 | # find original module for this lora 67 | module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" 68 | if module_name not in name_to_module: 69 | print(f"no module found for LoRA weight: {key}") 70 | continue 71 | module = name_to_module[module_name] 72 | # print(f"apply {key} to {module}") 73 | 74 | down_weight = lora_sd[key] 75 | up_weight = lora_sd[up_key] 76 | 77 | dim = down_weight.size()[0] 78 | alpha = lora_sd.get(alpha_key, dim) 79 | scale = alpha / dim 80 | 81 | # W <- W + U * D 82 | weight = module.weight 83 | if len(weight.size()) == 2: 84 | # linear 85 | weight = weight + ratio * (up_weight @ down_weight) * scale 86 | else: 87 | # conv2d 88 | weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale 89 | 90 | module.weight = torch.nn.Parameter(weight) 91 | 92 | 93 | def merge_lora_models(models, ratios, merge_dtype): 94 | merged_sd = {} 95 | 96 | alpha = None 97 | dim = None 98 | for model, ratio in zip(models, ratios): 99 | print(f"loading: {model}") 100 | lora_sd = load_state_dict(model, merge_dtype) 101 | 102 | print(f"merging...") 103 | for key in lora_sd.keys(): 104 | if 'alpha' in key: 105 | if key in merged_sd: 106 | assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" 107 | else: 108 | alpha = lora_sd[key].detach().numpy() 109 | merged_sd[key] = lora_sd[key] 110 | else: 111 | if key in merged_sd: 112 | assert merged_sd[key].size() == lora_sd[key].size( 113 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 114 | merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio 115 | else: 116 | if "lora_down" in key: 117 | dim = lora_sd[key].size()[0] 118 | merged_sd[key] = lora_sd[key] * ratio 119 | 120 | print(f"dim (rank): {dim}, alpha: {alpha}") 121 | if alpha is None: 122 | alpha = dim 123 | 124 | return merged_sd, dim, alpha 125 | 126 | 127 | def merge(args): 128 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 129 | 130 | def str_to_dtype(p): 131 | if p == 'float': 132 | return torch.float 133 | if p == 'fp16': 134 | return torch.float16 135 | if p == 'bf16': 136 | return torch.bfloat16 137 | return None 138 | 139 | merge_dtype = str_to_dtype(args.precision) 140 | save_dtype = str_to_dtype(args.save_precision) 141 | if save_dtype is None: 142 | save_dtype = merge_dtype 143 | 144 | if args.sd_model is not None: 145 | print(f"loading SD model: {args.sd_model}") 146 | 147 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 148 | 149 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 150 | 151 | print(f"saving SD model to: {args.save_to}") 152 | model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, 153 | args.sd_model, 0, 0, save_dtype, vae) 154 | else: 155 | state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) 156 | 157 | print(f"saving model to: {args.save_to}") 158 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 159 | 160 | 161 | def setup_parser() -> argparse.ArgumentParser: 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--v2", action='store_true', 164 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 165 | parser.add_argument("--save_precision", type=str, default=None, 166 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") 167 | parser.add_argument("--precision", type=str, default="float", 168 | choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") 169 | parser.add_argument("--sd_model", type=str, default=None, 170 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") 171 | parser.add_argument("--save_to", type=str, default=None, 172 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 173 | parser.add_argument("--models", type=str, nargs='*', 174 | help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") 175 | parser.add_argument("--ratios", type=float, nargs='*', 176 | help="ratios for each model / それぞれのLoRAモデルの比率") 177 | 178 | return parser 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = setup_parser() 183 | 184 | args = parser.parse_args() 185 | merge(args) 186 | -------------------------------------------------------------------------------- /networks/resize_lora.py: -------------------------------------------------------------------------------- 1 | # Convert LoRA to different rank approximation (should only be used to go to lower rank) 2 | # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo 4 | 5 | import argparse 6 | import torch 7 | from safetensors.torch import load_file, save_file, safe_open 8 | from tqdm import tqdm 9 | from library import train_util, model_util 10 | import numpy as np 11 | 12 | MIN_SV = 1e-6 13 | 14 | # Model save and load functions 15 | 16 | def load_state_dict(file_name, dtype): 17 | if model_util.is_safetensors(file_name): 18 | sd = load_file(file_name) 19 | with safe_open(file_name, framework="pt") as f: 20 | metadata = f.metadata() 21 | else: 22 | sd = torch.load(file_name, map_location='cpu') 23 | metadata = None 24 | 25 | for key in list(sd.keys()): 26 | if type(sd[key]) == torch.Tensor: 27 | sd[key] = sd[key].to(dtype) 28 | 29 | return sd, metadata 30 | 31 | 32 | def save_to_file(file_name, model, state_dict, dtype, metadata): 33 | if dtype is not None: 34 | for key in list(state_dict.keys()): 35 | if type(state_dict[key]) == torch.Tensor: 36 | state_dict[key] = state_dict[key].to(dtype) 37 | 38 | if model_util.is_safetensors(file_name): 39 | save_file(model, file_name, metadata) 40 | else: 41 | torch.save(model, file_name) 42 | 43 | 44 | # Indexing functions 45 | 46 | def index_sv_cumulative(S, target): 47 | original_sum = float(torch.sum(S)) 48 | cumulative_sums = torch.cumsum(S, dim=0)/original_sum 49 | index = int(torch.searchsorted(cumulative_sums, target)) + 1 50 | index = max(1, min(index, len(S)-1)) 51 | 52 | return index 53 | 54 | 55 | def index_sv_fro(S, target): 56 | S_squared = S.pow(2) 57 | s_fro_sq = float(torch.sum(S_squared)) 58 | sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq 59 | index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 60 | index = max(1, min(index, len(S)-1)) 61 | 62 | return index 63 | 64 | 65 | def index_sv_ratio(S, target): 66 | max_sv = S[0] 67 | min_sv = max_sv/target 68 | index = int(torch.sum(S > min_sv).item()) 69 | index = max(1, min(index, len(S)-1)) 70 | 71 | return index 72 | 73 | 74 | # Modified from Kohaku-blueleaf's extract/merge functions 75 | def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): 76 | out_size, in_size, kernel_size, _ = weight.size() 77 | U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) 78 | 79 | param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) 80 | lora_rank = param_dict["new_rank"] 81 | 82 | U = U[:, :lora_rank] 83 | S = S[:lora_rank] 84 | U = U @ torch.diag(S) 85 | Vh = Vh[:lora_rank, :] 86 | 87 | param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() 88 | param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() 89 | del U, S, Vh, weight 90 | return param_dict 91 | 92 | 93 | def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): 94 | out_size, in_size = weight.size() 95 | 96 | U, S, Vh = torch.linalg.svd(weight.to(device)) 97 | 98 | param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) 99 | lora_rank = param_dict["new_rank"] 100 | 101 | U = U[:, :lora_rank] 102 | S = S[:lora_rank] 103 | U = U @ torch.diag(S) 104 | Vh = Vh[:lora_rank, :] 105 | 106 | param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() 107 | param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() 108 | del U, S, Vh, weight 109 | return param_dict 110 | 111 | 112 | def merge_conv(lora_down, lora_up, device): 113 | in_rank, in_size, kernel_size, k_ = lora_down.shape 114 | out_size, out_rank, _, _ = lora_up.shape 115 | assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" 116 | 117 | lora_down = lora_down.to(device) 118 | lora_up = lora_up.to(device) 119 | 120 | merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) 121 | weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) 122 | del lora_up, lora_down 123 | return weight 124 | 125 | 126 | def merge_linear(lora_down, lora_up, device): 127 | in_rank, in_size = lora_down.shape 128 | out_size, out_rank = lora_up.shape 129 | assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" 130 | 131 | lora_down = lora_down.to(device) 132 | lora_up = lora_up.to(device) 133 | 134 | weight = lora_up @ lora_down 135 | del lora_up, lora_down 136 | return weight 137 | 138 | 139 | # Calculate new rank 140 | 141 | def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): 142 | param_dict = {} 143 | 144 | if dynamic_method=="sv_ratio": 145 | # Calculate new dim and alpha based off ratio 146 | new_rank = index_sv_ratio(S, dynamic_param) + 1 147 | new_alpha = float(scale*new_rank) 148 | 149 | elif dynamic_method=="sv_cumulative": 150 | # Calculate new dim and alpha based off cumulative sum 151 | new_rank = index_sv_cumulative(S, dynamic_param) + 1 152 | new_alpha = float(scale*new_rank) 153 | 154 | elif dynamic_method=="sv_fro": 155 | # Calculate new dim and alpha based off sqrt sum of squares 156 | new_rank = index_sv_fro(S, dynamic_param) + 1 157 | new_alpha = float(scale*new_rank) 158 | else: 159 | new_rank = rank 160 | new_alpha = float(scale*new_rank) 161 | 162 | 163 | if S[0] <= MIN_SV: # Zero matrix, set dim to 1 164 | new_rank = 1 165 | new_alpha = float(scale*new_rank) 166 | elif new_rank > rank: # cap max rank at rank 167 | new_rank = rank 168 | new_alpha = float(scale*new_rank) 169 | 170 | 171 | # Calculate resize info 172 | s_sum = torch.sum(torch.abs(S)) 173 | s_rank = torch.sum(torch.abs(S[:new_rank])) 174 | 175 | S_squared = S.pow(2) 176 | s_fro = torch.sqrt(torch.sum(S_squared)) 177 | s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) 178 | fro_percent = float(s_red_fro/s_fro) 179 | 180 | param_dict["new_rank"] = new_rank 181 | param_dict["new_alpha"] = new_alpha 182 | param_dict["sum_retained"] = (s_rank)/s_sum 183 | param_dict["fro_retained"] = fro_percent 184 | param_dict["max_ratio"] = S[0]/S[new_rank - 1] 185 | 186 | return param_dict 187 | 188 | 189 | def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): 190 | network_alpha = None 191 | network_dim = None 192 | verbose_str = "\n" 193 | fro_list = [] 194 | 195 | # Extract loaded lora dim and alpha 196 | for key, value in lora_sd.items(): 197 | if network_alpha is None and 'alpha' in key: 198 | network_alpha = value 199 | if network_dim is None and 'lora_down' in key and len(value.size()) == 2: 200 | network_dim = value.size()[0] 201 | if network_alpha is not None and network_dim is not None: 202 | break 203 | if network_alpha is None: 204 | network_alpha = network_dim 205 | 206 | scale = network_alpha/network_dim 207 | 208 | if dynamic_method: 209 | print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") 210 | 211 | lora_down_weight = None 212 | lora_up_weight = None 213 | 214 | o_lora_sd = lora_sd.copy() 215 | block_down_name = None 216 | block_up_name = None 217 | 218 | with torch.no_grad(): 219 | for key, value in tqdm(lora_sd.items()): 220 | weight_name = None 221 | if 'lora_down' in key: 222 | block_down_name = key.split(".")[0] 223 | weight_name = key.split(".")[-1] 224 | lora_down_weight = value 225 | else: 226 | continue 227 | 228 | # find corresponding lora_up and alpha 229 | block_up_name = block_down_name 230 | lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) 231 | lora_alpha = lora_sd.get(block_down_name + '.alpha', None) 232 | 233 | weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) 234 | 235 | if weights_loaded: 236 | 237 | conv2d = (len(lora_down_weight.size()) == 4) 238 | if lora_alpha is None: 239 | scale = 1.0 240 | else: 241 | scale = lora_alpha/lora_down_weight.size()[0] 242 | 243 | if conv2d: 244 | full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) 245 | param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) 246 | else: 247 | full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) 248 | param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) 249 | 250 | if verbose: 251 | max_ratio = param_dict['max_ratio'] 252 | sum_retained = param_dict['sum_retained'] 253 | fro_retained = param_dict['fro_retained'] 254 | if not np.isnan(fro_retained): 255 | fro_list.append(float(fro_retained)) 256 | 257 | verbose_str+=f"{block_down_name:75} | " 258 | verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" 259 | 260 | if verbose and dynamic_method: 261 | verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" 262 | else: 263 | verbose_str+=f"\n" 264 | 265 | new_alpha = param_dict['new_alpha'] 266 | o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() 267 | o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() 268 | o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) 269 | 270 | block_down_name = None 271 | block_up_name = None 272 | lora_down_weight = None 273 | lora_up_weight = None 274 | weights_loaded = False 275 | del param_dict 276 | 277 | if verbose: 278 | print(verbose_str) 279 | 280 | print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") 281 | print("resizing complete") 282 | return o_lora_sd, network_dim, new_alpha 283 | 284 | 285 | def resize(args): 286 | 287 | def str_to_dtype(p): 288 | if p == 'float': 289 | return torch.float 290 | if p == 'fp16': 291 | return torch.float16 292 | if p == 'bf16': 293 | return torch.bfloat16 294 | return None 295 | 296 | if args.dynamic_method and not args.dynamic_param: 297 | raise Exception("If using dynamic_method, then dynamic_param is required") 298 | 299 | merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 300 | save_dtype = str_to_dtype(args.save_precision) 301 | if save_dtype is None: 302 | save_dtype = merge_dtype 303 | 304 | print("loading Model...") 305 | lora_sd, metadata = load_state_dict(args.model, merge_dtype) 306 | 307 | print("Resizing Lora...") 308 | state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) 309 | 310 | # update metadata 311 | if metadata is None: 312 | metadata = {} 313 | 314 | comment = metadata.get("ss_training_comment", "") 315 | 316 | if not args.dynamic_method: 317 | metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" 318 | metadata["ss_network_dim"] = str(args.new_rank) 319 | metadata["ss_network_alpha"] = str(new_alpha) 320 | else: 321 | metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" 322 | metadata["ss_network_dim"] = 'Dynamic' 323 | metadata["ss_network_alpha"] = 'Dynamic' 324 | 325 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 326 | metadata["sshs_model_hash"] = model_hash 327 | metadata["sshs_legacy_hash"] = legacy_hash 328 | 329 | print(f"saving model to: {args.save_to}") 330 | save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) 331 | 332 | 333 | def setup_parser() -> argparse.ArgumentParser: 334 | parser = argparse.ArgumentParser() 335 | 336 | parser.add_argument("--save_precision", type=str, default=None, 337 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") 338 | parser.add_argument("--new_rank", type=int, default=4, 339 | help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") 340 | parser.add_argument("--save_to", type=str, default=None, 341 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 342 | parser.add_argument("--model", type=str, default=None, 343 | help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") 344 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") 345 | parser.add_argument("--verbose", action="store_true", 346 | help="Display verbose resizing information / rank変更時の詳細情報を出力する") 347 | parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], 348 | help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") 349 | parser.add_argument("--dynamic_param", type=float, default=None, 350 | help="Specify target for dynamic reduction") 351 | 352 | return parser 353 | 354 | 355 | if __name__ == '__main__': 356 | parser = setup_parser() 357 | 358 | args = parser.parse_args() 359 | resize(args) 360 | -------------------------------------------------------------------------------- /networks/svd_merge_lora.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import argparse 4 | import os 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | from tqdm import tqdm 8 | import library.model_util as model_util 9 | import lora 10 | 11 | 12 | CLAMP_QUANTILE = 0.99 13 | 14 | 15 | def load_state_dict(file_name, dtype): 16 | if os.path.splitext(file_name)[1] == '.safetensors': 17 | sd = load_file(file_name) 18 | else: 19 | sd = torch.load(file_name, map_location='cpu') 20 | for key in list(sd.keys()): 21 | if type(sd[key]) == torch.Tensor: 22 | sd[key] = sd[key].to(dtype) 23 | return sd 24 | 25 | 26 | def save_to_file(file_name, state_dict, dtype): 27 | if dtype is not None: 28 | for key in list(state_dict.keys()): 29 | if type(state_dict[key]) == torch.Tensor: 30 | state_dict[key] = state_dict[key].to(dtype) 31 | 32 | if os.path.splitext(file_name)[1] == '.safetensors': 33 | save_file(state_dict, file_name) 34 | else: 35 | torch.save(state_dict, file_name) 36 | 37 | 38 | def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): 39 | print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") 40 | merged_sd = {} 41 | for model, ratio in zip(models, ratios): 42 | print(f"loading: {model}") 43 | lora_sd = load_state_dict(model, merge_dtype) 44 | 45 | # merge 46 | print(f"merging...") 47 | for key in tqdm(list(lora_sd.keys())): 48 | if 'lora_down' not in key: 49 | continue 50 | 51 | lora_module_name = key[:key.rfind(".lora_down")] 52 | 53 | down_weight = lora_sd[key] 54 | network_dim = down_weight.size()[0] 55 | 56 | up_weight = lora_sd[lora_module_name + '.lora_up.weight'] 57 | alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) 58 | 59 | in_dim = down_weight.size()[1] 60 | out_dim = up_weight.size()[0] 61 | conv2d = len(down_weight.size()) == 4 62 | kernel_size = None if not conv2d else down_weight.size()[2:4] 63 | # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) 64 | 65 | # make original weight if not exist 66 | if lora_module_name not in merged_sd: 67 | weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) 68 | if device: 69 | weight = weight.to(device) 70 | else: 71 | weight = merged_sd[lora_module_name] 72 | 73 | # merge to weight 74 | if device: 75 | up_weight = up_weight.to(device) 76 | down_weight = down_weight.to(device) 77 | 78 | # W <- W + U * D 79 | scale = (alpha / network_dim) 80 | 81 | if device: # and isinstance(scale, torch.Tensor): 82 | scale = scale.to(device) 83 | 84 | if not conv2d: # linear 85 | weight = weight + ratio * (up_weight @ down_weight) * scale 86 | elif kernel_size == (1, 1): 87 | weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) 88 | ).unsqueeze(2).unsqueeze(3) * scale 89 | else: 90 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 91 | weight = weight + ratio * conved * scale 92 | 93 | merged_sd[lora_module_name] = weight 94 | 95 | # extract from merged weights 96 | print("extract new lora...") 97 | merged_lora_sd = {} 98 | with torch.no_grad(): 99 | for lora_module_name, mat in tqdm(list(merged_sd.items())): 100 | conv2d = (len(mat.size()) == 4) 101 | kernel_size = None if not conv2d else mat.size()[2:4] 102 | conv2d_3x3 = conv2d and kernel_size != (1, 1) 103 | out_dim, in_dim = mat.size()[0:2] 104 | 105 | if conv2d: 106 | if conv2d_3x3: 107 | mat = mat.flatten(start_dim=1) 108 | else: 109 | mat = mat.squeeze() 110 | 111 | module_new_rank = new_conv_rank if conv2d_3x3 else new_rank 112 | module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim 113 | 114 | U, S, Vh = torch.linalg.svd(mat) 115 | 116 | U = U[:, :module_new_rank] 117 | S = S[:module_new_rank] 118 | U = U @ torch.diag(S) 119 | 120 | Vh = Vh[:module_new_rank, :] 121 | 122 | dist = torch.cat([U.flatten(), Vh.flatten()]) 123 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) 124 | low_val = -hi_val 125 | 126 | U = U.clamp(low_val, hi_val) 127 | Vh = Vh.clamp(low_val, hi_val) 128 | 129 | if conv2d: 130 | U = U.reshape(out_dim, module_new_rank, 1, 1) 131 | Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) 132 | 133 | up_weight = U 134 | down_weight = Vh 135 | 136 | merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() 137 | merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() 138 | merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank) 139 | 140 | return merged_lora_sd 141 | 142 | 143 | def merge(args): 144 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 145 | 146 | def str_to_dtype(p): 147 | if p == 'float': 148 | return torch.float 149 | if p == 'fp16': 150 | return torch.float16 151 | if p == 'bf16': 152 | return torch.bfloat16 153 | return None 154 | 155 | merge_dtype = str_to_dtype(args.precision) 156 | save_dtype = str_to_dtype(args.save_precision) 157 | if save_dtype is None: 158 | save_dtype = merge_dtype 159 | 160 | new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank 161 | state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) 162 | 163 | print(f"saving model to: {args.save_to}") 164 | save_to_file(args.save_to, state_dict, save_dtype) 165 | 166 | 167 | def setup_parser() -> argparse.ArgumentParser: 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument("--save_precision", type=str, default=None, 170 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") 171 | parser.add_argument("--precision", type=str, default="float", 172 | choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") 173 | parser.add_argument("--save_to", type=str, default=None, 174 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 175 | parser.add_argument("--models", type=str, nargs='*', 176 | help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") 177 | parser.add_argument("--ratios", type=float, nargs='*', 178 | help="ratios for each model / それぞれのLoRAモデルの比率") 179 | parser.add_argument("--new_rank", type=int, default=4, 180 | help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") 181 | parser.add_argument("--new_conv_rank", type=int, default=None, 182 | help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") 183 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") 184 | 185 | return parser 186 | 187 | 188 | if __name__ == '__main__': 189 | parser = setup_parser() 190 | 191 | args = parser.parse_args() 192 | merge(args) 193 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.15.0 2 | transformers==4.26.0 3 | ftfy==6.1.1 4 | albumentations==1.3.1 5 | opencv-python==4.8.0.76 6 | einops==0.6.0 7 | diffusers==0.10.2 8 | pytorch-lightning==1.9.0 9 | bitsandbytes==0.41.3.post2 10 | tensorflow 11 | safetensors==0.4.1 12 | # gradio==3.16.2 13 | # altair==4.2.2 14 | # easygui==0.98.3 15 | toml==0.10.2 16 | voluptuous==0.13.1 17 | # for BLIP captioning 18 | requests==2.31.0 19 | timm==0.6.12 20 | fairscale==0.4.13 21 | # for WD14 captioning 22 | huggingface-hub==0.13.4 23 | # for kohya trainer colab 24 | gallery-dl==1.25.1 25 | gdown==4.7.1 26 | imjoy-elfinder==0.1.61 27 | dadaptation==1.5 28 | lion-pytorch==0.0.6 29 | # for network module 30 | # locon==0.0.4 31 | lycoris-lora==0.1.4 32 | xformers==0.0.22.post7 33 | # for kohya_ss library 34 | . 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name = "library", packages = find_packages()) -------------------------------------------------------------------------------- /tools/canny.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | 4 | 5 | def canny(args): 6 | img = cv2.imread(args.input) 7 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 8 | 9 | canny_img = cv2.Canny(img, args.thres1, args.thres2) 10 | # canny_img = 255 - canny_img 11 | 12 | cv2.imwrite(args.output, canny_img) 13 | print("done!") 14 | 15 | 16 | def setup_parser() -> argparse.ArgumentParser: 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--input", type=str, default=None, help="input path") 19 | parser.add_argument("--output", type=str, default=None, help="output path") 20 | parser.add_argument("--thres1", type=int, default=32, help="thres1") 21 | parser.add_argument("--thres2", type=int, default=224, help="thres2") 22 | 23 | return parser 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = setup_parser() 28 | 29 | args = parser.parse_args() 30 | canny(args) 31 | -------------------------------------------------------------------------------- /tools/convert_diffusers20_original_sd.py: -------------------------------------------------------------------------------- 1 | # convert Diffusers v1.x/v2.0 model to original Stable Diffusion 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from diffusers import StableDiffusionPipeline 7 | 8 | import library.model_util as model_util 9 | 10 | 11 | def convert(args): 12 | # 引数を確認する 13 | load_dtype = torch.float16 if args.fp16 else None 14 | 15 | save_dtype = None 16 | if args.fp16 or args.save_precision_as == "fp16": 17 | save_dtype = torch.float16 18 | elif args.bf16 or args.save_precision_as == "bf16": 19 | save_dtype = torch.bfloat16 20 | elif args.float or args.save_precision_as == "float": 21 | save_dtype = torch.float 22 | 23 | is_load_ckpt = os.path.isfile(args.model_to_load) 24 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 25 | 26 | assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" 27 | assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" 28 | 29 | # モデルを読み込む 30 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 31 | print(f"loading {msg}: {args.model_to_load}") 32 | 33 | if is_load_ckpt: 34 | v2_model = args.v2 35 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) 36 | else: 37 | pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) 38 | text_encoder = pipe.text_encoder 39 | vae = pipe.vae 40 | unet = pipe.unet 41 | 42 | if args.v1 == args.v2: 43 | # 自動判定する 44 | v2_model = unet.config.cross_attention_dim == 1024 45 | print("checking model version: model is " + ('v2' if v2_model else 'v1')) 46 | else: 47 | v2_model = not args.v1 48 | 49 | # 変換して保存する 50 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 51 | print(f"converting and saving as {msg}: {args.model_to_save}") 52 | 53 | if is_save_ckpt: 54 | original_model = args.model_to_load if is_load_ckpt else None 55 | key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, 56 | original_model, args.epoch, args.global_step, save_dtype, vae) 57 | print(f"model saved. total converted state_dict keys: {key_count}") 58 | else: 59 | print(f"copy scheduler/tokenizer config from: {args.reference_model}") 60 | model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) 61 | print(f"model saved.") 62 | 63 | 64 | def setup_parser() -> argparse.ArgumentParser: 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--v1", action='store_true', 67 | help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') 68 | parser.add_argument("--v2", action='store_true', 69 | help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') 70 | parser.add_argument("--fp16", action='store_true', 71 | help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') 72 | parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') 73 | parser.add_argument("--float", action='store_true', 74 | help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') 75 | parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"], 76 | help="save precision") 77 | parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') 78 | parser.add_argument("--global_step", type=int, default=0, 79 | help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') 80 | parser.add_argument("--reference_model", type=str, default=None, 81 | help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") 82 | parser.add_argument("--use_safetensors", action='store_true', 83 | help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)") 84 | 85 | parser.add_argument("model_to_load", type=str, default=None, 86 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") 87 | parser.add_argument("model_to_save", type=str, default=None, 88 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") 89 | return parser 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = setup_parser() 94 | 95 | args = parser.parse_args() 96 | convert(args) 97 | -------------------------------------------------------------------------------- /tools/detect_face_rotate.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す 5 | 6 | # v2: extract max face if multiple faces are found 7 | # v3: add crop_ratio option 8 | # v4: add multiple faces extraction and min/max size 9 | 10 | import argparse 11 | import math 12 | import cv2 13 | import glob 14 | import os 15 | from anime_face_detector import create_detector 16 | from tqdm import tqdm 17 | import numpy as np 18 | 19 | KP_REYE = 11 20 | KP_LEYE = 19 21 | 22 | SCORE_THRES = 0.90 23 | 24 | 25 | def detect_faces(detector, image, min_size): 26 | preds = detector(image) # bgr 27 | # print(len(preds)) 28 | 29 | faces = [] 30 | for pred in preds: 31 | bb = pred['bbox'] 32 | score = bb[-1] 33 | if score < SCORE_THRES: 34 | continue 35 | 36 | left, top, right, bottom = bb[:4] 37 | cx = int((left + right) / 2) 38 | cy = int((top + bottom) / 2) 39 | fw = int(right - left) 40 | fh = int(bottom - top) 41 | 42 | lex, ley = pred['keypoints'][KP_LEYE, 0:2] 43 | rex, rey = pred['keypoints'][KP_REYE, 0:2] 44 | angle = math.atan2(ley - rey, lex - rex) 45 | angle = angle / math.pi * 180 46 | 47 | faces.append((cx, cy, fw, fh, angle)) 48 | 49 | faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順 50 | return faces 51 | 52 | 53 | def rotate_image(image, angle, cx, cy): 54 | h, w = image.shape[0:2] 55 | rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) 56 | 57 | # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 58 | # nh = max(h, int(w * math.sin(angle))) 59 | # nw = max(w, int(h * math.sin(angle))) 60 | # if nh > h or nw > w: 61 | # pad_y = nh - h 62 | # pad_t = pad_y // 2 63 | # pad_x = nw - w 64 | # pad_l = pad_x // 2 65 | # m = np.array([[0, 0, pad_l], 66 | # [0, 0, pad_t]]) 67 | # rot_mat = rot_mat + m 68 | # h, w = nh, nw 69 | # cx += pad_l 70 | # cy += pad_t 71 | 72 | result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) 73 | return result, cx, cy 74 | 75 | 76 | def process(args): 77 | assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" 78 | assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" 79 | 80 | # アニメ顔検出モデルを読み込む 81 | print("loading face detector.") 82 | detector = create_detector('yolov3') 83 | 84 | # cropの引数を解析する 85 | if args.crop_size is None: 86 | crop_width = crop_height = None 87 | else: 88 | tokens = args.crop_size.split(',') 89 | assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" 90 | crop_width, crop_height = [int(t) for t in tokens] 91 | 92 | if args.crop_ratio is None: 93 | crop_h_ratio = crop_v_ratio = None 94 | else: 95 | tokens = args.crop_ratio.split(',') 96 | assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" 97 | crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] 98 | 99 | # 画像を処理する 100 | print("processing.") 101 | output_extension = ".png" 102 | 103 | os.makedirs(args.dst_dir, exist_ok=True) 104 | paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ 105 | glob.glob(os.path.join(args.src_dir, "*.webp")) 106 | for path in tqdm(paths): 107 | basename = os.path.splitext(os.path.basename(path))[0] 108 | 109 | # image = cv2.imread(path) # 日本語ファイル名でエラーになる 110 | image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) 111 | if len(image.shape) == 2: 112 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 113 | if image.shape[2] == 4: 114 | print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") 115 | image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい 116 | 117 | h, w = image.shape[:2] 118 | 119 | faces = detect_faces(detector, image, args.multiple_faces) 120 | for i, face in enumerate(faces): 121 | cx, cy, fw, fh, angle = face 122 | face_size = max(fw, fh) 123 | if args.min_size is not None and face_size < args.min_size: 124 | continue 125 | if args.max_size is not None and face_size >= args.max_size: 126 | continue 127 | face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" 128 | 129 | # オプション指定があれば回転する 130 | face_img = image 131 | if args.rotate: 132 | face_img, cx, cy = rotate_image(face_img, angle, cx, cy) 133 | 134 | # オプション指定があれば顔を中心に切り出す 135 | if crop_width is not None or crop_h_ratio is not None: 136 | cur_crop_width, cur_crop_height = crop_width, crop_height 137 | if crop_h_ratio is not None: 138 | cur_crop_width = int(face_size * crop_h_ratio + .5) 139 | cur_crop_height = int(face_size * crop_v_ratio + .5) 140 | 141 | # リサイズを必要なら行う 142 | scale = 1.0 143 | if args.resize_face_size is not None: 144 | # 顔サイズを基準にリサイズする 145 | scale = args.resize_face_size / face_size 146 | if scale < cur_crop_width / w: 147 | print( 148 | f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") 149 | scale = cur_crop_width / w 150 | if scale < cur_crop_height / h: 151 | print( 152 | f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") 153 | scale = cur_crop_height / h 154 | elif crop_h_ratio is not None: 155 | # 倍率指定の時にはリサイズしない 156 | pass 157 | else: 158 | # 切り出しサイズ指定あり 159 | if w < cur_crop_width: 160 | print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") 161 | scale = cur_crop_width / w 162 | if h < cur_crop_height: 163 | print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") 164 | scale = cur_crop_height / h 165 | if args.resize_fit: 166 | scale = max(cur_crop_width / w, cur_crop_height / h) 167 | 168 | if scale != 1.0: 169 | w = int(w * scale + .5) 170 | h = int(h * scale + .5) 171 | face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) 172 | cx = int(cx * scale + .5) 173 | cy = int(cy * scale + .5) 174 | fw = int(fw * scale + .5) 175 | fh = int(fh * scale + .5) 176 | 177 | cur_crop_width = min(cur_crop_width, face_img.shape[1]) 178 | cur_crop_height = min(cur_crop_height, face_img.shape[0]) 179 | 180 | x = cx - cur_crop_width // 2 181 | cx = cur_crop_width // 2 182 | if x < 0: 183 | cx = cx + x 184 | x = 0 185 | elif x + cur_crop_width > w: 186 | cx = cx + (x + cur_crop_width - w) 187 | x = w - cur_crop_width 188 | face_img = face_img[:, x:x+cur_crop_width] 189 | 190 | y = cy - cur_crop_height // 2 191 | cy = cur_crop_height // 2 192 | if y < 0: 193 | cy = cy + y 194 | y = 0 195 | elif y + cur_crop_height > h: 196 | cy = cy + (y + cur_crop_height - h) 197 | y = h - cur_crop_height 198 | face_img = face_img[y:y + cur_crop_height] 199 | 200 | # # debug 201 | # print(path, cx, cy, angle) 202 | # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) 203 | # cv2.imshow("image", crp) 204 | # if cv2.waitKey() == 27: 205 | # break 206 | # cv2.destroyAllWindows() 207 | 208 | # debug 209 | if args.debug: 210 | cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) 211 | 212 | _, buf = cv2.imencode(output_extension, face_img) 213 | with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: 214 | buf.tofile(f) 215 | 216 | 217 | def setup_parser() -> argparse.ArgumentParser: 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") 220 | parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") 221 | parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") 222 | parser.add_argument("--resize_fit", action="store_true", 223 | help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") 224 | parser.add_argument("--resize_face_size", type=int, default=None, 225 | help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") 226 | parser.add_argument("--crop_size", type=str, default=None, 227 | help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") 228 | parser.add_argument("--crop_ratio", type=str, default=None, 229 | help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") 230 | parser.add_argument("--min_size", type=int, default=None, 231 | help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)") 232 | parser.add_argument("--max_size", type=int, default=None, 233 | help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)") 234 | parser.add_argument("--multiple_faces", action="store_true", 235 | help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") 236 | parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") 237 | 238 | return parser 239 | 240 | 241 | if __name__ == '__main__': 242 | parser = setup_parser() 243 | 244 | args = parser.parse_args() 245 | 246 | process(args) 247 | -------------------------------------------------------------------------------- /tools/merge_block_weighted.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/eyriewow/merge-models 2 | 3 | import os 4 | import argparse 5 | import re 6 | import torch 7 | from tqdm import tqdm 8 | 9 | 10 | NUM_INPUT_BLOCKS = 12 11 | NUM_MID_BLOCK = 1 12 | NUM_OUTPUT_BLOCKS = 12 13 | NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS 14 | 15 | 16 | def merge(args): 17 | if args.weights is None: 18 | weights = None 19 | else: 20 | weights = [float(w) for w in args.weights.split(',')] 21 | if len(weights) != NUM_TOTAL_BLOCKS: 22 | print(f"weights value must be {NUM_TOTAL_BLOCKS}.") 23 | return 24 | 25 | device = args.device 26 | print("loading", args.model_0) 27 | model_0 = torch.load(args.model_0, map_location=device) 28 | print("loading", args.model_1) 29 | model_1 = torch.load(args.model_1, map_location=device) 30 | theta_0 = model_0["state_dict"] 31 | theta_1 = model_1["state_dict"] 32 | alpha = args.base_alpha 33 | 34 | output_file = f'{args.output}-{str(alpha)[2:] + "0"}-bw.ckpt' 35 | 36 | # check if output file already exists, ask to overwrite 37 | if os.path.isfile(output_file): 38 | print("Output file already exists. Overwrite? (y/n)") 39 | while True: 40 | overwrite = input() 41 | if overwrite == "y": 42 | break 43 | elif overwrite == "n": 44 | print("Exiting...") 45 | return 46 | else: 47 | print("Please enter y or n") 48 | 49 | re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12 50 | re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1 51 | re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12 52 | 53 | for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not args.verbose else theta_0.keys()): 54 | if "model" in key and key in theta_1: 55 | current_alpha = alpha 56 | 57 | # check weighted and U-Net or not 58 | if weights is not None and 'model.diffusion_model.' in key: 59 | # check block index 60 | weight_index = -1 61 | 62 | if 'time_embed' in key: 63 | weight_index = 0 # before input blocks 64 | elif '.out.' in key: 65 | weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks 66 | else: 67 | m = re_inp.search(key) 68 | if m: 69 | inp_idx = int(m.groups()[0]) 70 | weight_index = inp_idx 71 | else: 72 | m = re_mid.search(key) 73 | if m: 74 | weight_index = NUM_INPUT_BLOCKS 75 | else: 76 | m = re_out.search(key) 77 | if m: 78 | out_idx = int(m.groups()[0]) 79 | weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx 80 | 81 | if weight_index >= NUM_TOTAL_BLOCKS: 82 | print(f"error. illegal block index: {key}") 83 | if weight_index >= 0: 84 | current_alpha = weights[weight_index] 85 | if args.verbose: 86 | print(f"weighted '{key}': {current_alpha}") 87 | 88 | theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key] 89 | 90 | for key in tqdm(theta_1.keys(), desc="Stage 2/2"): 91 | if "model" in key and key not in theta_0: 92 | theta_0[key] = theta_1[key] 93 | 94 | print("Saving...") 95 | 96 | torch.save({"state_dict": theta_0}, output_file) 97 | 98 | print("Done!") 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser(description="Merge two models with weights for each block") 103 | parser.add_argument("model_0", type=str, help="Path to model 0") 104 | parser.add_argument("model_1", type=str, help="Path to model 1") 105 | parser.add_argument("--base_alpha", type=float, 106 | help="Alpha value (for model 0) except U-Net, optional, defaults to 0.5", default=0.5, required=False) 107 | parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False) 108 | parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False) 109 | parser.add_argument("--weights", type=str, 110 | help=f"comma separated {NUM_TOTAL_BLOCKS} weights value (for model 0) for each U-Net block", default=None, required=False) 111 | parser.add_argument("--verbose", action='store_true', help="show each block weight", required=False) 112 | 113 | args = parser.parse_args() 114 | merge(args) 115 | -------------------------------------------------------------------------------- /tools/merge_vae.py: -------------------------------------------------------------------------------- 1 | # License of this file is ASL 2.0 2 | 3 | import argparse 4 | import torch 5 | 6 | 7 | VAE_PREFIX = "first_stage_model." 8 | 9 | # copy from convert_diffusers_to_original_stable_diffusion.py ASL 2.0 10 | 11 | # ================# 12 | # VAE Conversion # 13 | # ================# 14 | 15 | vae_conversion_map = [ 16 | # (stable-diffusion, HF Diffusers) 17 | ("nin_shortcut", "conv_shortcut"), 18 | ("norm_out", "conv_norm_out"), 19 | ("mid.attn_1.", "mid_block.attentions.0."), 20 | ] 21 | 22 | for i in range(4): 23 | # down_blocks have two resnets 24 | for j in range(2): 25 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 26 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 27 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 28 | 29 | if i < 3: 30 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 31 | sd_downsample_prefix = f"down.{i}.downsample." 32 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 33 | 34 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 35 | sd_upsample_prefix = f"up.{3-i}.upsample." 36 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 37 | 38 | # up_blocks have three resnets 39 | # also, up blocks in hf are numbered in reverse from sd 40 | for j in range(3): 41 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 42 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 43 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 44 | 45 | # this part accounts for mid blocks in both the encoder and the decoder 46 | for i in range(2): 47 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 48 | sd_mid_res_prefix = f"mid.block_{i+1}." 49 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 50 | 51 | 52 | vae_conversion_map_attn = [ 53 | # (stable-diffusion, HF Diffusers) 54 | ("norm.", "group_norm."), 55 | ("q.", "query."), 56 | ("k.", "key."), 57 | ("v.", "value."), 58 | ("proj_out.", "proj_attn."), 59 | ] 60 | 61 | 62 | def reshape_weight_for_sd(w): 63 | # convert HF linear weights to SD conv2d weights 64 | return w.reshape(*w.shape, 1, 1) 65 | 66 | 67 | def convert_vae_state_dict(vae_state_dict): 68 | mapping = {k: k for k in vae_state_dict.keys()} 69 | for k, v in mapping.items(): 70 | for sd_part, hf_part in vae_conversion_map: 71 | v = v.replace(hf_part, sd_part) 72 | mapping[k] = v 73 | for k, v in mapping.items(): 74 | if "attentions" in k: 75 | for sd_part, hf_part in vae_conversion_map_attn: 76 | v = v.replace(hf_part, sd_part) 77 | mapping[k] = v 78 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 79 | weights_to_convert = ["q", "k", "v", "proj_out"] 80 | for k, v in new_state_dict.items(): 81 | for weight_name in weights_to_convert: 82 | if f"mid.attn_1.{weight_name}.weight" in k: 83 | print(f"Reshaping {k} for SD format") 84 | new_state_dict[k] = reshape_weight_for_sd(v) 85 | return new_state_dict 86 | 87 | 88 | def convert_diffusers_vae(vae_path): 89 | vae_state_dict = torch.load(vae_path, map_location="cpu") 90 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 91 | return vae_state_dict 92 | 93 | 94 | def merge_vae(ckpt, vae, output): 95 | print(f"load checkpoint: {ckpt}") 96 | model = torch.load(ckpt, map_location="cpu") 97 | sd = model['state_dict'] 98 | 99 | full_model = False 100 | 101 | print(f"load VAE: {vae}") 102 | if vae.endswith(".bin"): 103 | print("convert diffusers VAE to stablediffusion") 104 | vae_sd = convert_diffusers_vae(vae) 105 | else: 106 | vae_model = torch.load(vae, map_location="cpu") 107 | vae_sd = vae_model['state_dict'] 108 | 109 | # vae only or full model 110 | for vae_key in vae_sd: 111 | if vae_key.startswith(VAE_PREFIX): 112 | full_model = True 113 | break 114 | 115 | count = 0 116 | for vae_key in vae_sd: 117 | sd_key = vae_key 118 | if full_model: 119 | if not sd_key.startswith(VAE_PREFIX): 120 | continue 121 | else: 122 | if sd_key not in sd: 123 | sd_key = VAE_PREFIX + sd_key 124 | if sd_key not in sd: 125 | print(f"key not exists in model: {vae_key}") 126 | continue 127 | sd[sd_key] = vae_sd[vae_key] 128 | count += 1 129 | print(f"{count} weights are copied") 130 | 131 | print(f"saving checkpoint to: {output}") 132 | torch.save(model, output) 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument("ckpt", type=str, help="target checkpoint to replace VAE / マージ対象のモデルcheckpoint") 138 | parser.add_argument("vae", type=str, help="VAE/model checkpoint to merge / マージするVAEまたはモデルのcheckpoint") 139 | parser.add_argument("output", type=str, help="output checkoint / 出力先checkpoint") 140 | args = parser.parse_args() 141 | 142 | merge_vae(args.ckpt, args.vae, args.output) 143 | -------------------------------------------------------------------------------- /tools/original_control_net.py: -------------------------------------------------------------------------------- 1 | from typing import List, NamedTuple, Any 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from safetensors.torch import load_file 6 | 7 | from diffusers import UNet2DConditionModel 8 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 9 | 10 | import library.model_util as model_util 11 | 12 | 13 | class ControlNetInfo(NamedTuple): 14 | unet: Any 15 | net: Any 16 | prep: Any 17 | weight: float 18 | ratio: float 19 | 20 | 21 | class ControlNet(torch.nn.Module): 22 | def __init__(self) -> None: 23 | super().__init__() 24 | 25 | # make control model 26 | self.control_model = torch.nn.Module() 27 | 28 | dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] 29 | zero_convs = torch.nn.ModuleList() 30 | for i, dim in enumerate(dims): 31 | sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) 32 | zero_convs.append(sub_list) 33 | self.control_model.add_module("zero_convs", zero_convs) 34 | 35 | middle_block_out = torch.nn.Conv2d(1280, 1280, 1) 36 | self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) 37 | 38 | dims = [16, 16, 32, 32, 96, 96, 256, 320] 39 | strides = [1, 1, 2, 1, 2, 1, 2, 1] 40 | prev_dim = 3 41 | input_hint_block = torch.nn.Sequential() 42 | for i, (dim, stride) in enumerate(zip(dims, strides)): 43 | input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) 44 | if i < len(dims) - 1: 45 | input_hint_block.append(torch.nn.SiLU()) 46 | prev_dim = dim 47 | self.control_model.add_module("input_hint_block", input_hint_block) 48 | 49 | 50 | def load_control_net(v2, unet, model): 51 | device = unet.device 52 | 53 | # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む 54 | # state dictを読み込む 55 | print(f"ControlNet: loading control SD model : {model}") 56 | 57 | if model_util.is_safetensors(model): 58 | ctrl_sd_sd = load_file(model) 59 | else: 60 | ctrl_sd_sd = torch.load(model, map_location='cpu') 61 | ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) 62 | 63 | # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む 64 | is_difference = "difference" in ctrl_sd_sd 65 | print("ControlNet: loading difference") 66 | 67 | # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく 68 | # またTransfer Controlの元weightとなる 69 | ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) 70 | 71 | # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける 72 | for key in list(ctrl_unet_sd_sd.keys()): 73 | ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() 74 | 75 | zero_conv_sd = {} 76 | for key in list(ctrl_sd_sd.keys()): 77 | if key.startswith("control_"): 78 | unet_key = "model.diffusion_" + key[len("control_"):] 79 | if unet_key not in ctrl_unet_sd_sd: # zero conv 80 | zero_conv_sd[key] = ctrl_sd_sd[key] 81 | continue 82 | if is_difference: # Transfer Control 83 | ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) 84 | else: 85 | ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) 86 | 87 | unet_config = model_util.create_unet_diffusers_config(v2) 88 | ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict 89 | 90 | # ControlNetのU-Netを作成する 91 | ctrl_unet = UNet2DConditionModel(**unet_config) 92 | info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) 93 | print("ControlNet: loading Control U-Net:", info) 94 | 95 | # U-Net以外のControlNetを作成する 96 | # TODO support middle only 97 | ctrl_net = ControlNet() 98 | info = ctrl_net.load_state_dict(zero_conv_sd) 99 | print("ControlNet: loading ControlNet:", info) 100 | 101 | ctrl_unet.to(unet.device, dtype=unet.dtype) 102 | ctrl_net.to(unet.device, dtype=unet.dtype) 103 | return ctrl_unet, ctrl_net 104 | 105 | 106 | def load_preprocess(prep_type: str): 107 | if prep_type is None or prep_type.lower() == "none": 108 | return None 109 | 110 | if prep_type.startswith("canny"): 111 | args = prep_type.split("_") 112 | th1 = int(args[1]) if len(args) >= 2 else 63 113 | th2 = int(args[2]) if len(args) >= 3 else 191 114 | 115 | def canny(img): 116 | img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 117 | return cv2.Canny(img, th1, th2) 118 | return canny 119 | 120 | print("Unsupported prep type:", prep_type) 121 | return None 122 | 123 | 124 | def preprocess_ctrl_net_hint_image(image): 125 | image = np.array(image).astype(np.float32) / 255.0 126 | image = image[:, :, ::-1].copy() # rgb to bgr 127 | image = image[None].transpose(0, 3, 1, 2) # nchw 128 | image = torch.from_numpy(image) 129 | return image # 0 to 1 130 | 131 | 132 | def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): 133 | guided_hints = [] 134 | for i, cnet_info in enumerate(control_nets): 135 | # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること 136 | b_hints = [] 137 | if len(hints) == 1: # すべて同じ画像をhintとして使う 138 | hint = hints[0] 139 | if cnet_info.prep is not None: 140 | hint = cnet_info.prep(hint) 141 | hint = preprocess_ctrl_net_hint_image(hint) 142 | b_hints = [hint for _ in range(b_size)] 143 | else: 144 | for bi in range(b_size): 145 | hint = hints[(bi * len(control_nets) + i) % len(hints)] 146 | if cnet_info.prep is not None: 147 | hint = cnet_info.prep(hint) 148 | hint = preprocess_ctrl_net_hint_image(hint) 149 | b_hints.append(hint) 150 | b_hints = torch.cat(b_hints, dim=0) 151 | b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) 152 | 153 | guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) 154 | guided_hints.append(guided_hint) 155 | return guided_hints 156 | 157 | 158 | def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): 159 | # ControlNet 160 | # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する 161 | cnet_cnt = len(control_nets) 162 | cnet_idx = step % cnet_cnt 163 | cnet_info = control_nets[cnet_idx] 164 | 165 | # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) 166 | if cnet_info.ratio < current_ratio: 167 | return original_unet(sample, timestep, encoder_hidden_states) 168 | 169 | guided_hint = guided_hints[cnet_idx] 170 | guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) 171 | outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) 172 | outs = [o * cnet_info.weight for o in outs] 173 | 174 | # U-Net 175 | return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) 176 | 177 | 178 | """ 179 | # これはmergeのバージョン 180 | # ControlNet 181 | cnet_outs_list = [] 182 | for i, cnet_info in enumerate(control_nets): 183 | # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) 184 | if cnet_info.ratio < current_ratio: 185 | continue 186 | guided_hint = guided_hints[i] 187 | outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) 188 | for i in range(len(outs)): 189 | outs[i] *= cnet_info.weight 190 | 191 | cnet_outs_list.append(outs) 192 | 193 | count = len(cnet_outs_list) 194 | if count == 0: 195 | return original_unet(sample, timestep, encoder_hidden_states) 196 | 197 | # sum of controlnets 198 | for i in range(1, count): 199 | cnet_outs_list[0] += cnet_outs_list[i] 200 | 201 | # U-Net 202 | return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states) 203 | """ 204 | 205 | 206 | def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): 207 | # copy from UNet2DConditionModel 208 | default_overall_up_factor = 2**unet.num_upsamplers 209 | 210 | forward_upsample_size = False 211 | upsample_size = None 212 | 213 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 214 | print("Forward upsample size to force interpolation output size.") 215 | forward_upsample_size = True 216 | 217 | # 0. center input if necessary 218 | if unet.config.center_input_sample: 219 | sample = 2 * sample - 1.0 220 | 221 | # 1. time 222 | timesteps = timestep 223 | if not torch.is_tensor(timesteps): 224 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 225 | # This would be a good case for the `match` statement (Python 3.10+) 226 | is_mps = sample.device.type == "mps" 227 | if isinstance(timestep, float): 228 | dtype = torch.float32 if is_mps else torch.float64 229 | else: 230 | dtype = torch.int32 if is_mps else torch.int64 231 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 232 | elif len(timesteps.shape) == 0: 233 | timesteps = timesteps[None].to(sample.device) 234 | 235 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 236 | timesteps = timesteps.expand(sample.shape[0]) 237 | 238 | t_emb = unet.time_proj(timesteps) 239 | 240 | # timesteps does not contain any weights and will always return f32 tensors 241 | # but time_embedding might actually be running in fp16. so we need to cast here. 242 | # there might be better ways to encapsulate this. 243 | t_emb = t_emb.to(dtype=unet.dtype) 244 | emb = unet.time_embedding(t_emb) 245 | 246 | outs = [] # output of ControlNet 247 | zc_idx = 0 248 | 249 | # 2. pre-process 250 | sample = unet.conv_in(sample) 251 | if is_control_net: 252 | sample += guided_hint 253 | outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) 254 | zc_idx += 1 255 | 256 | # 3. down 257 | down_block_res_samples = (sample,) 258 | for downsample_block in unet.down_blocks: 259 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 260 | sample, res_samples = downsample_block( 261 | hidden_states=sample, 262 | temb=emb, 263 | encoder_hidden_states=encoder_hidden_states, 264 | ) 265 | else: 266 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 267 | if is_control_net: 268 | for rs in res_samples: 269 | outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) 270 | zc_idx += 1 271 | 272 | down_block_res_samples += res_samples 273 | 274 | # 4. mid 275 | sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 276 | if is_control_net: 277 | outs.append(control_net.control_model.middle_block_out[0](sample)) 278 | return outs 279 | 280 | if not is_control_net: 281 | sample += ctrl_outs.pop() 282 | 283 | # 5. up 284 | for i, upsample_block in enumerate(unet.up_blocks): 285 | is_final_block = i == len(unet.up_blocks) - 1 286 | 287 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 288 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 289 | 290 | if not is_control_net and len(ctrl_outs) > 0: 291 | res_samples = list(res_samples) 292 | apply_ctrl_outs = ctrl_outs[-len(res_samples):] 293 | ctrl_outs = ctrl_outs[:-len(res_samples)] 294 | for j in range(len(res_samples)): 295 | res_samples[j] = res_samples[j] + apply_ctrl_outs[j] 296 | res_samples = tuple(res_samples) 297 | 298 | # if we have not reached the final block and need to forward the 299 | # upsample size, we do it here 300 | if not is_final_block and forward_upsample_size: 301 | upsample_size = down_block_res_samples[-1].shape[2:] 302 | 303 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 304 | sample = upsample_block( 305 | hidden_states=sample, 306 | temb=emb, 307 | res_hidden_states_tuple=res_samples, 308 | encoder_hidden_states=encoder_hidden_states, 309 | upsample_size=upsample_size, 310 | ) 311 | else: 312 | sample = upsample_block( 313 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 314 | ) 315 | # 6. post-process 316 | sample = unet.conv_norm_out(sample) 317 | sample = unet.conv_act(sample) 318 | sample = unet.conv_out(sample) 319 | 320 | return UNet2DConditionOutput(sample=sample) 321 | -------------------------------------------------------------------------------- /tools/resize_images_to_resolution.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import argparse 5 | import shutil 6 | import math 7 | from PIL import Image 8 | import numpy as np 9 | 10 | 11 | def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): 12 | # Split the max_resolution string by "," and strip any whitespaces 13 | max_resolutions = [res.strip() for res in max_resolution.split(',')] 14 | 15 | # # Calculate max_pixels from max_resolution string 16 | # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 17 | 18 | # Create destination folder if it does not exist 19 | if not os.path.exists(dst_img_folder): 20 | os.makedirs(dst_img_folder) 21 | 22 | # Select interpolation method 23 | if interpolation == 'lanczos4': 24 | cv2_interpolation = cv2.INTER_LANCZOS4 25 | elif interpolation == 'cubic': 26 | cv2_interpolation = cv2.INTER_CUBIC 27 | else: 28 | cv2_interpolation = cv2.INTER_AREA 29 | 30 | # Iterate through all files in src_img_folder 31 | img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py 32 | for filename in os.listdir(src_img_folder): 33 | # Check if the image is png, jpg or webp etc... 34 | if not filename.endswith(img_exts): 35 | # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) 36 | shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) 37 | continue 38 | 39 | # Load image 40 | # img = cv2.imread(os.path.join(src_img_folder, filename)) 41 | image = Image.open(os.path.join(src_img_folder, filename)) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | img = np.array(image, np.uint8) 45 | 46 | base, _ = os.path.splitext(filename) 47 | for max_resolution in max_resolutions: 48 | # Calculate max_pixels from max_resolution string 49 | max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 50 | 51 | # Calculate current number of pixels 52 | current_pixels = img.shape[0] * img.shape[1] 53 | 54 | # Check if the image needs resizing 55 | if current_pixels > max_pixels: 56 | # Calculate scaling factor 57 | scale_factor = max_pixels / current_pixels 58 | 59 | # Calculate new dimensions 60 | new_height = int(img.shape[0] * math.sqrt(scale_factor)) 61 | new_width = int(img.shape[1] * math.sqrt(scale_factor)) 62 | 63 | # Resize image 64 | img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) 65 | else: 66 | new_height, new_width = img.shape[0:2] 67 | 68 | # Calculate the new height and width that are divisible by divisible_by (with/without resizing) 69 | new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by 70 | new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by 71 | 72 | # Center crop the image to the calculated dimensions 73 | y = int((img.shape[0] - new_height) / 2) 74 | x = int((img.shape[1] - new_width) / 2) 75 | img = img[y:y + new_height, x:x + new_width] 76 | 77 | # Split filename into base and extension 78 | new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') 79 | 80 | # Save resized image in dst_img_folder 81 | # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) 82 | image = Image.fromarray(img) 83 | image.save(os.path.join(dst_img_folder, new_filename), quality=100) 84 | 85 | proc = "Resized" if current_pixels > max_pixels else "Saved" 86 | print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") 87 | 88 | # If other files with same basename, copy them with resolution suffix 89 | if copy_associated_files: 90 | asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) 91 | for asoc_file in asoc_files: 92 | ext = os.path.splitext(asoc_file)[1] 93 | if ext in img_exts: 94 | continue 95 | for max_resolution in max_resolutions: 96 | new_asoc_file = base + '+' + max_resolution + ext 97 | print(f"Copy {asoc_file} as {new_asoc_file}") 98 | shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) 99 | 100 | 101 | def setup_parser() -> argparse.ArgumentParser: 102 | parser = argparse.ArgumentParser( 103 | description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') 104 | parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') 105 | parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') 106 | parser.add_argument('--max_resolution', type=str, 107 | help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") 108 | parser.add_argument('--divisible_by', type=int, 109 | help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) 110 | parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], 111 | default='area', help='Interpolation method for resizing / リサイズ時の補完方法') 112 | parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') 113 | parser.add_argument('--copy_associated_files', action='store_true', 114 | help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') 115 | 116 | return parser 117 | 118 | 119 | def main(): 120 | parser = setup_parser() 121 | 122 | args = parser.parse_args() 123 | resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, 124 | args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | --------------------------------------------------------------------------------