├── .gitignore ├── patches ├── shared.py ├── external_pr │ ├── sd_hijack_checkpoint.py │ ├── dadapt_test │ │ └── install.py │ ├── dataset.py │ ├── ui.py │ ├── textual_inversion.py │ └── hypernetwork.py ├── hnutil.py ├── textual_inversion.py ├── hashes_backup.py ├── ddpm_hijack.py ├── clip_hijack.py ├── tbutils.py ├── ui.py ├── scheduler.py ├── dataset.py ├── hypernetworks.py └── hypernetwork.py ├── README.md └── scripts └── hypernetwork-extensions.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /patches/shared.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.shared import cmd_opts, opts 3 | import modules.shared 4 | 5 | version_flag = hasattr(modules.shared, 'loaded_hypernetwork') 6 | 7 | def reload_hypernetworks(): 8 | from .hypernetwork import list_hypernetworks, load_hypernetwork 9 | modules.shared.hypernetworks = list_hypernetworks(cmd_opts.hypernetwork_dir) 10 | if hasattr(modules.shared, 'loaded_hypernetwork'): 11 | load_hypernetwork(opts.sd_hypernetwork) 12 | 13 | 14 | try: 15 | modules.shared.reload_hypernetworks = reload_hypernetworks 16 | except: 17 | pass 18 | -------------------------------------------------------------------------------- /patches/external_pr/sd_hijack_checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch.utils.checkpoint import checkpoint 2 | 3 | 4 | def BasicTransformerBlock_forward(self, x, context=None): 5 | return checkpoint(self._forward, x, context) 6 | 7 | def AttentionBlock_forward(self, x): 8 | return checkpoint(self._forward, x) 9 | 10 | def ResBlock_forward(self, x, emb): 11 | return checkpoint(self._forward, x, emb) 12 | 13 | 14 | try: 15 | import ldm.modules.attention 16 | import ldm.modules.diffusionmodules.model 17 | import ldm.modules.diffusionmodules.openaimodel 18 | ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward 19 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward 20 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward 21 | except: 22 | pass 23 | -------------------------------------------------------------------------------- /patches/external_pr/dadapt_test/install.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def install_or_import() -> bool: 4 | try: 5 | import pip 6 | try: 7 | import dadaptation 8 | except (ModuleNotFoundError, ImportError): 9 | print("Trying to install dadaptation...") 10 | pip.main(['install', 'dadaptation']) 11 | return True 12 | except (ModuleNotFoundError, ImportError): 13 | print("Cannot found pip!") 14 | return False 15 | return True 16 | 17 | 18 | def get_dadapt_adam(optimizer_name=None): 19 | if install_or_import(): 20 | if optimizer_name is None or optimizer_name in ['DAdaptAdamW', 'AdamW', 'DAdaptAdam', 'Adam']: # Adam-dadapt implementation 21 | try: 22 | from dadaptation.dadapt_adam import DAdaptAdam 23 | return DAdaptAdam 24 | except (ModuleNotFoundError, ImportError): 25 | print('Cannot use DAdaptAdam!') 26 | elif optimizer_name == 'DAdaptSGD': 27 | try: 28 | from dadaptation.dadapt_sgd import DAdaptSGD 29 | return DAdaptSGD 30 | except (ModuleNotFoundError, ImportError): 31 | print('Cannot use DAdaptSGD!') 32 | elif optimizer_name == 'DAdaptAdagrad': 33 | try: 34 | from dadaptation.dadapt_adagrad import DAdaptAdaGrad 35 | return DAdaptAdaGrad 36 | except (ModuleNotFoundError, ImportError): 37 | print('Cannot use DAdaptAdaGrad!') 38 | from torch.optim import AdamW 39 | return AdamW 40 | -------------------------------------------------------------------------------- /patches/hnutil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import modules.shared 4 | 5 | 6 | def find_self(self): 7 | for k, v in modules.shared.hypernetworks.items(): 8 | if v == self: 9 | return k 10 | return None 11 | 12 | 13 | def optim_to(optim:torch.optim.Optimizer, device="cpu"): 14 | def inplace_move(obj: torch.Tensor, target): 15 | if hasattr(obj, 'data'): 16 | obj.data = obj.data.to(target) 17 | if hasattr(obj, '_grad') and obj._grad is not None: 18 | obj._grad.data = obj._grad.data.to(target) 19 | if isinstance(optim, torch.optim.Optimizer): 20 | for param in optim.state.values(): 21 | if isinstance(param, torch.Tensor): 22 | inplace_move(param, device) 23 | elif isinstance(param, dict): 24 | for subparams in param.values(): 25 | inplace_move(subparams, device) 26 | torch.cuda.empty_cache() 27 | 28 | 29 | def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): 30 | if layer_structure is None: 31 | layer_structure = [1, 2, 1] 32 | if not use_dropout: 33 | return [0] * len(layer_structure) 34 | dropout_values = [0] 35 | dropout_values.extend([0.3] * (len(layer_structure) - 3)) 36 | if last_layer_dropout: 37 | dropout_values.append(0.3) 38 | else: 39 | dropout_values.append(0) 40 | dropout_values.append(0) 41 | return dropout_values 42 | 43 | 44 | def get_closest(val): 45 | i, j = divmod(val,64) 46 | return i*64 + (j!=0) * 64 -------------------------------------------------------------------------------- /patches/textual_inversion.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import modules.textual_inversion.textual_inversion 5 | from modules import shared 6 | 7 | delayed_values = {} 8 | 9 | 10 | def write_loss(log_directory, filename, step, epoch_len, values): 11 | if shared.opts.training_write_csv_every == 0: 12 | return 13 | 14 | if (step + 1) % shared.opts.training_write_csv_every != 0: 15 | return 16 | write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True 17 | try: 18 | with open(os.path.join(log_directory, filename), "a+", newline='') as fout: 19 | csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) 20 | 21 | if write_csv_header: 22 | csv_writer.writeheader() 23 | if log_directory + filename in delayed_values: 24 | delayed = delayed_values[log_directory + filename] 25 | for step, epoch, epoch_step, values in delayed: 26 | csv_writer.writerow({ 27 | "step": step, 28 | "epoch": epoch, 29 | "epoch_step": epoch_step + 1, 30 | **values, 31 | }) 32 | delayed.clear() 33 | epoch = step // epoch_len 34 | epoch_step = step % epoch_len 35 | csv_writer.writerow({ 36 | "step": step + 1, 37 | "epoch": epoch, 38 | "epoch_step": epoch_step + 1, 39 | **values, 40 | }) 41 | except OSError: 42 | epoch, epoch_step = divmod(step, epoch_len) 43 | if log_directory + filename in delayed_values: 44 | delayed_values[log_directory + filename].append((step + 1, epoch, epoch_step, values)) 45 | else: 46 | delayed_values[log_directory + filename] = [(step+1, epoch, epoch_step, values)] 47 | 48 | modules.textual_inversion.textual_inversion.write_loss = write_loss -------------------------------------------------------------------------------- /patches/hashes_backup.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os.path 4 | 5 | import filelock 6 | 7 | # This is full copy of modules/hashes. This will be only loaded if compatibility issue happens due to version mismatch. 8 | cache_filename = "cache.json" 9 | cache_data = None 10 | blksize = 1 << 20 11 | 12 | def dump_cache(): 13 | with filelock.FileLock(cache_filename+".lock"): 14 | with open(cache_filename, "w", encoding="utf8") as file: 15 | json.dump(cache_data, file, indent=4) 16 | 17 | 18 | def cache(subsection): 19 | global cache_data 20 | 21 | if cache_data is None: 22 | with filelock.FileLock(cache_filename+".lock"): 23 | if not os.path.isfile(cache_filename): 24 | cache_data = {} 25 | else: 26 | with open(cache_filename, "r", encoding="utf8") as file: 27 | cache_data = json.load(file) 28 | 29 | s = cache_data.get(subsection, {}) 30 | cache_data[subsection] = s 31 | 32 | return s 33 | 34 | 35 | def calculate_sha256(filename): 36 | hash_sha256 = hashlib.sha256() 37 | global blksize 38 | with open(filename, "rb") as f: 39 | for chunk in iter(lambda: f.read(blksize), b""): 40 | hash_sha256.update(chunk) 41 | 42 | return hash_sha256.hexdigest() 43 | 44 | 45 | def sha256_from_cache(filename, title): 46 | hashes = cache("hashes") 47 | ondisk_mtime = os.path.getmtime(filename) 48 | 49 | if title not in hashes: 50 | return None 51 | 52 | cached_sha256 = hashes[title].get("sha256", None) 53 | cached_mtime = hashes[title].get("mtime", 0) 54 | 55 | if ondisk_mtime > cached_mtime or cached_sha256 is None: 56 | return None 57 | 58 | return cached_sha256 59 | 60 | 61 | def sha256(filename, title): 62 | hashes = cache("hashes") 63 | 64 | sha256_value = sha256_from_cache(filename, title) 65 | if sha256_value is not None: 66 | return sha256_value 67 | 68 | print(f"Calculating sha256 for {filename}: ", end='') 69 | sha256_value = calculate_sha256(filename) 70 | print(f"{sha256_value}") 71 | 72 | hashes[title] = { 73 | "mtime": os.path.getmtime(filename), 74 | "sha256": sha256_value, 75 | } 76 | 77 | dump_cache() 78 | 79 | return sha256_value 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /patches/ddpm_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ldm.models.diffusion.ddpm 3 | from modules import shared 4 | 5 | 6 | class Scheduler: 7 | """ Proportional Noise Step Scheduler""" 8 | def __init__(self, cycle_step=128, repeat=True): 9 | self.disabled = True 10 | self.cycle_step = int(cycle_step) 11 | self.repeat = repeat 12 | self.run_assertion() 13 | 14 | def __call__(self, value, step): 15 | if self.disabled: 16 | return value 17 | if self.repeat: 18 | step %= self.cycle_step 19 | return max(1, int(value * step / self.cycle_step)) 20 | else: 21 | return value if step >= self.cycle_step else max(1, int(value * step / self.cycle_step)) 22 | 23 | def run_assertion(self): 24 | assert type(self.cycle_step) is int 25 | assert type(self.repeat) is bool 26 | assert not self.repeat or self.cycle_step > 0 27 | 28 | def set(self, cycle_step=-1, repeat=-1, disabled=True): 29 | self.disabled = disabled 30 | if cycle_step >= 0: 31 | self.cycle_step = int(cycle_step) 32 | if repeat != -1: 33 | self.repeat = repeat 34 | self.run_assertion() 35 | 36 | 37 | training_scheduler = Scheduler(cycle_step=-1, repeat=False) 38 | 39 | 40 | def get_current(value, step=None): 41 | if step is None: 42 | if hasattr(shared, 'accessible_hypernetwork'): 43 | hypernetwork = shared.accessible_hypernetwork 44 | else: 45 | return value 46 | if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None: 47 | return training_scheduler(value, hypernetwork.step) 48 | return value 49 | return max(1, training_scheduler(value, step)) 50 | 51 | 52 | def set_scheduler(cycle_step, repeat, enabled=False): 53 | global training_scheduler 54 | training_scheduler.set(cycle_step, repeat, not enabled) 55 | 56 | 57 | def forward(self, x, c, *args, **kwargs): 58 | t = torch.randint(0, get_current(self.num_timesteps), (x.shape[0],), device=self.device).long() 59 | if self.model.conditioning_key is not None: 60 | assert c is not None 61 | if self.cond_stage_trainable: 62 | c = self.get_learned_conditioning(c) 63 | if self.shorten_cond_schedule: # TODO: drop this option 64 | tc = self.cond_ids[t].to(self.device) 65 | c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) 66 | return self.p_losses(x, c, t, *args, **kwargs) 67 | 68 | 69 | 70 | 71 | ldm.models.diffusion.ddpm.LatentDiffusion.forward = forward 72 | -------------------------------------------------------------------------------- /patches/clip_hijack.py: -------------------------------------------------------------------------------- 1 | from modules import sd_hijack_clip, sd_hijack, shared 2 | from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations 3 | try: 4 | from modules.sd_hijack import fix_checkpoint 5 | def clear_any_hijacks(): 6 | StableDiffusionModelHijack.hijack = default_hijack 7 | except (ModuleNotFoundError, ImportError): 8 | from modules.sd_hijack_checkpoint import add, remove 9 | def fix_checkpoint(): 10 | add() 11 | 12 | def clear_any_hijacks(): 13 | remove() 14 | StableDiffusionModelHijack.hijack = default_hijack 15 | 16 | 17 | import ldm.modules.encoders.modules 18 | 19 | default_hijack = StableDiffusionModelHijack.hijack 20 | 21 | def trigger_sd_hijack(enabled, pretrained_key): 22 | clear_any_hijacks() 23 | if not enabled or pretrained_key == '': 24 | pretrained_key = 'openai/clip-vit-large-patch14' 25 | StableDiffusionModelHijack.hijack = create_lambda(pretrained_key) 26 | print("Hijacked clip text model!") 27 | sd_hijack.model_hijack.undo_hijack(shared.sd_model) 28 | sd_hijack.model_hijack.hijack(shared.sd_model) 29 | if not enabled: 30 | StableDiffusionModelHijack.hijack = default_hijack 31 | 32 | 33 | 34 | 35 | def create_lambda(model): 36 | def hijack_lambda(self, m): 37 | if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: 38 | from transformers import CLIPTextModel, CLIPTokenizer 39 | print(f"Changing CLIP model to {model}") 40 | try: 41 | m.cond_stage_model.transformer = CLIPTextModel.from_pretrained( 42 | model).to(m.cond_stage_model.transformer.device) 43 | m.cond_stage_model.transformer.requires_grad_(False) 44 | m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained( 45 | model) 46 | except: 47 | print(f"Cannot initiate from given model key {model}!") 48 | 49 | model_embeddings = m.cond_stage_model.transformer.text_model.embeddings 50 | model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) 51 | m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) 52 | 53 | self.optimization_method = apply_optimizations() 54 | 55 | self.clip = m.cond_stage_model 56 | 57 | fix_checkpoint() 58 | 59 | 60 | def flatten(el): 61 | flattened = [flatten(children) for children in el.children()] 62 | res = [el] 63 | for c in flattened: 64 | res += c 65 | return res 66 | 67 | self.layers = flatten(m) 68 | else: 69 | print("CLIP change can be only applied to FrozenCLIPEmbedder class") 70 | return default_hijack(self, m) 71 | return hijack_lambda 72 | -------------------------------------------------------------------------------- /patches/tbutils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | from modules import shared 8 | 9 | 10 | def tensorboard_setup(log_directory): 11 | os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) 12 | return SummaryWriter( 13 | log_dir=os.path.join(log_directory, "tensorboard"), 14 | flush_secs=shared.opts.training_tensorboard_flush_every) 15 | 16 | def tensorboard_log_hyperparameter(tensorboard_writer:SummaryWriter, **kwargs): 17 | for keys in kwargs: 18 | if type(kwargs[keys]) not in [bool, str, float, int,None]: 19 | kwargs[keys] = str(kwargs[keys]) 20 | tensorboard_writer.add_hparams({ 21 | 'lr' : kwargs.get('lr', 0.01), 22 | 'GA steps' : kwargs.get('GA_steps', 1), 23 | 'bsize' : kwargs.get('batch_size', 1), 24 | 'layer structure' : kwargs.get('layer_structure', '1,2,1'), 25 | 'activation' : kwargs.get('activation', 'Linear'), 26 | 'weight_init' : kwargs.get('weight_init', 'Normal'), 27 | 'dropout_structure' : kwargs.get('dropout_structure', '0,0,0'), 28 | 'steps' : kwargs.get('max_steps', 10000), 29 | 'latent sampling': kwargs.get('latent_sampling_method', 'once'), 30 | 'template file': kwargs.get('template', 'nothing'), 31 | 'CosineAnnealing' : kwargs.get('CosineAnnealing', False), 32 | 'beta_repeat epoch': kwargs.get('beta_repeat_epoch', 0), 33 | 'epoch_mult':kwargs.get('epoch_mult', 1), 34 | 'warmup_step' : kwargs.get('warmup', 5), 35 | 'min_lr' : kwargs.get('min_lr', 6e-7), 36 | 'decay' : kwargs.get('gamma_rate', 1), 37 | 'adamW' : kwargs.get('adamW_opts', False), 38 | 'adamW_decay' : kwargs.get('adamW_decay', 0.01), 39 | 'adamW_beta1' : kwargs.get('adamW_beta_1', 0.9), 40 | 'adamW_beta2': kwargs.get('adamW_beta_2', 0.99), 41 | 'adamW_eps': kwargs.get('adamW_eps', 1e-8), 42 | 'gradient_clip_opt':kwargs.get('gradient_clip', 'None'), 43 | 'gradient_clip_value' : kwargs.get('gradient_clip_value', 1e-1), 44 | 'gradient_clip_norm' : kwargs.get('gradient_clip_norm_type', 2) 45 | }, 46 | {'hparam/loss' : kwargs.get('loss', 0.0)} 47 | ) 48 | def tensorboard_add(tensorboard_writer:SummaryWriter, loss, global_step, step, learn_rate, epoch_num, base_name=""): 49 | prefix = base_name + "/" if base_name else "" 50 | tensorboard_add_scaler(tensorboard_writer, prefix+"Loss/train", loss, global_step) 51 | tensorboard_add_scaler(tensorboard_writer, prefix+f"Loss/train/epoch-{epoch_num}", loss, step) 52 | tensorboard_add_scaler(tensorboard_writer, prefix+"Learn rate/train", learn_rate, global_step) 53 | tensorboard_add_scaler(tensorboard_writer, prefix+f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) 54 | 55 | 56 | def tensorboard_add_scaler(tensorboard_writer:SummaryWriter, tag, value, step): 57 | tensorboard_writer.add_scalar(tag=tag, 58 | scalar_value=value, global_step=step) 59 | 60 | 61 | def tensorboard_add_image(tensorboard_writer:SummaryWriter, tag, pil_image, step, base_name=""): 62 | # Convert a pil image to a torch tensor 63 | prefix = base_name + "/" if base_name else "" 64 | img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) 65 | img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], 66 | len(pil_image.getbands())) 67 | img_tensor = img_tensor.permute((2, 0, 1)) 68 | 69 | tensorboard_writer.add_image(prefix+tag, img_tensor, global_step=step) 70 | -------------------------------------------------------------------------------- /patches/ui.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modules import shared 4 | from .hypernetwork import Hypernetwork, load_hypernetwork 5 | 6 | 7 | def create_hypernetwork_load(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None, 8 | weight_init_seed=None, normal_std=0.01, skip_connection=False): 9 | # Remove illegal characters from name. 10 | name = "".join( x for x in name if (x.isalnum() or x in "._- ")) 11 | assert name, "Name cannot be empty!" 12 | fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") 13 | if not overwrite_old: 14 | assert not os.path.exists(fn), f"file {fn} already exists" 15 | 16 | if type(layer_structure) == str: 17 | layer_structure = [float(x.strip()) for x in layer_structure.split(",")] 18 | 19 | if dropout_structure and type(dropout_structure) == str: 20 | dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")] 21 | normal_std = float(normal_std) 22 | assert normal_std > 0, "Normal Standard Deviation should be bigger than 0!" 23 | hypernet = Hypernetwork( 24 | name=name, 25 | enable_sizes=[int(x) for x in enable_sizes], 26 | layer_structure=layer_structure, 27 | activation_func=activation_func, 28 | weight_init=weight_init, 29 | add_layer_norm=add_layer_norm, 30 | use_dropout=use_dropout, 31 | dropout_structure=dropout_structure if use_dropout and dropout_structure else [0] * len(layer_structure), 32 | optional_info=optional_info, 33 | generation_seed=weight_init_seed if weight_init_seed != -1 else None, 34 | normal_std=normal_std, 35 | skip_connection=skip_connection 36 | ) 37 | hypernet.save(fn) 38 | shared.reload_hypernetworks() 39 | hypernet = load_hypernetwork(name) 40 | assert hypernet is not None, f"Cannot load from {name}!" 41 | return hypernet 42 | 43 | 44 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None, optional_info=None, 45 | weight_init_seed=None, normal_std=0.01, skip_connection=False): 46 | # Remove illegal characters from name. 47 | name = "".join( x for x in name if (x.isalnum() or x in "._- ")) 48 | assert name, "Name cannot be empty!" 49 | fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") 50 | if not overwrite_old: 51 | assert not os.path.exists(fn), f"file {fn} already exists" 52 | 53 | if type(layer_structure) == str: 54 | layer_structure = [float(x.strip()) for x in layer_structure.split(",")] 55 | 56 | if dropout_structure and type(dropout_structure) == str: 57 | dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")] 58 | normal_std = float(normal_std) 59 | assert normal_std >= 0, "Normal Standard Deviation should be bigger than 0!" 60 | hypernet = Hypernetwork( 61 | name=name, 62 | enable_sizes=[int(x) for x in enable_sizes], 63 | layer_structure=layer_structure, 64 | activation_func=activation_func, 65 | weight_init=weight_init, 66 | add_layer_norm=add_layer_norm, 67 | use_dropout=use_dropout, 68 | dropout_structure=dropout_structure if use_dropout and dropout_structure else [0] * len(layer_structure), 69 | optional_info=optional_info, 70 | generation_seed=weight_init_seed if weight_init_seed != -1 else None, 71 | normal_std=normal_std, 72 | skip_connection=skip_connection 73 | ) 74 | hypernet.save(fn) 75 | 76 | shared.reload_hypernetworks() 77 | 78 | return name, f"Created: {fn}", "" 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hypernetwork-MonkeyPatch-Extension 2 | Extension that patches Hypernetwork structures and training 3 | ![image](https://user-images.githubusercontent.com/35677394/210898033-44da3cdb-a501-4cb3-a176-07ff8548d699.png) 4 | 5 | ![image](https://user-images.githubusercontent.com/35677394/203494809-9874c123-fca7-4d14-9995-63dc8772c920.png) 6 | 7 | For Hypernetwork structure, see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4334 8 | 9 | For Variable Dropout, see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4288 10 | 11 | 12 | ### Train_Beta(now, train_gamma) tab allows some more options with improved training. 13 | ![image](https://user-images.githubusercontent.com/35677394/203494907-68e0ef39-4d8c-42de-ba2e-65590375c435.png) 14 | 15 | ### Features 16 | 17 | **No-Crop Training** 18 | ![image](https://user-images.githubusercontent.com/35677394/203495373-cef04677-cdd6-43b0-ba42-d7c0f3d5a78f.png) 19 | You can train without cropping images. 20 | THis feature is now implemented in original webui too! :partying_face: 21 | 22 | **Fix OSError while training** 23 | 24 | **Unload Optimizer while generating Previews** 25 | 26 | **Tensorboard integration, and Tuning** 27 | 28 | **Residual-Block based Hypernetwork(in beta test)** 29 | 30 | 31 | ### Create_Beta_hypernetwork allows creating beta hypernetworks. 32 | 33 | Beta hypernetworks* can contain additional informations and specified dropout structures. It will be loaded without extension too, but it won't load dropout structure, so training won't work as original. Generating images should work identically. 34 | 35 | This extension also overrides how webui loads and finds hypernetworks, to use variable dropout rates, and etc. 36 | Thus, hypernetwork created with variable dropout rate might not work correctly in original webui. 37 | 38 | Well, at least now it should work, without having any problem except you cannot use variable dropout rate in original webui. If you have problem with loading hypernetworks, please create an issue. I can submit pr to original branch to load these beta typed hypernetworks correctly. 39 | 40 | ### Training features are in train_gamma tab. 41 | ![image](https://user-images.githubusercontent.com/35677394/204087550-94b8e7fb-70cb-4157-96bc-e022340901c9.png) 42 | 43 | If you're unsure about options, just enable every checkbox, and don't change default value. 44 | 45 | 46 | ### CosineAnnealingWarmupRestarts 47 | ![image](https://user-images.githubusercontent.com/35677394/204087530-b7938e7e-ebe5-4326-b5cd-25480645a11b.png) 48 | 49 | This also fixes some CUDA memory issues. Currently both Beta and Gamma Training is working very well, as far as I could say. 50 | 51 | 52 | ### Hyperparameter Tuning 53 | ![image](https://user-images.githubusercontent.com/35677394/212574147-22a32b03-6544-4aee-9ac7-fdefd2b7ee56.png) 54 | Now you can save hypernetwork generation / training setting, and load it in train_tuning tab. This will allow combination of hypernetwork structures, and training setups, to find best way for stuff. 55 | 56 | ### CLIP change test tab 57 | ![image](https://user-images.githubusercontent.com/35677394/212574217-3dd08007-e33f-4179-96e9-5a90bccd4907.png) 58 | Now you can select CLIP model, its difference is significant but whether its better or not is unknown. 59 | 60 | 61 | ## Residual hypernetwork? 62 | The concept of ResNet, returning x + f(x) instead in layers, are available with option. Original webui does not support this, so you cannot load it without extension. 63 | Unlike expanding type (1 -> 2 -> 1), shrinking type(1 -> 0.1 -> 1) network will lost information at initial phase. In this case, we need to additionally train transformation that compresses and decompresses it. This is currently only in code, its not offered in UI at default. 64 | 65 | ## D-Adaptation 66 | Currently D-Adaptation is available for hypernetwork training. You can use this with enabling advanced AdamW parameter option and checking the checkbox. 67 | Recommended LR is 1.0, only change it if its required. Other features are not tested with this feature. 68 | The code references to this: 69 | https://github.com/facebookresearch/dadaptation 70 | 71 | ### Planned features 72 | Training option loading and tuning for textual inversion 73 | 74 | D-Adaptation for textual inversion 75 | 76 | Adan and more optimizer options. 77 | 78 | D-Adaptation repository update matching 79 | 80 | 81 | ### Some personal researches 82 | 83 | We cannot apply convolution for attention, it does do something, but hypernetwork here, only affects attention, and its different from 'attention map' which is already a decoded form(image BW vectors) of attention(latent space). Same goes to SENet, unfortunately. 84 | -------------------------------------------------------------------------------- /patches/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CosineAnnealingWarmUpRestarts(_LRScheduler): 6 | # see https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup 7 | """ 8 | optimizer (Optimizer): Wrapped optimizer. 9 | first_cycle_steps (int): First cycle step size. 10 | cycle_mult(float): Cycle steps magnification. Default: -1. 11 | max_lr(float): First cycle's max learning rate. Default: 0.1. 12 | min_lr(float): Min learning rate. Default: 0.001. 13 | warmup_steps(int): Linear warmup step size. Default: 0. 14 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 15 | last_epoch (int): The index of last epoch. Default: -1. 16 | """ 17 | 18 | def __init__(self, 19 | optimizer, 20 | first_cycle_steps: int, 21 | cycle_mult: float = 1., 22 | max_lr: float = 0.1, 23 | min_lr: float = 0.001, 24 | warmup_steps: int = 0, 25 | gamma: float = 1., 26 | last_epoch: int = -1 27 | ): 28 | assert warmup_steps < first_cycle_steps 29 | 30 | self.first_cycle_steps = first_cycle_steps # first cycle step size 31 | self.cycle_mult = cycle_mult # cycle steps magnification 32 | self.base_max_lr = max_lr # first max learning rate 33 | self.max_lr = max_lr # max learning rate in the current cycle 34 | self.min_lr = min_lr # min learning rate 35 | self.warmup_steps = warmup_steps # warmup step size 36 | self.gamma = gamma # decrease rate of max learning rate by cycle 37 | 38 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 39 | self.cycle = 0 # cycle count 40 | self.step_in_cycle = last_epoch # step size of the current cycle 41 | 42 | super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch) 43 | 44 | # set learning rate min_lr 45 | self.init_lr() 46 | 47 | def init_lr(self): 48 | self.base_lrs = [] 49 | for param_group in self.optimizer.param_groups: 50 | param_group['lr'] = self.min_lr 51 | self.base_lrs.append(self.min_lr) 52 | 53 | def get_lr(self): 54 | if self.step_in_cycle == -1: 55 | return self.base_lrs 56 | elif self.step_in_cycle < self.warmup_steps: 57 | return [(self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in 58 | self.base_lrs] 59 | else: 60 | return [base_lr + (self.max_lr - base_lr) \ 61 | * (1 + math.cos(math.pi * (self.step_in_cycle - self.warmup_steps) / (self.cur_cycle_steps - self.warmup_steps))) / 2 62 | for base_lr in self.base_lrs] 63 | 64 | def step(self, epoch=None): 65 | if epoch is None: 66 | epoch = self.last_epoch + 1 67 | self.step_in_cycle = self.step_in_cycle + 1 68 | if self.step_in_cycle >= self.cur_cycle_steps: 69 | self.cycle += 1 70 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 71 | self.cur_cycle_steps = int( 72 | (self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 73 | else: 74 | if epoch >= self.first_cycle_steps: 75 | if self.cycle_mult == 1.: 76 | self.step_in_cycle = epoch % self.first_cycle_steps 77 | self.cycle = epoch // self.first_cycle_steps 78 | else: 79 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 80 | self.cycle = n 81 | self.step_in_cycle = epoch - int( 82 | self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 83 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 84 | else: 85 | self.cur_cycle_steps = self.first_cycle_steps 86 | self.step_in_cycle = epoch 87 | 88 | self.max_lr = self.base_max_lr * (self.gamma ** self.cycle) 89 | self.last_epoch = math.floor(epoch) 90 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 91 | param_group['lr'] = lr 92 | 93 | def is_EOC(self, epoch=None): 94 | saved_cycle = self.cycle 95 | expect_cycle = saved_cycle 96 | step_in_cycle_2 = self.step_in_cycle 97 | cur_cycle_step_2 = self.cur_cycle_steps 98 | if epoch is None: 99 | step_in_cycle_2 = step_in_cycle_2 + 1 100 | if step_in_cycle_2 >= cur_cycle_step_2: 101 | expect_cycle += 1 102 | else: 103 | if epoch >= self.first_cycle_steps: 104 | if self.cycle_mult == 1.: 105 | expect_cycle = epoch // self.first_cycle_steps 106 | else: 107 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 108 | expect_cycle = n 109 | ''' returns if current cycle is end of cycle''' 110 | return expect_cycle > saved_cycle 111 | -------------------------------------------------------------------------------- /patches/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | 5 | import PIL 6 | import torch 7 | import tqdm 8 | import numpy as np 9 | from PIL import Image 10 | from .hnutil import get_closest 11 | from torch.utils.data import Dataset 12 | from torchvision import transforms 13 | 14 | from modules import shared, devices 15 | from modules.textual_inversion.dataset import DatasetEntry, re_numbers_at_start 16 | 17 | 18 | class PersonalizedBase(Dataset): 19 | def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): 20 | re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None 21 | 22 | self.placeholder_token = placeholder_token 23 | 24 | self.batch_size = batch_size 25 | self.width = width 26 | self.height = height 27 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 28 | 29 | self.dataset = [] 30 | 31 | with open(template_file, "r") as file: 32 | lines = [x.strip() for x in file.readlines()] 33 | 34 | self.lines = lines 35 | 36 | assert data_root, 'dataset directory not specified' 37 | assert os.path.isdir(data_root), "Dataset directory doesn't exist" 38 | assert os.listdir(data_root), "Dataset directory is empty" 39 | 40 | cond_model = shared.sd_model.cond_stage_model 41 | 42 | self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] * batch_size 43 | print("Preparing dataset...") 44 | for path in tqdm.tqdm(self.image_paths): 45 | try: 46 | image = Image.open(path).convert('RGB') 47 | w, h = image.size 48 | r = max(1, w / self.width, h / self.height) # divide by this 49 | amp = min(self.width / w, self.height / h) # if amp < 1, then ignore, else, multiply. 50 | if amp > 1: 51 | w, h = w * amp, h * amp 52 | w, h = int(w/r), int(h/r) 53 | w, h = get_closest(w), get_closest(h) 54 | image = image.resize((w,h), PIL.Image.LANCZOS) 55 | 56 | except Exception: 57 | continue 58 | 59 | text_filename = os.path.splitext(path)[0] + ".txt" 60 | filename = os.path.basename(path) 61 | 62 | if os.path.exists(text_filename): 63 | with open(text_filename, "r", encoding="utf8") as file: 64 | filename_text = file.read() 65 | else: 66 | filename_text = os.path.splitext(filename)[0] 67 | filename_text = re.sub(re_numbers_at_start, '', filename_text) 68 | if re_word: 69 | tokens = re_word.findall(filename_text) 70 | filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) 71 | 72 | npimage = np.array(image).astype(np.uint8) 73 | npimage = (npimage / 127.5 - 1.0).astype(np.float32) 74 | 75 | torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) 76 | torchdata = torch.moveaxis(torchdata, 2, 0) 77 | 78 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() 79 | init_latent = init_latent.to(devices.cpu) 80 | 81 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) 82 | 83 | if include_cond: 84 | entry.cond_text = self.create_text(filename_text) 85 | entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) 86 | 87 | self.dataset.append(entry) 88 | 89 | assert len(self.dataset) > 0, "No images have been found in the dataset." 90 | self.length = len(self.dataset) * repeats // batch_size 91 | 92 | self.dataset_length = len(self.dataset) 93 | self.indexes = None 94 | self.random = np.random.default_rng(42) 95 | self.shuffle() 96 | 97 | def shuffle(self): 98 | self.indexes = self.random.permutation(self.dataset_length) 99 | 100 | def create_text(self, filename_text): 101 | text = random.choice(self.lines) 102 | text = text.replace("[name]", self.placeholder_token) 103 | tags = filename_text.split(',') 104 | if shared.opts.tag_drop_out != 0: 105 | tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] 106 | if shared.opts.shuffle_tags: 107 | random.shuffle(tags) 108 | text = text.replace("[filewords]", ','.join(tags)) 109 | return text 110 | 111 | def __len__(self): 112 | return self.length 113 | 114 | def __getitem__(self, i): 115 | res = [] 116 | 117 | for j in range(self.batch_size): 118 | position = i * self.batch_size + j 119 | if position % len(self.indexes) == 0: 120 | self.shuffle() 121 | 122 | index = self.indexes[position % len(self.indexes)] 123 | entry = self.dataset[index] 124 | 125 | if entry.cond is None: 126 | entry.cond_text = self.create_text(entry.filename_text) 127 | 128 | res.append(entry) 129 | 130 | return res -------------------------------------------------------------------------------- /scripts/hypernetwork-extensions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modules.call_queue import wrap_gradio_call 4 | from modules.hypernetworks.ui import keys 5 | import modules.scripts as scripts 6 | from modules import script_callbacks, shared 7 | import gradio as gr 8 | 9 | from modules.ui import gr_show 10 | import patches.clip_hijack as clip_hijack 11 | import patches.textual_inversion as textual_inversion 12 | import patches.ui as ui 13 | import patches.shared as shared_patch 14 | import patches.external_pr.ui as external_patch_ui 15 | 16 | setattr(shared.opts,'pin_memory', False) 17 | 18 | def create_extension_tab(params=None): 19 | with gr.Tab(label="Create Beta hypernetwork") as create_beta: 20 | new_hypernetwork_name = gr.Textbox(label="Name") 21 | new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1024", "1280"], 22 | choices=["768", "320", "640", "1024", "1280"]) 23 | new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", 24 | placeholder="1st and last digit must be 1. ex:'1, 2, 1'") 25 | new_hypernetwork_activation_func = gr.Dropdown(value="linear", 26 | label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", 27 | choices=keys) 28 | new_hypernetwork_initialization_option = gr.Dropdown(value="Normal", 29 | label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", 30 | choices=["Normal", "KaimingUniform", "KaimingNormal", 31 | "XavierUniform", "XavierNormal"]) 32 | show_additional_options = gr.Checkbox( 33 | label='Show advanced options') 34 | with gr.Row(visible=False) as weight_options: 35 | generation_seed = gr.Number(label='Weight initialization seed, set -1 for default', value=-1, precision=0) 36 | normal_std = gr.Textbox(label="Standard Deviation for Normal weight initialization", placeholder="must be positive float", value="0.01") 37 | show_additional_options.change( 38 | fn=lambda show: gr_show(show), 39 | inputs=[show_additional_options], 40 | outputs=[weight_options],) 41 | new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") 42 | new_hypernetwork_use_dropout = gr.Checkbox( 43 | label="Use dropout. Might improve training when dataset is small / limited.") 44 | new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", 45 | label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", 46 | placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") 47 | skip_connection = gr.Checkbox(label="Use skip-connection. Won't work without extension!") 48 | optional_info = gr.Textbox("", label="Optional information about Hypernetwork", placeholder="Training information, dateset, etc") 49 | overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") 50 | 51 | with gr.Row(): 52 | with gr.Column(scale=3): 53 | gr.HTML(value="") 54 | 55 | with gr.Column(): 56 | create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') 57 | setting_name = gr.Textbox(label="Setting file name", value="") 58 | save_setting = gr.Button(value="Save hypernetwork setting to file") 59 | ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False) 60 | ti_outcome = gr.HTML(elem_id="ti_error2", value="") 61 | 62 | 63 | 64 | save_setting.click( 65 | fn=wrap_gradio_call(external_patch_ui.save_hypernetwork_setting), 66 | inputs=[ 67 | setting_name, 68 | new_hypernetwork_sizes, 69 | overwrite_old_hypernetwork, 70 | new_hypernetwork_layer_structure, 71 | new_hypernetwork_activation_func, 72 | new_hypernetwork_initialization_option, 73 | new_hypernetwork_add_layer_norm, 74 | new_hypernetwork_use_dropout, 75 | new_hypernetwork_dropout_structure, 76 | optional_info, 77 | generation_seed if generation_seed.visible else None, 78 | normal_std if normal_std.visible else 0.01, 79 | skip_connection], 80 | outputs=[ 81 | ti_output, 82 | ti_outcome, 83 | ] 84 | ) 85 | create_hypernetwork.click( 86 | fn=ui.create_hypernetwork, 87 | inputs=[ 88 | new_hypernetwork_name, 89 | new_hypernetwork_sizes, 90 | overwrite_old_hypernetwork, 91 | new_hypernetwork_layer_structure, 92 | new_hypernetwork_activation_func, 93 | new_hypernetwork_initialization_option, 94 | new_hypernetwork_add_layer_norm, 95 | new_hypernetwork_use_dropout, 96 | new_hypernetwork_dropout_structure, 97 | optional_info, 98 | generation_seed if generation_seed.visible else None, 99 | normal_std if normal_std.visible else 0.01, 100 | skip_connection 101 | ], 102 | outputs=[ 103 | new_hypernetwork_name, 104 | ti_output, 105 | ti_outcome, 106 | ] 107 | ) 108 | return [(create_beta, "Create_beta", "create_beta")] 109 | 110 | 111 | def create_extension_tab2(params=None): 112 | with gr.Blocks(analytics_enabled=False) as CLIP_test_interface: 113 | with gr.Tab(label="CLIP-test") as clip_test: 114 | with gr.Row(): 115 | clipTextModelPath = gr.Textbox("openai/clip-vit-large-patch14", label="CLIP Text models. Set to empty to not change.") 116 | # see https://huggingface.co/openai/clip-vit-large-patch14 and related pages to find model. 117 | change_model = gr.Checkbox(label="Enable clip model change. This will be triggered from next model changes.") 118 | change_model.change( 119 | fn=clip_hijack.trigger_sd_hijack, 120 | inputs=[ 121 | change_model, 122 | clipTextModelPath 123 | ], 124 | outputs=[] 125 | ) 126 | with gr.Row(): 127 | def track_vram_usage(*args): 128 | import torch 129 | import gc 130 | torch.cuda.empty_cache() 131 | gc.collect() 132 | for obj in gc.get_objects(): 133 | try: 134 | if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): 135 | if obj.is_cuda: 136 | print(type(obj), obj.size()) 137 | except: pass 138 | track_vram_usage_button = gr.Button(value="Track VRAM usage") 139 | track_vram_usage_button.click( 140 | fn = track_vram_usage, 141 | inputs=[], 142 | outputs=[] 143 | ) 144 | return [(CLIP_test_interface, "CLIP_test", "clip_test")] 145 | 146 | def on_ui_settings(): 147 | shared.opts.add_option("disable_ema", 148 | shared.OptionInfo(False, "Detach grad from conditioning models", 149 | section=('training', "Training"))) 150 | if not hasattr(shared.opts, 'training_enable_tensorboard'): 151 | shared.opts.add_option("training_enable_tensorboard", 152 | shared.OptionInfo(False, "Enable tensorboard logging", 153 | section=('training', "Training"))) 154 | 155 | #script_callbacks.on_ui_train_tabs(create_training_tab) # Deprecate Beta Training 156 | script_callbacks.on_ui_train_tabs(create_extension_tab) 157 | script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_gamma_tab) 158 | script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_tuning) 159 | script_callbacks.on_ui_tabs(create_extension_tab2) 160 | script_callbacks.on_ui_settings(on_ui_settings) 161 | class Script(scripts.Script): 162 | def title(self): 163 | return "Hypernetwork Monkey Patch" 164 | 165 | def show(self, _): 166 | return scripts.AlwaysVisible 167 | -------------------------------------------------------------------------------- /patches/external_pr/dataset.py: -------------------------------------------------------------------------------- 1 | # source:https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4886/files 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset, DataLoader, Sampler 11 | from torchvision import transforms 12 | 13 | from ..hnutil import get_closest 14 | from collections import defaultdict 15 | from random import Random 16 | import tqdm 17 | from modules import devices, shared 18 | import re 19 | 20 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 21 | 22 | re_numbers_at_start = re.compile(r"^[-\d]+\s*") 23 | 24 | random_state_manager = Random(None) 25 | shuffle = random_state_manager.shuffle 26 | choice = random_state_manager.choice 27 | choices = random_state_manager.choices 28 | randrange = random_state_manager.randrange 29 | 30 | 31 | def set_rng(seed=None): 32 | random_state_manager.seed(seed) 33 | 34 | 35 | class DatasetEntry: 36 | def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, 37 | cond_text=None, pixel_values=None, weight=None): 38 | self.filename = filename 39 | self.filename_text = filename_text 40 | self.latent_dist = latent_dist 41 | self.latent_sample = latent_sample 42 | self.cond = cond 43 | self.cond_text = cond_text 44 | self.pixel_values = pixel_values 45 | self.weight = weight 46 | 47 | 48 | class PersonalizedBase(Dataset): 49 | def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, 50 | cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, 51 | shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1, manual_seed=-1, use_weight=False): 52 | re_word = re.compile(shared.opts.dataset_filename_word_regex) if len( 53 | shared.opts.dataset_filename_word_regex) > 0 else None 54 | if manual_seed == -1: 55 | seed = randrange(sys.maxsize) 56 | set_rng(seed) # reset forked RNG state when we create dataset. 57 | print(f"Dataset seed was set to f{seed}") 58 | else: 59 | set_rng(manual_seed) 60 | print(f"Dataset seed was set to f{manual_seed}") 61 | self.placeholder_token = placeholder_token 62 | 63 | self.width = width 64 | self.height = height 65 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 66 | 67 | self.dataset = [] 68 | 69 | with open(template_file, "r") as file: 70 | lines = [x.strip() for x in file.readlines()] 71 | 72 | self.lines = lines 73 | 74 | assert data_root, 'dataset directory not specified' 75 | assert os.path.isdir(data_root), "Dataset directory doesn't exist" 76 | assert os.listdir(data_root), "Dataset directory is empty" 77 | 78 | self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] # We assert batch size > 1 can work, by having multiple same-size images 79 | # But note that we can't stack tensors with other size. so it's not working now. 80 | self.shuffle_tags = shuffle_tags 81 | self.tag_drop_out = tag_drop_out 82 | groups = defaultdict(list) 83 | 84 | print("Preparing dataset...") 85 | _i = 0 86 | for path in tqdm.tqdm(self.image_paths): 87 | if shared.state.interrupted: 88 | raise Exception("inturrupted") 89 | try: # apply variable size here 90 | image = Image.open(path).convert('RGB') 91 | w, h = image.size 92 | r = max(1, w / self.width, h / self.height) # divide by this 93 | amp = min(self.width / w, self.height / h) # if amp < 1, then ignore, else, multiply. 94 | if amp > 1: 95 | w, h = w * amp, h * amp 96 | w, h = int(w/r), int(h/r) 97 | w, h = get_closest(w), get_closest(h) 98 | image = image.resize((w,h), PIL.Image.LANCZOS) 99 | except Exception: 100 | continue 101 | 102 | text_filename = os.path.splitext(path)[0] + ".txt" 103 | filename = os.path.basename(path) 104 | 105 | if os.path.exists(text_filename): 106 | with open(text_filename, "r", encoding="utf8") as file: 107 | filename_text = file.read() 108 | else: 109 | filename_text = os.path.splitext(filename)[0] 110 | filename_text = re.sub(re_numbers_at_start, '', filename_text) 111 | if re_word: 112 | tokens = re_word.findall(filename_text) 113 | filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) 114 | 115 | npimage = np.array(image).astype(np.uint8) 116 | npimage = (npimage / 127.5 - 1.0).astype(np.float32) 117 | 118 | torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) 119 | 120 | with torch.autocast("cuda"): 121 | latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) 122 | latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) 123 | weight = torch.ones_like(latent_sample) 124 | if latent_sampling_method == "once" or ( 125 | latent_sampling_method == "deterministic" and not isinstance(latent_dist, 126 | DiagonalGaussianDistribution)): 127 | latent_sampling_method = "once" 128 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) 129 | elif latent_sampling_method == "deterministic": 130 | # Works only for DiagonalGaussianDistribution 131 | latent_dist.std = 0 132 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) 133 | elif latent_sampling_method == "random": 134 | if latent_sampling_std != -1: 135 | assert latent_sampling_std > 0, f"Cannnot apply negative standard deviation {latent_sampling_std}" 136 | print(f"Applying patch, clipping std from {torch.max(latent_dist.std).item()} to {latent_sampling_std}...") 137 | latent_dist.std.clip_(latent_sampling_std) 138 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) 139 | else: 140 | raise RuntimeError("Entry was undefined because of undefined latent sampling method!") 141 | alpha_channel = None 142 | if use_weight and 'A' in image.getbands(): 143 | alpha_channel = image.getchannel('A') 144 | if use_weight and alpha_channel is not None: 145 | channels, *latent_size = latent_sample.shape 146 | weight_img = alpha_channel.resize(latent_size) 147 | npweight = np.array(weight_img).astype(np.float32) 148 | #Repeat for every channel in the latent sample 149 | weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) 150 | #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. 151 | weight -= weight.min() 152 | weight /= weight.mean() 153 | elif use_weight: 154 | #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later 155 | weight = torch.ones_like(latent_sample) 156 | entry.weight = weight 157 | if not (self.tag_drop_out != 0 or self.shuffle_tags): 158 | entry.cond_text = self.create_text(filename_text) 159 | 160 | if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): 161 | with torch.autocast("cuda"): 162 | entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) 163 | groups[image.size].append(_i) #record indexes of images in dataset into group. When we pull batch, try using single group to make torch.stack work. 164 | _i += 1 165 | self.dataset.append(entry) 166 | del torchdata 167 | del latent_dist 168 | del latent_sample 169 | self.groups = list(groups.values()) 170 | self.length = len(self.dataset) 171 | assert self.length > 0, "No images have been found in the dataset." 172 | self.batch_size = min(batch_size, self.length) 173 | self.gradient_step = min(gradient_step, self.length // self.batch_size) 174 | self.latent_sampling_method = latent_sampling_method 175 | 176 | def create_text(self, filename_text): 177 | text = choice(self.lines) 178 | tags = filename_text.split(',') 179 | if self.tag_drop_out != 0: 180 | tags = [t for t in tags if random_state_manager.random() > self.tag_drop_out] 181 | if self.shuffle_tags: 182 | shuffle(tags) 183 | text = text.replace("[filewords]", ','.join(tags)) 184 | text = text.replace("[name]", self.placeholder_token) 185 | return text 186 | 187 | def __len__(self): 188 | return self.length 189 | 190 | def __getitem__(self, i): 191 | entry = self.dataset[i] 192 | if self.tag_drop_out != 0 or self.shuffle_tags: 193 | entry.cond_text = self.create_text(entry.filename_text) 194 | if self.latent_sampling_method == "random": 195 | entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) 196 | if entry.weight is None: 197 | entry.weight = torch.ones_like(entry.latent_sample) 198 | return entry 199 | 200 | class GroupedBatchSampler(Sampler): 201 | # See https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6620 202 | def __init__(self, data_source: PersonalizedBase, batch_size: int): 203 | n = len(data_source) 204 | self.groups = data_source.groups 205 | self.len = n_batch = n // batch_size 206 | expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] 207 | self.base = [int(e) // batch_size for e in expected] 208 | self.n_rand_batches = n_batch - sum(self.base) 209 | self.probs = [e % batch_size/self.n_rand_batches/batch_size if self.n_rand_batches > 0 else 0 for e in expected] 210 | self.batch_size = batch_size 211 | 212 | 213 | def __len__(self): 214 | return self.len 215 | 216 | def __iter__(self): 217 | b = self.batch_size 218 | batches = [] 219 | for g in self.groups: 220 | shuffle(g) 221 | batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) 222 | for _ in range(self.n_rand_batches): 223 | rand_group = choices(self.groups, self.probs)[0] 224 | batches.append(choices(rand_group, k=b)) 225 | shuffle(batches) 226 | yield from batches 227 | 228 | class PersonalizedDataLoader(DataLoader): 229 | def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): 230 | super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) 231 | if latent_sampling_method == "random": 232 | self.collate_fn = collate_wrapper_random 233 | else: 234 | self.collate_fn = collate_wrapper 235 | 236 | 237 | class BatchLoader: 238 | def __init__(self, data): 239 | self.cond_text = [entry.cond_text for entry in data] 240 | self.cond = [entry.cond for entry in data] 241 | self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) 242 | self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) 243 | self.filename = [entry.filename for entry in data] 244 | # self.emb_index = [entry.emb_index for entry in data] 245 | # print(self.latent_sample.device) 246 | 247 | def pin_memory(self): 248 | self.latent_sample = self.latent_sample.pin_memory() 249 | return self 250 | 251 | 252 | def collate_wrapper(batch): 253 | return BatchLoader(batch) 254 | 255 | 256 | class BatchLoaderRandom(BatchLoader): 257 | def __init__(self, data): 258 | super().__init__(data) 259 | 260 | def pin_memory(self): 261 | return self 262 | 263 | 264 | def collate_wrapper_random(batch): 265 | return BatchLoaderRandom(batch) 266 | -------------------------------------------------------------------------------- /patches/hypernetworks.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os.path 3 | 4 | import torch 5 | 6 | from modules import devices, shared 7 | from .hnutil import find_self 8 | from .shared import version_flag 9 | 10 | lazy_load = False # when this is enabled, HNs will be loaded when required. 11 | if not hasattr(devices, 'cond_cast_unet'): 12 | raise RuntimeError("Cannot find cond_cast_unet attribute, please update your webui version!") 13 | 14 | 15 | class DynamicDict(dict): # Brief dict that dynamically unloads Hypernetworks if required. 16 | def __init__(self, **kwargs): 17 | super().__init__(**kwargs) 18 | self.current = None 19 | self.hash = None 20 | self.dict = {**kwargs} 21 | 22 | def prepare(self, key, value): 23 | if lazy_load and self.current is not None and ( 24 | key != self.current): # or filename is identical, but somehow hash is changed? 25 | self.current.to('cpu') 26 | self.current = value 27 | if self.current is not None: 28 | self.current.to(devices.device) 29 | 30 | def __getitem__(self, item): 31 | value = self.dict[item] 32 | self.prepare(item, value) 33 | return value 34 | 35 | def __setitem__(self, key, value): 36 | if key in self.dict: 37 | return 38 | self.dict[key] = value 39 | 40 | def __contains__(self, item): 41 | return item in self.dict 42 | 43 | 44 | available_opts = DynamicDict() # string -> HN itself. 45 | 46 | 47 | # Behavior definition. 48 | # [[], [], []] -> sequential processing 49 | # [{"A" : 0.8, "B" : 0.1}] -> parallel processing. with weighted sum in this case, A = 8/9 effect, B = 1/9 effect. 50 | # [("A", 0.2), ("B", 0.4)] -> tuple is used to specify strength. 51 | # [{"A", "B", "C"}] -> parallel, but having same effects (set) 52 | # ["A", "B", []] -> sequential processing 53 | # [{"A":0.6}, "B", "C"] -> sequential, dict with single value will be considered as strength modification. 54 | # [["A"], {"B"}, "C"] -> singletons are equal to items without covers, nested singleton will not be parsed, because its inefficient. 55 | # {{'Aa' : 0.2, 'Ab' : 0.8} : 0.8, 'B' : 0.1} (X) -> {"{'Aa' : 0.2, 'Ab' : 0.8}" : 0.8, 'B' : 0.1} (O), When you want complex setups in parallel, you need to cover them with "". You can use backslash too. 56 | 57 | 58 | # Testing parsing function. 59 | 60 | def test_parsing(string=None): 61 | def test(arg): 62 | print(arg) 63 | try: 64 | obj = str(Forward.parse(arg)) 65 | print(obj) 66 | except Exception as e: 67 | print(e) 68 | 69 | if string: 70 | test(string) 71 | else: 72 | for strings in ["[[], [], []]", "[{\"A\" : 0.8, \"B\" : 0.1}]", '[("A", 0.2), ("B", 0.4)]', '[{"A", "B", "C"}]', 73 | '[{"A":0.6}, "B", "C"]', '[["A"], {"B"}, "C"]', 74 | '{"{\'Aa\' : 0.2, \'Ab\' : 0.8}" : 0.8, \'B\' : 0.1}']: 75 | test(strings) 76 | 77 | 78 | class Forward: 79 | def __init__(self, **kwargs): 80 | self.name = "defaultForward" if 'name' not in kwargs else kwargs['name'] 81 | pass 82 | 83 | def __call__(self, *args, **kwargs): 84 | raise NotImplementedError 85 | 86 | def set_multiplier(self, *args, **kwargs): 87 | pass 88 | 89 | def extra_name(self): 90 | if version_flag: 91 | return "" 92 | found = find_self(self) 93 | if found is not None: 94 | return f" " 95 | return f" " 96 | 97 | @staticmethod 98 | def parse(arg, name=None): 99 | arg = Forward.unpack(arg) 100 | arg = Forward.eval(arg) 101 | if Forward.isSingleTon(arg): 102 | return SingularForward(*Forward.parseSingleTon(arg)) 103 | elif Forward.isParallel(arg): 104 | return ParallelForward(Forward.parseParallel(arg), name=name) 105 | elif Forward.isSequential(arg): 106 | return SequentialForward(Forward.parseSequential(arg), name=name) 107 | raise ValueError(f"Cannot parse {arg} into sequences!") 108 | 109 | @staticmethod 110 | def unpack(arg): # stop using ({({{((a))}})}) please 111 | if len(arg) == 1 and type(arg) in (set, list, tuple): 112 | return Forward.unpack(list(arg)[0]) 113 | if len(arg) == 1 and type(arg) is dict: 114 | key = list(arg.keys())[0] 115 | if arg[key] == 1: 116 | return Forward.unpack(key) 117 | return arg 118 | 119 | @staticmethod 120 | def eval(arg): # from "{something}", parse as etc form. 121 | if arg is None: 122 | raise ValueError("None cannot be evaluated!") 123 | try: 124 | newarg = ast.literal_eval(arg) 125 | if type(arg) is str and arg.startswith(("{", "[", "(")) and newarg is not None: 126 | if not newarg: 127 | raise RuntimeError(f"Cannot eval false object {arg}!") 128 | return newarg 129 | except ValueError: 130 | return arg 131 | return arg 132 | 133 | @staticmethod 134 | def isSingleTon( 135 | arg): # Very strict. This applies strength to HN, which cannot happen in combined networks. Only weighting is allowed in complex process. 136 | if type(arg) is str and not arg.startswith(('[', '(', '{')): # Strict. only accept str 137 | return True 138 | elif type( 139 | arg) is dict: # Strict. only accept {str : int/float} - Strength modification can only happen for str. 140 | return len(arg) == 1 and all(type(value) in (int, float) for value in arg.values()) and all( 141 | type(k) is str for k in arg) 142 | elif type(arg) in (list, set): 143 | return len(arg) == 1 and all(type(x) is str for x in arg) 144 | elif type(arg) is tuple: 145 | return len(arg) == 2 and type(arg[0]) is str and type(arg[1]) in (int, float) 146 | return False 147 | 148 | @staticmethod 149 | def parseSingleTon(sequence): # accepts sequence, returns str, float pair. This is Strict. 150 | if type(sequence) in (list, dict, set): 151 | assert len(sequence) == 1, f"SingularForward only accepts singletons, but given {sequence}!" 152 | key = list(sequence)[0] 153 | if type(sequence) is dict: 154 | assert type(key) is str, f"Strength modification only accepts single Hypernetwork, but given {key}!" 155 | return key, sequence[key] 156 | else: 157 | key = list(key)[0] 158 | return key, 1 159 | elif type(sequence) is tuple: 160 | assert len(sequence) == 2, f"Tuple with non-couple {sequence} encountered in SingularForward!" 161 | assert type( 162 | sequence[0]) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence[0]}!" 163 | assert type(sequence[1]) in (int, float), f"Strength tuple only accepts Numbers, but given {sequence[1]}!" 164 | return sequence[0], sequence[1] 165 | else: 166 | assert type( 167 | sequence) is str, f"Strength modification only accepts single Hypernetwork, but given {sequence}!" 168 | return sequence, 1 169 | 170 | @staticmethod 171 | def isParallel( 172 | arg): # Parallel, or Sequential processing is not strict, it can have {"String covered sequence or just HN String" : weight, ... 173 | if type(arg) in (dict, set) and len(arg) > 1: 174 | if type(arg) is set: 175 | return all(type(key) is str for key in 176 | arg), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" 177 | else: 178 | arg: dict 179 | return all(type(key) is str for key in 180 | arg.keys()), f"All keys should be Hypernetwork Name/Sequence for Set but given :{arg}" 181 | else: 182 | return False 183 | 184 | @staticmethod 185 | def parseParallel(sequence): # accepts sequence, returns {"Name or sequence" : weight...} 186 | assert len(sequence) > 1, f"Length of sequence {sequence} was not enough for parallel!" 187 | if type(sequence) is set: # only allows hashable types. otherwise it should be supplied as string cover 188 | assert all(type(key) in (str, tuple) for key in 189 | sequence), f"All keys should be Hypernetwork Name/Sequence for Set but given :{sequence}" 190 | return {key: 1 / len(sequence) for key in sequence} 191 | elif type(sequence) is dict: 192 | assert all(type(key) in (str, tuple) for key in 193 | sequence.keys()), f"All keys should be Hypernetwork Name/Sequence for Dict but given :{sequence}" 194 | assert all(type(value) in (int, float) for value in 195 | sequence.values()), f"All values should be int/float for Dict but given :{sequence}" 196 | return sequence 197 | else: 198 | raise ValueError(f"Cannot parse parallel sequence {sequence}!") 199 | 200 | @staticmethod 201 | def isSequential(arg): 202 | if type(arg) is list and len(arg) > 0: 203 | return True 204 | return False 205 | 206 | @staticmethod 207 | def parseSequential(sequence): # accepts sequence, only checks if its list, then returns sequence. 208 | if type(sequence) is list and len(sequence) > 0: 209 | return sequence 210 | else: 211 | raise ValueError(f"Cannot parse non-list sequence {sequence}!") 212 | 213 | def shorthash(self): 214 | return '0000000000' 215 | 216 | from .hypernetwork import Hypernetwork 217 | 218 | 219 | def find_non_hash_key(target): 220 | closest = [x for x in shared.hypernetworks if x.rsplit('(', 1)[0] == target or x == target] 221 | if closest: 222 | return shared.hypernetworks[closest[0]] 223 | raise KeyError(f"{target} is not found in Hypernetworks!") 224 | 225 | 226 | class SingularForward(Forward): 227 | 228 | def __init__(self, processor, strength): 229 | assert processor != 'defaultForward', "Cannot use name defaultForward!" 230 | super(SingularForward, self).__init__() 231 | self.name = processor 232 | self.processor = processor 233 | self.strength = strength 234 | # parse. We expect parsing Singletons or (k,v) pair here, which is HN Name and Strength. 235 | hn = Hypernetwork() 236 | try: 237 | hn.load(find_non_hash_key(self.processor)) 238 | except: 239 | global lazy_load 240 | lazy_load = True 241 | print("Encountered CUDA Memory Error, will unload HNs, speed might go down severely!") 242 | hn.load(find_non_hash_key(self.processor)) 243 | available_opts[self.processor] = hn 244 | # assert self.processor in available_opts, f"Hypernetwork named {processor} is not ready!" 245 | assert 0 <= self.strength <= 1, "Strength must be between 0 and 1!" 246 | print(f"SingularForward <{self.name}, {self.strength}>") 247 | 248 | def __call__(self, context_k, context_v=None, layer=None): 249 | if self.processor in available_opts: 250 | context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None) 251 | if context_v is None: 252 | context_v = context_k 253 | if context_layers is None: 254 | return context_k, context_v 255 | #if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): 256 | # layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1] 257 | return devices.cond_cast_unet(context_layers[0](devices.cond_cast_float(context_k), multiplier=self.strength)),\ 258 | devices.cond_cast_unet(context_layers[1](devices.cond_cast_float(context_v), multiplier=self.strength)) 259 | # define forward_strength, which invokes HNModule with specified strength. 260 | # Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure. 261 | raise KeyError(f"Key {self.processor} is not found in cached Hypernetworks!") 262 | 263 | def __str__(self): 264 | return "SingularForward>" + str(self.processor) 265 | 266 | 267 | class ParallelForward(Forward): 268 | 269 | def __init__(self, sequence, name=None): 270 | self.name = "ParallelForwardHypernet" if name is None else name 271 | self.callers = {} 272 | self.weights = {} 273 | super(ParallelForward, self).__init__() 274 | # parse 275 | for keys in sequence: 276 | self.callers[keys] = Forward.parse(keys) 277 | self.weights[keys] = sequence[keys] / sum(sequence.values()) 278 | print(str(self)) 279 | 280 | def __call__(self, context, context_v=None, layer=None): 281 | ctx_k, ctx_v = torch.zeros_like(context, device=context.device), torch.zeros_like(context, 282 | device=context.device) 283 | for key in self.callers: 284 | k, v = self.callers[key](context, context_v, layer=layer) 285 | ctx_k += k * self.weights[key] 286 | ctx_v += v * self.weights[key] 287 | return ctx_k, ctx_v 288 | 289 | def __str__(self): 290 | return "ParallelForward>" + str({str(k): str(v) for (k, v) in self.callers.items()}) 291 | 292 | 293 | class SequentialForward(Forward): 294 | def __init__(self, sequence, name=None): 295 | self.name = "SequentialForwardHypernet" if name is None else name 296 | self.callers = [] 297 | super(SequentialForward, self).__init__() 298 | for keys in sequence: 299 | self.callers.append(Forward.parse(keys)) 300 | print(str(self)) 301 | 302 | def __call__(self, context, context_v=None, layer=None): 303 | if context_v is None: 304 | context_v = context 305 | for keys in self.callers: 306 | context, context_v = keys(context, context_v, layer=layer) 307 | return context, context_v 308 | 309 | def __str__(self): 310 | return "SequentialForward>" + str([str(x) for x in self.callers]) 311 | 312 | 313 | class EmptyForward(Forward): 314 | def __init__(self): 315 | super().__init__() 316 | self.name = None 317 | 318 | def __call__(self, context, context_v=None, layer=None): 319 | if context_v is None: 320 | context_v = context 321 | return context, context_v 322 | 323 | def __str__(self): 324 | return "EmptyForward" 325 | 326 | 327 | def load(filename): 328 | with open(filename, 'r') as file: 329 | return Forward.parse(file.read(), name=os.path.basename(filename)) 330 | -------------------------------------------------------------------------------- /patches/external_pr/ui.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import html 3 | import json 4 | import os 5 | import random 6 | 7 | from modules import shared, sd_hijack, devices 8 | from modules.call_queue import wrap_gradio_call, wrap_gradio_gpu_call 9 | from modules.paths import script_path 10 | from modules.ui import create_refresh_button, gr_show 11 | from .textual_inversion import train_embedding as train_embedding_external 12 | from .hypernetwork import train_hypernetwork as train_hypernetwork_external, train_hypernetwork_tuning 13 | import gradio as gr 14 | 15 | 16 | def train_hypernetwork_ui(*args): 17 | initial_hypernetwork = None 18 | if hasattr(shared, 'loaded_hypernetwork'): 19 | initial_hypernetwork = shared.loaded_hypernetwork 20 | else: 21 | shared.loaded_hypernetworks = [] 22 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 23 | 24 | try: 25 | sd_hijack.undo_optimizations() 26 | 27 | hypernetwork, filename = train_hypernetwork_external(*args) 28 | 29 | res = f""" 30 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. 31 | Hypernetwork saved to {html.escape(filename)} 32 | """ 33 | return res, "" 34 | except Exception: 35 | raise 36 | finally: 37 | if hasattr(shared, 'loaded_hypernetwork'): 38 | shared.loaded_hypernetwork = initial_hypernetwork 39 | else: 40 | shared.loaded_hypernetworks = [] 41 | # check hypernetwork is bounded then delete it 42 | if locals().get('hypernetwork', None) is not None: 43 | del hypernetwork 44 | gc.collect() 45 | shared.sd_model.cond_stage_model.to(devices.device) 46 | shared.sd_model.first_stage_model.to(devices.device) 47 | sd_hijack.apply_optimizations() 48 | 49 | 50 | def train_hypernetwork_ui_tuning(*args): 51 | initial_hypernetwork = None 52 | if hasattr(shared, 'loaded_hypernetwork'): 53 | initial_hypernetwork = shared.loaded_hypernetwork 54 | else: 55 | shared.loaded_hypernetworks = [] 56 | 57 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 58 | 59 | try: 60 | sd_hijack.undo_optimizations() 61 | 62 | train_hypernetwork_tuning(*args) 63 | 64 | res = f""" 65 | Training {'interrupted' if shared.state.interrupted else 'finished'}. 66 | """ 67 | return res, "" 68 | except Exception: 69 | raise 70 | finally: 71 | if hasattr(shared, 'loaded_hypernetwork'): 72 | shared.loaded_hypernetwork = initial_hypernetwork 73 | else: 74 | shared.loaded_hypernetworks = [] 75 | shared.sd_model.cond_stage_model.to(devices.device) 76 | shared.sd_model.first_stage_model.to(devices.device) 77 | sd_hijack.apply_optimizations() 78 | 79 | 80 | def save_training_setting(*args): 81 | save_file_name, learn_rate, batch_size, gradient_step, training_width, \ 82 | training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, \ 83 | template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \ 84 | gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \ 85 | adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \ 86 | gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\ 87 | noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle, loss_opt, use_dadaptation, dadapt_growth_factor, use_weight = args 88 | dumped_locals = locals() 89 | dumped_locals.pop('args') 90 | filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json' 91 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename) 92 | with open(filename, 'w') as file: 93 | print(dumped_locals) 94 | json.dump(dumped_locals, file) 95 | print(f"File saved as {filename}") 96 | return filename, "" 97 | 98 | 99 | def save_hypernetwork_setting(*args): 100 | save_file_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure, optional_info, weight_init_seed, normal_std, skip_connection = args 101 | dumped_locals = locals() 102 | dumped_locals.pop('args') 103 | filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_hypernetwork_' + '.json' 104 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename) 105 | with open(filename, 'w') as file: 106 | print(dumped_locals) 107 | json.dump(dumped_locals, file) 108 | print(f"File saved as {filename}") 109 | return filename, "" 110 | 111 | 112 | def on_train_gamma_tab(params=None): 113 | dummy_component = gr.Label(visible=False) 114 | with gr.Tab(label="Train_Gamma") as train_gamma: 115 | gr.HTML( 116 | value="

Train an embedding or Hypernetwork; you must specify a directory [wiki]

") 117 | with gr.Row(): 118 | train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted( 119 | sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) 120 | create_refresh_button(train_embedding_name, 121 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: { 122 | "choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, 123 | "refresh_train_embedding_name") 124 | with gr.Row(): 125 | train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", 126 | choices=[x for x in shared.hypernetworks.keys()]) 127 | create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, 128 | lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, 129 | "refresh_train_hypernetwork_name") 130 | with gr.Row(): 131 | embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', 132 | placeholder="Embedding Learning rate", value="0.005") 133 | hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', 134 | placeholder="Hypernetwork Learning rate", value="0.00004") 135 | use_beta_scheduler_checkbox = gr.Checkbox( 136 | label='Show advanced learn rate scheduler options') 137 | use_beta_adamW_checkbox = gr.Checkbox( 138 | label='Show advanced adamW parameter options)') 139 | show_gradient_clip_checkbox = gr.Checkbox( 140 | label='Show Gradient Clipping Options(for both)') 141 | show_noise_options = gr.Checkbox( 142 | label='Show Noise Scheduler Options(for both)') 143 | with gr.Row(visible=False) as adamW_options: 144 | use_dadaptation = gr.Checkbox(label="Uses D-Adaptation(LR Free) AdamW. Recommended LR is 1.0 at base") 145 | adamw_weight_decay = gr.Textbox(label="AdamW weight decay parameter", placeholder="default = 0.01", 146 | value="0.01") 147 | adamw_beta_1 = gr.Textbox(label="AdamW beta1 parameter", placeholder="default = 0.9", value="0.9") 148 | adamw_beta_2 = gr.Textbox(label="AdamW beta2 parameter", placeholder="default = 0.99", value="0.99") 149 | adamw_eps = gr.Textbox(label="AdamW epsilon parameter", placeholder="default = 1e-8", value="1e-8") 150 | with gr.Row(visible=False) as dadapt_growth_options: 151 | dadapt_growth_factor = gr.Number(value=-1, label='Growth factor limiting, use value like 1.02 or leave it as -1') 152 | with gr.Row(visible=False) as beta_scheduler_options: 153 | use_beta_scheduler = gr.Checkbox(label='Use CosineAnnealingWarmupRestarts Scheduler') 154 | beta_repeat_epoch = gr.Textbox(label='Steps for cycle', placeholder="Cycles every nth Step", value="64") 155 | epoch_mult = gr.Textbox(label='Step multiplier per cycle', placeholder="Step length multiplier every cycle", 156 | value="1") 157 | warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", 158 | value="5") 159 | min_lr = gr.Textbox(label='Minimum learning rate', 160 | placeholder="restricts decay value, but does not restrict gamma rate decay", 161 | value="6e-7") 162 | gamma_rate = gr.Textbox(label='Decays learning rate every cycle', 163 | placeholder="Value should be in (0-1]", value="1") 164 | with gr.Row(visible=False) as beta_scheduler_options2: 165 | save_converge_opt = gr.Checkbox(label="Saves when every cycle finishes") 166 | generate_converge_opt = gr.Checkbox(label="Generates image when every cycle finishes") 167 | with gr.Row(visible=False) as gradient_clip_options: 168 | gradient_clip_opt = gr.Radio(label="Gradient Clipping Options", choices=["None", "limit", "norm"]) 169 | optional_gradient_clip_value = gr.Textbox(label="Limiting value", value="1e-1") 170 | optional_gradient_norm_type = gr.Textbox(label="Norm type", value="2") 171 | with gr.Row(visible=False) as noise_scheduler_options: 172 | noise_training_scheduler_enabled = gr.Checkbox(label="Use Noise training scheduler(test)") 173 | noise_training_scheduler_repeat = gr.Checkbox(label="Restarts noise scheduler, or linear") 174 | noise_training_scheduler_cycle = gr.Number(label="Restarts noise scheduler every nth epoch") 175 | use_weight = gr.Checkbox(label="Uses image alpha(transparency) channel for adjusting loss") 176 | # change by feedback 177 | use_dadaptation.change( 178 | fn=lambda show: gr_show(show), 179 | inputs=[use_dadaptation], 180 | outputs=[dadapt_growth_options] 181 | ) 182 | show_noise_options.change( 183 | fn = lambda show:gr_show(show), 184 | inputs = [show_noise_options], 185 | outputs = [noise_scheduler_options] 186 | ) 187 | use_beta_adamW_checkbox.change( 188 | fn=lambda show: gr_show(show), 189 | inputs=[use_beta_adamW_checkbox], 190 | outputs=[adamW_options], 191 | ) 192 | use_beta_scheduler_checkbox.change( 193 | fn=lambda show: gr_show(show), 194 | inputs=[use_beta_scheduler_checkbox], 195 | outputs=[beta_scheduler_options], 196 | ) 197 | use_beta_scheduler_checkbox.change( 198 | fn=lambda show: gr_show(show), 199 | inputs=[use_beta_scheduler_checkbox], 200 | outputs=[beta_scheduler_options2], 201 | ) 202 | show_gradient_clip_checkbox.change( 203 | fn=lambda show: gr_show(show), 204 | inputs=[show_gradient_clip_checkbox], 205 | outputs=[gradient_clip_options], 206 | ) 207 | move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)", 208 | value=True) 209 | batch_size = gr.Number(label='Batch size', value=1, precision=0) 210 | gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) 211 | dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") 212 | log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", 213 | value="textual_inversion") 214 | template_file = gr.Textbox(label='Prompt template file', 215 | value=os.path.join(script_path, "textual_inversion_templates", 216 | "style_filewords.txt")) 217 | training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) 218 | training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) 219 | steps = gr.Number(label='Max steps', value=100000, precision=0) 220 | create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', 221 | value=500, precision=0) 222 | save_embedding_every = gr.Number( 223 | label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) 224 | save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) 225 | preview_from_txt2img = gr.Checkbox( 226 | label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) 227 | with gr.Row(): 228 | shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) 229 | tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", 230 | value=0) 231 | with gr.Row(): 232 | latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", 233 | choices=['once', 'deterministic', 'random']) 234 | latent_sampling_std_value = gr.Number(label="Standard deviation for sampling", value=-1) 235 | with gr.Row(): 236 | loss_opt = gr.Radio(label="loss type", value="loss", 237 | choices=['loss', 'loss_simple', 'loss_vlb']) 238 | with gr.Row(): 239 | save_training_option = gr.Button(value="Save training setting") 240 | save_file_name = gr.Textbox(label="File name to save setting as", value="") 241 | load_training_option = gr.Textbox( 242 | label="Load training option from saved json file. This will override settings above", value="") 243 | with gr.Row(): 244 | interrupt_training = gr.Button(value="Interrupt") 245 | train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') 246 | train_embedding = gr.Button(value="Train Embedding", variant='primary') 247 | ti_output = gr.Text(elem_id="ti_output3", value="", show_label=False) 248 | ti_outcome = gr.HTML(elem_id="ti_error3", value="") 249 | 250 | # Full path to .json or simple names are recommended. 251 | save_training_option.click( 252 | fn=wrap_gradio_call(save_training_setting), 253 | inputs=[ 254 | save_file_name, 255 | hypernetwork_learn_rate, 256 | batch_size, 257 | gradient_step, 258 | training_width, 259 | training_height, 260 | steps, 261 | shuffle_tags, 262 | tag_drop_out, 263 | latent_sampling_method, 264 | template_file, 265 | use_beta_scheduler, 266 | beta_repeat_epoch, 267 | epoch_mult, 268 | warmup, 269 | min_lr, 270 | gamma_rate, 271 | use_beta_adamW_checkbox, 272 | save_converge_opt, 273 | generate_converge_opt, 274 | adamw_weight_decay, 275 | adamw_beta_1, 276 | adamw_beta_2, 277 | adamw_eps, 278 | show_gradient_clip_checkbox, 279 | gradient_clip_opt, 280 | optional_gradient_clip_value, 281 | optional_gradient_norm_type, 282 | latent_sampling_std_value, 283 | noise_training_scheduler_enabled, 284 | noise_training_scheduler_repeat, 285 | noise_training_scheduler_cycle, 286 | loss_opt, 287 | use_dadaptation, 288 | dadapt_growth_factor, 289 | use_weight 290 | ], 291 | outputs=[ 292 | ti_output, 293 | ti_outcome, 294 | ] 295 | ) 296 | train_embedding.click( 297 | fn=wrap_gradio_gpu_call(train_embedding_external, extra_outputs=[gr.update()]), 298 | _js="start_training_textual_inversion", 299 | inputs=[ 300 | dummy_component, 301 | train_embedding_name, 302 | embedding_learn_rate, 303 | batch_size, 304 | gradient_step, 305 | dataset_directory, 306 | log_directory, 307 | training_width, 308 | training_height, 309 | steps, 310 | shuffle_tags, 311 | tag_drop_out, 312 | latent_sampling_method, 313 | create_image_every, 314 | save_embedding_every, 315 | template_file, 316 | save_image_with_stored_embedding, 317 | preview_from_txt2img, 318 | *params.txt2img_preview_params, 319 | use_beta_scheduler, 320 | beta_repeat_epoch, 321 | epoch_mult, 322 | warmup, 323 | min_lr, 324 | gamma_rate, 325 | save_converge_opt, 326 | generate_converge_opt, 327 | move_optim_when_generate, 328 | use_beta_adamW_checkbox, 329 | adamw_weight_decay, 330 | adamw_beta_1, 331 | adamw_beta_2, 332 | adamw_eps, 333 | show_gradient_clip_checkbox, 334 | gradient_clip_opt, 335 | optional_gradient_clip_value, 336 | optional_gradient_norm_type, 337 | latent_sampling_std_value, 338 | use_weight 339 | ], 340 | outputs=[ 341 | ti_output, 342 | ti_outcome, 343 | ] 344 | ) 345 | 346 | train_hypernetwork.click( 347 | fn=wrap_gradio_gpu_call(train_hypernetwork_ui, extra_outputs=[gr.update()]), 348 | _js="start_training_textual_inversion", 349 | inputs=[ 350 | dummy_component, 351 | train_hypernetwork_name, 352 | hypernetwork_learn_rate, 353 | batch_size, 354 | gradient_step, 355 | dataset_directory, 356 | log_directory, 357 | training_width, 358 | training_height, 359 | steps, 360 | shuffle_tags, 361 | tag_drop_out, 362 | latent_sampling_method, 363 | create_image_every, 364 | save_embedding_every, 365 | template_file, 366 | preview_from_txt2img, 367 | *params.txt2img_preview_params, 368 | use_beta_scheduler, 369 | beta_repeat_epoch, 370 | epoch_mult, 371 | warmup, 372 | min_lr, 373 | gamma_rate, 374 | save_converge_opt, 375 | generate_converge_opt, 376 | move_optim_when_generate, 377 | use_beta_adamW_checkbox, 378 | adamw_weight_decay, 379 | adamw_beta_1, 380 | adamw_beta_2, 381 | adamw_eps, 382 | show_gradient_clip_checkbox, 383 | gradient_clip_opt, 384 | optional_gradient_clip_value, 385 | optional_gradient_norm_type, 386 | latent_sampling_std_value, 387 | noise_training_scheduler_enabled, 388 | noise_training_scheduler_repeat, 389 | noise_training_scheduler_cycle, 390 | load_training_option, 391 | loss_opt, 392 | use_dadaptation, 393 | dadapt_growth_factor, 394 | use_weight 395 | ], 396 | outputs=[ 397 | ti_output, 398 | ti_outcome, 399 | ] 400 | ) 401 | 402 | interrupt_training.click( 403 | fn=lambda: shared.state.interrupt(), 404 | inputs=[], 405 | outputs=[], 406 | ) 407 | return [(train_gamma, "Train Gamma", "train_gamma")] 408 | 409 | 410 | def on_train_tuning(params=None): 411 | dummy_component = gr.Label(visible=False) 412 | with gr.Tab(label="Train_Tuning") as train_tuning: 413 | gr.HTML( 414 | value="

Train Hypernetwork; you must specify a directory [wiki]

") 415 | with gr.Row(): 416 | train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", 417 | choices=[x for x in shared.hypernetworks.keys()]) 418 | create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, 419 | lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, 420 | "refresh_train_hypernetwork_name") 421 | optional_new_hypernetwork_name = gr.Textbox( 422 | label="Hypernetwork name to create, leave it empty to use selected", value="") 423 | with gr.Row(): 424 | load_hypernetworks_option = gr.Textbox( 425 | label="Load Hypernetwork creation option from saved json file", 426 | placeholder=". filename cannot have ',' inside, and files should be splitted by ','.", value="") 427 | with gr.Row(): 428 | load_training_options = gr.Textbox( 429 | label="Load training option(s) from saved json file", 430 | placeholder=". filename cannot have ',' inside, and files should be splitted by ','.", value="") 431 | move_optim_when_generate = gr.Checkbox(label="Unload Optimizer when generating preview(hypernetwork)", 432 | value=True) 433 | dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") 434 | log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", 435 | value="textual_inversion") 436 | create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', 437 | value=500, precision=0) 438 | save_model_every = gr.Number( 439 | label='Save a copy of model to log directory every N steps, 0 to disable', value=500, precision=0) 440 | preview_from_txt2img = gr.Checkbox( 441 | label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) 442 | manual_dataset_seed = gr.Number( 443 | label="Manual dataset seed", value=-1, precision=0 444 | ) 445 | with gr.Row(): 446 | interrupt_training = gr.Button(value="Interrupt") 447 | train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') 448 | ti_output = gr.Text(elem_id="ti_output4", value="", show_label=False) 449 | ti_outcome = gr.HTML(elem_id="ti_error4", value="") 450 | train_hypernetwork.click( 451 | fn=wrap_gradio_gpu_call(train_hypernetwork_ui_tuning, extra_outputs=[gr.update()]), 452 | _js="start_training_textual_inversion", 453 | inputs=[ 454 | dummy_component, 455 | train_hypernetwork_name, 456 | dataset_directory, 457 | log_directory, 458 | create_image_every, 459 | save_model_every, 460 | preview_from_txt2img, 461 | *params.txt2img_preview_params, 462 | move_optim_when_generate, 463 | optional_new_hypernetwork_name, 464 | load_hypernetworks_option, 465 | load_training_options, 466 | manual_dataset_seed 467 | ], 468 | outputs=[ 469 | ti_output, 470 | ti_outcome, 471 | ] 472 | ) 473 | 474 | interrupt_training.click( 475 | fn=lambda: shared.state.interrupt(), 476 | inputs=[], 477 | outputs=[], 478 | ) 479 | return [(train_tuning, "Train Tuning", "train_tuning")] 480 | -------------------------------------------------------------------------------- /patches/external_pr/textual_inversion.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | import gc 4 | import html 5 | import os 6 | import sys 7 | import traceback 8 | 9 | import torch 10 | import tqdm 11 | from PIL import PngImagePlugin 12 | 13 | from modules import shared, devices, sd_models, images, processing, sd_samplers, sd_hijack, sd_hijack_checkpoint 14 | from modules.textual_inversion.image_embedding import caption_image_overlay, insert_image_data_embed, embedding_to_b64 15 | from modules.textual_inversion.learn_schedule import LearnRateScheduler 16 | from modules.textual_inversion.textual_inversion import save_embedding 17 | from .dataset import PersonalizedBase, PersonalizedDataLoader 18 | from ..hnutil import optim_to 19 | from ..scheduler import CosineAnnealingWarmUpRestarts 20 | from ..tbutils import tensorboard_setup, tensorboard_add_image 21 | 22 | # apply OsError avoid here 23 | delayed_values = {} 24 | 25 | 26 | def write_loss(log_directory, filename, step, epoch_len, values): 27 | if shared.opts.training_write_csv_every == 0: 28 | return 29 | 30 | if step % shared.opts.training_write_csv_every != 0: 31 | return 32 | write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True 33 | try: 34 | with open(os.path.join(log_directory, filename), "a+", newline='') as fout: 35 | csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) 36 | 37 | if write_csv_header: 38 | csv_writer.writeheader() 39 | if log_directory + filename in delayed_values: 40 | delayed = delayed_values[log_directory + filename] 41 | for step, epoch, epoch_step, values in delayed: 42 | csv_writer.writerow({ 43 | "step": step, 44 | "epoch": epoch, 45 | "epoch_step": epoch_step, 46 | **values, 47 | }) 48 | delayed.clear() 49 | epoch, epoch_step = divmod(step - 1, epoch_len) 50 | csv_writer.writerow({ 51 | "step": step, 52 | "epoch": epoch, 53 | "epoch_step": epoch_step, 54 | **values, 55 | }) 56 | except OSError: 57 | epoch, epoch_step = divmod(step - 1, epoch_len) 58 | if log_directory + filename in delayed_values: 59 | delayed_values[log_directory + filename].append((step, epoch, epoch_step, values)) 60 | else: 61 | delayed_values[log_directory + filename] = [(step, epoch, epoch_step, values)] 62 | 63 | 64 | def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, 65 | save_model_every, create_image_every, log_directory, name="embedding"): 66 | assert model_name, f"{name} not selected" 67 | assert learn_rate, "Learning rate is empty or 0" 68 | assert isinstance(batch_size, int), "Batch size must be integer" 69 | assert batch_size > 0, "Batch size must be positive" 70 | assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" 71 | assert gradient_step > 0, "Gradient accumulation step must be positive" 72 | assert data_root, "Dataset directory is empty" 73 | assert os.path.isdir(data_root), "Dataset directory doesn't exist" 74 | assert os.listdir(data_root), "Dataset directory is empty" 75 | assert template_file, "Prompt template file is empty" 76 | assert os.path.isfile(template_file), "Prompt template file doesn't exist" 77 | assert steps, "Max steps is empty or 0" 78 | assert isinstance(steps, int), "Max steps must be integer" 79 | assert steps > 0, "Max steps must be positive" 80 | assert isinstance(save_model_every, int), "Save {name} must be integer" 81 | assert save_model_every >= 0, "Save {name} must be positive or 0" 82 | assert isinstance(create_image_every, int), "Create image must be integer" 83 | assert create_image_every >= 0, "Create image must be positive or 0" 84 | if save_model_every or create_image_every: 85 | assert log_directory, "Log directory is empty" 86 | 87 | 88 | def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, 89 | training_width, 90 | training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, 91 | save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, 92 | preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, 93 | preview_seed, preview_width, preview_height, 94 | use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1, warmup=10, min_lr=1e-7, 95 | gamma_rate=1, save_when_converge=False, create_when_converge=False, 96 | move_optimizer=True, 97 | use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99, 98 | adamw_eps=1e-8, 99 | use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, 100 | optional_gradient_norm_type=2, latent_sampling_std=-1, use_weight=False 101 | ): 102 | save_embedding_every = save_embedding_every or 0 103 | create_image_every = create_image_every or 0 104 | validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, 105 | save_embedding_every, create_image_every, log_directory, name="embedding") 106 | try: 107 | if use_adamw_parameter: 108 | adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in 109 | [adamw_weight_decay, adamw_beta_1, 110 | adamw_beta_2, adamw_eps]] 111 | assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!" 112 | assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, 113 | adamw_eps])), "Cannot use negative or >1 number for adamW parameters!" 114 | adamW_kwarg_dict = { 115 | 'weight_decay': adamw_weight_decay, 116 | 'betas': (adamw_beta_1, adamw_beta_2), 117 | 'eps': adamw_eps 118 | } 119 | print('Using custom AdamW parameters') 120 | else: 121 | adamW_kwarg_dict = { 122 | 'weight_decay': 0.01, 123 | 'betas': (0.9, 0.99), 124 | 'eps': 1e-8 125 | } 126 | if use_beta_scheduler: 127 | print("Using Beta Scheduler") 128 | beta_repeat_epoch = int(beta_repeat_epoch) 129 | assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!" 130 | min_lr = float(min_lr) 131 | assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!" 132 | gamma_rate = float(gamma_rate) 133 | print(f"Using learn rate decay(per cycle) of {gamma_rate}") 134 | assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!" 135 | epoch_mult = float(epoch_mult) 136 | assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!" 137 | warmup = int(warmup) 138 | assert warmup >= 1, "Warmup epoch should be larger than 0!" 139 | print(f"Save when converges : {save_when_converge}") 140 | print(f"Generate image when converges : {create_when_converge}") 141 | else: 142 | beta_repeat_epoch = 4000 143 | epoch_mult = 1 144 | warmup = 10 145 | min_lr = 1e-7 146 | gamma_rate = 1 147 | save_when_converge = False 148 | create_when_converge = False 149 | except ValueError: 150 | raise RuntimeError("Cannot use advanced LR scheduler settings!") 151 | if use_grad_opts and gradient_clip_opt != "None": 152 | try: 153 | optional_gradient_clip_value = float(optional_gradient_clip_value) 154 | except ValueError: 155 | raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})") 156 | if gradient_clip_opt == "Norm": 157 | try: 158 | grad_norm = int(optional_gradient_norm_type) 159 | except ValueError: 160 | raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})") 161 | assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}" 162 | 163 | def gradient_clipping(arg1): 164 | torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type) 165 | return 166 | else: 167 | def gradient_clipping(arg1): 168 | torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value) 169 | return 170 | else: 171 | def gradient_clipping(arg1): 172 | return 173 | # Function gradient clipping is inplace(_) operation. 174 | shared.state.job = "train-embedding" 175 | shared.state.textinfo = "Initializing textual inversion training..." 176 | shared.state.job_count = steps 177 | 178 | filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') 179 | 180 | log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) 181 | unload = shared.opts.unload_models_when_training 182 | 183 | if save_embedding_every > 0 or save_when_converge: 184 | embedding_dir = os.path.join(log_directory, "embeddings") 185 | os.makedirs(embedding_dir, exist_ok=True) 186 | else: 187 | embedding_dir = None 188 | 189 | if create_image_every > 0 or create_when_converge: 190 | images_dir = os.path.join(log_directory, "images") 191 | os.makedirs(images_dir, exist_ok=True) 192 | else: 193 | images_dir = None 194 | 195 | if (create_image_every > 0 or create_when_converge) and save_image_with_stored_embedding: 196 | images_embeds_dir = os.path.join(log_directory, "image_embeddings") 197 | os.makedirs(images_embeds_dir, exist_ok=True) 198 | else: 199 | images_embeds_dir = None 200 | 201 | hijack = sd_hijack.model_hijack 202 | 203 | embedding = hijack.embedding_db.word_embeddings[embedding_name] 204 | checkpoint = sd_models.select_checkpoint() 205 | 206 | initial_step = embedding.step or 0 207 | if initial_step >= steps: 208 | shared.state.textinfo = f"Model has already been trained beyond specified max steps" 209 | return embedding, filename 210 | scheduler = LearnRateScheduler(learn_rate, steps, initial_step) 211 | 212 | # dataset loading may take a while, so input validations and early returns should be done before this 213 | shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." 214 | old_parallel_processing_allowed = shared.parallel_processing_allowed 215 | 216 | tensorboard_writer = None 217 | if shared.opts.training_enable_tensorboard: 218 | print("Tensorboard logging enabled") 219 | tensorboard_writer = tensorboard_setup(log_directory) 220 | 221 | pin_memory = shared.opts.pin_memory 222 | detach_grad = shared.opts.disable_ema # test code that removes EMA 223 | if detach_grad: 224 | print("Disabling training for staged models!") 225 | shared.sd_model.cond_stage_model.requires_grad_(False) 226 | shared.sd_model.first_stage_model.requires_grad_(False) 227 | torch.cuda.empty_cache() 228 | ds = PersonalizedBase(data_root=data_root, width=training_width, 229 | height=training_height, 230 | repeats=shared.opts.training_image_repeats_per_epoch, 231 | placeholder_token=embedding_name, model=shared.sd_model, 232 | cond_model=shared.sd_model.cond_stage_model, 233 | device=devices.device, template_file=template_file, 234 | batch_size=batch_size, gradient_step=gradient_step, 235 | shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, 236 | latent_sampling_method=latent_sampling_method, 237 | latent_sampling_std=latent_sampling_std, use_weight=use_weight) 238 | 239 | latent_sampling_method = ds.latent_sampling_method 240 | 241 | dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, 242 | batch_size=ds.batch_size, pin_memory=pin_memory) 243 | if unload: 244 | shared.parallel_processing_allowed = False 245 | shared.sd_model.first_stage_model.to(devices.cpu) 246 | 247 | embedding.vec.requires_grad_(True) 248 | optimizer_name = 'AdamW' # hardcoded optimizer name now 249 | if use_adamw_parameter: 250 | optimizer = torch.optim.AdamW(params=[embedding.vec], lr=scheduler.learn_rate, **adamW_kwarg_dict) 251 | else: 252 | optimizer = torch.optim.AdamW(params=[embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) 253 | 254 | if os.path.exists( 255 | filename + '.optim'): # This line must be changed if Optimizer type can be different from saved optimizer. 256 | try: 257 | optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') 258 | if embedding.checksum() == optimizer_saved_dict.get('hash', None): 259 | optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) 260 | if optimizer_state_dict is not None: 261 | optimizer.load_state_dict(optimizer_state_dict) 262 | print("Loaded existing optimizer from checkpoint") 263 | except RuntimeError as e: 264 | print("Cannot resume from saved optimizer!") 265 | print(e) 266 | else: 267 | print("No saved optimizer exists in checkpoint") 268 | if move_optimizer: 269 | optim_to(optimizer, devices.device) 270 | if use_beta_scheduler: 271 | scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, 272 | cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, 273 | warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate) 274 | scheduler_beta.last_epoch = embedding.step - 1 275 | else: 276 | scheduler_beta = None 277 | for pg in optimizer.param_groups: 278 | pg['lr'] = scheduler.learn_rate 279 | 280 | scaler = torch.cuda.amp.GradScaler() 281 | 282 | batch_size = ds.batch_size 283 | gradient_step = ds.gradient_step 284 | # n steps = batch_size * gradient_step * n image processed 285 | steps_per_epoch = len(ds) // batch_size // gradient_step 286 | max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step 287 | loss_step = 0 288 | _loss_step = 0 # internal 289 | 290 | last_saved_file = "" 291 | last_saved_image = "" 292 | forced_filename = "" 293 | embedding_yet_to_be_embedded = False 294 | 295 | is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} 296 | img_c = None 297 | 298 | pbar = tqdm.tqdm(total=steps - initial_step) 299 | if hasattr(sd_hijack_checkpoint, 'add'): 300 | sd_hijack_checkpoint.add() 301 | try: 302 | for i in range((steps - initial_step) * gradient_step): 303 | if scheduler.finished: 304 | break 305 | if shared.state.interrupted: 306 | break 307 | for j, batch in enumerate(dl): 308 | # works as a drop_last=True for gradient accumulation 309 | if j == max_steps_per_epoch: 310 | break 311 | if use_beta_scheduler: 312 | scheduler_beta.step(embedding.step) 313 | else: 314 | scheduler.apply(optimizer, embedding.step) 315 | if scheduler.finished: 316 | break 317 | if shared.state.interrupted: 318 | break 319 | 320 | with devices.autocast(): 321 | x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) 322 | if use_weight: 323 | w = batch.weight.to(devices.device, non_blocking=pin_memory) 324 | shared.sd_model.cond_stage_model.to(devices.device) 325 | c = shared.sd_model.cond_stage_model(batch.cond_text) 326 | if is_training_inpainting_model: 327 | if img_c is None: 328 | img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, 329 | training_height) 330 | 331 | cond = {"c_concat": [img_c], "c_crossattn": [c]} 332 | else: 333 | cond = c 334 | if use_weight: 335 | loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step 336 | del w 337 | else: 338 | loss = shared.sd_model.forward(x, cond)[0] / gradient_step 339 | del x 340 | _loss_step += loss.item() 341 | scaler.scale(loss).backward() 342 | # go back until we reach gradient accumulation steps 343 | if (j + 1) % gradient_step != 0: 344 | continue 345 | gradient_clipping(embedding.vec) 346 | try: 347 | scaler.step(optimizer) 348 | except AssertionError: 349 | raise RuntimeError("This error happens because None of the template used embedding's text!") 350 | scaler.update() 351 | embedding.step += 1 352 | pbar.update() 353 | optimizer.zero_grad(set_to_none=True) 354 | loss_step = _loss_step 355 | _loss_step = 0 356 | 357 | steps_done = embedding.step + 1 358 | 359 | epoch_num = embedding.step // steps_per_epoch 360 | epoch_step = embedding.step % steps_per_epoch 361 | 362 | pbar.set_description(f"[Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}") 363 | if embedding_dir is not None and ( 364 | (use_beta_scheduler and scheduler_beta.is_EOC(embedding.step) and save_when_converge) or ( 365 | save_embedding_every > 0 and steps_done % save_embedding_every == 0)): 366 | # Before saving, change name to match current checkpoint. 367 | embedding_name_every = f'{embedding_name}-{steps_done}' 368 | last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') 369 | # if shared.opts.save_optimizer_state: 370 | # embedding.optimizer_state_dict = optimizer.state_dict() 371 | save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, 372 | remove_cached_checksum=True) 373 | embedding_yet_to_be_embedded = True 374 | 375 | write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { 376 | "loss": f"{loss_step:.7f}", 377 | "learn_rate": scheduler.learn_rate 378 | }) 379 | 380 | if images_dir is not None and ( 381 | (use_beta_scheduler and scheduler_beta.is_EOC(embedding.step) and create_when_converge) or ( 382 | create_image_every > 0 and steps_done % create_image_every == 0)): 383 | forced_filename = f'{embedding_name}-{steps_done}' 384 | last_saved_image = os.path.join(images_dir, forced_filename) 385 | rng_state = torch.get_rng_state() 386 | cuda_rng_state = None 387 | if torch.cuda.is_available(): 388 | cuda_rng_state = torch.cuda.get_rng_state_all() 389 | if move_optimizer: 390 | optim_to(optimizer, devices.cpu) 391 | gc.collect() 392 | shared.sd_model.first_stage_model.to(devices.device) 393 | 394 | p = processing.StableDiffusionProcessingTxt2Img( 395 | sd_model=shared.sd_model, 396 | do_not_save_grid=True, 397 | do_not_save_samples=True, 398 | do_not_reload_embeddings=True, 399 | ) 400 | 401 | if preview_from_txt2img: 402 | p.prompt = preview_prompt 403 | p.negative_prompt = preview_negative_prompt 404 | p.steps = preview_steps 405 | p.sampler_name = sd_samplers.samplers[preview_sampler_index].name 406 | p.cfg_scale = preview_cfg_scale 407 | p.seed = preview_seed 408 | p.width = preview_width 409 | p.height = preview_height 410 | else: 411 | p.prompt = batch.cond_text[0] 412 | p.steps = 20 413 | p.width = training_width 414 | p.height = training_height 415 | 416 | preview_text = p.prompt 417 | if hasattr(p, 'disable_extra_networks'): 418 | p.disable_extra_networks = True 419 | is_patched = True 420 | else: 421 | is_patched = False 422 | processed = processing.process_images(p) 423 | image = processed.images[0] if len(processed.images) > 0 else None 424 | 425 | if move_optimizer: 426 | optim_to(optimizer, devices.device) 427 | if image is not None: 428 | if hasattr(shared.state, 'assign_current_image'): 429 | shared.state.assign_current_image(image) 430 | else: 431 | shared.state.current_image = image 432 | last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, 433 | shared.opts.samples_format, 434 | processed.infotexts[0], p=p, 435 | forced_filename=forced_filename, 436 | save_to_dirs=False) 437 | last_saved_image += f", prompt: {preview_text}" 438 | if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: 439 | tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, 440 | embedding.step) 441 | 442 | if save_image_with_stored_embedding and os.path.exists( 443 | last_saved_file) and embedding_yet_to_be_embedded: 444 | 445 | last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') 446 | 447 | info = PngImagePlugin.PngInfo() 448 | data = torch.load(last_saved_file) 449 | info.add_text("sd-ti-embedding", embedding_to_b64(data)) 450 | 451 | title = "<{}>".format(data.get('name', '???')) 452 | 453 | try: 454 | vectorSize = list(data['string_to_param'].values())[0].shape[0] 455 | except Exception as e: 456 | vectorSize = '?' 457 | 458 | checkpoint = sd_models.select_checkpoint() 459 | footer_left = checkpoint.model_name 460 | footer_mid = '[{}]'.format( 461 | checkpoint.shorthash if hasattr(checkpoint, 'shorthash') else checkpoint.hash) 462 | footer_right = '{}v {}s'.format(vectorSize, steps_done) 463 | 464 | captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) 465 | captioned_image = insert_image_data_embed(captioned_image, data) 466 | 467 | captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) 468 | embedding_yet_to_be_embedded = False 469 | if unload: 470 | shared.sd_model.first_stage_model.to(devices.cpu) 471 | torch.set_rng_state(rng_state) 472 | if torch.cuda.is_available(): 473 | torch.cuda.set_rng_state_all(cuda_rng_state) 474 | last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, 475 | shared.opts.samples_format, 476 | processed.infotexts[0], p=p, 477 | forced_filename=forced_filename, 478 | save_to_dirs=False) 479 | last_saved_image += f", prompt: {preview_text}" 480 | 481 | shared.state.job_no = embedding.step 482 | 483 | shared.state.textinfo = f""" 484 |

485 | Loss: {loss_step:.7f}
486 | Step: {steps_done}
487 | Last prompt: {html.escape(batch.cond_text[0])}
488 | Last saved embedding: {html.escape(last_saved_file)}
489 | Last saved image: {html.escape(last_saved_image)}
490 |

491 | """ 492 | filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') 493 | save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) 494 | except Exception: 495 | print(traceback.format_exc(), file=sys.stderr) 496 | pass 497 | finally: 498 | pbar.leave = False 499 | pbar.close() 500 | shared.sd_model.first_stage_model.to(devices.device) 501 | shared.parallel_processing_allowed = old_parallel_processing_allowed 502 | if hasattr(sd_hijack_checkpoint, 'remove'): 503 | sd_hijack_checkpoint.remove() 504 | return embedding, filename 505 | -------------------------------------------------------------------------------- /patches/hypernetwork.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import os 4 | import sys 5 | import traceback 6 | 7 | import torch 8 | from torch.nn.init import normal_, xavier_uniform_, zeros_, xavier_normal_, kaiming_uniform_, kaiming_normal_ 9 | 10 | try: 11 | from modules.hashes import sha256 12 | except (ImportError, ModuleNotFoundError): 13 | print("modules.hashes is not found, will use backup module from extension!") 14 | from .hashes_backup import sha256 15 | 16 | import modules.hypernetworks.hypernetwork 17 | from modules import devices, shared, sd_models 18 | from .hnutil import parse_dropout_structure, find_self 19 | from .shared import version_flag 20 | 21 | def init_weight(layer, weight_init="Normal", normal_std=0.01, activation_func="relu"): 22 | w, b = layer.weight.data, layer.bias.data 23 | if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: 24 | normal_(w, mean=0.0, std=normal_std) 25 | normal_(b, mean=0.0, std=0) 26 | elif weight_init == 'XavierUniform': 27 | xavier_uniform_(w) 28 | zeros_(b) 29 | elif weight_init == 'XavierNormal': 30 | xavier_normal_(w) 31 | zeros_(b) 32 | elif weight_init == 'KaimingUniform': 33 | kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') 34 | zeros_(b) 35 | elif weight_init == 'KaimingNormal': 36 | kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') 37 | zeros_(b) 38 | else: 39 | raise KeyError(f"Key {weight_init} is not defined as initialization!") 40 | 41 | 42 | class ResBlock(torch.nn.Module): 43 | """Residual Block""" 44 | def __init__(self, n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device=None, state_dict=None, **kwargs): 45 | super().__init__() 46 | self.n_outputs = n_outputs 47 | self.upsample_layer = None 48 | self.upsample = kwargs.get("upsample_model", None) 49 | if self.upsample == "Linear": 50 | self.upsample_layer = torch.nn.Linear(n_inputs, n_outputs, bias=False) 51 | linears = [torch.nn.Linear(n_inputs, n_outputs)] 52 | init_weight(linears[0], weight_init, normal_std, activation_func) 53 | if add_layer_norm: 54 | linears.append(torch.nn.LayerNorm(n_outputs)) 55 | init_weight(linears[1], weight_init, normal_std, activation_func) 56 | if dropout_p > 0: 57 | linears.append(torch.nn.Dropout(p=dropout_p)) 58 | if activation_func == "linear" or activation_func is None: 59 | pass 60 | elif activation_func in HypernetworkModule.activation_dict: 61 | linears.append(HypernetworkModule.activation_dict[activation_func]()) 62 | else: 63 | raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') 64 | self.linear = torch.nn.Sequential(*linears) 65 | if state_dict is not None: 66 | self.load_state_dict(state_dict) 67 | if device is not None: 68 | self.to(device) 69 | 70 | def trainables(self, train=False): 71 | layer_structure = [] 72 | for layer in self.linear: 73 | if train: 74 | layer.train() 75 | else: 76 | layer.eval() 77 | if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: 78 | layer_structure += [layer.weight, layer.bias] 79 | return layer_structure 80 | 81 | def forward(self, x, **kwargs): 82 | if self.upsample_layer is None: 83 | interpolated = torch.nn.functional.interpolate(x, size=self.n_outputs, mode="nearest-exact") 84 | else: 85 | interpolated = self.upsample_layer(x) 86 | return interpolated + self.linear(x) 87 | 88 | 89 | 90 | class HypernetworkModule(torch.nn.Module): 91 | multiplier = 1.0 92 | activation_dict = { 93 | "linear": torch.nn.Identity, 94 | "relu": torch.nn.ReLU, 95 | "leakyrelu": torch.nn.LeakyReLU, 96 | "elu": torch.nn.ELU, 97 | "swish": torch.nn.Hardswish, 98 | "tanh": torch.nn.Tanh, 99 | "sigmoid": torch.nn.Sigmoid, 100 | } 101 | activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) 102 | 103 | def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', 104 | add_layer_norm=False, activate_output=False, dropout_structure=None, device=None, generation_seed=None, normal_std=0.01, **kwargs): 105 | super().__init__() 106 | self.skip_connection = skip_connection = kwargs.get('skip_connection', False) 107 | upsample_linear = kwargs.get('upsample_linear', None) 108 | assert layer_structure is not None, "layer_structure must not be None" 109 | assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" 110 | assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" 111 | # instead of throwing error, maybe try warning. first value is always not used. 112 | if not (skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0): 113 | print("Dropout sequence does not starts or ends with zero.") 114 | # assert skip_connection or dropout_structure is None or dropout_structure[0] == dropout_structure[-1] == 0, "Dropout Sequence should start and end with probability 0!" 115 | assert dropout_structure is None or len(dropout_structure) == len(layer_structure), "Dropout Sequence should match length with layer structure!" 116 | 117 | linears = [] 118 | if skip_connection: 119 | if generation_seed is not None: 120 | torch.manual_seed(generation_seed) 121 | for i in range(len(layer_structure) - 1): 122 | if skip_connection: 123 | n_inputs, n_outputs = int(dim * layer_structure[i]), int(dim * layer_structure[i+1]) 124 | dropout_p = dropout_structure[i+1] 125 | if activation_func is None: 126 | activation_func = "linear" 127 | linears.append(ResBlock(n_inputs, n_outputs, activation_func, weight_init, add_layer_norm, dropout_p, normal_std, device, upsample_model=upsample_linear)) 128 | continue 129 | 130 | # Add a fully-connected layer 131 | linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) 132 | 133 | # Add an activation func except last layer 134 | if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): 135 | pass 136 | elif activation_func in self.activation_dict: 137 | linears.append(self.activation_dict[activation_func]()) 138 | else: 139 | raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') 140 | 141 | # Add layer normalization 142 | if add_layer_norm: 143 | linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) 144 | 145 | # Everything should be now parsed into dropout structure, and applied here. 146 | # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0. 147 | if dropout_structure is not None and dropout_structure[i+1] > 0: 148 | assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!" 149 | linears.append(torch.nn.Dropout(p=dropout_structure[i+1])) 150 | # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0]. 151 | 152 | self.linear = torch.nn.Sequential(*linears) 153 | 154 | if state_dict is not None: 155 | self.fix_old_state_dict(state_dict) 156 | self.load_state_dict(state_dict) 157 | elif not skip_connection: 158 | if generation_seed is not None: 159 | torch.manual_seed(generation_seed) 160 | for layer in self.linear: 161 | if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: 162 | w, b = layer.weight.data, layer.bias.data 163 | if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: 164 | normal_(w, mean=0.0, std=normal_std) 165 | normal_(b, mean=0.0, std=0) 166 | elif weight_init == 'XavierUniform': 167 | xavier_uniform_(w) 168 | zeros_(b) 169 | elif weight_init == 'XavierNormal': 170 | xavier_normal_(w) 171 | zeros_(b) 172 | elif weight_init == 'KaimingUniform': 173 | kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') 174 | zeros_(b) 175 | elif weight_init == 'KaimingNormal': 176 | kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') 177 | zeros_(b) 178 | else: 179 | raise KeyError(f"Key {weight_init} is not defined as initialization!") 180 | if device is None: 181 | self.to(devices.device) 182 | else: 183 | self.to(device) 184 | 185 | 186 | def fix_old_state_dict(self, state_dict): 187 | changes = { 188 | 'linear1.bias': 'linear.0.bias', 189 | 'linear1.weight': 'linear.0.weight', 190 | 'linear2.bias': 'linear.1.bias', 191 | 'linear2.weight': 'linear.1.weight', 192 | } 193 | 194 | for fr, to in changes.items(): 195 | x = state_dict.get(fr, None) 196 | if x is None: 197 | continue 198 | 199 | del state_dict[fr] 200 | state_dict[to] = x 201 | 202 | def forward(self, x, multiplier=None): 203 | if self.skip_connection: 204 | if self.training: 205 | return self.linear(x) 206 | else: 207 | resnet_result = self.linear(x) 208 | residual = resnet_result - x 209 | if multiplier is None or not isinstance(multiplier, (int, float)): 210 | multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier 211 | return x + multiplier * residual # interpolate 212 | if multiplier is None or not isinstance(multiplier, (int, float)): 213 | return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1) 214 | return x + self.linear(x) * multiplier 215 | 216 | def trainables(self, train=False): 217 | layer_structure = [] 218 | self.train(train) 219 | for layer in self.linear: 220 | if train: 221 | layer.train() 222 | else: 223 | layer.eval() 224 | if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: 225 | layer_structure += [layer.weight, layer.bias] 226 | elif type(layer) == ResBlock: 227 | layer_structure += layer.trainables(train) 228 | return layer_structure 229 | 230 | def set_train(self,mode=True): 231 | self.train(mode) 232 | for layer in self.linear: 233 | if mode: 234 | layer.train(mode) 235 | else: 236 | layer.eval() 237 | 238 | 239 | class Hypernetwork: 240 | filename = None 241 | name = None 242 | 243 | def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs): 244 | self.filename = None 245 | self.name = name 246 | self.layers = {} 247 | self.step = 0 248 | self.sd_checkpoint = None 249 | self.sd_checkpoint_name = None 250 | self.layer_structure = layer_structure 251 | self.activation_func = activation_func 252 | self.weight_init = weight_init 253 | self.add_layer_norm = add_layer_norm 254 | self.use_dropout = use_dropout 255 | self.activate_output = activate_output 256 | self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True 257 | self.optimizer_name = None 258 | self.optimizer_state_dict = None 259 | self.dropout_structure = kwargs['dropout_structure'] if 'dropout_structure' in kwargs and use_dropout else None 260 | self.optional_info = kwargs.get('optional_info', None) 261 | self.skip_connection = kwargs.get('skip_connection', False) 262 | self.upsample_linear = kwargs.get('upsample_linear', None) 263 | self.training = False 264 | generation_seed = kwargs.get('generation_seed', None) 265 | normal_std = kwargs.get('normal_std', 0.01) 266 | if self.dropout_structure is None: 267 | self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) 268 | 269 | for size in enable_sizes or []: 270 | self.layers[size] = ( 271 | HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, 272 | self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection, 273 | upsample_linear=self.upsample_linear), 274 | HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, 275 | self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure, generation_seed=generation_seed, normal_std=normal_std, skip_connection=self.skip_connection, 276 | upsample_linear=self.upsample_linear), 277 | ) 278 | self.eval() 279 | 280 | def weights(self, train=False): 281 | self.training = train 282 | res = [] 283 | for k, layers in self.layers.items(): 284 | for layer in layers: 285 | res += layer.trainables(train) 286 | return res 287 | 288 | def eval(self): 289 | self.training = False 290 | for k, layers in self.layers.items(): 291 | for layer in layers: 292 | layer.eval() 293 | layer.set_train(False) 294 | 295 | def train(self, mode=True): 296 | self.training = mode 297 | for k, layers in self.layers.items(): 298 | for layer in layers: 299 | layer.set_train(mode) 300 | 301 | def detach_grad(self): 302 | for k, layers in self.layers.items(): 303 | for layer in layers: 304 | layer.requires_grad_(False) 305 | 306 | def shorthash(self): 307 | sha256v = sha256(self.filename, f'hypernet/{self.name}') 308 | return sha256v[0:10] 309 | 310 | def extra_name(self): 311 | if version_flag: 312 | return "" 313 | found = find_self(self) 314 | if found is not None: 315 | return f" " 316 | return f" " 317 | 318 | def save(self, filename): 319 | state_dict = {} 320 | optimizer_saved_dict = {} 321 | 322 | for k, v in self.layers.items(): 323 | state_dict[k] = (v[0].state_dict(), v[1].state_dict()) 324 | 325 | state_dict['step'] = self.step 326 | state_dict['name'] = self.name 327 | state_dict['layer_structure'] = self.layer_structure 328 | state_dict['activation_func'] = self.activation_func 329 | state_dict['is_layer_norm'] = self.add_layer_norm 330 | state_dict['weight_initialization'] = self.weight_init 331 | state_dict['sd_checkpoint'] = self.sd_checkpoint 332 | state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name 333 | state_dict['activate_output'] = self.activate_output 334 | state_dict['use_dropout'] = self.use_dropout 335 | state_dict['dropout_structure'] = self.dropout_structure 336 | state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout 337 | state_dict['optional_info'] = self.optional_info if self.optional_info else None 338 | state_dict['skip_connection'] = self.skip_connection 339 | state_dict['upsample_linear'] = self.upsample_linear 340 | 341 | if self.optimizer_name is not None: 342 | optimizer_saved_dict['optimizer_name'] = self.optimizer_name 343 | 344 | torch.save(state_dict, filename) 345 | if shared.opts.save_optimizer_state and self.optimizer_state_dict: 346 | optimizer_saved_dict['hash'] = self.shorthash() # this is necessary 347 | optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict 348 | torch.save(optimizer_saved_dict, filename + '.optim') 349 | 350 | def load(self, filename): 351 | self.filename = filename 352 | if self.name is None: 353 | self.name = os.path.splitext(os.path.basename(filename))[0] 354 | 355 | state_dict = torch.load(filename, map_location='cpu') 356 | 357 | self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) 358 | print(self.layer_structure) 359 | optional_info = state_dict.get('optional_info', None) 360 | if optional_info is not None: 361 | self.optional_info = optional_info 362 | self.activation_func = state_dict.get('activation_func', None) 363 | self.weight_init = state_dict.get('weight_initialization', 'Normal') 364 | self.add_layer_norm = state_dict.get('is_layer_norm', False) 365 | self.dropout_structure = state_dict.get('dropout_structure', None) 366 | self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) 367 | self.activate_output = state_dict.get('activate_output', True) 368 | self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Silent fix for HNs before 4918eb6 369 | self.skip_connection = state_dict.get('skip_connection', False) 370 | self.upsample_linear = state_dict.get('upsample_linear', False) 371 | # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. 372 | if self.dropout_structure is None: 373 | self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) 374 | if hasattr(shared.opts, 'print_hypernet_extra') and shared.opts.print_hypernet_extra: 375 | if optional_info is not None: 376 | print(f"INFO:\n {optional_info}\n") 377 | print(f"Activation function is {self.activation_func}") 378 | print(f"Weight initialization is {self.weight_init}") 379 | print(f"Layer norm is set to {self.add_layer_norm}") 380 | print(f"Dropout usage is set to {self.use_dropout}") 381 | print(f"Activate last layer is set to {self.activate_output}") 382 | print(f"Dropout structure is set to {self.dropout_structure}") 383 | optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} 384 | self.optimizer_name = state_dict.get('optimizer_name', 'AdamW') 385 | 386 | if optimizer_saved_dict.get('hash', None) == self.shorthash() or optimizer_saved_dict.get('hash', None) == sd_models.model_hash(filename): 387 | self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) 388 | else: 389 | self.optimizer_state_dict = None 390 | if self.optimizer_state_dict: 391 | self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') 392 | print("Loaded existing optimizer from checkpoint") 393 | print(f"Optimizer name is {self.optimizer_name}") 394 | else: 395 | print("No saved optimizer exists in checkpoint") 396 | 397 | for size, sd in state_dict.items(): 398 | if type(size) == int: 399 | self.layers[size] = ( 400 | HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, 401 | self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear), 402 | HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, 403 | self.add_layer_norm, self.activate_output, self.dropout_structure, skip_connection=self.skip_connection, upsample_linear=self.upsample_linear), 404 | ) 405 | 406 | self.name = state_dict.get('name', self.name) 407 | self.step = state_dict.get('step', 0) 408 | self.sd_checkpoint = state_dict.get('sd_checkpoint', None) 409 | self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) 410 | self.eval() 411 | 412 | def to(self, device): 413 | for k, layers in self.layers.items(): 414 | for layer in layers: 415 | layer.to(device) 416 | 417 | return self 418 | 419 | def set_multiplier(self, multiplier): 420 | for k, layers in self.layers.items(): 421 | for layer in layers: 422 | layer.multiplier = multiplier 423 | 424 | return self 425 | 426 | def __call__(self, context, *args, **kwargs): 427 | return self.forward(context, *args, **kwargs) 428 | 429 | def forward(self, context, context_v=None, layer=None): 430 | context_layers = self.layers.get(context.shape[2], None) 431 | if context_v is None: 432 | context_v = context 433 | if context_layers is None: 434 | return context, context_v 435 | if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): 436 | layer.hyper_k = context_layers[0] 437 | layer.hyper_v = context_layers[1] 438 | transform_k, transform_v = context_layers[0](context), context_layers[1](context_v) 439 | return transform_k, transform_v 440 | 441 | 442 | def list_hypernetworks(path): 443 | res = {} 444 | for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): 445 | name = os.path.splitext(os.path.basename(filename))[0] 446 | idx = 0 447 | while name in res: 448 | idx += 1 449 | name = name + f"({idx})" 450 | # Prevent a hypothetical "None.pt" from being listed. 451 | if name != "None": 452 | res[name] = filename 453 | for filename in glob.iglob(os.path.join(path, '**/*.hns'), recursive=True): 454 | name = os.path.splitext(os.path.basename(filename))[0] 455 | if name != "None": 456 | res[name] = filename 457 | return res 458 | 459 | def find_closest_first(keyset, target): 460 | for keys in keyset: 461 | if target == keys.rsplit('(', 1)[0]: 462 | return keys 463 | return None 464 | 465 | 466 | 467 | def load_hypernetwork(filename): 468 | hypernetwork = None 469 | path = shared.hypernetworks.get(filename, None) 470 | if path is None: 471 | filename = find_closest_first(shared.hypernetworks.keys(), filename) 472 | path = shared.hypernetworks.get(filename, None) 473 | print(path) 474 | # Prevent any file named "None.pt" from being loaded. 475 | if path is not None and filename != "None": 476 | print(f"Loading hypernetwork {filename}") 477 | if path.endswith(".pt"): 478 | try: 479 | hypernetwork = Hypernetwork() 480 | hypernetwork.load(path) 481 | if hasattr(shared, 'loaded_hypernetwork'): 482 | shared.loaded_hypernetwork = hypernetwork 483 | else: 484 | return hypernetwork 485 | 486 | except Exception: 487 | print(f"Error loading hypernetwork {path}", file=sys.stderr) 488 | print(traceback.format_exc(), file=sys.stderr) 489 | elif path.endswith(".hns"): 490 | # Load Hypernetwork processing 491 | try: 492 | from .hypernetworks import load as load_hns 493 | if hasattr(shared, 'loaded_hypernetwork'): 494 | shared.loaded_hypernetwork = load_hns(path) 495 | else: 496 | hypernetwork = load_hns(path) 497 | print(f"Loaded Hypernetwork Structure {path}") 498 | return hypernetwork 499 | except Exception: 500 | print(f"Error loading hypernetwork processing file {path}", file=sys.stderr) 501 | print(traceback.format_exc(), file=sys.stderr) 502 | else: 503 | print(f"Tried to load unknown file extension: {filename}") 504 | else: 505 | if hasattr(shared, 'loaded_hypernetwork'): 506 | if shared.loaded_hypernetwork is not None: 507 | print(f"Unloading hypernetwork") 508 | shared.loaded_hypernetwork = None 509 | return hypernetwork 510 | 511 | 512 | def apply_hypernetwork(hypernetwork, context, layer=None): 513 | if hypernetwork is None: 514 | return context, context 515 | if isinstance(hypernetwork, Hypernetwork): 516 | hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) 517 | if hypernetwork_layers is None: 518 | return context, context 519 | if layer is not None: 520 | layer.hyper_k = hypernetwork_layers[0] 521 | layer.hyper_v = hypernetwork_layers[1] 522 | 523 | context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context))) 524 | context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context))) 525 | return context_k, context_v 526 | context_k, context_v = hypernetwork(context, layer=layer) 527 | return context_k, context_v 528 | 529 | 530 | def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): 531 | if hypernetwork is None: 532 | return context_k, context_v 533 | if isinstance(hypernetwork, Hypernetwork): 534 | hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) 535 | if hypernetwork_layers is None: 536 | return context_k, context_v 537 | if layer is not None: 538 | layer.hyper_k = hypernetwork_layers[0] 539 | layer.hyper_v = hypernetwork_layers[1] 540 | 541 | context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k))) 542 | context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v))) 543 | return context_k, context_v 544 | context_k, context_v = hypernetwork(context_k, context_v, layer=layer) 545 | return context_k, context_v 546 | 547 | 548 | def apply_strength(value=None): 549 | HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength 550 | 551 | 552 | def apply_hypernetwork_strength(p, x, xs): 553 | apply_strength(x) 554 | 555 | 556 | modules.hypernetworks.hypernetwork.list_hypernetworks = list_hypernetworks 557 | modules.hypernetworks.hypernetwork.load_hypernetwork = load_hypernetwork 558 | if hasattr(modules.hypernetworks.hypernetwork, 'apply_hypernetwork'): 559 | modules.hypernetworks.hypernetwork.apply_hypernetwork = apply_hypernetwork 560 | else: 561 | modules.hypernetworks.hypernetwork.apply_single_hypernetwork = apply_single_hypernetwork 562 | if hasattr(modules.hypernetworks.hypernetwork, 'apply_strength'): 563 | modules.hypernetworks.hypernetwork.apply_strength = apply_strength 564 | modules.hypernetworks.hypernetwork.Hypernetwork = Hypernetwork 565 | modules.hypernetworks.hypernetwork.HypernetworkModule = HypernetworkModule 566 | try: 567 | import scripts.xy_grid 568 | if hasattr(scripts.xy_grid, 'apply_hypernetwork_strength'): 569 | scripts.xy_grid.apply_hypernetwork_strength = apply_hypernetwork_strength 570 | except (ModuleNotFoundError, ImportError): 571 | pass 572 | 573 | -------------------------------------------------------------------------------- /patches/external_pr/hypernetwork.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import gc 3 | import html 4 | import json 5 | import os 6 | import sys 7 | import time 8 | import traceback 9 | from collections import defaultdict, deque 10 | 11 | import torch 12 | import tqdm 13 | 14 | from modules import shared, sd_models, devices, processing, sd_samplers 15 | from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork 16 | from modules.textual_inversion import textual_inversion 17 | from modules.textual_inversion.learn_schedule import LearnRateScheduler 18 | from ..tbutils import tensorboard_setup, tensorboard_add, tensorboard_add_image, tensorboard_log_hyperparameter 19 | from .textual_inversion import validate_train_inputs, write_loss 20 | from ..hypernetwork import Hypernetwork, load_hypernetwork 21 | from . import sd_hijack_checkpoint 22 | from ..hnutil import optim_to 23 | from ..ui import create_hypernetwork_load 24 | from ..scheduler import CosineAnnealingWarmUpRestarts 25 | from .dataset import PersonalizedBase, PersonalizedDataLoader 26 | from ..ddpm_hijack import set_scheduler 27 | 28 | 29 | def get_lr_from_optimizer(optimizer: torch.optim.Optimizer): 30 | return optimizer.param_groups[0].get('d', 1) * optimizer.param_groups[0].get('lr', 1) 31 | 32 | 33 | def set_accessible(obj): 34 | setattr(shared, 'accessible_hypernetwork', obj) 35 | if hasattr(shared, 'loaded_hypernetworks'): 36 | shared.loaded_hypernetworks.clear() 37 | shared.loaded_hypernetworks = [obj,] 38 | 39 | 40 | def remove_accessible(): 41 | delattr(shared, 'accessible_hypernetwork') 42 | if hasattr(shared, 'loaded_hypernetworks'): 43 | shared.loaded_hypernetworks.clear() 44 | 45 | def get_training_option(filename): 46 | print(filename) 47 | if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile( 48 | os.path.join(shared.cmd_opts.hypernetwork_dir, filename)): 49 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename) 50 | elif os.path.exists(filename) and os.path.isfile(filename): 51 | filename = filename 52 | elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile( 53 | os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')): 54 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json') 55 | else: 56 | return False 57 | print(f"Loading setting from {filename}!") 58 | with open(filename, 'r') as file: 59 | obj = json.load(file) 60 | return obj 61 | 62 | 63 | def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1, use_adamw_parameter=False, use_dadaptation=False, dadapt_growth_factor=-1, **adamW_kwarg_dict): 64 | """ returns hypernetwork object binded with optimizer""" 65 | hypernetwork = load_hypernetwork(hypernetwork_name) 66 | hypernetwork.to(devices.device) 67 | assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!" 68 | if not isinstance(hypernetwork, Hypernetwork): 69 | raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") 70 | set_accessible(hypernetwork) 71 | weights = hypernetwork.weights(True) 72 | hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] 73 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') 74 | # Here we use optimizer from saved HN, or we can specify as UI option. 75 | if hypernetwork.optimizer_name == 'DAdaptAdamW': 76 | use_dadaptation = True 77 | optimizer = None 78 | optimizer_name = 'AdamW' 79 | # Here we use optimizer from saved HN, or we can specify as UI option. 80 | if hypernetwork.optimizer_name in optimizer_dict: 81 | if use_adamw_parameter: 82 | if hypernetwork.optimizer_name != 'AdamW' and hypernetwork.optimizer_name != 'DAdaptAdamW': 83 | raise NotImplementedError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!") 84 | if use_dadaptation: 85 | from .dadapt_test.install import get_dadapt_adam 86 | optim_class = get_dadapt_adam(hypernetwork.optimizer_name) 87 | if optim_class != torch.optim.AdamW: 88 | print('Optimizer class is ' + str(optim_class)) 89 | optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, **adamW_kwarg_dict) 90 | hypernetwork.optimizer_name = 'DAdaptAdamW' 91 | else: 92 | optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict) 93 | else: 94 | optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict) 95 | else: 96 | optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=learn_rate) 97 | optimizer_name = hypernetwork.optimizer_name 98 | else: 99 | print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") 100 | if use_dadaptation: 101 | from .dadapt_test.install import get_dadapt_adam 102 | optim_class = get_dadapt_adam(hypernetwork.optimizer_name) 103 | if optim_class != torch.optim.AdamW: 104 | optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, **adamW_kwarg_dict) 105 | optimizer_name = 'DAdaptAdamW' 106 | hypernetwork.optimizer_name = 'DAdaptAdamW' 107 | if optimizer is None: 108 | optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict) 109 | optimizer_name = 'AdamW' 110 | if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. 111 | try: 112 | optimizer.load_state_dict(hypernetwork.optimizer_state_dict) 113 | optim_to(optimizer, devices.device) 114 | print('Loaded optimizer successfully!') 115 | except RuntimeError as e: 116 | print("Cannot resume from saved optimizer!") 117 | print(e) 118 | 119 | return hypernetwork, optimizer, weights, optimizer_name 120 | 121 | def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, 122 | training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, 123 | create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, 124 | preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, 125 | preview_width, preview_height, 126 | use_beta_scheduler=False, beta_repeat_epoch=4000, epoch_mult=1, warmup=10, min_lr=1e-7, 127 | gamma_rate=1, save_when_converge=False, create_when_converge=False, 128 | move_optimizer=True, 129 | use_adamw_parameter=False, adamw_weight_decay=0.01, adamw_beta_1=0.9, adamw_beta_2=0.99, 130 | adamw_eps=1e-8, 131 | use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, 132 | optional_gradient_norm_type=2, latent_sampling_std=-1, 133 | noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128, 134 | load_training_options='', loss_opt='loss_simple', use_dadaptation=False, dadapt_growth_factor=-1, use_weight=False 135 | ): 136 | # images allows training previews to have infotext. Importing it at the top causes a circular import problem. 137 | from modules import images 138 | if load_training_options != '': 139 | dump: dict = get_training_option(load_training_options) 140 | if dump and dump is not None: 141 | print(f"Loading from {load_training_options}") 142 | learn_rate = dump['learn_rate'] 143 | batch_size = dump['batch_size'] 144 | gradient_step = dump['gradient_step'] 145 | training_width = dump['training_width'] 146 | training_height = dump['training_height'] 147 | steps = dump['steps'] 148 | shuffle_tags = dump['shuffle_tags'] 149 | tag_drop_out = dump['tag_drop_out'] 150 | save_when_converge = dump['save_when_converge'] 151 | create_when_converge = dump['create_when_converge'] 152 | latent_sampling_method = dump['latent_sampling_method'] 153 | template_file = dump['template_file'] 154 | use_beta_scheduler = dump['use_beta_scheduler'] 155 | beta_repeat_epoch = dump['beta_repeat_epoch'] 156 | epoch_mult = dump['epoch_mult'] 157 | warmup = dump['warmup'] 158 | min_lr = dump['min_lr'] 159 | gamma_rate = dump['gamma_rate'] 160 | use_adamw_parameter = dump['use_beta_adamW_checkbox'] 161 | adamw_weight_decay = dump['adamw_weight_decay'] 162 | adamw_beta_1 = dump['adamw_beta_1'] 163 | adamw_beta_2 = dump['adamw_beta_2'] 164 | adamw_eps = dump['adamw_eps'] 165 | use_grad_opts = dump['show_gradient_clip_checkbox'] 166 | gradient_clip_opt = dump['gradient_clip_opt'] 167 | optional_gradient_clip_value = dump['optional_gradient_clip_value'] 168 | optional_gradient_norm_type = dump['optional_gradient_norm_type'] 169 | latent_sampling_std = dump.get('latent_sampling_std', -1) 170 | noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) 171 | noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) 172 | noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) 173 | loss_opt = dump.get('loss_opt', 'loss_simple') 174 | use_dadaptation = dump.get('use_dadaptation', False) 175 | dadapt_growth_factor = dump.get('dadapt_growth_factor', -1) 176 | use_weight = dump.get('use_weight', False) 177 | try: 178 | if use_adamw_parameter: 179 | adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in 180 | [adamw_weight_decay, adamw_beta_1, 181 | adamw_beta_2, adamw_eps]] 182 | assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!" 183 | assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, 184 | adamw_eps])), "Cannot use negative or >1 number for adamW parameters!" 185 | adamW_kwarg_dict = { 186 | 'weight_decay': adamw_weight_decay, 187 | 'betas': (adamw_beta_1, adamw_beta_2), 188 | 'eps': adamw_eps 189 | } 190 | print('Using custom AdamW parameters') 191 | else: 192 | adamW_kwarg_dict = { 193 | 'weight_decay': 0.01, 194 | 'betas': (0.9, 0.99), 195 | 'eps': 1e-8 196 | } 197 | if use_beta_scheduler: 198 | print("Using Beta Scheduler") 199 | beta_repeat_epoch = int(float(beta_repeat_epoch)) 200 | assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!" 201 | min_lr = float(min_lr) 202 | assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!" 203 | gamma_rate = float(gamma_rate) 204 | print(f"Using learn rate decay(per cycle) of {gamma_rate}") 205 | assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!" 206 | epoch_mult = float(epoch_mult) 207 | assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!" 208 | warmup = int(float(warmup)) 209 | assert warmup >= 1, "Warmup epoch should be larger than 0!" 210 | print(f"Save when converges : {save_when_converge}") 211 | print(f"Generate image when converges : {create_when_converge}") 212 | else: 213 | beta_repeat_epoch = 4000 214 | epoch_mult = 1 215 | warmup = 10 216 | min_lr = 1e-7 217 | gamma_rate = 1 218 | save_when_converge = False 219 | create_when_converge = False 220 | except ValueError as e: 221 | raise RuntimeError("Cannot use advanced LR scheduler settings! "+ str(e)) 222 | if noise_training_scheduler_enabled: 223 | set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) 224 | print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!") 225 | else: 226 | set_scheduler(-1, False, False) 227 | if use_grad_opts and gradient_clip_opt != "None": 228 | try: 229 | optional_gradient_clip_value = float(optional_gradient_clip_value) 230 | except ValueError: 231 | raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})") 232 | if gradient_clip_opt == "Norm": 233 | try: 234 | grad_norm = int(float(optional_gradient_norm_type)) 235 | except ValueError: 236 | raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})") 237 | assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}" 238 | print( 239 | f"Using gradient clipping by Norm, norm type {optional_gradient_norm_type}, norm limit {optional_gradient_clip_value}") 240 | 241 | def gradient_clipping(arg1): 242 | torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type) 243 | return 244 | else: 245 | print(f"Using gradient clipping by Value, limit {optional_gradient_clip_value}") 246 | 247 | def gradient_clipping(arg1): 248 | torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value) 249 | return 250 | else: 251 | def gradient_clipping(arg1): 252 | return 253 | save_hypernetwork_every = save_hypernetwork_every or 0 254 | create_image_every = create_image_every or 0 255 | if not os.path.isfile(template_file): 256 | template_file = textual_inversion.textual_inversion_templates.get(template_file, None) 257 | if template_file is not None: 258 | template_file = template_file.path 259 | else: 260 | raise AssertionError(f"Cannot find {template_file}!") 261 | validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") 262 | shared.state.job = "train-hypernetwork" 263 | shared.state.textinfo = "Initializing hypernetwork training..." 264 | shared.state.job_count = steps 265 | tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0) 266 | hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, use_dadaptation,dadapt_growth_factor, **adamW_kwarg_dict) 267 | 268 | hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] 269 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') 270 | 271 | log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) 272 | unload = shared.opts.unload_models_when_training 273 | 274 | if save_hypernetwork_every > 0 or save_when_converge: 275 | hypernetwork_dir = os.path.join(log_directory, "hypernetworks") 276 | os.makedirs(hypernetwork_dir, exist_ok=True) 277 | else: 278 | hypernetwork_dir = None 279 | 280 | if create_image_every > 0 or create_when_converge: 281 | images_dir = os.path.join(log_directory, "images") 282 | os.makedirs(images_dir, exist_ok=True) 283 | else: 284 | images_dir = None 285 | 286 | checkpoint = sd_models.select_checkpoint() 287 | 288 | initial_step = hypernetwork.step or 0 289 | if initial_step >= steps: 290 | shared.state.textinfo = f"Model has already been trained beyond specified max steps" 291 | return hypernetwork, filename 292 | 293 | scheduler = LearnRateScheduler(learn_rate, steps, initial_step) 294 | if shared.opts.training_enable_tensorboard: 295 | print("Tensorboard logging enabled") 296 | tensorboard_writer = tensorboard_setup(log_directory) 297 | else: 298 | tensorboard_writer = None 299 | # dataset loading may take a while, so input validations and early returns should be done before this 300 | shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." 301 | detach_grad = shared.opts.disable_ema # test code that removes EMA 302 | if detach_grad: 303 | print("Disabling training for staged models!") 304 | shared.sd_model.cond_stage_model.requires_grad_(False) 305 | shared.sd_model.first_stage_model.requires_grad_(False) 306 | torch.cuda.empty_cache() 307 | pin_memory = shared.opts.pin_memory 308 | 309 | ds = PersonalizedBase(data_root=data_root, width=training_width, 310 | height=training_height, 311 | repeats=shared.opts.training_image_repeats_per_epoch, 312 | placeholder_token=hypernetwork_name, model=shared.sd_model, 313 | cond_model=shared.sd_model.cond_stage_model, 314 | device=devices.device, template_file=template_file, 315 | include_cond=True, batch_size=batch_size, 316 | gradient_step=gradient_step, shuffle_tags=shuffle_tags, 317 | tag_drop_out=tag_drop_out, 318 | latent_sampling_method=latent_sampling_method, 319 | latent_sampling_std=latent_sampling_std, 320 | use_weight=use_weight) 321 | 322 | latent_sampling_method = ds.latent_sampling_method 323 | 324 | dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, 325 | batch_size=ds.batch_size, pin_memory=pin_memory) 326 | old_parallel_processing_allowed = shared.parallel_processing_allowed 327 | 328 | if unload: 329 | shared.parallel_processing_allowed = False 330 | shared.sd_model.cond_stage_model.to(devices.cpu) 331 | shared.sd_model.first_stage_model.to(devices.cpu) 332 | 333 | if use_beta_scheduler: 334 | scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, 335 | cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, 336 | warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate) 337 | scheduler_beta.last_epoch = hypernetwork.step - 1 338 | else: 339 | scheduler_beta = None 340 | for pg in optimizer.param_groups: 341 | pg['lr'] = scheduler.learn_rate 342 | scaler = torch.cuda.amp.GradScaler() 343 | 344 | batch_size = ds.batch_size 345 | gradient_step = ds.gradient_step 346 | # n steps = batch_size * gradient_step * n image processed 347 | steps_per_epoch = len(ds) // batch_size // gradient_step 348 | max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step 349 | loss_step = 0 350 | _loss_step = 0 # internal 351 | # size = len(ds.indexes) 352 | loss_dict = defaultdict(lambda: deque(maxlen=1024)) 353 | # losses = torch.zeros((size,)) 354 | # previous_mean_losses = [0] 355 | # previous_mean_loss = 0 356 | # print("Mean loss of {} elements".format(size)) 357 | 358 | steps_without_grad = 0 359 | 360 | last_saved_file = "" 361 | last_saved_image = "" 362 | forced_filename = "" 363 | if hasattr(sd_hijack_checkpoint, 'add'): 364 | sd_hijack_checkpoint.add() 365 | pbar = tqdm.tqdm(total=steps - initial_step) 366 | try: 367 | for i in range((steps - initial_step) * gradient_step): 368 | if scheduler.finished or hypernetwork.step > steps: 369 | break 370 | if shared.state.interrupted: 371 | break 372 | for j, batch in enumerate(dl): 373 | # works as a drop_last=True for gradient accumulation 374 | if j == max_steps_per_epoch: 375 | break 376 | if use_beta_scheduler: 377 | scheduler_beta.step(hypernetwork.step) 378 | else: 379 | scheduler.apply(optimizer, hypernetwork.step) 380 | if scheduler.finished: 381 | break 382 | if shared.state.interrupted: 383 | break 384 | 385 | with torch.autocast("cuda"): 386 | x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) 387 | if use_weight: 388 | w = batch.weight.to(devices.device, non_blocking=pin_memory) 389 | if tag_drop_out != 0 or shuffle_tags: 390 | shared.sd_model.cond_stage_model.to(devices.device) 391 | c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, 392 | non_blocking=pin_memory) 393 | shared.sd_model.cond_stage_model.to(devices.cpu) 394 | else: 395 | c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) 396 | if use_weight: 397 | loss = shared.sd_model.weighted_forward(x, c, w)[0] 398 | else: 399 | _, losses = shared.sd_model.forward(x, c) 400 | loss = losses['val/' + loss_opt] 401 | for filenames in batch.filename: 402 | loss_dict[filenames].append(loss.detach().item()) 403 | loss /= gradient_step 404 | assert not torch.isnan(loss), "Loss is NaN" 405 | del x 406 | del c 407 | 408 | _loss_step += loss.item() 409 | scaler.scale(loss).backward() 410 | batch.latent_sample.to(devices.cpu) 411 | # go back until we reach gradient accumulation steps 412 | if (j + 1) % gradient_step != 0: 413 | continue 414 | gradient_clipping(weights) 415 | # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") 416 | # scaler.unscale_(optimizer) 417 | # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") 418 | # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) 419 | # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") 420 | try: 421 | scaler.step(optimizer) 422 | except AssertionError: 423 | optimizer.param_groups[0]['capturable'] = True 424 | scaler.step(optimizer) 425 | scaler.update() 426 | hypernetwork.step += 1 427 | pbar.update() 428 | optimizer.zero_grad(set_to_none=True) 429 | loss_step = _loss_step 430 | _loss_step = 0 431 | 432 | steps_done = hypernetwork.step + 1 433 | 434 | epoch_num = hypernetwork.step // steps_per_epoch 435 | epoch_step = hypernetwork.step % steps_per_epoch 436 | 437 | description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}" 438 | pbar.set_description(description) 439 | if hypernetwork_dir is not None and ( 440 | (use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and save_when_converge) or ( 441 | save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)): 442 | # Before saving, change name to match current checkpoint. 443 | hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' 444 | last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') 445 | hypernetwork.optimizer_name = optimizer_name 446 | if shared.opts.save_optimizer_state: 447 | hypernetwork.optimizer_state_dict = optimizer.state_dict() 448 | save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) 449 | hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. 450 | 451 | write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, 452 | { 453 | "loss": f"{loss_step:.7f}", 454 | "learn_rate": get_lr_from_optimizer(optimizer) 455 | }) 456 | if shared.opts.training_enable_tensorboard: 457 | epoch_num = hypernetwork.step // len(ds) 458 | epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 459 | mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) 460 | tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, 461 | learn_rate=scheduler.learn_rate if not use_beta_scheduler else 462 | get_lr_from_optimizer(optimizer), epoch_num=epoch_num) 463 | if images_dir is not None and ( 464 | use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( 465 | create_image_every > 0 and steps_done % create_image_every == 0): 466 | set_scheduler(-1, False, False) 467 | forced_filename = f'{hypernetwork_name}-{steps_done}' 468 | last_saved_image = os.path.join(images_dir, forced_filename) 469 | rng_state = torch.get_rng_state() 470 | cuda_rng_state = None 471 | if torch.cuda.is_available(): 472 | cuda_rng_state = torch.cuda.get_rng_state_all() 473 | hypernetwork.eval() 474 | if move_optimizer: 475 | optim_to(optimizer, devices.cpu) 476 | gc.collect() 477 | shared.sd_model.cond_stage_model.to(devices.device) 478 | shared.sd_model.first_stage_model.to(devices.device) 479 | 480 | p = processing.StableDiffusionProcessingTxt2Img( 481 | sd_model=shared.sd_model, 482 | do_not_save_grid=True, 483 | do_not_save_samples=True, 484 | ) 485 | if hasattr(p, 'disable_extra_networks'): 486 | p.disable_extra_networks = True 487 | is_patched = True 488 | else: 489 | is_patched = False 490 | if preview_from_txt2img: 491 | p.prompt = preview_prompt + (hypernetwork.extra_name() if not is_patched else "") 492 | print(p.prompt) 493 | p.negative_prompt = preview_negative_prompt 494 | p.steps = preview_steps 495 | p.sampler_name = sd_samplers.samplers[preview_sampler_index].name 496 | p.cfg_scale = preview_cfg_scale 497 | p.seed = preview_seed 498 | p.width = preview_width 499 | p.height = preview_height 500 | else: 501 | p.prompt = batch.cond_text[0] + (hypernetwork.extra_name() if not is_patched else "") 502 | p.steps = 20 503 | p.width = training_width 504 | p.height = training_height 505 | 506 | preview_text = p.prompt 507 | 508 | processed = processing.process_images(p) 509 | image = processed.images[0] if len(processed.images) > 0 else None 510 | if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: 511 | tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, 512 | hypernetwork.step) 513 | 514 | if unload: 515 | shared.sd_model.cond_stage_model.to(devices.cpu) 516 | shared.sd_model.first_stage_model.to(devices.cpu) 517 | torch.set_rng_state(rng_state) 518 | if torch.cuda.is_available(): 519 | torch.cuda.set_rng_state_all(cuda_rng_state) 520 | hypernetwork.train() 521 | if move_optimizer: 522 | optim_to(optimizer, devices.device) 523 | if noise_training_scheduler_enabled: 524 | set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) 525 | if image is not None: 526 | if hasattr(shared.state, 'assign_current_image'): 527 | shared.state.assign_current_image(image) 528 | else: 529 | shared.state.current_image = image 530 | last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, 531 | shared.opts.samples_format, 532 | processed.infotexts[0], p=p, 533 | forced_filename=forced_filename, 534 | save_to_dirs=False) 535 | last_saved_image += f", prompt: {preview_text}" 536 | set_accessible(hypernetwork) 537 | 538 | shared.state.job_no = hypernetwork.step 539 | 540 | shared.state.textinfo = f""" 541 |

542 | Loss: {loss_step:.7f}
543 | Step: {steps_done}
544 | Last prompt: {html.escape(batch.cond_text[0])}
545 | Last saved hypernetwork: {html.escape(last_saved_file)}
546 | Last saved image: {html.escape(last_saved_image)}
547 |

548 | """ 549 | except Exception: 550 | print(traceback.format_exc(), file=sys.stderr) 551 | finally: 552 | pbar.leave = False 553 | pbar.close() 554 | if hypernetwork is not None: 555 | hypernetwork.eval() 556 | shared.parallel_processing_allowed = old_parallel_processing_allowed 557 | if hasattr(sd_hijack_checkpoint, 'remove'): 558 | sd_hijack_checkpoint.remove() 559 | set_scheduler(-1, False, False) 560 | remove_accessible() 561 | gc.collect() 562 | torch.cuda.empty_cache() 563 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') 564 | hypernetwork.optimizer_name = optimizer_name 565 | if shared.opts.save_optimizer_state: 566 | hypernetwork.optimizer_state_dict = optimizer.state_dict() 567 | save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) 568 | del optimizer 569 | hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. 570 | shared.sd_model.cond_stage_model.to(devices.device) 571 | shared.sd_model.first_stage_model.to(devices.device) 572 | 573 | return hypernetwork, filename 574 | 575 | 576 | def internal_clean_training(hypernetwork_name, data_root, log_directory, 577 | create_image_every, save_hypernetwork_every, 578 | preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, 579 | preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height, 580 | move_optimizer=True, 581 | load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1, 582 | setting_tuple=None): 583 | # images allows training previews to have infotext. Importing it at the top causes a circular import problem. 584 | from modules import images 585 | base_hypernetwork_name = hypernetwork_name 586 | manual_seed = int(manual_dataset_seed) 587 | if setting_tuple is not None: 588 | setting_suffix = f"_{setting_tuple[0]}_{setting_tuple[1]}" 589 | else: 590 | setting_suffix = time.strftime('%Y%m%d%H%M%S') 591 | if load_hypernetworks_option != '': 592 | dump_hyper: dict = get_training_option(load_hypernetworks_option) 593 | hypernetwork_name = hypernetwork_name + setting_suffix 594 | enable_sizes = dump_hyper['enable_sizes'] 595 | overwrite_old = dump_hyper['overwrite_old'] 596 | layer_structure = dump_hyper['layer_structure'] 597 | activation_func = dump_hyper['activation_func'] 598 | weight_init = dump_hyper['weight_init'] 599 | add_layer_norm = dump_hyper['add_layer_norm'] 600 | use_dropout = dump_hyper['use_dropout'] 601 | dropout_structure = dump_hyper['dropout_structure'] 602 | optional_info = dump_hyper['optional_info'] 603 | weight_init_seed = dump_hyper['weight_init_seed'] 604 | normal_std = dump_hyper['normal_std'] 605 | skip_connection = dump_hyper['skip_connection'] 606 | hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure, 607 | activation_func, weight_init, add_layer_norm, use_dropout, 608 | dropout_structure, optional_info, weight_init_seed, normal_std, 609 | skip_connection) 610 | else: 611 | hypernetwork = load_hypernetwork(hypernetwork_name) 612 | hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix 613 | hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt")) 614 | shared.reload_hypernetworks() 615 | hypernetwork = load_hypernetwork(hypernetwork_name) 616 | if load_training_options != '': 617 | dump: dict = get_training_option(load_training_options) 618 | if dump and dump is not None: 619 | learn_rate = dump['learn_rate'] 620 | batch_size = dump['batch_size'] 621 | gradient_step = dump['gradient_step'] 622 | training_width = dump['training_width'] 623 | training_height = dump['training_height'] 624 | steps = dump['steps'] 625 | shuffle_tags = dump['shuffle_tags'] 626 | tag_drop_out = dump['tag_drop_out'] 627 | save_when_converge = dump['save_when_converge'] 628 | create_when_converge = dump['create_when_converge'] 629 | latent_sampling_method = dump['latent_sampling_method'] 630 | template_file = dump['template_file'] 631 | use_beta_scheduler = dump['use_beta_scheduler'] 632 | beta_repeat_epoch = dump['beta_repeat_epoch'] 633 | epoch_mult = dump['epoch_mult'] 634 | warmup = dump['warmup'] 635 | min_lr = dump['min_lr'] 636 | gamma_rate = dump['gamma_rate'] 637 | use_adamw_parameter = dump['use_beta_adamW_checkbox'] 638 | adamw_weight_decay = dump['adamw_weight_decay'] 639 | adamw_beta_1 = dump['adamw_beta_1'] 640 | adamw_beta_2 = dump['adamw_beta_2'] 641 | adamw_eps = dump['adamw_eps'] 642 | use_grad_opts = dump['show_gradient_clip_checkbox'] 643 | gradient_clip_opt = dump['gradient_clip_opt'] 644 | optional_gradient_clip_value = dump['optional_gradient_clip_value'] 645 | optional_gradient_norm_type = dump['optional_gradient_norm_type'] 646 | latent_sampling_std = dump.get('latent_sampling_std', -1) 647 | noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) 648 | noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) 649 | noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) 650 | loss_opt = dump.get('loss_opt', 'loss_simple') 651 | use_dadaptation = dump.get('use_dadaptation', False) 652 | dadapt_growth_factor = dump.get('dadapt_growth_factor', -1) 653 | use_weight = dump.get('use_weight', False) 654 | else: 655 | raise RuntimeError(f"Cannot load from {load_training_options}!") 656 | else: 657 | raise RuntimeError(f"Cannot load from {load_training_options}!") 658 | try: 659 | if use_adamw_parameter: 660 | adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in 661 | [adamw_weight_decay, adamw_beta_1, 662 | adamw_beta_2, adamw_eps]] 663 | assert 0 <= adamw_weight_decay, "Weight decay paramter should be larger or equal than zero!" 664 | assert (all(0 <= x <= 1 for x in [adamw_beta_1, adamw_beta_2, 665 | adamw_eps])), "Cannot use negative or >1 number for adamW parameters!" 666 | adamW_kwarg_dict = { 667 | 'weight_decay': adamw_weight_decay, 668 | 'betas': (adamw_beta_1, adamw_beta_2), 669 | 'eps': adamw_eps 670 | } 671 | print('Using custom AdamW parameters') 672 | else: 673 | adamW_kwarg_dict = { 674 | 'weight_decay': 0.01, 675 | 'betas': (0.9, 0.99), 676 | 'eps': 1e-8 677 | } 678 | if use_beta_scheduler: 679 | print("Using Beta Scheduler") 680 | beta_repeat_epoch = int(float(beta_repeat_epoch)) 681 | assert beta_repeat_epoch > 0, f"Cannot use too small cycle {beta_repeat_epoch}!" 682 | min_lr = float(min_lr) 683 | assert min_lr < 1, f"Cannot use minimum lr with {min_lr}!" 684 | gamma_rate = float(gamma_rate) 685 | print(f"Using learn rate decay(per cycle) of {gamma_rate}") 686 | assert 0 <= gamma_rate <= 1, f"Cannot use gamma rate with {gamma_rate}!" 687 | epoch_mult = float(epoch_mult) 688 | assert 1 <= epoch_mult, "Cannot use epoch multiplier smaller than 1!" 689 | warmup = int(float(warmup)) 690 | assert warmup >= 1, "Warmup epoch should be larger than 0!" 691 | print(f"Save when converges : {save_when_converge}") 692 | print(f"Generate image when converges : {create_when_converge}") 693 | else: 694 | beta_repeat_epoch = 4000 695 | epoch_mult = 1 696 | warmup = 10 697 | min_lr = 1e-7 698 | gamma_rate = 1 699 | save_when_converge = False 700 | create_when_converge = False 701 | except ValueError: 702 | raise RuntimeError("Cannot use advanced LR scheduler settings!") 703 | if use_grad_opts and gradient_clip_opt != "None": 704 | try: 705 | optional_gradient_clip_value = float(optional_gradient_clip_value) 706 | except ValueError: 707 | raise RuntimeError(f"Cannot convert invalid gradient clipping value {optional_gradient_clip_value})") 708 | if gradient_clip_opt == "Norm": 709 | try: 710 | grad_norm = int(float(optional_gradient_norm_type)) 711 | except ValueError: 712 | raise RuntimeError(f"Cannot convert invalid gradient norm type {optional_gradient_norm_type})") 713 | assert grad_norm >= 0, f"P-norm cannot be calculated from negative number {grad_norm}" 714 | print( 715 | f"Using gradient clipping by Norm, norm type {optional_gradient_norm_type}, norm limit {optional_gradient_clip_value}") 716 | 717 | def gradient_clipping(arg1): 718 | torch.nn.utils.clip_grad_norm_(arg1, optional_gradient_clip_value, optional_gradient_norm_type) 719 | return 720 | else: 721 | print(f"Using gradient clipping by Value, limit {optional_gradient_clip_value}") 722 | 723 | def gradient_clipping(arg1): 724 | torch.nn.utils.clip_grad_value_(arg1, optional_gradient_clip_value) 725 | return 726 | else: 727 | def gradient_clipping(arg1): 728 | return 729 | if noise_training_scheduler_enabled: 730 | set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) 731 | print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!") 732 | else: 733 | set_scheduler(-1, False, False) 734 | save_hypernetwork_every = save_hypernetwork_every or 0 735 | create_image_every = create_image_every or 0 736 | if not os.path.isfile(template_file): 737 | template_file = textual_inversion.textual_inversion_templates.get(template_file, None) 738 | if template_file is not None: 739 | template_file = template_file.path 740 | else: 741 | raise AssertionError(f"Cannot find {template_file}!") 742 | validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") 743 | hypernetwork.to(devices.device) 744 | assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!" 745 | if not isinstance(hypernetwork, Hypernetwork): 746 | raise RuntimeError("Cannot perform training for Hypernetwork structure pipeline!") 747 | set_accessible(hypernetwork) 748 | shared.state.job = "train-hypernetwork" 749 | shared.state.textinfo = "Initializing hypernetwork training..." 750 | shared.state.job_count = steps 751 | 752 | hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] 753 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') 754 | base_log_directory = log_directory 755 | log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) 756 | unload = shared.opts.unload_models_when_training 757 | 758 | if save_hypernetwork_every > 0 or save_when_converge: 759 | hypernetwork_dir = os.path.join(log_directory, "hypernetworks") 760 | os.makedirs(hypernetwork_dir, exist_ok=True) 761 | else: 762 | hypernetwork_dir = None 763 | 764 | if create_image_every > 0 or create_when_converge: 765 | images_dir = os.path.join(log_directory, "images") 766 | os.makedirs(images_dir, exist_ok=True) 767 | else: 768 | images_dir = None 769 | 770 | checkpoint = sd_models.select_checkpoint() 771 | 772 | initial_step = hypernetwork.step or 0 773 | if initial_step >= steps: 774 | shared.state.textinfo = f"Model has already been trained beyond specified max steps" 775 | return hypernetwork, filename 776 | 777 | scheduler = LearnRateScheduler(learn_rate, steps, initial_step) 778 | if shared.opts.training_enable_tensorboard: 779 | print("Tensorboard logging enabled") 780 | tensorboard_writer = tensorboard_setup(os.path.join(base_log_directory, base_hypernetwork_name)) 781 | 782 | else: 783 | tensorboard_writer = None 784 | # dataset loading may take a while, so input validations and early returns should be done before this 785 | shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." 786 | detach_grad = shared.opts.disable_ema # test code that removes EMA 787 | if detach_grad: 788 | print("Disabling training for staged models!") 789 | shared.sd_model.cond_stage_model.requires_grad_(False) 790 | shared.sd_model.first_stage_model.requires_grad_(False) 791 | torch.cuda.empty_cache() 792 | pin_memory = shared.opts.pin_memory 793 | ds = PersonalizedBase(data_root=data_root, width=training_width, 794 | height=training_height, 795 | repeats=shared.opts.training_image_repeats_per_epoch, 796 | placeholder_token=hypernetwork_name, model=shared.sd_model, 797 | cond_model=shared.sd_model.cond_stage_model, 798 | device=devices.device, template_file=template_file, 799 | include_cond=True, batch_size=batch_size, 800 | gradient_step=gradient_step, shuffle_tags=shuffle_tags, 801 | tag_drop_out=tag_drop_out, 802 | latent_sampling_method=latent_sampling_method, 803 | latent_sampling_std=latent_sampling_std, 804 | manual_seed=manual_seed, 805 | use_weight=use_weight) 806 | 807 | latent_sampling_method = ds.latent_sampling_method 808 | 809 | dl = PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, 810 | batch_size=ds.batch_size, pin_memory=pin_memory) 811 | old_parallel_processing_allowed = shared.parallel_processing_allowed 812 | 813 | if unload: 814 | shared.parallel_processing_allowed = False 815 | shared.sd_model.cond_stage_model.to(devices.cpu) 816 | shared.sd_model.first_stage_model.to(devices.cpu) 817 | 818 | weights = hypernetwork.weights(True) 819 | optimizer_name = hypernetwork.optimizer_name 820 | if hypernetwork.optimizer_name == 'DAdaptAdamW': 821 | use_dadaptation = True 822 | optimizer = None 823 | # Here we use optimizer from saved HN, or we can specify as UI option. 824 | if hypernetwork.optimizer_name in optimizer_dict: 825 | if use_adamw_parameter: 826 | if hypernetwork.optimizer_name != 'AdamW' and hypernetwork.optimizer_name != 'DAdaptAdamW': 827 | raise RuntimeError(f"Cannot use adamW paramters for optimizer {hypernetwork.optimizer_name}!") 828 | if use_dadaptation: 829 | from .dadapt_test.install import get_dadapt_adam 830 | optim_class = get_dadapt_adam(hypernetwork.optimizer_name) 831 | if optim_class != torch.optim.AdamW: 832 | optimizer = optim_class(params=weights, lr=scheduler.learn_rate, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, decouple=True, **adamW_kwarg_dict) 833 | else: 834 | optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict) 835 | else: 836 | optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict) 837 | else: 838 | optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) 839 | optimizer_name = hypernetwork.optimizer_name 840 | else: 841 | print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") 842 | if use_dadaptation: 843 | from .dadapt_test.install import get_dadapt_adam 844 | optim_class = get_dadapt_adam(hypernetwork.optimizer_name) 845 | if optim_class != torch.optim.AdamW: 846 | optimizer = optim_class(params=weights, lr=scheduler.learn_rate, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, decouple=True, **adamW_kwarg_dict) 847 | optimizer_name = 'DAdaptAdamW' 848 | if optimizer is None: 849 | optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict) 850 | optimizer_name = 'AdamW' 851 | 852 | 853 | 854 | if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. 855 | try: 856 | optimizer.load_state_dict(hypernetwork.optimizer_state_dict) 857 | except RuntimeError as e: 858 | print("Cannot resume from saved optimizer!") 859 | print(e) 860 | optim_to(optimizer, devices.device) 861 | if use_beta_scheduler: 862 | scheduler_beta = CosineAnnealingWarmUpRestarts(optimizer=optimizer, first_cycle_steps=beta_repeat_epoch, 863 | cycle_mult=epoch_mult, max_lr=scheduler.learn_rate, 864 | warmup_steps=warmup, min_lr=min_lr, gamma=gamma_rate) 865 | scheduler_beta.last_epoch = hypernetwork.step - 1 866 | else: 867 | scheduler_beta = None 868 | for pg in optimizer.param_groups: 869 | pg['lr'] = scheduler.learn_rate 870 | scaler = torch.cuda.amp.GradScaler() 871 | 872 | batch_size = ds.batch_size 873 | gradient_step = ds.gradient_step 874 | # n steps = batch_size * gradient_step * n image processed 875 | steps_per_epoch = len(ds) // batch_size // gradient_step 876 | max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step 877 | loss_step = 0 878 | _loss_step = 0 # internal 879 | # size = len(ds.indexes) 880 | loss_dict = defaultdict(lambda: deque(maxlen=1024)) 881 | # losses = torch.zeros((size,)) 882 | # previous_mean_losses = [0] 883 | # previous_mean_loss = 0 884 | # print("Mean loss of {} elements".format(size)) 885 | 886 | steps_without_grad = 0 887 | 888 | last_saved_file = "" 889 | last_saved_image = "" 890 | forced_filename = "" 891 | if hasattr(sd_hijack_checkpoint, 'add'): 892 | sd_hijack_checkpoint.add() 893 | pbar = tqdm.tqdm(total=steps - initial_step) 894 | try: 895 | for i in range((steps - initial_step) * gradient_step): 896 | if scheduler.finished or hypernetwork.step > steps: 897 | break 898 | if shared.state.interrupted: 899 | break 900 | for j, batch in enumerate(dl): 901 | # works as a drop_last=True for gradient accumulation 902 | if j == max_steps_per_epoch: 903 | break 904 | if use_beta_scheduler: 905 | scheduler_beta.step(hypernetwork.step) 906 | else: 907 | scheduler.apply(optimizer, hypernetwork.step) 908 | if scheduler.finished: 909 | break 910 | if shared.state.interrupted: 911 | break 912 | 913 | with torch.autocast("cuda"): 914 | x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) 915 | if use_weight: 916 | w = batch.weight.to(devices.device, non_blocking=pin_memory) 917 | if tag_drop_out != 0 or shuffle_tags: 918 | shared.sd_model.cond_stage_model.to(devices.device) 919 | c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, 920 | non_blocking=pin_memory) 921 | shared.sd_model.cond_stage_model.to(devices.cpu) 922 | else: 923 | c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) 924 | if use_weight: 925 | loss = shared.sd_model.weighted_forward(x, c, w)[0] 926 | else: 927 | _, losses = shared.sd_model.forward(x, c) 928 | loss = losses['val/' + loss_opt] 929 | for filenames in batch.filename: 930 | loss_dict[filenames].append(loss.detach().item()) 931 | loss /= gradient_step 932 | del x 933 | del c 934 | 935 | _loss_step += loss.item() 936 | scaler.scale(loss).backward() 937 | batch.latent_sample.to(devices.cpu) 938 | # go back until we reach gradient accumulation steps 939 | if (j + 1) % gradient_step != 0: 940 | continue 941 | gradient_clipping(weights) 942 | # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") 943 | # scaler.unscale_(optimizer) 944 | # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") 945 | # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) 946 | # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") 947 | try: 948 | scaler.step(optimizer) 949 | except AssertionError: 950 | optimizer.param_groups[0]['capturable'] = True 951 | scaler.step(optimizer) 952 | scaler.update() 953 | hypernetwork.step += 1 954 | pbar.update() 955 | optimizer.zero_grad(set_to_none=True) 956 | loss_step = _loss_step 957 | _loss_step = 0 958 | 959 | steps_done = hypernetwork.step + 1 960 | 961 | epoch_num = hypernetwork.step // steps_per_epoch 962 | epoch_step = hypernetwork.step % steps_per_epoch 963 | 964 | description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step + 1}/{steps_per_epoch}]loss: {loss_step:.7f}" 965 | pbar.set_description(description) 966 | if hypernetwork_dir is not None and ( 967 | (use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and save_when_converge) or ( 968 | save_hypernetwork_every > 0 and steps_done % save_hypernetwork_every == 0)): 969 | # Before saving, change name to match current checkpoint. 970 | hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' 971 | last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') 972 | hypernetwork.optimizer_name = optimizer_name 973 | if shared.opts.save_optimizer_state: 974 | hypernetwork.optimizer_state_dict = optimizer.state_dict() 975 | save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) 976 | hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. 977 | 978 | write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, 979 | { 980 | "loss": f"{loss_step:.7f}", 981 | "learn_rate": get_lr_from_optimizer(optimizer) 982 | }) 983 | if shared.opts.training_enable_tensorboard: 984 | epoch_num = hypernetwork.step // len(ds) 985 | epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 986 | mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) 987 | tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, 988 | learn_rate=scheduler.learn_rate if not use_beta_scheduler else 989 | get_lr_from_optimizer(optimizer), epoch_num=epoch_num, base_name=hypernetwork_name) 990 | if images_dir is not None and ( 991 | use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( 992 | create_image_every > 0 and steps_done % create_image_every == 0): 993 | set_scheduler(-1, False, False) 994 | forced_filename = f'{hypernetwork_name}-{steps_done}' 995 | last_saved_image = os.path.join(images_dir, forced_filename) 996 | rng_state = torch.get_rng_state() 997 | cuda_rng_state = None 998 | if torch.cuda.is_available(): 999 | cuda_rng_state = torch.cuda.get_rng_state_all() 1000 | hypernetwork.eval() 1001 | if move_optimizer: 1002 | optim_to(optimizer, devices.cpu) 1003 | shared.sd_model.cond_stage_model.to(devices.device) 1004 | shared.sd_model.first_stage_model.to(devices.device) 1005 | 1006 | p = processing.StableDiffusionProcessingTxt2Img( 1007 | sd_model=shared.sd_model, 1008 | do_not_save_grid=True, 1009 | do_not_save_samples=True, 1010 | ) 1011 | if hasattr(p, 'disable_extra_networks'): 1012 | p.disable_extra_networks = True 1013 | is_patched = True 1014 | else: 1015 | is_patched = False 1016 | if preview_from_txt2img: 1017 | p.prompt = preview_prompt + (hypernetwork.extra_name() if not is_patched else "") 1018 | p.negative_prompt = preview_negative_prompt 1019 | p.steps = preview_steps 1020 | p.sampler_name = sd_samplers.samplers[preview_sampler_index].name 1021 | p.cfg_scale = preview_cfg_scale 1022 | p.seed = preview_seed 1023 | p.width = preview_width 1024 | p.height = preview_height 1025 | else: 1026 | p.prompt = batch.cond_text[0] + (hypernetwork.extra_name() if not is_patched else "") 1027 | p.steps = 20 1028 | p.width = training_width 1029 | p.height = training_height 1030 | 1031 | preview_text = p.prompt 1032 | 1033 | processed = processing.process_images(p) 1034 | image = processed.images[0] if len(processed.images) > 0 else None 1035 | if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: 1036 | tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, 1037 | hypernetwork.step, base_name=hypernetwork_name) 1038 | 1039 | if unload: 1040 | shared.sd_model.cond_stage_model.to(devices.cpu) 1041 | shared.sd_model.first_stage_model.to(devices.cpu) 1042 | torch.set_rng_state(rng_state) 1043 | if torch.cuda.is_available(): 1044 | torch.cuda.set_rng_state_all(cuda_rng_state) 1045 | hypernetwork.train() 1046 | if move_optimizer: 1047 | optim_to(optimizer, devices.device) 1048 | if noise_training_scheduler_enabled: 1049 | set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat, True) 1050 | if image is not None: 1051 | if hasattr(shared.state, 'assign_current_image'): 1052 | shared.state.assign_current_image(image) 1053 | else: 1054 | shared.state.current_image = image 1055 | last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, 1056 | shared.opts.samples_format, 1057 | processed.infotexts[0], p=p, 1058 | forced_filename=forced_filename, 1059 | save_to_dirs=False) 1060 | last_saved_image += f", prompt: {preview_text}" 1061 | set_accessible(hypernetwork) 1062 | 1063 | shared.state.job_no = hypernetwork.step 1064 | 1065 | shared.state.textinfo = f""" 1066 |

1067 | Loss: {loss_step:.7f}
1068 | Step: {steps_done}
1069 | Last prompt: {html.escape(batch.cond_text[0])}
1070 | Last saved hypernetwork: {html.escape(last_saved_file)}
1071 | Last saved image: {html.escape(last_saved_image)}
1072 |

1073 | """ 1074 | except Exception: 1075 | if pbar is not None: 1076 | pbar.set_description(traceback.format_exc()) 1077 | shared.state.textinfo = traceback.format_exc() 1078 | print(traceback.format_exc(), file=sys.stderr) 1079 | finally: 1080 | pbar.leave = False 1081 | pbar.close() 1082 | hypernetwork.eval() 1083 | set_scheduler(-1, False, False) 1084 | shared.parallel_processing_allowed = old_parallel_processing_allowed 1085 | remove_accessible() 1086 | if hasattr(sd_hijack_checkpoint, 'remove'): 1087 | sd_hijack_checkpoint.remove() 1088 | if shared.opts.training_enable_tensorboard: 1089 | mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) if sum(len(x) for x in loss_dict.values()) > 0 else 0 1090 | tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate, 1091 | GA_steps=gradient_step, 1092 | batch_size=batch_size, 1093 | layer_structure=hypernetwork.layer_structure, 1094 | activation=hypernetwork.activation_func, 1095 | weight_init=hypernetwork.weight_init, 1096 | dropout_structure=hypernetwork.dropout_structure, 1097 | max_steps=steps, 1098 | latent_sampling_method=latent_sampling_method, 1099 | template=template_file, 1100 | CosineAnnealing=use_beta_scheduler, 1101 | beta_repeat_epoch=beta_repeat_epoch, 1102 | epoch_mult=epoch_mult, 1103 | warmup=warmup, 1104 | min_lr=min_lr, 1105 | gamma_rate=gamma_rate, 1106 | adamW_opts=use_adamw_parameter, 1107 | adamW_decay=adamw_weight_decay, 1108 | adamW_beta_1=adamw_beta_1, 1109 | adamW_beta_2=adamw_beta_2, 1110 | adamW_eps=adamw_eps, 1111 | gradient_clip=gradient_clip_opt, 1112 | gradient_clip_value=optional_gradient_clip_value, 1113 | gradient_clip_norm_type=optional_gradient_norm_type, 1114 | loss=mean_loss, 1115 | base_hypernetwork_name=hypernetwork_name 1116 | ) 1117 | filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') 1118 | hypernetwork.optimizer_name = optimizer_name 1119 | if shared.opts.save_optimizer_state: 1120 | hypernetwork.optimizer_state_dict = optimizer.state_dict() 1121 | save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) 1122 | del optimizer 1123 | hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. 1124 | shared.sd_model.cond_stage_model.to(devices.device) 1125 | shared.sd_model.first_stage_model.to(devices.device) 1126 | gc.collect() 1127 | torch.cuda.empty_cache() 1128 | return hypernetwork, filename 1129 | 1130 | 1131 | def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directory, 1132 | create_image_every, save_hypernetwork_every, preview_from_txt2img, preview_prompt, 1133 | preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, 1134 | preview_seed, 1135 | preview_width, preview_height, 1136 | move_optimizer=True, 1137 | optional_new_hypernetwork_name='', load_hypernetworks_options='', 1138 | load_training_options='', manual_dataset_seed=-1): 1139 | load_hypernetworks_options = load_hypernetworks_options.split(',') 1140 | load_training_options = load_training_options.split(',') 1141 | # images allows training previews to have infotext. Importing it at the top causes a circular import problem. 1142 | for _i, load_hypernetworks_option in enumerate(load_hypernetworks_options): 1143 | load_hypernetworks_option = load_hypernetworks_option.strip(' ') 1144 | if load_hypernetworks_option != '' and get_training_option(load_hypernetworks_option) is False: 1145 | print(f"Cannot load from {load_hypernetworks_option}!") 1146 | continue 1147 | for _j, load_training_option in enumerate(load_training_options): 1148 | load_training_option = load_training_option.strip(' ') 1149 | if get_training_option(load_training_option) is False: 1150 | print(f"Cannot load from {load_training_option}!") 1151 | continue 1152 | internal_clean_training( 1153 | hypernetwork_name if load_hypernetworks_option == '' else optional_new_hypernetwork_name, 1154 | data_root, 1155 | log_directory, 1156 | create_image_every, 1157 | save_hypernetwork_every, 1158 | preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, 1159 | preview_cfg_scale, preview_seed, preview_width, preview_height, 1160 | move_optimizer, 1161 | load_hypernetworks_option, load_training_option, manual_dataset_seed, setting_tuple=(_i, _j)) 1162 | if shared.state.interrupted: 1163 | return None, None 1164 | --------------------------------------------------------------------------------