├── README.MD ├── arguments.py ├── configs ├── cogvideox_2b.yaml ├── cogvideox_2b_lora.yaml ├── cogvideox_5b.yaml ├── cogvideox_5b_lora.yaml ├── inference │ ├── inference_single_identity.yaml │ └── inference_two_identities.yaml └── training │ ├── sft_single_identity.yaml │ └── sft_two_identities.yaml ├── data_video.py ├── diffusion_video.py ├── dit_video_concat.py ├── examples ├── cropped_images │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ └── 6.png ├── images │ ├── 3_stars_woman_Taylor_Swift_3.png │ ├── 43_stars_man_Leonardo_DiCaprio_3.png │ ├── 69_politicians_woman_Tulsi_Gabbard_4.png │ ├── 72_politicians_woman_Tulsi_Gabbard_2.png │ ├── 73_politicians_woman_Harris_3.png │ ├── 80_normal_man_5.jpg │ └── 93_normal_woman_3.jpg ├── results │ ├── 1.gif │ ├── 2.gif │ ├── 3.gif │ ├── 4.gif │ ├── 5.gif │ └── 6.gif ├── single_identity.txt └── two_identities.txt ├── finetune_single_identity.sh ├── finetune_two_identities.sh ├── inference_single_identity.sh ├── inference_two_identities.sh ├── inference_wan.py ├── pipelines └── wan_video.py ├── requirements.txt ├── sample_video.py ├── sgm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── util.cpython-310.pyc │ └── webds.cpython-310.pyc ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── autoencoder.cpython-310.pyc │ └── autoencoder.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── attention.cpython-310.pyc │ │ ├── cp_enc_dec.cpython-310.pyc │ │ ├── ema.cpython-310.pyc │ │ └── video_attention.cpython-310.pyc │ ├── attention.py │ ├── autoencoding │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── temporal_ae.cpython-310.pyc │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── discriminator_loss.py │ │ │ ├── lpips.py │ │ │ └── video_loss.py │ │ ├── lpips │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── util.cpython-310.pyc │ │ │ ├── loss │ │ │ │ ├── .gitignore │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ │ └── lpips.cpython-310.pyc │ │ │ │ └── lpips.py │ │ │ ├── model │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── model.py │ │ │ ├── util.py │ │ │ └── vqperceptual.py │ │ ├── magvit2_pytorch.py │ │ ├── regularizers │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── base.cpython-310.pyc │ │ │ ├── base.py │ │ │ ├── finite_scalar_quantization.py │ │ │ ├── lookup_free_quantization.py │ │ │ └── quantize.py │ │ ├── temporal_ae.py │ │ └── vqvae │ │ │ ├── movq_dec_3d.py │ │ │ ├── movq_dec_3d_dev.py │ │ │ ├── movq_enc_3d.py │ │ │ ├── movq_modules.py │ │ │ ├── quantize.py │ │ │ └── vqvae_blocks.py │ ├── cp_enc_dec.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── denoiser.cpython-310.pyc │ │ │ ├── denoiser_scaling.cpython-310.pyc │ │ │ ├── denoiser_weighting.cpython-310.pyc │ │ │ ├── discretizer.cpython-310.pyc │ │ │ ├── guiders.cpython-310.pyc │ │ │ ├── lora.cpython-310.pyc │ │ │ ├── loss.cpython-310.pyc │ │ │ ├── model.cpython-310.pyc │ │ │ ├── openaimodel.cpython-310.pyc │ │ │ ├── sampling.cpython-310.pyc │ │ │ ├── sampling_utils.cpython-310.pyc │ │ │ ├── sigma_sampling.cpython-310.pyc │ │ │ ├── util.cpython-310.pyc │ │ │ └── wrappers.cpython-310.pyc │ │ ├── denoiser.py │ │ ├── denoiser_scaling.py │ │ ├── denoiser_weighting.py │ │ ├── discretizer.py │ │ ├── guiders.py │ │ ├── lora.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── sampling.py │ │ ├── sampling_utils.py │ │ ├── sigma_sampling.py │ │ ├── util.py │ │ └── wrappers.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── distributions.cpython-310.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── modules.cpython-310.pyc │ │ └── modules.py │ └── video_attention.py ├── util.py └── webds.py ├── train_concat-id_wan2.1.py ├── train_video.py └── vae_modules ├── __pycache__ ├── autoencoder.cpython-310.pyc ├── cp_enc_dec.cpython-310.pyc ├── ema.cpython-310.pyc ├── regularizers.cpython-310.pyc └── utils.cpython-310.pyc ├── attention.py ├── autoencoder.py ├── cp_enc_dec.py ├── ema.py ├── regularizers.py └── utils.py /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import json 5 | import warnings 6 | import omegaconf 7 | from omegaconf import OmegaConf 8 | from sat.helpers import print_rank0 9 | from sat import mpu 10 | from sat.arguments import set_random_seed 11 | from sat.arguments import add_training_args, add_evaluation_args, add_data_args 12 | import torch.distributed 13 | 14 | 15 | def add_model_config_args(parser): 16 | """Model arguments""" 17 | 18 | group = parser.add_argument_group("model", "model configuration") 19 | group.add_argument("--base", type=str, nargs="*", help="config for input and saving") 20 | group.add_argument( 21 | "--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert." 22 | ) 23 | group.add_argument("--force-pretrain", action="store_true") 24 | group.add_argument("--device", type=int, default=-1) 25 | group.add_argument("--debug", action="store_true") 26 | group.add_argument("--log-image", type=bool, default=True) 27 | 28 | return parser 29 | 30 | 31 | def add_sampling_config_args(parser): 32 | """Sampling configurations""" 33 | 34 | group = parser.add_argument_group("sampling", "Sampling Configurations") 35 | group.add_argument("--output-dir", type=str, default="samples") 36 | group.add_argument("--input-dir", type=str, default=None) 37 | group.add_argument("--input-type", type=str, default="cli") 38 | group.add_argument("--input-file", type=str, default="input.txt") 39 | group.add_argument("--final-size", type=int, default=2048) 40 | group.add_argument("--sdedit", action="store_true") 41 | group.add_argument("--grid-num-rows", type=int, default=1) 42 | group.add_argument("--force-inference", action="store_true") 43 | group.add_argument("--lcm_steps", type=int, default=None) 44 | group.add_argument("--sampling-num-frames", type=int, default=32) 45 | group.add_argument("--sampling-fps", type=int, default=8) 46 | group.add_argument("--only-save-latents", type=bool, default=False) 47 | group.add_argument("--only-log-video-latents", type=bool, default=False) 48 | group.add_argument("--latent-channels", type=int, default=32) 49 | group.add_argument("--image2video", action="store_true") 50 | group.add_argument("--contextimage2video", action="store_true") 51 | 52 | return parser 53 | 54 | 55 | def get_args(args_list=None, parser=None): 56 | """Parse all the args.""" 57 | if parser is None: 58 | parser = argparse.ArgumentParser(description="sat") 59 | else: 60 | assert isinstance(parser, argparse.ArgumentParser) 61 | parser = add_model_config_args(parser) 62 | parser = add_sampling_config_args(parser) 63 | parser = add_training_args(parser) 64 | parser = add_evaluation_args(parser) 65 | parser = add_data_args(parser) 66 | 67 | import deepspeed 68 | 69 | parser = deepspeed.add_config_arguments(parser) 70 | 71 | args = parser.parse_args(args_list) 72 | args = process_config_to_args(args) 73 | 74 | if not args.train_data: 75 | print_rank0("No training data specified", level="WARNING") 76 | 77 | assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set." 78 | if args.train_iters is None and args.epochs is None: 79 | args.train_iters = 10000 # default 10k iters 80 | print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING") 81 | 82 | args.cuda = torch.cuda.is_available() 83 | 84 | args.rank = int(os.getenv("RANK", "0")) 85 | args.world_size = int(os.getenv("WORLD_SIZE", "1")) 86 | if args.local_rank is None: 87 | args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun 88 | 89 | if args.device == -1: 90 | if torch.cuda.device_count() == 0: 91 | args.device = "cpu" 92 | elif args.local_rank is not None: 93 | args.device = args.local_rank 94 | else: 95 | args.device = args.rank % torch.cuda.device_count() 96 | 97 | if args.local_rank != args.device and args.mode != "inference": 98 | raise ValueError( 99 | "LOCAL_RANK (default 0) and args.device inconsistent. " 100 | "This can only happens in inference mode. " 101 | "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. " 102 | ) 103 | 104 | if args.rank == 0: 105 | print_rank0("using world size: {}".format(args.world_size)) 106 | 107 | if args.train_data_weights is not None: 108 | assert len(args.train_data_weights) == len(args.train_data) 109 | 110 | if args.mode != "inference": # training with deepspeed 111 | args.deepspeed = True 112 | if args.deepspeed_config is None: # not specified 113 | deepspeed_config_path = os.path.join( 114 | os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json" 115 | ) 116 | with open(deepspeed_config_path) as file: 117 | args.deepspeed_config = json.load(file) 118 | override_deepspeed_config = True 119 | else: 120 | override_deepspeed_config = False 121 | 122 | assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16." 123 | 124 | if args.zero_stage > 0 and not args.fp16 and not args.bf16: 125 | print_rank0("Automatically set fp16=True to use ZeRO.") 126 | args.fp16 = True 127 | args.bf16 = False 128 | 129 | if args.deepspeed: 130 | if args.checkpoint_activations: 131 | args.deepspeed_activation_checkpointing = True 132 | else: 133 | args.deepspeed_activation_checkpointing = False 134 | if args.deepspeed_config is not None: 135 | deepspeed_config = args.deepspeed_config 136 | 137 | if override_deepspeed_config: # not specify deepspeed_config, use args 138 | if args.fp16: 139 | deepspeed_config["fp16"]["enabled"] = True 140 | elif args.bf16: 141 | deepspeed_config["bf16"]["enabled"] = True 142 | deepspeed_config["fp16"]["enabled"] = False 143 | else: 144 | deepspeed_config["fp16"]["enabled"] = False 145 | deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size 146 | deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps 147 | optimizer_params_config = deepspeed_config["optimizer"]["params"] 148 | optimizer_params_config["lr"] = args.lr 149 | optimizer_params_config["weight_decay"] = args.weight_decay 150 | else: # override args with values in deepspeed_config 151 | if args.rank == 0: 152 | print_rank0("Will override arguments with manually specified deepspeed_config!") 153 | if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: 154 | args.fp16 = True 155 | else: 156 | args.fp16 = False 157 | if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]: 158 | args.bf16 = True 159 | else: 160 | args.bf16 = False 161 | if "train_micro_batch_size_per_gpu" in deepspeed_config: 162 | args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] 163 | if "gradient_accumulation_steps" in deepspeed_config: 164 | args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] 165 | else: 166 | args.gradient_accumulation_steps = None 167 | if "optimizer" in deepspeed_config: 168 | optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) 169 | args.lr = optimizer_params_config.get("lr", args.lr) 170 | args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) 171 | args.deepspeed_config = deepspeed_config 172 | 173 | # initialize distributed and random seed because it always seems to be necessary. 174 | initialize_distributed(args) 175 | args.seed = args.seed + mpu.get_data_parallel_rank() 176 | set_random_seed(args.seed) 177 | return args 178 | 179 | 180 | def initialize_distributed(args): 181 | """Initialize torch.distributed.""" 182 | if torch.distributed.is_initialized(): 183 | if mpu.model_parallel_is_initialized(): 184 | if args.model_parallel_size != mpu.get_model_parallel_world_size(): 185 | raise ValueError( 186 | "model_parallel_size is inconsistent with prior configuration." 187 | "We currently do not support changing model_parallel_size." 188 | ) 189 | return False 190 | else: 191 | if args.model_parallel_size > 1: 192 | warnings.warn( 193 | "model_parallel_size > 1 but torch.distributed is not initialized via SAT." 194 | "Please carefully make sure the correctness on your own." 195 | ) 196 | mpu.initialize_model_parallel(args.model_parallel_size) 197 | return True 198 | # the automatic assignment of devices has been moved to arguments.py 199 | if args.device == "cpu": 200 | pass 201 | else: 202 | torch.cuda.set_device(args.device) 203 | # Call the init process 204 | init_method = "tcp://" 205 | args.master_ip = os.getenv("MASTER_ADDR", "localhost") 206 | 207 | if args.world_size == 1: 208 | from sat.helpers import get_free_port 209 | 210 | default_master_port = str(get_free_port()) 211 | else: 212 | default_master_port = "6000" 213 | args.master_port = os.getenv("MASTER_PORT", default_master_port) 214 | init_method += args.master_ip + ":" + args.master_port 215 | torch.distributed.init_process_group( 216 | backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method 217 | ) 218 | 219 | # Set the model-parallel / data-parallel communicators. 220 | mpu.initialize_model_parallel(args.model_parallel_size) 221 | 222 | # Set vae context parallel group equal to model parallel group 223 | from sgm.util import set_context_parallel_group, initialize_context_parallel 224 | 225 | if args.model_parallel_size <= 2: 226 | set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group()) 227 | else: 228 | initialize_context_parallel(2) 229 | # mpu.initialize_model_parallel(1) 230 | # Optional DeepSpeed Activation Checkpointing Features 231 | if args.deepspeed: 232 | import deepspeed 233 | 234 | deepspeed.init_distributed( 235 | dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method 236 | ) 237 | # # It seems that it has no negative influence to configure it even without using checkpointing. 238 | # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) 239 | else: 240 | # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout. 241 | try: 242 | import deepspeed 243 | from deepspeed.runtime.activation_checkpointing.checkpointing import ( 244 | _CUDA_RNG_STATE_TRACKER, 245 | _MODEL_PARALLEL_RNG_TRACKER_NAME, 246 | ) 247 | 248 | _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1 249 | except Exception as e: 250 | from sat.helpers import print_rank0 251 | 252 | print_rank0(str(e), level="DEBUG") 253 | 254 | return True 255 | 256 | 257 | def process_config_to_args(args): 258 | """Fetch args from only --base""" 259 | 260 | configs = [OmegaConf.load(cfg) for cfg in args.base] 261 | config = OmegaConf.merge(*configs) 262 | 263 | args_config = config.pop("args", OmegaConf.create()) 264 | for key in args_config: 265 | if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): 266 | arg = OmegaConf.to_object(args_config[key]) 267 | else: 268 | arg = args_config[key] 269 | if hasattr(args, key): 270 | setattr(args, key, arg) 271 | 272 | if "model" in config: 273 | model_config = config.pop("model", OmegaConf.create()) 274 | args.model_config = model_config 275 | if "deepspeed" in config: 276 | deepspeed_config = config.pop("deepspeed", OmegaConf.create()) 277 | args.deepspeed_config = OmegaConf.to_object(deepspeed_config) 278 | if "data" in config: 279 | data_config = config.pop("data", OmegaConf.create()) 280 | args.data_config = data_config 281 | 282 | return args 283 | -------------------------------------------------------------------------------- /configs/cogvideox_2b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | scale_factor: 1.15258426 3 | disable_first_stage_autocast: true 4 | log_keys: 5 | - txt 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | quantize_c_noise: False 12 | 13 | weighting_config: 14 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 15 | scaling_config: 16 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 17 | discretization_config: 18 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 19 | params: 20 | shift_scale: 3.0 21 | 22 | network_config: 23 | target: dit_video_concat.DiffusionTransformer 24 | params: 25 | time_embed_dim: 512 26 | elementwise_affine: True 27 | num_frames: 49 28 | time_compressed_rate: 4 29 | latent_width: 90 30 | latent_height: 60 31 | num_layers: 30 32 | patch_size: 2 33 | in_channels: 16 34 | out_channels: 16 35 | hidden_size: 1920 36 | adm_in_channels: 256 37 | num_attention_heads: 30 38 | 39 | transformer_args: 40 | checkpoint_activations: True ## using gradient checkpointing 41 | vocab_size: 1 42 | max_sequence_length: 64 43 | layernorm_order: pre 44 | skip_init: false 45 | model_parallel_size: 1 46 | is_decoder: false 47 | 48 | modules: 49 | pos_embed_config: 50 | target: dit_video_concat.Basic3DPositionEmbeddingMixin 51 | params: 52 | text_length: 226 53 | height_interpolation: 1.875 54 | width_interpolation: 1.875 55 | 56 | patch_embed_config: 57 | target: dit_video_concat.ImagePatchEmbeddingMixin 58 | params: 59 | text_hidden_size: 4096 60 | 61 | adaln_layer_config: 62 | target: dit_video_concat.AdaLNMixin 63 | params: 64 | qk_ln: True 65 | 66 | final_layer_config: 67 | target: dit_video_concat.FinalLayerMixin 68 | 69 | conditioner_config: 70 | target: sgm.modules.GeneralConditioner 71 | params: 72 | emb_models: 73 | - is_trainable: false 74 | input_key: txt 75 | ucg_rate: 0.1 76 | target: sgm.modules.encoders.modules.FrozenT5Embedder 77 | params: 78 | model_dir: "/workspace/intern/yongzhong/pre-trained-models/t5-v1_1-xxl" 79 | max_length: 226 80 | 81 | first_stage_config: 82 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 83 | params: 84 | cp_size: 1 85 | ckpt_path: "/workspace/intern/yongzhong/pre-trained-models/CogVideoX-2b-sat/vae/3d-vae.pt" 86 | ignore_keys: [ 'loss' ] 87 | 88 | loss_config: 89 | target: torch.nn.Identity 90 | 91 | regularizer_config: 92 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 93 | 94 | encoder_config: 95 | target: vae_modules.cp_enc_dec.ContextParallelEncoder3D 96 | params: 97 | double_z: true 98 | z_channels: 16 99 | resolution: 256 100 | in_channels: 3 101 | out_ch: 3 102 | ch: 128 103 | ch_mult: [ 1, 2, 2, 4 ] 104 | attn_resolutions: [ ] 105 | num_res_blocks: 3 106 | dropout: 0.0 107 | gather_norm: True 108 | 109 | decoder_config: 110 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 111 | params: 112 | double_z: True 113 | z_channels: 16 114 | resolution: 256 115 | in_channels: 3 116 | out_ch: 3 117 | ch: 128 118 | ch_mult: [ 1, 2, 2, 4 ] 119 | attn_resolutions: [ ] 120 | num_res_blocks: 3 121 | dropout: 0.0 122 | gather_norm: False 123 | 124 | loss_fn_config: 125 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss 126 | params: 127 | offset_noise_level: 0 128 | sigma_sampler_config: 129 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 130 | params: 131 | uniform_sampling: True 132 | num_idx: 1000 133 | discretization_config: 134 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 135 | params: 136 | shift_scale: 3.0 137 | 138 | sampler_config: 139 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler 140 | params: 141 | num_steps: 50 142 | verbose: True 143 | 144 | discretization_config: 145 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 146 | params: 147 | shift_scale: 3.0 148 | 149 | guider_config: 150 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG 151 | params: 152 | scale: 6 153 | exp: 5 154 | num_steps: 50 -------------------------------------------------------------------------------- /configs/cogvideox_2b_lora.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | scale_factor: 1.15258426 3 | disable_first_stage_autocast: true 4 | not_trainable_prefixes: ['all'] ## Using Lora 5 | log_keys: 6 | - txt 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 10 | params: 11 | num_idx: 1000 12 | quantize_c_noise: False 13 | 14 | weighting_config: 15 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 16 | scaling_config: 17 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 18 | discretization_config: 19 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 20 | params: 21 | shift_scale: 3.0 22 | 23 | network_config: 24 | target: dit_video_concat.DiffusionTransformer 25 | params: 26 | time_embed_dim: 512 27 | elementwise_affine: True 28 | num_frames: 49 29 | time_compressed_rate: 4 30 | latent_width: 90 31 | latent_height: 60 32 | num_layers: 30 33 | patch_size: 2 34 | in_channels: 16 35 | out_channels: 16 36 | hidden_size: 1920 37 | adm_in_channels: 256 38 | num_attention_heads: 30 39 | 40 | transformer_args: 41 | checkpoint_activations: True ## using gradient checkpointing 42 | vocab_size: 1 43 | max_sequence_length: 64 44 | layernorm_order: pre 45 | skip_init: false 46 | model_parallel_size: 1 47 | is_decoder: false 48 | 49 | modules: 50 | pos_embed_config: 51 | target: dit_video_concat.Basic3DPositionEmbeddingMixin 52 | params: 53 | text_length: 226 54 | height_interpolation: 1.875 55 | width_interpolation: 1.875 56 | 57 | lora_config: 58 | target: sat.model.finetune.lora2.LoraMixin 59 | params: 60 | r: 128 61 | 62 | patch_embed_config: 63 | target: dit_video_concat.ImagePatchEmbeddingMixin 64 | params: 65 | text_hidden_size: 4096 66 | 67 | adaln_layer_config: 68 | target: dit_video_concat.AdaLNMixin 69 | params: 70 | qk_ln: True 71 | 72 | final_layer_config: 73 | target: dit_video_concat.FinalLayerMixin 74 | 75 | conditioner_config: 76 | target: sgm.modules.GeneralConditioner 77 | params: 78 | emb_models: 79 | - is_trainable: false 80 | input_key: txt 81 | ucg_rate: 0.1 82 | target: sgm.modules.encoders.modules.FrozenT5Embedder 83 | params: 84 | model_dir: "t5-v1_1-xxl" 85 | max_length: 226 86 | 87 | first_stage_config: 88 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 89 | params: 90 | cp_size: 1 91 | ckpt_path: "cogvideox-2b-sat/vae/3d-vae.pt" 92 | ignore_keys: [ 'loss' ] 93 | 94 | loss_config: 95 | target: torch.nn.Identity 96 | 97 | regularizer_config: 98 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 99 | 100 | encoder_config: 101 | target: vae_modules.cp_enc_dec.ContextParallelEncoder3D 102 | params: 103 | double_z: true 104 | z_channels: 16 105 | resolution: 256 106 | in_channels: 3 107 | out_ch: 3 108 | ch: 128 109 | ch_mult: [ 1, 2, 2, 4 ] 110 | attn_resolutions: [ ] 111 | num_res_blocks: 3 112 | dropout: 0.0 113 | gather_norm: True 114 | 115 | decoder_config: 116 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 117 | params: 118 | double_z: True 119 | z_channels: 16 120 | resolution: 256 121 | in_channels: 3 122 | out_ch: 3 123 | ch: 128 124 | ch_mult: [ 1, 2, 2, 4 ] 125 | attn_resolutions: [ ] 126 | num_res_blocks: 3 127 | dropout: 0.0 128 | gather_norm: False 129 | 130 | loss_fn_config: 131 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss 132 | params: 133 | offset_noise_level: 0 134 | sigma_sampler_config: 135 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 136 | params: 137 | uniform_sampling: True 138 | num_idx: 1000 139 | discretization_config: 140 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 141 | params: 142 | shift_scale: 3.0 143 | 144 | sampler_config: 145 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler 146 | params: 147 | num_steps: 50 148 | verbose: True 149 | 150 | discretization_config: 151 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 152 | params: 153 | shift_scale: 3.0 154 | 155 | guider_config: 156 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG 157 | params: 158 | scale: 6 159 | exp: 5 160 | num_steps: 50 -------------------------------------------------------------------------------- /configs/cogvideox_5b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | scale_factor: 0.7 3 | disable_first_stage_autocast: true 4 | ref_image_dropout: 0.1 # the probability of dropping reference images 5 | log_keys: 6 | - txt 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 10 | params: 11 | num_idx: 1000 12 | quantize_c_noise: False 13 | 14 | weighting_config: 15 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 16 | scaling_config: 17 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 18 | discretization_config: 19 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 20 | params: 21 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 22 | 23 | network_config: 24 | target: dit_video_concat.DiffusionTransformer 25 | params: 26 | time_embed_dim: 512 27 | elementwise_affine: True 28 | num_frames: 49 29 | time_compressed_rate: 4 30 | latent_width: 90 31 | latent_height: 60 32 | num_layers: 42 # different from cogvideox_2b_infer.yaml 33 | patch_size: 2 34 | in_channels: 16 35 | out_channels: 16 36 | hidden_size: 3072 # different from cogvideox_2b_infer.yaml 37 | adm_in_channels: 256 38 | num_attention_heads: 48 # different from cogvideox_2b_infer.yaml 39 | 40 | transformer_args: 41 | checkpoint_activations: True 42 | vocab_size: 1 43 | max_sequence_length: 64 44 | layernorm_order: pre 45 | skip_init: false 46 | model_parallel_size: 1 47 | is_decoder: false 48 | 49 | modules: 50 | pos_embed_config: 51 | target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml 52 | params: 53 | hidden_size_head: 64 54 | text_length: 226 55 | 56 | patch_embed_config: 57 | target: dit_video_concat.ImagePatchEmbeddingMixin 58 | params: 59 | text_hidden_size: 4096 60 | 61 | adaln_layer_config: 62 | target: dit_video_concat.AdaLNMixin 63 | params: 64 | qk_ln: True 65 | 66 | final_layer_config: 67 | target: dit_video_concat.FinalLayerMixin 68 | 69 | conditioner_config: 70 | target: sgm.modules.GeneralConditioner 71 | params: 72 | emb_models: 73 | - is_trainable: false 74 | input_key: txt 75 | ucg_rate: 0.1 76 | target: sgm.modules.encoders.modules.FrozenT5Embedder 77 | params: 78 | model_dir: "./models/t5-v1_1-xxl" 79 | max_length: 226 80 | 81 | first_stage_config: 82 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 83 | params: 84 | cp_size: 1 85 | ckpt_path: "./models/vae/3d-vae.pt" 86 | ignore_keys: [ 'loss' ] 87 | 88 | loss_config: 89 | target: torch.nn.Identity 90 | 91 | regularizer_config: 92 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 93 | 94 | encoder_config: 95 | target: vae_modules.cp_enc_dec.ContextParallelEncoder3D 96 | params: 97 | double_z: true 98 | z_channels: 16 99 | resolution: 256 100 | in_channels: 3 101 | out_ch: 3 102 | ch: 128 103 | ch_mult: [ 1, 2, 2, 4 ] 104 | attn_resolutions: [ ] 105 | num_res_blocks: 3 106 | dropout: 0.0 107 | gather_norm: True 108 | 109 | decoder_config: 110 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 111 | params: 112 | double_z: True 113 | z_channels: 16 114 | resolution: 256 115 | in_channels: 3 116 | out_ch: 3 117 | ch: 128 118 | ch_mult: [ 1, 2, 2, 4 ] 119 | attn_resolutions: [ ] 120 | num_res_blocks: 3 121 | dropout: 0.0 122 | gather_norm: False 123 | 124 | loss_fn_config: 125 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss 126 | params: 127 | offset_noise_level: 0 128 | sigma_sampler_config: 129 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 130 | params: 131 | uniform_sampling: True 132 | num_idx: 1000 133 | discretization_config: 134 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 135 | params: 136 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 137 | 138 | # sampler_config: 139 | # target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler 140 | # params: 141 | # num_steps: 50 142 | # verbose: True 143 | 144 | # discretization_config: 145 | # target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 146 | # params: 147 | # shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 148 | 149 | # guider_config: 150 | # target: sgm.modules.diffusionmodules.guiders.DynamicCFG 151 | # params: 152 | # scale: 6 153 | # exp: 5 154 | # num_steps: 50 155 | 156 | sampler_config: 157 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler_cfg 158 | params: 159 | num_steps: 50 160 | verbose: True 161 | 162 | discretization_config: 163 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 164 | params: 165 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 166 | 167 | guider_config: 168 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 169 | params: 170 | scale: 6 171 | # exp: 5 172 | # num_steps: 50 -------------------------------------------------------------------------------- /configs/cogvideox_5b_lora.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | scale_factor: 0.7 # different from cogvideox_2b_infer.yaml 3 | disable_first_stage_autocast: true 4 | not_trainable_prefixes: ['all'] # Using Lora 5 | log_keys: 6 | - txt 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 10 | params: 11 | num_idx: 1000 12 | quantize_c_noise: False 13 | 14 | weighting_config: 15 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 16 | scaling_config: 17 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 18 | discretization_config: 19 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 20 | params: 21 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 22 | 23 | network_config: 24 | target: dit_video_concat.DiffusionTransformer 25 | params: 26 | time_embed_dim: 512 27 | elementwise_affine: True 28 | num_frames: 49 29 | time_compressed_rate: 4 30 | latent_width: 90 31 | latent_height: 60 32 | num_layers: 42 # different from cogvideox_2b_infer.yaml 33 | patch_size: 2 34 | in_channels: 16 35 | out_channels: 16 36 | hidden_size: 3072 # different from cogvideox_2b_infer.yaml 37 | adm_in_channels: 256 38 | num_attention_heads: 48 # different from cogvideox_2b_infer.yaml 39 | 40 | transformer_args: 41 | checkpoint_activations: True 42 | vocab_size: 1 43 | max_sequence_length: 64 44 | layernorm_order: pre 45 | skip_init: false 46 | model_parallel_size: 1 47 | is_decoder: false 48 | 49 | modules: 50 | pos_embed_config: 51 | target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml 52 | params: 53 | hidden_size_head: 64 54 | text_length: 226 55 | 56 | lora_config: # Using Lora 57 | target: sat.model.finetune.lora2.LoraMixin 58 | params: 59 | r: 128 60 | 61 | patch_embed_config: 62 | target: dit_video_concat.ImagePatchEmbeddingMixin 63 | params: 64 | text_hidden_size: 4096 65 | 66 | adaln_layer_config: 67 | target: dit_video_concat.AdaLNMixin 68 | params: 69 | qk_ln: True 70 | 71 | final_layer_config: 72 | target: dit_video_concat.FinalLayerMixin 73 | 74 | conditioner_config: 75 | target: sgm.modules.GeneralConditioner 76 | params: 77 | emb_models: 78 | - is_trainable: false 79 | input_key: txt 80 | ucg_rate: 0.1 81 | target: sgm.modules.encoders.modules.FrozenT5Embedder 82 | params: 83 | model_dir: "t5-v1_1-xxl" 84 | max_length: 226 85 | 86 | first_stage_config: 87 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 88 | params: 89 | cp_size: 1 90 | ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt" 91 | ignore_keys: [ 'loss' ] 92 | 93 | loss_config: 94 | target: torch.nn.Identity 95 | 96 | regularizer_config: 97 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 98 | 99 | encoder_config: 100 | target: vae_modules.cp_enc_dec.ContextParallelEncoder3D 101 | params: 102 | double_z: true 103 | z_channels: 16 104 | resolution: 256 105 | in_channels: 3 106 | out_ch: 3 107 | ch: 128 108 | ch_mult: [ 1, 2, 2, 4 ] 109 | attn_resolutions: [ ] 110 | num_res_blocks: 3 111 | dropout: 0.0 112 | gather_norm: True 113 | 114 | decoder_config: 115 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 116 | params: 117 | double_z: True 118 | z_channels: 16 119 | resolution: 256 120 | in_channels: 3 121 | out_ch: 3 122 | ch: 128 123 | ch_mult: [ 1, 2, 2, 4 ] 124 | attn_resolutions: [ ] 125 | num_res_blocks: 3 126 | dropout: 0.0 127 | gather_norm: False 128 | 129 | loss_fn_config: 130 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss 131 | params: 132 | offset_noise_level: 0 133 | sigma_sampler_config: 134 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 135 | params: 136 | uniform_sampling: True 137 | num_idx: 1000 138 | discretization_config: 139 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 140 | params: 141 | shift_scale: 1.0 # different from cogvideox_2b_infer.yaml 142 | 143 | sampler_config: 144 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler 145 | params: 146 | num_steps: 50 147 | verbose: True 148 | 149 | discretization_config: 150 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 151 | params: 152 | shift_scale: 1.0 153 | 154 | guider_config: 155 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG 156 | params: 157 | scale: 6 158 | exp: 5 159 | num_steps: 50 -------------------------------------------------------------------------------- /configs/inference/inference_single_identity.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | image2video: False # True for image2video, False for text2video 3 | contextimage2video: True 4 | latent_channels: 16 5 | mode: inference 6 | load: "./models/single-identity/" 7 | batch_size: 1 8 | input_type: txt 9 | input_file: ./examples/single_identity.txt 10 | sampling_image_size: [480, 720] 11 | sampling_num_frames: 13 # Must be 13, 11 or 9 12 | sampling_fps: 8 13 | # fp16: True # For CogVideoX-2B 14 | bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V 15 | output_dir: outputs/single-identity 16 | force_inference: True 17 | -------------------------------------------------------------------------------- /configs/inference/inference_two_identities.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | image2video: False # True for image2video, False for text2video 3 | contextimage2video: True 4 | latent_channels: 16 5 | mode: inference 6 | load: "./models/two-identities/" 7 | batch_size: 1 8 | input_type: txt 9 | input_file: ./examples/two_identities.txt 10 | sampling_image_size: [480, 720] 11 | sampling_num_frames: 13 # Must be 13, 11 or 9 12 | sampling_fps: 8 13 | # fp16: True # For CogVideoX-2B 14 | bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V 15 | output_dir: outputs/two-identities 16 | force_inference: True 17 | -------------------------------------------------------------------------------- /configs/training/sft_single_identity.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | checkpoint_activations: True # using gradient checkpointing 3 | model_parallel_size: 1 4 | wandb: False 5 | wandb_project_name: "single-identity" 6 | experiment_name: ID-Preserving-Generation 7 | mode: finetune 8 | load: "models/single-identity" 9 | no_load_rng: True 10 | # lr_decay_style: cosine # use the cosine lr 11 | train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough 12 | eval_iters: 1 13 | eval_interval: 100 14 | eval_batch_size: 1 15 | save: exps 16 | save_interval: 100 17 | log_interval: 20 18 | train_data: [ "{path}/training-data.json" ] # Train data path 19 | valid_data: [ "{path}/validation-data.json" ] # Validation data path, can be the same as train_data(not recommended) 20 | split: 1,0,0 21 | num_workers: 8 22 | force_train: True 23 | only_log_video_latents: False 24 | 25 | data: 26 | target: data_video.FaceJsonMultiPerSFTDataset 27 | params: 28 | video_size: [ 480, 720 ] 29 | fps: 8 30 | max_num_frames: 49 31 | skip_frms_num: 3. 32 | 33 | 34 | deepspeed: 35 | # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs 36 | train_micro_batch_size_per_gpu: 1 37 | gradient_accumulation_steps: 1 38 | steps_per_print: 50 39 | gradient_clipping: 0.1 40 | zero_optimization: 41 | stage: 2 42 | cpu_offload: false 43 | contiguous_gradients: false 44 | overlap_comm: true 45 | reduce_scatter: true 46 | reduce_bucket_size: 1000000000 47 | allgather_bucket_size: 1000000000 48 | load_from_fp32_weights: false 49 | zero_allow_untested_optimizer: true 50 | bf16: 51 | enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True 52 | fp16: 53 | enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False 54 | loss_scale: 0 55 | loss_scale_window: 400 56 | hysteresis: 2 57 | min_loss_scale: 1 58 | 59 | optimizer: 60 | type: sat.ops.FusedEmaAdam 61 | params: 62 | lr: 0.00001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT 63 | betas: [ 0.9, 0.95 ] 64 | eps: 1e-8 65 | weight_decay: 1e-4 66 | activation_checkpointing: 67 | partition_activations: false 68 | contiguous_memory_optimization: false 69 | wall_clock_breakdown: false -------------------------------------------------------------------------------- /configs/training/sft_two_identities.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | checkpoint_activations: True # using gradient checkpointing 3 | model_parallel_size: 1 4 | wandb: False 5 | wandb_project_name: "two-identities" 6 | experiment_name: ID-Preserving-Generation-2identities 7 | mode: finetune 8 | load: "models/two-identities" 9 | no_load_rng: True 10 | # lr_decay_style: cosine # use the cosine lr 11 | train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough 12 | eval_iters: 1 13 | eval_interval: 100 14 | eval_batch_size: 1 15 | save: exps-2identities 16 | save_interval: 100 17 | log_interval: 20 18 | train_data: [ "{path}/training-data.json" ] # Train data path 19 | valid_data: [ "{path}/validation-data.json" ] # Validation data path, can be the same as train_data(not recommended) 20 | split: 1,0,0 21 | num_workers: 8 22 | force_train: True 23 | only_log_video_latents: False 24 | 25 | data: 26 | target: data_video.FaceJsonMultiPerSFTDataset 27 | params: 28 | video_size: [ 480, 720 ] 29 | fps: 8 30 | max_num_frames: 49 31 | skip_frms_num: 3. 32 | 33 | 34 | deepspeed: 35 | # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs 36 | train_micro_batch_size_per_gpu: 1 37 | gradient_accumulation_steps: 1 38 | steps_per_print: 50 39 | gradient_clipping: 0.1 40 | zero_optimization: 41 | stage: 2 42 | cpu_offload: false 43 | contiguous_gradients: false 44 | overlap_comm: true 45 | reduce_scatter: true 46 | reduce_bucket_size: 1000000000 47 | allgather_bucket_size: 1000000000 48 | load_from_fp32_weights: false 49 | zero_allow_untested_optimizer: true 50 | bf16: 51 | enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True 52 | fp16: 53 | enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False 54 | loss_scale: 0 55 | loss_scale_window: 400 56 | hysteresis: 2 57 | min_loss_scale: 1 58 | 59 | optimizer: 60 | type: sat.ops.FusedEmaAdam 61 | params: 62 | lr: 0.00001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT 63 | betas: [ 0.9, 0.95 ] 64 | eps: 1e-8 65 | weight_decay: 1e-4 66 | activation_checkpointing: 67 | partition_activations: false 68 | contiguous_memory_optimization: false 69 | wall_clock_breakdown: false -------------------------------------------------------------------------------- /examples/cropped_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/cropped_images/1.png -------------------------------------------------------------------------------- /examples/cropped_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/cropped_images/2.png -------------------------------------------------------------------------------- /examples/cropped_images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/cropped_images/3.png -------------------------------------------------------------------------------- /examples/cropped_images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/cropped_images/4.png -------------------------------------------------------------------------------- /examples/cropped_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/cropped_images/5.png -------------------------------------------------------------------------------- /examples/cropped_images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/cropped_images/6.png -------------------------------------------------------------------------------- /examples/images/3_stars_woman_Taylor_Swift_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/3_stars_woman_Taylor_Swift_3.png -------------------------------------------------------------------------------- /examples/images/43_stars_man_Leonardo_DiCaprio_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/43_stars_man_Leonardo_DiCaprio_3.png -------------------------------------------------------------------------------- /examples/images/69_politicians_woman_Tulsi_Gabbard_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/69_politicians_woman_Tulsi_Gabbard_4.png -------------------------------------------------------------------------------- /examples/images/72_politicians_woman_Tulsi_Gabbard_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/72_politicians_woman_Tulsi_Gabbard_2.png -------------------------------------------------------------------------------- /examples/images/73_politicians_woman_Harris_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/73_politicians_woman_Harris_3.png -------------------------------------------------------------------------------- /examples/images/80_normal_man_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/80_normal_man_5.jpg -------------------------------------------------------------------------------- /examples/images/93_normal_woman_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/images/93_normal_woman_3.jpg -------------------------------------------------------------------------------- /examples/results/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/results/1.gif -------------------------------------------------------------------------------- /examples/results/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/results/2.gif -------------------------------------------------------------------------------- /examples/results/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/results/3.gif -------------------------------------------------------------------------------- /examples/results/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/results/4.gif -------------------------------------------------------------------------------- /examples/results/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/results/5.gif -------------------------------------------------------------------------------- /examples/results/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/examples/results/6.gif -------------------------------------------------------------------------------- /examples/single_identity.txt: -------------------------------------------------------------------------------- 1 | A man standing next to an airplane, engaged in a conversation on his cell phone. He is wearing sunglasses and a black top, and he appears to be talking seriously. The airplane has a green stripe running along its side, and there is a large engine visible behind his. The man seems to be standing near the entrance of the airplane, possibly preparing to board or just having disembarked. The setting suggests that he might be at an airport or a private airfield. The overall atmosphere of the video is professional and focused, with the man's attire and the presence of the airplane indicating a business or travel context.@@examples/images/80_normal_man_5.jpg 2 | A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon.@@examples/images/3_stars_woman_Taylor_Swift_3.png -------------------------------------------------------------------------------- /examples/two_identities.txt: -------------------------------------------------------------------------------- 1 | Two individuals are studying together for an upcoming exam. They sit in a quiet library, discussing the material, explaining difficult concepts to each other, and quizzing one another on key points. Highlight how collaboration can enhance learning and deepen understanding.@@examples/images/43_stars_man_Leonardo_DiCaprio_3.png@@examples/images/3_stars_woman_Taylor_Swift_3.png 2 | -------------------------------------------------------------------------------- /finetune_single_identity.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 4 | 5 | run_cmd="PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b.yaml configs/training/sft_single_identity.yaml --seed 42" 6 | 7 | echo ${run_cmd} 8 | eval ${run_cmd} 9 | 10 | echo "DONE on `hostname`" -------------------------------------------------------------------------------- /finetune_two_identities.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 4 | 5 | run_cmd="PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b.yaml configs/training/sft_two_identities.yaml --seed 42" 6 | 7 | echo ${run_cmd} 8 | eval ${run_cmd} 9 | 10 | echo "DONE on `hostname`" -------------------------------------------------------------------------------- /inference_single_identity.sh: -------------------------------------------------------------------------------- 1 | python sample_video.py --base configs/cogvideox_5b.yaml configs/inference/inference_single_identity.yaml --seed 42 -------------------------------------------------------------------------------- /inference_two_identities.sh: -------------------------------------------------------------------------------- 1 | python sample_video.py --base configs/cogvideox_5b.yaml configs/inference/inference_two_identities.yaml --seed 42 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | albucore==0.0.20 3 | albumentations==1.4.21 4 | annotated-types==0.7.0 5 | antlr4-python3-runtime==4.9.3 6 | anyio==4.4.0 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==2.4.1 11 | astunparse==1.6.3 12 | async-lru==2.0.4 13 | attrs==24.2.0 14 | audioread==3.0.1 15 | babel==2.16.0 16 | beartype==0.19.0 17 | beautifulsoup4==4.12.3 18 | black==24.8.0 19 | bleach==6.1.0 20 | blis==0.7.11 21 | boto3==1.35.64 22 | botocore==1.35.64 23 | braceexpand==0.1.7 24 | catalogue==2.0.10 25 | certifi==2024.7.4 26 | cffi==1.17.0 27 | charset-normalizer==3.3.2 28 | click==8.1.7 29 | cloudpathlib==0.18.1 30 | cmake==3.30.2 31 | coloredlogs==15.0.1 32 | comm==0.2.2 33 | confection==0.1.5 34 | contourpy==1.2.1 35 | cpm-kernels==1.0.11 36 | cycler==0.12.1 37 | cymem==2.0.8 38 | Cython==3.0.11 39 | datasets==3.1.0 40 | debugpy==1.8.5 41 | decorator==5.1.1 42 | decord==0.6.0 43 | deepspeed==0.15.4 44 | defusedxml==0.7.1 45 | dill==0.3.8 46 | dm-tree==0.1.8 47 | docker-pycreds==0.4.0 48 | easydict==1.13 49 | einops==0.8.0 50 | eval_type_backport==0.2.0 51 | exceptiongroup==1.2.2 52 | execnet==2.1.1 53 | executing==2.0.1 54 | expecttest==0.1.3 55 | fastjsonschema==2.20.0 56 | ffmpeg==1.4 57 | filelock==3.15.4 58 | flatbuffers==24.3.25 59 | fonttools==4.53.1 60 | fqdn==1.5.1 61 | fsspec==2024.6.1 62 | gast==0.6.0 63 | gdown==5.2.0 64 | gitdb==4.0.11 65 | GitPython==3.1.43 66 | h11==0.14.0 67 | hjson==3.1.0 68 | httpcore==1.0.5 69 | httpx==0.27.0 70 | huggingface-hub==0.26.2 71 | humanfriendly==10.0 72 | hypothesis==5.35.1 73 | idna==3.7 74 | igraph==0.11.6 75 | imageio==2.36.0 76 | imageio-ffmpeg==0.5.1 77 | iniconfig==2.0.0 78 | insightface==0.7.3 79 | intel-openmp==2021.4.0 80 | ipykernel==6.29.5 81 | ipython==8.26.0 82 | isoduration==20.11.0 83 | isort==5.13.2 84 | jedi==0.19.1 85 | Jinja2==3.1.4 86 | jmespath==1.0.1 87 | joblib==1.4.2 88 | json5==0.9.25 89 | jsonpointer==3.0.0 90 | jsonschema==4.23.0 91 | jsonschema-specifications==2023.12.1 92 | jupyter-events==0.10.0 93 | jupyter-lsp==2.2.5 94 | jupyter_client==8.6.2 95 | jupyter_core==5.7.2 96 | jupyter_server==2.14.2 97 | jupyter_server_terminals==0.5.3 98 | jupyterlab==4.2.4 99 | jupyterlab-tensorboard-pro==4.0.0 100 | jupyterlab_code_formatter==3.0.2 101 | jupyterlab_pygments==0.3.0 102 | jupyterlab_server==2.27.3 103 | jupytext==1.16.4 104 | kiwisolver==1.4.5 105 | kornia==0.7.4 106 | kornia_rs==0.1.7 107 | langcodes==3.4.0 108 | language_data==1.2.0 109 | lazy_loader==0.4 110 | librosa==0.10.1 111 | lintrunner==0.12.5 112 | looseversion==1.3.0 113 | marisa-trie==1.2.0 114 | matplotlib==3.9.2 115 | matplotlib-inline==0.1.7 116 | mdit-py-plugins==0.4.1 117 | mdurl==0.1.2 118 | mistune==3.0.2 119 | mkl==2021.1.1 120 | mkl-devel==2021.1.1 121 | mkl-include==2021.1.1 122 | mock==5.1.0 123 | mpmath==1.3.0 124 | msgpack==1.0.8 125 | multiprocess==0.70.16 126 | murmurhash==1.0.10 127 | mypy-extensions==1.0.0 128 | nbclient==0.10.0 129 | nbconvert==7.16.4 130 | nbformat==5.10.4 131 | nest-asyncio==1.6.0 132 | networkx==3.3 133 | ninja==1.11.1.1 134 | notebook==7.2.1 135 | notebook_shim==0.2.4 136 | numpy==1.24.4 137 | nvidia-cublas-cu12==12.4.5.8 138 | nvidia-cuda-cupti-cu12==12.4.127 139 | nvidia-cuda-nvrtc-cu12==12.4.127 140 | nvidia-cuda-runtime-cu12==12.4.127 141 | nvidia-cudnn-cu12==9.1.0.70 142 | nvidia-cufft-cu12==11.2.1.3 143 | nvidia-curand-cu12==10.3.5.147 144 | nvidia-cusolver-cu12==11.6.1.9 145 | nvidia-cusparse-cu12==12.3.1.170 146 | nvidia-dali-cuda120==1.40.0 147 | nvidia-ml-py==12.560.30 148 | nvidia-modelopt==0.15.0 149 | nvidia-nccl-cu12==2.21.5 150 | nvidia-nvimgcodec-cu12==0.3.0.5 151 | nvidia-nvjitlink-cu12==12.4.127 152 | nvidia-nvtx-cu12==12.4.127 153 | nvidia-pyindex==1.0.9 154 | omegaconf==2.3.0 155 | onnxruntime==1.20.0 156 | opencv-fixer==0.2.5 157 | opencv-python-headless==4.10.0.84 158 | opt-einsum==3.3.0 159 | optree==0.12.1 160 | overrides==7.7.0 161 | pandocfilters==1.5.1 162 | parso==0.8.4 163 | pexpect==4.9.0 164 | pillow==10.4.0 165 | platformdirs==4.2.2 166 | pluggy==1.5.0 167 | pooch==1.8.2 168 | preshed==3.0.9 169 | prettytable==3.12.0 170 | prometheus_client==0.20.0 171 | prompt_toolkit==3.0.47 172 | protobuf==4.24.4 173 | psutil==6.0.0 174 | ptyprocess==0.7.0 175 | PuLP==2.9.0 176 | pure_eval==0.2.3 177 | py-cpuinfo==9.0.0 178 | pybind11==2.13.4 179 | pybind11_global==2.13.4 180 | pycparser==2.22 181 | pydantic==2.8.2 182 | pydantic_core==2.20.1 183 | Pygments==2.18.0 184 | pyparsing==3.1.2 185 | PySocks==1.7.1 186 | pytest==8.1.1 187 | pytest-flakefinder==1.1.0 188 | pytest-rerunfailures==14.0 189 | pytest-shard==0.1.2 190 | pytest-xdist==3.6.1 191 | python-dateutil==2.9.0.post0 192 | python-hostlist==1.23.0 193 | python-json-logger==2.0.7 194 | pytorch-lightning==2.4.0 195 | PyYAML==6.0.2 196 | pyzmq==26.1.0 197 | referencing==0.35.1 198 | regex==2024.7.24 199 | requests==2.32.3 200 | rfc3339-validator==0.1.4 201 | rfc3986-validator==0.1.1 202 | rich==13.7.1 203 | rpds-py==0.20.0 204 | s3transfer==0.10.3 205 | safetensors==0.4.5 206 | scikit-image==0.24.0 207 | scikit-learn==1.5.1 208 | scipy==1.14.1 209 | Send2Trash==1.8.3 210 | sentencepiece==0.2.0 211 | sentry-sdk==2.18.0 212 | setproctitle==1.3.4 213 | shellingham==1.5.4 214 | simsimd==6.0.7 215 | six==1.16.0 216 | smart-open==7.0.4 217 | smmap==5.0.1 218 | sniffio==1.3.1 219 | sortedcontainers==2.4.0 220 | soundfile==0.12.1 221 | soupsieve==2.6 222 | soxr==0.4.0 223 | spacy==3.7.5 224 | spacy-legacy==3.0.12 225 | spacy-loggers==1.0.5 226 | srsly==2.4.8 227 | stack-data==0.6.3 228 | stringzilla==3.10.10 229 | SwissArmyTransformer==0.4.12 230 | sympy==1.13.1 231 | tabulate==0.9.0 232 | tbb==2021.13.1 233 | tensorboard==2.16.2 234 | tensorboard-data-server==0.7.2 235 | tensorboardX==2.6.2.2 236 | terminado==0.18.1 237 | texttable==1.7.0 238 | thinc==8.2.5 239 | threadpoolctl==3.5.0 240 | tifffile==2024.9.20 241 | tinycss2==1.3.0 242 | tokenizers==0.20.3 243 | tomli==2.0.1 244 | torch==2.5.1 245 | torchmetrics==1.6.0 246 | torchvision==0.20.1 247 | tornado==6.2 248 | tqdm==4.66.5 249 | traitlets==5.14.3 250 | transformers==4.46.3 251 | triton==3.1.0 252 | typer==0.12.4 253 | types-dataclasses==0.6.6 254 | types-python-dateutil==2.9.0.20240316 255 | typing_extensions==4.12.2 256 | uri-template==1.3.0 257 | wandb==0.18.7 258 | wasabi==1.1.3 259 | wcwidth==0.2.13 260 | weasel==0.4.1 261 | webencodings==0.5.1 262 | websocket-client==1.8.0 263 | Werkzeug==3.0.3 264 | wrapt==1.16.0 265 | xdoctest==1.0.2 266 | xformer==1.0.1 267 | xxhash==3.5.0 268 | -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/__pycache__/webds.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/__pycache__/webds.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 32 | self.last_lr = lr 33 | return lr 34 | else: 35 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 36 | t = min(t, 1.0) 37 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) 38 | self.last_lr = lr 39 | return lr 40 | 41 | def __call__(self, n, **kwargs): 42 | return self.schedule(n, **kwargs) 43 | 44 | 45 | class LambdaWarmUpCosineScheduler2: 46 | """ 47 | supports repeated iterations, configurable via lists 48 | note: use with a base_lr of 1.0. 49 | """ 50 | 51 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 52 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 53 | self.lr_warm_up_steps = warm_up_steps 54 | self.f_start = f_start 55 | self.f_min = f_min 56 | self.f_max = f_max 57 | self.cycle_lengths = cycle_lengths 58 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 59 | self.last_f = 0.0 60 | self.verbosity_interval = verbosity_interval 61 | 62 | def find_in_interval(self, n): 63 | interval = 0 64 | for cl in self.cum_cycles[1:]: 65 | if n <= cl: 66 | return interval 67 | interval += 1 68 | 69 | def schedule(self, n, **kwargs): 70 | cycle = self.find_in_interval(n) 71 | n = n - self.cum_cycles[cycle] 72 | if self.verbosity_interval > 0: 73 | if n % self.verbosity_interval == 0: 74 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") 75 | if n < self.lr_warm_up_steps[cycle]: 76 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 77 | self.last_f = f 78 | return f 79 | else: 80 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 81 | t = min(t, 1.0) 82 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) 83 | self.last_f = f 84 | return f 85 | 86 | def __call__(self, n, **kwargs): 87 | return self.schedule(n, **kwargs) 88 | 89 | 90 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 91 | def schedule(self, n, **kwargs): 92 | cycle = self.find_in_interval(n) 93 | n = n - self.cum_cycles[cycle] 94 | if self.verbosity_interval > 0: 95 | if n % self.verbosity_interval == 0: 96 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") 97 | 98 | if n < self.lr_warm_up_steps[cycle]: 99 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 100 | self.last_f = f 101 | return f 102 | else: 103 | f = ( 104 | self.f_min[cycle] 105 | + (self.f_max[cycle] - self.f_min[cycle]) 106 | * (self.cycle_lengths[cycle] - n) 107 | / (self.cycle_lengths[cycle]) 108 | ) 109 | self.last_f = f 110 | return f 111 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | -------------------------------------------------------------------------------- /sgm/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/models/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/models/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/cp_enc_dec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/__pycache__/cp_enc_dec.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/__pycache__/video_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/__pycache__/video_attention.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | from .video_loss import VideoAutoencoderLoss 9 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous()) 41 | loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() 42 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 43 | 44 | if self.perceptual_weight_on_inputs > 0.0: 45 | image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions)) 46 | if self.scale_input_to_tgt_size: 47 | image_inputs = torch.nn.functional.interpolate( 48 | image_inputs, 49 | image_reconstructions.shape[2:], 50 | mode="bicubic", 51 | antialias=True, 52 | ) 53 | elif self.scale_tgt_to_input_size: 54 | image_reconstructions = torch.nn.functional.interpolate( 55 | image_reconstructions, 56 | image_inputs.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | 61 | perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous()) 62 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 63 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 64 | return loss, log 65 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/loss/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__pycache__/lpips.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/loss/__pycache__/lpips.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 31 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 32 | 33 | @classmethod 34 | def from_pretrained(cls, name="vgg_lpips"): 35 | if name != "vgg_lpips": 36 | raise NotImplementedError 37 | model = cls() 38 | ckpt = get_ckpt_path(name) 39 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 40 | return model 41 | 42 | def forward(self, input, target): 43 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 44 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 45 | feats0, feats1, diffs = {}, {}, {} 46 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 47 | for kk in range(len(self.chns)): 48 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 49 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 50 | 51 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 52 | val = res[0] 53 | for l in range(1, len(self.chns)): 54 | val += res[l] 55 | return val 56 | 57 | 58 | class ScalingLayer(nn.Module): 59 | def __init__(self): 60 | super(ScalingLayer, self).__init__() 61 | self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) 62 | self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) 63 | 64 | def forward(self, inp): 65 | return (inp - self.shift) / self.scale 66 | 67 | 68 | class NetLinLayer(nn.Module): 69 | """A single linear layer which does a 1x1 conv""" 70 | 71 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 72 | super(NetLinLayer, self).__init__() 73 | layers = ( 74 | [ 75 | nn.Dropout(), 76 | ] 77 | if (use_dropout) 78 | else [] 79 | ) 80 | layers += [ 81 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 82 | ] 83 | self.model = nn.Sequential(*layers) 84 | 85 | 86 | class vgg16(torch.nn.Module): 87 | def __init__(self, requires_grad=False, pretrained=True): 88 | super(vgg16, self).__init__() 89 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 90 | self.slice1 = torch.nn.Sequential() 91 | self.slice2 = torch.nn.Sequential() 92 | self.slice3 = torch.nn.Sequential() 93 | self.slice4 = torch.nn.Sequential() 94 | self.slice5 = torch.nn.Sequential() 95 | self.N_slices = 5 96 | for x in range(4): 97 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 98 | for x in range(4, 9): 99 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 100 | for x in range(9, 16): 101 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 102 | for x in range(16, 23): 103 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 104 | for x in range(23, 30): 105 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 106 | if not requires_grad: 107 | for param in self.parameters(): 108 | param.requires_grad = False 109 | 110 | def forward(self, X): 111 | h = self.slice1(X) 112 | h_relu1_2 = h 113 | h = self.slice2(h) 114 | h_relu2_2 = h 115 | h = self.slice3(h) 116 | h_relu3_3 = h 117 | h = self.slice4(h) 118 | h_relu4_3 = h 119 | h = self.slice5(h) 120 | h_relu5_3 = h 121 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) 122 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 123 | return out 124 | 125 | 126 | def normalize_tensor(x, eps=1e-10): 127 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 128 | return x / (norm_factor + eps) 129 | 130 | 131 | def spatial_average(x, keepdim=True): 132 | return x.mean([2, 3], keepdim=keepdim) 133 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | try: 12 | nn.init.normal_(m.weight.data, 0.0, 0.02) 13 | except: 14 | nn.init.normal_(m.conv.weight.data, 0.0, 0.02) 15 | elif classname.find("BatchNorm") != -1: 16 | nn.init.normal_(m.weight.data, 1.0, 0.02) 17 | nn.init.constant_(m.bias.data, 0) 18 | 19 | 20 | class NLayerDiscriminator(nn.Module): 21 | """Defines a PatchGAN discriminator as in Pix2Pix 22 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 23 | """ 24 | 25 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 26 | """Construct a PatchGAN discriminator 27 | Parameters: 28 | input_nc (int) -- the number of channels in input images 29 | ndf (int) -- the number of filters in the last conv layer 30 | n_layers (int) -- the number of conv layers in the discriminator 31 | norm_layer -- normalization layer 32 | """ 33 | super(NLayerDiscriminator, self).__init__() 34 | if not use_actnorm: 35 | norm_layer = nn.BatchNorm2d 36 | else: 37 | norm_layer = ActNorm 38 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 39 | use_bias = norm_layer.func != nn.BatchNorm2d 40 | else: 41 | use_bias = norm_layer != nn.BatchNorm2d 42 | 43 | kw = 4 44 | padw = 1 45 | sequence = [ 46 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 47 | nn.LeakyReLU(0.2, True), 48 | ] 49 | nf_mult = 1 50 | nf_mult_prev = 1 51 | for n in range(1, n_layers): # gradually increase the number of filters 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2**n, 8) 54 | sequence += [ 55 | nn.Conv2d( 56 | ndf * nf_mult_prev, 57 | ndf * nf_mult, 58 | kernel_size=kw, 59 | stride=2, 60 | padding=padw, 61 | bias=use_bias, 62 | ), 63 | norm_layer(ndf * nf_mult), 64 | nn.LeakyReLU(0.2, True), 65 | ] 66 | 67 | nf_mult_prev = nf_mult 68 | nf_mult = min(2**n_layers, 8) 69 | sequence += [ 70 | nn.Conv2d( 71 | ndf * nf_mult_prev, 72 | ndf * nf_mult, 73 | kernel_size=kw, 74 | stride=1, 75 | padding=padw, 76 | bias=use_bias, 77 | ), 78 | norm_layer(ndf * nf_mult), 79 | nn.LeakyReLU(0.2, True), 80 | ] 81 | 82 | sequence += [ 83 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 84 | ] # output 1 channel prediction map 85 | self.main = nn.Sequential(*sequence) 86 | 87 | def forward(self, input): 88 | """Standard forward.""" 89 | return self.main(input) 90 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): 47 | assert affine 48 | super().__init__() 49 | self.logdet = logdet 50 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 51 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 52 | self.allow_reverse_init = allow_reverse_init 53 | 54 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 55 | 56 | def initialize(self, input): 57 | with torch.no_grad(): 58 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 59 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) 60 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) 61 | 62 | self.loc.data.copy_(-mean) 63 | self.scale.data.copy_(1 / (std + 1e-6)) 64 | 65 | def forward(self, input, reverse=False): 66 | if reverse: 67 | return self.reverse(input) 68 | if len(input.shape) == 2: 69 | input = input[:, :, None, None] 70 | squeeze = True 71 | else: 72 | squeeze = False 73 | 74 | _, _, height, width = input.shape 75 | 76 | if self.training and self.initialized.item() == 0: 77 | self.initialize(input) 78 | self.initialized.fill_(1) 79 | 80 | h = self.scale * (input + self.loc) 81 | 82 | if squeeze: 83 | h = h.squeeze(-1).squeeze(-1) 84 | 85 | if self.logdet: 86 | log_abs = torch.log(torch.abs(self.scale)) 87 | logdet = height * width * torch.sum(log_abs) 88 | logdet = logdet * torch.ones(input.shape[0]).to(input) 89 | return h, logdet 90 | 91 | return h 92 | 93 | def reverse(self, output): 94 | if self.training and self.initialized.item() == 0: 95 | if not self.allow_reverse_init: 96 | raise RuntimeError( 97 | "Initializing ActNorm in reverse direction is " 98 | "disabled by default. Use allow_reverse_init=True to enable." 99 | ) 100 | else: 101 | self.initialize(output) 102 | self.initialized.fill_(1) 103 | 104 | if len(output.shape) == 2: 105 | output = output[:, :, None, None] 106 | squeeze = True 107 | else: 108 | squeeze = False 109 | 110 | h = output / self.scale - self.loc 111 | 112 | if squeeze: 113 | h = h.squeeze(-1).squeeze(-1) 114 | return h 115 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) 15 | ) 16 | return d_loss 17 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution 9 | from .base import AbstractRegularizer 10 | 11 | 12 | class DiagonalGaussianRegularizer(AbstractRegularizer): 13 | def __init__(self, sample: bool = True): 14 | super().__init__() 15 | self.sample = sample 16 | 17 | def get_trainable_parameters(self) -> Any: 18 | yield from () 19 | 20 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 21 | log = dict() 22 | posterior = DiagonalGaussianDistribution(z) 23 | if self.sample: 24 | z = posterior.sample() 25 | else: 26 | z = posterior.mode() 27 | kl_loss = posterior.kl() 28 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 29 | log["kl_loss"] = kl_loss 30 | return z, log 31 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: 30 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 31 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 32 | encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 33 | avg_probs = encodings.mean(0) 34 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 35 | cluster_use = torch.sum(avg_probs > 0) 36 | return perplexity, cluster_use 37 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 3 | Code adapted from Jax version in Appendix A.1 4 | """ 5 | 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import Module 11 | from torch import Tensor, int32 12 | from torch.cuda.amp import autocast 13 | 14 | from einops import rearrange, pack, unpack 15 | 16 | # helper functions 17 | 18 | 19 | def exists(v): 20 | return v is not None 21 | 22 | 23 | def default(*args): 24 | for arg in args: 25 | if exists(arg): 26 | return arg 27 | return None 28 | 29 | 30 | def pack_one(t, pattern): 31 | return pack([t], pattern) 32 | 33 | 34 | def unpack_one(t, ps, pattern): 35 | return unpack(t, ps, pattern)[0] 36 | 37 | 38 | # tensor helpers 39 | 40 | 41 | def round_ste(z: Tensor) -> Tensor: 42 | """Round with straight through gradients.""" 43 | zhat = z.round() 44 | return z + (zhat - z).detach() 45 | 46 | 47 | # main class 48 | 49 | 50 | class FSQ(Module): 51 | def __init__( 52 | self, 53 | levels: List[int], 54 | dim: Optional[int] = None, 55 | num_codebooks=1, 56 | keep_num_codebooks_dim: Optional[bool] = None, 57 | scale: Optional[float] = None, 58 | ): 59 | super().__init__() 60 | _levels = torch.tensor(levels, dtype=int32) 61 | self.register_buffer("_levels", _levels, persistent=False) 62 | 63 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) 64 | self.register_buffer("_basis", _basis, persistent=False) 65 | 66 | self.scale = scale 67 | 68 | codebook_dim = len(levels) 69 | self.codebook_dim = codebook_dim 70 | 71 | effective_codebook_dim = codebook_dim * num_codebooks 72 | self.num_codebooks = num_codebooks 73 | self.effective_codebook_dim = effective_codebook_dim 74 | 75 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) 76 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 77 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 78 | 79 | self.dim = default(dim, len(_levels) * num_codebooks) 80 | 81 | has_projections = self.dim != effective_codebook_dim 82 | self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() 83 | self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() 84 | self.has_projections = has_projections 85 | 86 | self.codebook_size = self._levels.prod().item() 87 | 88 | implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) 89 | self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) 90 | 91 | def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: 92 | """Bound `z`, an array of shape (..., d).""" 93 | half_l = (self._levels - 1) * (1 + eps) / 2 94 | offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) 95 | shift = (offset / half_l).atanh() 96 | return (z + shift).tanh() * half_l - offset 97 | 98 | def quantize(self, z: Tensor) -> Tensor: 99 | """Quantizes z, returns quantized zhat, same shape as z.""" 100 | quantized = round_ste(self.bound(z)) 101 | half_width = self._levels // 2 # Renormalize to [-1, 1]. 102 | return quantized / half_width 103 | 104 | def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor: 105 | half_width = self._levels // 2 106 | return (zhat_normalized * half_width) + half_width 107 | 108 | def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor: 109 | half_width = self._levels // 2 110 | return (zhat - half_width) / half_width 111 | 112 | def codes_to_indices(self, zhat: Tensor) -> Tensor: 113 | """Converts a `code` to an index in the codebook.""" 114 | assert zhat.shape[-1] == self.codebook_dim 115 | zhat = self._scale_and_shift(zhat) 116 | return (zhat * self._basis).sum(dim=-1).to(int32) 117 | 118 | def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor: 119 | """Inverse of `codes_to_indices`.""" 120 | 121 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 122 | 123 | indices = rearrange(indices, "... -> ... 1") 124 | codes_non_centered = (indices // self._basis) % self._levels 125 | codes = self._scale_and_shift_inverse(codes_non_centered) 126 | 127 | if self.keep_num_codebooks_dim: 128 | codes = rearrange(codes, "... c d -> ... (c d)") 129 | 130 | if project_out: 131 | codes = self.project_out(codes) 132 | 133 | if is_img_or_video: 134 | codes = rearrange(codes, "b ... d -> b d ...") 135 | 136 | return codes 137 | 138 | @autocast(enabled=False) 139 | def forward(self, z: Tensor) -> Tensor: 140 | """ 141 | einstein notation 142 | b - batch 143 | n - sequence (or flattened spatial dimensions) 144 | d - feature dimension 145 | c - number of codebook dim 146 | """ 147 | 148 | is_img_or_video = z.ndim >= 4 149 | 150 | # standardize image or video into (batch, seq, dimension) 151 | 152 | if is_img_or_video: 153 | z = rearrange(z, "b d ... -> b ... d") 154 | z, ps = pack_one(z, "b * d") 155 | 156 | assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" 157 | 158 | z = self.project_in(z) 159 | 160 | z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) 161 | 162 | codes = self.quantize(z) 163 | indices = self.codes_to_indices(codes) 164 | 165 | codes = rearrange(codes, "b n c d -> b n (c d)") 166 | 167 | out = self.project_out(codes) 168 | 169 | # reconstitute image or video dimensions 170 | 171 | if is_img_or_video: 172 | out = unpack_one(out, ps, "b * d") 173 | out = rearrange(out, "b ... d -> b d ...") 174 | 175 | indices = unpack_one(indices, ps, "b * c") 176 | 177 | if not self.keep_num_codebooks_dim: 178 | indices = rearrange(indices, "... 1 -> ...") 179 | 180 | return out, indices 181 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/lookup_free_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lookup Free Quantization 3 | Proposed in https://arxiv.org/abs/2310.05737 4 | 5 | In the simplest setup, each dimension is quantized into {-1, 1}. 6 | An entropy penalty is used to encourage utilization. 7 | """ 8 | 9 | from math import log2, ceil 10 | from collections import namedtuple 11 | 12 | import torch 13 | from torch import nn, einsum 14 | import torch.nn.functional as F 15 | from torch.nn import Module 16 | from torch.cuda.amp import autocast 17 | 18 | from einops import rearrange, reduce, pack, unpack 19 | 20 | # constants 21 | 22 | Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"]) 23 | 24 | LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"]) 25 | 26 | # helper functions 27 | 28 | 29 | def exists(v): 30 | return v is not None 31 | 32 | 33 | def default(*args): 34 | for arg in args: 35 | if exists(arg): 36 | return arg() if callable(arg) else arg 37 | return None 38 | 39 | 40 | def pack_one(t, pattern): 41 | return pack([t], pattern) 42 | 43 | 44 | def unpack_one(t, ps, pattern): 45 | return unpack(t, ps, pattern)[0] 46 | 47 | 48 | # entropy 49 | 50 | 51 | def log(t, eps=1e-5): 52 | return t.clamp(min=eps).log() 53 | 54 | 55 | def entropy(prob): 56 | return (-prob * log(prob)).sum(dim=-1) 57 | 58 | 59 | # class 60 | 61 | 62 | class LFQ(Module): 63 | def __init__( 64 | self, 65 | *, 66 | dim=None, 67 | codebook_size=None, 68 | entropy_loss_weight=0.1, 69 | commitment_loss_weight=0.25, 70 | diversity_gamma=1.0, 71 | straight_through_activation=nn.Identity(), 72 | num_codebooks=1, 73 | keep_num_codebooks_dim=None, 74 | codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer 75 | frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy 76 | ): 77 | super().__init__() 78 | 79 | # some assert validations 80 | 81 | assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ" 82 | assert ( 83 | not exists(codebook_size) or log2(codebook_size).is_integer() 84 | ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" 85 | 86 | codebook_size = default(codebook_size, lambda: 2**dim) 87 | codebook_dim = int(log2(codebook_size)) 88 | 89 | codebook_dims = codebook_dim * num_codebooks 90 | dim = default(dim, codebook_dims) 91 | 92 | has_projections = dim != codebook_dims 93 | self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity() 94 | self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity() 95 | self.has_projections = has_projections 96 | 97 | self.dim = dim 98 | self.codebook_dim = codebook_dim 99 | self.num_codebooks = num_codebooks 100 | 101 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) 102 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 103 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 104 | 105 | # straight through activation 106 | 107 | self.activation = straight_through_activation 108 | 109 | # entropy aux loss related weights 110 | 111 | assert 0 < frac_per_sample_entropy <= 1.0 112 | self.frac_per_sample_entropy = frac_per_sample_entropy 113 | 114 | self.diversity_gamma = diversity_gamma 115 | self.entropy_loss_weight = entropy_loss_weight 116 | 117 | # codebook scale 118 | 119 | self.codebook_scale = codebook_scale 120 | 121 | # commitment loss 122 | 123 | self.commitment_loss_weight = commitment_loss_weight 124 | 125 | # for no auxiliary loss, during inference 126 | 127 | self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1)) 128 | self.register_buffer("zero", torch.tensor(0.0), persistent=False) 129 | 130 | # codes 131 | 132 | all_codes = torch.arange(codebook_size) 133 | bits = ((all_codes[..., None].int() & self.mask) != 0).float() 134 | codebook = self.bits_to_codes(bits) 135 | 136 | self.register_buffer("codebook", codebook, persistent=False) 137 | 138 | def bits_to_codes(self, bits): 139 | return bits * self.codebook_scale * 2 - self.codebook_scale 140 | 141 | @property 142 | def dtype(self): 143 | return self.codebook.dtype 144 | 145 | def indices_to_codes(self, indices, project_out=True): 146 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 147 | 148 | if not self.keep_num_codebooks_dim: 149 | indices = rearrange(indices, "... -> ... 1") 150 | 151 | # indices to codes, which are bits of either -1 or 1 152 | 153 | bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) 154 | 155 | codes = self.bits_to_codes(bits) 156 | 157 | codes = rearrange(codes, "... c d -> ... (c d)") 158 | 159 | # whether to project codes out to original dimensions 160 | # if the input feature dimensions were not log2(codebook size) 161 | 162 | if project_out: 163 | codes = self.project_out(codes) 164 | 165 | # rearrange codes back to original shape 166 | 167 | if is_img_or_video: 168 | codes = rearrange(codes, "b ... d -> b d ...") 169 | 170 | return codes 171 | 172 | @autocast(enabled=False) 173 | def forward( 174 | self, 175 | x, 176 | inv_temperature=100.0, 177 | return_loss_breakdown=False, 178 | mask=None, 179 | ): 180 | """ 181 | einstein notation 182 | b - batch 183 | n - sequence (or flattened spatial dimensions) 184 | d - feature dimension, which is also log2(codebook size) 185 | c - number of codebook dim 186 | """ 187 | 188 | x = x.float() 189 | 190 | is_img_or_video = x.ndim >= 4 191 | 192 | # standardize image or video into (batch, seq, dimension) 193 | 194 | if is_img_or_video: 195 | x = rearrange(x, "b d ... -> b ... d") 196 | x, ps = pack_one(x, "b * d") 197 | 198 | assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}" 199 | 200 | x = self.project_in(x) 201 | 202 | # split out number of codebooks 203 | 204 | x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) 205 | 206 | # quantize by eq 3. 207 | 208 | original_input = x 209 | 210 | codebook_value = torch.ones_like(x) * self.codebook_scale 211 | quantized = torch.where(x > 0, codebook_value, -codebook_value) 212 | 213 | # use straight-through gradients (optionally with custom activation fn) if training 214 | 215 | if self.training: 216 | x = self.activation(x) 217 | x = x + (quantized - x).detach() 218 | else: 219 | x = quantized 220 | 221 | # calculate indices 222 | 223 | indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") 224 | 225 | # entropy aux loss 226 | 227 | if self.training: 228 | # the same as euclidean distance up to a constant 229 | distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook) 230 | 231 | prob = (-distance * inv_temperature).softmax(dim=-1) 232 | 233 | # account for mask 234 | 235 | if exists(mask): 236 | prob = prob[mask] 237 | else: 238 | prob = rearrange(prob, "b n ... -> (b n) ...") 239 | 240 | # whether to only use a fraction of probs, for reducing memory 241 | 242 | if self.frac_per_sample_entropy < 1.0: 243 | num_tokens = prob.shape[0] 244 | num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy) 245 | rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens 246 | per_sample_probs = prob[rand_mask] 247 | else: 248 | per_sample_probs = prob 249 | 250 | # calculate per sample entropy 251 | 252 | per_sample_entropy = entropy(per_sample_probs).mean() 253 | 254 | # distribution over all available tokens in the batch 255 | 256 | avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean") 257 | codebook_entropy = entropy(avg_prob).mean() 258 | 259 | # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions 260 | # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch 261 | 262 | entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy 263 | else: 264 | # if not training, just return dummy 0 265 | entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero 266 | 267 | # commit loss 268 | 269 | if self.training: 270 | commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none") 271 | 272 | if exists(mask): 273 | commit_loss = commit_loss[mask] 274 | 275 | commit_loss = commit_loss.mean() 276 | else: 277 | commit_loss = self.zero 278 | 279 | # merge back codebook dim 280 | 281 | x = rearrange(x, "b n c d -> b n (c d)") 282 | 283 | # project out to feature dimension if needed 284 | 285 | x = self.project_out(x) 286 | 287 | # reconstitute image or video dimensions 288 | 289 | if is_img_or_video: 290 | x = unpack_one(x, ps, "b * d") 291 | x = rearrange(x, "b ... d -> b d ...") 292 | 293 | indices = unpack_one(indices, ps, "b * c") 294 | 295 | # whether to remove single codebook dim 296 | 297 | if not self.keep_num_codebooks_dim: 298 | indices = rearrange(indices, "... 1 -> ...") 299 | 300 | # complete aux loss 301 | 302 | aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight 303 | 304 | ret = Return(x, indices, aux_loss) 305 | 306 | if not return_loss_breakdown: 307 | return ret 308 | 309 | return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss) 310 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/temporal_ae.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Union 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | 6 | from sgm.modules.diffusionmodules.model import ( 7 | XFORMERS_IS_AVAILABLE, 8 | AttnBlock, 9 | Decoder, 10 | MemoryEfficientAttnBlock, 11 | ResnetBlock, 12 | ) 13 | from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding 14 | from sgm.modules.video_attention import VideoTransformerBlock 15 | from sgm.util import partialclass 16 | 17 | 18 | class VideoResBlock(ResnetBlock): 19 | def __init__( 20 | self, 21 | out_channels, 22 | *args, 23 | dropout=0.0, 24 | video_kernel_size=3, 25 | alpha=0.0, 26 | merge_strategy="learned", 27 | **kwargs, 28 | ): 29 | super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) 30 | if video_kernel_size is None: 31 | video_kernel_size = [3, 1, 1] 32 | self.time_stack = ResBlock( 33 | channels=out_channels, 34 | emb_channels=0, 35 | dropout=dropout, 36 | dims=3, 37 | use_scale_shift_norm=False, 38 | use_conv=False, 39 | up=False, 40 | down=False, 41 | kernel_size=video_kernel_size, 42 | use_checkpoint=False, 43 | skip_t_emb=True, 44 | ) 45 | 46 | self.merge_strategy = merge_strategy 47 | if self.merge_strategy == "fixed": 48 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 49 | elif self.merge_strategy == "learned": 50 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) 51 | else: 52 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 53 | 54 | def get_alpha(self, bs): 55 | if self.merge_strategy == "fixed": 56 | return self.mix_factor 57 | elif self.merge_strategy == "learned": 58 | return torch.sigmoid(self.mix_factor) 59 | else: 60 | raise NotImplementedError() 61 | 62 | def forward(self, x, temb, skip_video=False, timesteps=None): 63 | if timesteps is None: 64 | timesteps = self.timesteps 65 | 66 | b, c, h, w = x.shape 67 | 68 | x = super().forward(x, temb) 69 | 70 | if not skip_video: 71 | x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 72 | 73 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 74 | 75 | x = self.time_stack(x, temb) 76 | 77 | alpha = self.get_alpha(bs=b // timesteps) 78 | x = alpha * x + (1.0 - alpha) * x_mix 79 | 80 | x = rearrange(x, "b c t h w -> (b t) c h w") 81 | return x 82 | 83 | 84 | class AE3DConv(torch.nn.Conv2d): 85 | def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): 86 | super().__init__(in_channels, out_channels, *args, **kwargs) 87 | if isinstance(video_kernel_size, Iterable): 88 | padding = [int(k // 2) for k in video_kernel_size] 89 | else: 90 | padding = int(video_kernel_size // 2) 91 | 92 | self.time_mix_conv = torch.nn.Conv3d( 93 | in_channels=out_channels, 94 | out_channels=out_channels, 95 | kernel_size=video_kernel_size, 96 | padding=padding, 97 | ) 98 | 99 | def forward(self, input, timesteps, skip_video=False): 100 | x = super().forward(input) 101 | if skip_video: 102 | return x 103 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 104 | x = self.time_mix_conv(x) 105 | return rearrange(x, "b c t h w -> (b t) c h w") 106 | 107 | 108 | class VideoBlock(AttnBlock): 109 | def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): 110 | super().__init__(in_channels) 111 | # no context, single headed, as in base class 112 | self.time_mix_block = VideoTransformerBlock( 113 | dim=in_channels, 114 | n_heads=1, 115 | d_head=in_channels, 116 | checkpoint=False, 117 | ff_in=True, 118 | attn_mode="softmax", 119 | ) 120 | 121 | time_embed_dim = self.in_channels * 4 122 | self.video_time_embed = torch.nn.Sequential( 123 | torch.nn.Linear(self.in_channels, time_embed_dim), 124 | torch.nn.SiLU(), 125 | torch.nn.Linear(time_embed_dim, self.in_channels), 126 | ) 127 | 128 | self.merge_strategy = merge_strategy 129 | if self.merge_strategy == "fixed": 130 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 131 | elif self.merge_strategy == "learned": 132 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) 133 | else: 134 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 135 | 136 | def forward(self, x, timesteps, skip_video=False): 137 | if skip_video: 138 | return super().forward(x) 139 | 140 | x_in = x 141 | x = self.attention(x) 142 | h, w = x.shape[2:] 143 | x = rearrange(x, "b c h w -> b (h w) c") 144 | 145 | x_mix = x 146 | num_frames = torch.arange(timesteps, device=x.device) 147 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 148 | num_frames = rearrange(num_frames, "b t -> (b t)") 149 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 150 | emb = self.video_time_embed(t_emb) # b, n_channels 151 | emb = emb[:, None, :] 152 | x_mix = x_mix + emb 153 | 154 | alpha = self.get_alpha() 155 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 156 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 157 | 158 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 159 | x = self.proj_out(x) 160 | 161 | return x_in + x 162 | 163 | def get_alpha( 164 | self, 165 | ): 166 | if self.merge_strategy == "fixed": 167 | return self.mix_factor 168 | elif self.merge_strategy == "learned": 169 | return torch.sigmoid(self.mix_factor) 170 | else: 171 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 172 | 173 | 174 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): 175 | def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): 176 | super().__init__(in_channels) 177 | # no context, single headed, as in base class 178 | self.time_mix_block = VideoTransformerBlock( 179 | dim=in_channels, 180 | n_heads=1, 181 | d_head=in_channels, 182 | checkpoint=False, 183 | ff_in=True, 184 | attn_mode="softmax-xformers", 185 | ) 186 | 187 | time_embed_dim = self.in_channels * 4 188 | self.video_time_embed = torch.nn.Sequential( 189 | torch.nn.Linear(self.in_channels, time_embed_dim), 190 | torch.nn.SiLU(), 191 | torch.nn.Linear(time_embed_dim, self.in_channels), 192 | ) 193 | 194 | self.merge_strategy = merge_strategy 195 | if self.merge_strategy == "fixed": 196 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 197 | elif self.merge_strategy == "learned": 198 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) 199 | else: 200 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 201 | 202 | def forward(self, x, timesteps, skip_time_block=False): 203 | if skip_time_block: 204 | return super().forward(x) 205 | 206 | x_in = x 207 | x = self.attention(x) 208 | h, w = x.shape[2:] 209 | x = rearrange(x, "b c h w -> b (h w) c") 210 | 211 | x_mix = x 212 | num_frames = torch.arange(timesteps, device=x.device) 213 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 214 | num_frames = rearrange(num_frames, "b t -> (b t)") 215 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 216 | emb = self.video_time_embed(t_emb) # b, n_channels 217 | emb = emb[:, None, :] 218 | x_mix = x_mix + emb 219 | 220 | alpha = self.get_alpha() 221 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 222 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 223 | 224 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 225 | x = self.proj_out(x) 226 | 227 | return x_in + x 228 | 229 | def get_alpha( 230 | self, 231 | ): 232 | if self.merge_strategy == "fixed": 233 | return self.mix_factor 234 | elif self.merge_strategy == "learned": 235 | return torch.sigmoid(self.mix_factor) 236 | else: 237 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 238 | 239 | 240 | def make_time_attn( 241 | in_channels, 242 | attn_type="vanilla", 243 | attn_kwargs=None, 244 | alpha: float = 0, 245 | merge_strategy: str = "learned", 246 | ): 247 | assert attn_type in [ 248 | "vanilla", 249 | "vanilla-xformers", 250 | ], f"attn_type {attn_type} not supported for spatio-temporal attention" 251 | print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels") 252 | if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": 253 | print( 254 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " 255 | f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" 256 | ) 257 | attn_type = "vanilla" 258 | 259 | if attn_type == "vanilla": 260 | assert attn_kwargs is None 261 | return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy) 262 | elif attn_type == "vanilla-xformers": 263 | print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") 264 | return partialclass( 265 | MemoryEfficientVideoBlock, 266 | in_channels, 267 | alpha=alpha, 268 | merge_strategy=merge_strategy, 269 | ) 270 | else: 271 | return NotImplementedError() 272 | 273 | 274 | class Conv2DWrapper(torch.nn.Conv2d): 275 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: 276 | return super().forward(input) 277 | 278 | 279 | class VideoDecoder(Decoder): 280 | available_time_modes = ["all", "conv-only", "attn-only"] 281 | 282 | def __init__( 283 | self, 284 | *args, 285 | video_kernel_size: Union[int, list] = 3, 286 | alpha: float = 0.0, 287 | merge_strategy: str = "learned", 288 | time_mode: str = "conv-only", 289 | **kwargs, 290 | ): 291 | self.video_kernel_size = video_kernel_size 292 | self.alpha = alpha 293 | self.merge_strategy = merge_strategy 294 | self.time_mode = time_mode 295 | assert ( 296 | self.time_mode in self.available_time_modes 297 | ), f"time_mode parameter has to be in {self.available_time_modes}" 298 | super().__init__(*args, **kwargs) 299 | 300 | def get_last_layer(self, skip_time_mix=False, **kwargs): 301 | if self.time_mode == "attn-only": 302 | raise NotImplementedError("TODO") 303 | else: 304 | return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight 305 | 306 | def _make_attn(self) -> Callable: 307 | if self.time_mode not in ["conv-only", "only-last-conv"]: 308 | return partialclass( 309 | make_time_attn, 310 | alpha=self.alpha, 311 | merge_strategy=self.merge_strategy, 312 | ) 313 | else: 314 | return super()._make_attn() 315 | 316 | def _make_conv(self) -> Callable: 317 | if self.time_mode != "attn-only": 318 | return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) 319 | else: 320 | return Conv2DWrapper 321 | 322 | def _make_resblock(self) -> Callable: 323 | if self.time_mode not in ["attn-only", "only-last-conv"]: 324 | return partialclass( 325 | VideoResBlock, 326 | video_kernel_size=self.video_kernel_size, 327 | alpha=self.alpha, 328 | merge_strategy=self.merge_strategy, 329 | ) 330 | else: 331 | return super()._make_resblock() 332 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch import einsum 6 | from einops import rearrange 7 | 8 | 9 | class VectorQuantizer2(nn.Module): 10 | """ 11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 13 | """ 14 | 15 | # NOTE: due to a bug the beta term was applied to the wrong term. for 16 | # backwards compatibility we use the buggy version by default, but you can 17 | # specify legacy=False to fix it. 18 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): 19 | super().__init__() 20 | self.n_e = n_e 21 | self.e_dim = e_dim 22 | self.beta = beta 23 | self.legacy = legacy 24 | 25 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 27 | 28 | self.remap = remap 29 | if self.remap is not None: 30 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 31 | self.re_embed = self.used.shape[0] 32 | self.unknown_index = unknown_index # "random" or "extra" or integer 33 | if self.unknown_index == "extra": 34 | self.unknown_index = self.re_embed 35 | self.re_embed = self.re_embed + 1 36 | print( 37 | f"Remapping {self.n_e} indices to {self.re_embed} indices. " 38 | f"Using {self.unknown_index} for unknown indices." 39 | ) 40 | else: 41 | self.re_embed = n_e 42 | 43 | self.sane_index_shape = sane_index_shape 44 | 45 | def remap_to_used(self, inds): 46 | ishape = inds.shape 47 | assert len(ishape) > 1 48 | inds = inds.reshape(ishape[0], -1) 49 | used = self.used.to(inds) 50 | match = (inds[:, :, None] == used[None, None, ...]).long() 51 | new = match.argmax(-1) 52 | unknown = match.sum(2) < 1 53 | if self.unknown_index == "random": 54 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) 55 | else: 56 | new[unknown] = self.unknown_index 57 | return new.reshape(ishape) 58 | 59 | def unmap_to_all(self, inds): 60 | ishape = inds.shape 61 | assert len(ishape) > 1 62 | inds = inds.reshape(ishape[0], -1) 63 | used = self.used.to(inds) 64 | if self.re_embed > self.used.shape[0]: # extra token 65 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 66 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 67 | return back.reshape(ishape) 68 | 69 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 70 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" 71 | assert rescale_logits == False, "Only for interface compatible with Gumbel" 72 | assert return_logits == False, "Only for interface compatible with Gumbel" 73 | # reshape z -> (batch, height, width, channel) and flatten 74 | z = rearrange(z, "b c h w -> b h w c").contiguous() 75 | z_flattened = z.view(-1, self.e_dim) 76 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 77 | 78 | d = ( 79 | torch.sum(z_flattened**2, dim=1, keepdim=True) 80 | + torch.sum(self.embedding.weight**2, dim=1) 81 | - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) 82 | ) 83 | 84 | min_encoding_indices = torch.argmin(d, dim=1) 85 | z_q = self.embedding(min_encoding_indices).view(z.shape) 86 | perplexity = None 87 | min_encodings = None 88 | 89 | # compute loss for embedding 90 | if not self.legacy: 91 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) 92 | else: 93 | loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) 94 | 95 | # preserve gradients 96 | z_q = z + (z_q - z).detach() 97 | 98 | # reshape back to match original input shape 99 | z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() 100 | 101 | if self.remap is not None: 102 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis 103 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 104 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten 105 | 106 | if self.sane_index_shape: 107 | min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) 108 | 109 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 110 | 111 | def get_codebook_entry(self, indices, shape): 112 | # shape specifying (batch, height, width, channel) 113 | if self.remap is not None: 114 | indices = indices.reshape(shape[0], -1) # add batch axis 115 | indices = self.unmap_to_all(indices) 116 | indices = indices.reshape(-1) # flatten again 117 | 118 | # get quantized latent vectors 119 | z_q = self.embedding(indices) 120 | 121 | if shape is not None: 122 | z_q = z_q.view(shape) 123 | # reshape back to match original input shape 124 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 125 | 126 | return z_q 127 | 128 | 129 | class GumbelQuantize(nn.Module): 130 | """ 131 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 132 | Gumbel Softmax trick quantizer 133 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 134 | https://arxiv.org/abs/1611.01144 135 | """ 136 | 137 | def __init__( 138 | self, 139 | num_hiddens, 140 | embedding_dim, 141 | n_embed, 142 | straight_through=True, 143 | kl_weight=5e-4, 144 | temp_init=1.0, 145 | use_vqinterface=True, 146 | remap=None, 147 | unknown_index="random", 148 | ): 149 | super().__init__() 150 | 151 | self.embedding_dim = embedding_dim 152 | self.n_embed = n_embed 153 | 154 | self.straight_through = straight_through 155 | self.temperature = temp_init 156 | self.kl_weight = kl_weight 157 | 158 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 159 | self.embed = nn.Embedding(n_embed, embedding_dim) 160 | 161 | self.use_vqinterface = use_vqinterface 162 | 163 | self.remap = remap 164 | if self.remap is not None: 165 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 166 | self.re_embed = self.used.shape[0] 167 | self.unknown_index = unknown_index # "random" or "extra" or integer 168 | if self.unknown_index == "extra": 169 | self.unknown_index = self.re_embed 170 | self.re_embed = self.re_embed + 1 171 | print( 172 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 173 | f"Using {self.unknown_index} for unknown indices." 174 | ) 175 | else: 176 | self.re_embed = n_embed 177 | 178 | def remap_to_used(self, inds): 179 | ishape = inds.shape 180 | assert len(ishape) > 1 181 | inds = inds.reshape(ishape[0], -1) 182 | used = self.used.to(inds) 183 | match = (inds[:, :, None] == used[None, None, ...]).long() 184 | new = match.argmax(-1) 185 | unknown = match.sum(2) < 1 186 | if self.unknown_index == "random": 187 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) 188 | else: 189 | new[unknown] = self.unknown_index 190 | return new.reshape(ishape) 191 | 192 | def unmap_to_all(self, inds): 193 | ishape = inds.shape 194 | assert len(ishape) > 1 195 | inds = inds.reshape(ishape[0], -1) 196 | used = self.used.to(inds) 197 | if self.re_embed > self.used.shape[0]: # extra token 198 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 199 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 200 | return back.reshape(ishape) 201 | 202 | def forward(self, z, temp=None, return_logits=False): 203 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 204 | hard = self.straight_through if self.training else True 205 | temp = self.temperature if temp is None else temp 206 | 207 | logits = self.proj(z) 208 | if self.remap is not None: 209 | # continue only with used logits 210 | full_zeros = torch.zeros_like(logits) 211 | logits = logits[:, self.used, ...] 212 | 213 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 214 | if self.remap is not None: 215 | # go back to all entries but unused set to zero 216 | full_zeros[:, self.used, ...] = soft_one_hot 217 | soft_one_hot = full_zeros 218 | z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) 219 | 220 | # + kl divergence to the prior loss 221 | qy = F.softmax(logits, dim=1) 222 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 223 | 224 | ind = soft_one_hot.argmax(dim=1) 225 | if self.remap is not None: 226 | ind = self.remap_to_used(ind) 227 | if self.use_vqinterface: 228 | if return_logits: 229 | return z_q, diff, (None, None, ind), logits 230 | return z_q, diff, (None, None, ind) 231 | return z_q, diff, ind 232 | 233 | def get_codebook_entry(self, indices, shape): 234 | b, h, w, c = shape 235 | assert b * h * w == indices.shape[0] 236 | indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) 237 | if self.remap is not None: 238 | indices = self.unmap_to_all(indices) 239 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 240 | z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) 241 | return z_q 242 | -------------------------------------------------------------------------------- /sgm/modules/cp_enc_dec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.distributed 4 | import torch.nn as nn 5 | from ..util import ( 6 | get_context_parallel_group, 7 | get_context_parallel_rank, 8 | get_context_parallel_world_size, 9 | ) 10 | 11 | _USE_CP = True 12 | 13 | 14 | def cast_tuple(t, length=1): 15 | return t if isinstance(t, tuple) else ((t,) * length) 16 | 17 | 18 | def divisible_by(num, den): 19 | return (num % den) == 0 20 | 21 | 22 | def is_odd(n): 23 | return not divisible_by(n, 2) 24 | 25 | 26 | def exists(v): 27 | return v is not None 28 | 29 | 30 | def pair(t): 31 | return t if isinstance(t, tuple) else (t, t) 32 | 33 | 34 | def get_timestep_embedding(timesteps, embedding_dim): 35 | """ 36 | This matches the implementation in Denoising Diffusion Probabilistic Models: 37 | From Fairseq. 38 | Build sinusoidal embeddings. 39 | This matches the implementation in tensor2tensor, but differs slightly 40 | from the description in Section 3.5 of "Attention Is All You Need". 41 | """ 42 | assert len(timesteps.shape) == 1 43 | 44 | half_dim = embedding_dim // 2 45 | emb = math.log(10000) / (half_dim - 1) 46 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 47 | emb = emb.to(device=timesteps.device) 48 | emb = timesteps.float()[:, None] * emb[None, :] 49 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 50 | if embedding_dim % 2 == 1: # zero pad 51 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 52 | return emb 53 | 54 | 55 | def nonlinearity(x): 56 | # swish 57 | return x * torch.sigmoid(x) 58 | 59 | 60 | def leaky_relu(p=0.1): 61 | return nn.LeakyReLU(p) 62 | 63 | 64 | def _split(input_, dim): 65 | cp_world_size = get_context_parallel_world_size() 66 | 67 | if cp_world_size == 1: 68 | return input_ 69 | 70 | cp_rank = get_context_parallel_rank() 71 | 72 | # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) 73 | 74 | inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() 75 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() 76 | dim_size = input_.size()[dim] // cp_world_size 77 | 78 | input_list = torch.split(input_, dim_size, dim=dim) 79 | output = input_list[cp_rank] 80 | 81 | if cp_rank == 0: 82 | output = torch.cat([inpu_first_frame_, output], dim=dim) 83 | output = output.contiguous() 84 | 85 | # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) 86 | 87 | return output 88 | 89 | 90 | def _gather(input_, dim): 91 | cp_world_size = get_context_parallel_world_size() 92 | 93 | # Bypass the function if context parallel is 1 94 | if cp_world_size == 1: 95 | return input_ 96 | 97 | group = get_context_parallel_group() 98 | cp_rank = get_context_parallel_rank() 99 | 100 | # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 101 | 102 | input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() 103 | if cp_rank == 0: 104 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() 105 | 106 | tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ 107 | torch.empty_like(input_) for _ in range(cp_world_size - 1) 108 | ] 109 | 110 | if cp_rank == 0: 111 | input_ = torch.cat([input_first_frame_, input_], dim=dim) 112 | 113 | tensor_list[cp_rank] = input_ 114 | torch.distributed.all_gather(tensor_list, input_, group=group) 115 | 116 | output = torch.cat(tensor_list, dim=dim).contiguous() 117 | 118 | # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) 119 | 120 | return output 121 | 122 | 123 | def _conv_split(input_, dim, kernel_size): 124 | cp_world_size = get_context_parallel_world_size() 125 | 126 | # Bypass the function if context parallel is 1 127 | if cp_world_size == 1: 128 | return input_ 129 | 130 | # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) 131 | 132 | cp_rank = get_context_parallel_rank() 133 | 134 | dim_size = (input_.size()[dim] - kernel_size) // cp_world_size 135 | 136 | if cp_rank == 0: 137 | output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) 138 | else: 139 | output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose( 140 | dim, 0 141 | ) 142 | output = output.contiguous() 143 | 144 | # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) 145 | 146 | return output 147 | 148 | 149 | def _conv_gather(input_, dim, kernel_size): 150 | cp_world_size = get_context_parallel_world_size() 151 | 152 | # Bypass the function if context parallel is 1 153 | if cp_world_size == 1: 154 | return input_ 155 | 156 | group = get_context_parallel_group() 157 | cp_rank = get_context_parallel_rank() 158 | 159 | # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 160 | 161 | input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() 162 | if cp_rank == 0: 163 | input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() 164 | else: 165 | input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous() 166 | 167 | tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ 168 | torch.empty_like(input_) for _ in range(cp_world_size - 1) 169 | ] 170 | if cp_rank == 0: 171 | input_ = torch.cat([input_first_kernel_, input_], dim=dim) 172 | 173 | tensor_list[cp_rank] = input_ 174 | torch.distributed.all_gather(tensor_list, input_, group=group) 175 | 176 | # Note: torch.cat already creates a contiguous tensor. 177 | output = torch.cat(tensor_list, dim=dim).contiguous() 178 | 179 | # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) 180 | 181 | return output 182 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | from .discretizer import Discretization 3 | from .model import Decoder, Encoder, Model 4 | from .openaimodel import UNetModel 5 | from .sampling import BaseDiffusionSampler 6 | from .wrappers import OpenAIWrapper 7 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/lora.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | 8 | 9 | class Denoiser(nn.Module): 10 | def __init__(self, weighting_config, scaling_config): 11 | super().__init__() 12 | 13 | self.weighting = instantiate_from_config(weighting_config) 14 | self.scaling = instantiate_from_config(scaling_config) 15 | 16 | def possibly_quantize_sigma(self, sigma): 17 | return sigma 18 | 19 | def possibly_quantize_c_noise(self, c_noise): 20 | return c_noise 21 | 22 | def w(self, sigma): 23 | return self.weighting(sigma) 24 | 25 | def forward( 26 | self, 27 | network: nn.Module, 28 | input: torch.Tensor, 29 | sigma: torch.Tensor, 30 | cond: Dict, 31 | **additional_model_inputs, 32 | ) -> torch.Tensor: 33 | sigma = self.possibly_quantize_sigma(sigma) 34 | sigma_shape = sigma.shape 35 | sigma = append_dims(sigma, input.ndim) 36 | c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) 37 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 38 | 39 | return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip 40 | 41 | 42 | class DiscreteDenoiser(Denoiser): 43 | def __init__( 44 | self, 45 | weighting_config, 46 | scaling_config, 47 | num_idx, 48 | discretization_config, 49 | do_append_zero=False, 50 | quantize_c_noise=True, 51 | flip=True, 52 | ): 53 | super().__init__(weighting_config, scaling_config) 54 | sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) 55 | self.sigmas = sigmas 56 | # self.register_buffer("sigmas", sigmas) 57 | self.quantize_c_noise = quantize_c_noise 58 | 59 | def sigma_to_idx(self, sigma): 60 | dists = sigma - self.sigmas.to(sigma.device)[:, None] 61 | return dists.abs().argmin(dim=0).view(sigma.shape) 62 | 63 | def idx_to_sigma(self, idx): 64 | return self.sigmas.to(idx.device)[idx] 65 | 66 | def possibly_quantize_sigma(self, sigma): 67 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 68 | 69 | def possibly_quantize_c_noise(self, c_noise): 70 | if self.quantize_c_noise: 71 | return self.sigma_to_idx(c_noise) 72 | else: 73 | return c_noise 74 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 10 | pass 11 | 12 | 13 | class EDMScaling: 14 | def __init__(self, sigma_data: float = 0.5): 15 | self.sigma_data = sigma_data 16 | 17 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 18 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 19 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 20 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 21 | c_noise = 0.25 * sigma.log() 22 | return c_skip, c_out, c_in, c_noise 23 | 24 | 25 | class EpsScaling: 26 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 27 | c_skip = torch.ones_like(sigma, device=sigma.device) 28 | c_out = -sigma 29 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 30 | c_noise = sigma.clone() 31 | return c_skip, c_out, c_in, c_noise 32 | 33 | 34 | class VScaling: 35 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 36 | c_skip = 1.0 / (sigma**2 + 1.0) 37 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 38 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 39 | c_noise = sigma.clone() 40 | return c_skip, c_out, c_in, c_noise 41 | 42 | 43 | class VScalingWithEDMcNoise(DenoiserScaling): 44 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 45 | c_skip = 1.0 / (sigma**2 + 1.0) 46 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 47 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 48 | c_noise = 0.25 * sigma.log() 49 | return c_skip, c_out, c_in, c_noise 50 | 51 | 52 | class VideoScaling: # similar to VScaling 53 | def __call__( 54 | self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs 55 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 56 | c_skip = alphas_cumprod_sqrt 57 | c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5) 58 | c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device) 59 | c_noise = additional_model_inputs["idx"].clone() 60 | return c_skip, c_out, c_in, c_noise 61 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray: 12 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 13 | 14 | 15 | class Discretization: 16 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False): 17 | if return_idx: 18 | sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx) 19 | else: 20 | sigmas = self.get_sigmas(n, device=device, return_idx=return_idx) 21 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 22 | if return_idx: 23 | return sigmas if not flip else torch.flip(sigmas, (0,)), idx 24 | else: 25 | return sigmas if not flip else torch.flip(sigmas, (0,)) 26 | 27 | @abstractmethod 28 | def get_sigmas(self, n, device): 29 | pass 30 | 31 | 32 | class EDMDiscretization(Discretization): 33 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 34 | self.sigma_min = sigma_min 35 | self.sigma_max = sigma_max 36 | self.rho = rho 37 | 38 | def get_sigmas(self, n, device="cpu"): 39 | ramp = torch.linspace(0, 1, n, device=device) 40 | min_inv_rho = self.sigma_min ** (1 / self.rho) 41 | max_inv_rho = self.sigma_max ** (1 / self.rho) 42 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 43 | return sigmas 44 | 45 | 46 | class LegacyDDPMDiscretization(Discretization): 47 | def __init__( 48 | self, 49 | linear_start=0.00085, 50 | linear_end=0.0120, 51 | num_timesteps=1000, 52 | ): 53 | super().__init__() 54 | self.num_timesteps = num_timesteps 55 | betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) 56 | alphas = 1.0 - betas 57 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 58 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 59 | 60 | def get_sigmas(self, n, device="cpu"): 61 | if n < self.num_timesteps: 62 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 63 | alphas_cumprod = self.alphas_cumprod[timesteps] 64 | elif n == self.num_timesteps: 65 | alphas_cumprod = self.alphas_cumprod 66 | else: 67 | raise ValueError 68 | 69 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 70 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 71 | return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029 72 | 73 | 74 | class ZeroSNRDDPMDiscretization(Discretization): 75 | def __init__( 76 | self, 77 | linear_start=0.00085, 78 | linear_end=0.0120, 79 | num_timesteps=1000, 80 | shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) 81 | keep_start=False, 82 | post_shift=False, 83 | ): 84 | super().__init__() 85 | if keep_start and not post_shift: 86 | linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) 87 | self.num_timesteps = num_timesteps 88 | betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) 89 | alphas = 1.0 - betas 90 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 91 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 92 | 93 | # SNR shift 94 | if not post_shift: 95 | self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) 96 | 97 | self.post_shift = post_shift 98 | self.shift_scale = shift_scale 99 | 100 | def get_sigmas(self, n, device="cpu", return_idx=False): 101 | if n < self.num_timesteps: 102 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 103 | alphas_cumprod = self.alphas_cumprod[timesteps] 104 | elif n == self.num_timesteps: 105 | alphas_cumprod = self.alphas_cumprod 106 | else: 107 | raise ValueError 108 | 109 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 110 | alphas_cumprod = to_torch(alphas_cumprod) 111 | alphas_cumprod_sqrt = alphas_cumprod.sqrt() 112 | alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone() 113 | alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() 114 | 115 | alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T 116 | alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) 117 | 118 | if self.post_shift: 119 | alphas_cumprod_sqrt = ( 120 | alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) 121 | ) ** 0.5 122 | 123 | if return_idx: 124 | return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps 125 | else: 126 | return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99 127 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Optional, Tuple, Union 4 | from functools import partial 5 | import math 6 | 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | from ...util import append_dims, default, instantiate_from_config 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: 19 | pass 20 | 21 | 22 | class VanillaCFG: 23 | """ 24 | implements parallelized CFG 25 | """ 26 | 27 | def __init__(self, scale, dyn_thresh_config=None): 28 | self.scale = scale 29 | scale_schedule = lambda scale, sigma: scale # independent of step 30 | self.scale_schedule = partial(scale_schedule, scale) 31 | self.dyn_thresh = instantiate_from_config( 32 | default( 33 | dyn_thresh_config, 34 | {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, 35 | ) 36 | ) 37 | 38 | def __call__(self, x, sigma, scale=None): 39 | x_u, x_c = x.chunk(2) 40 | scale_value = default(scale, self.scale_schedule(sigma)) 41 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 42 | return x_pred 43 | 44 | def prepare_inputs(self, x, s, c, uc): 45 | c_out = dict() 46 | 47 | for k in c: 48 | if k in ["vector", "crossattn", "concat"]: 49 | c_out[k] = torch.cat((uc[k], c[k]), 0) 50 | else: 51 | assert c[k] == uc[k] 52 | c_out[k] = c[k] 53 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 54 | 55 | 56 | class DynamicCFG(VanillaCFG): 57 | def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): 58 | super().__init__(scale, dyn_thresh_config) 59 | scale_schedule = ( 60 | lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 61 | ) 62 | self.scale_schedule = partial(scale_schedule, scale) 63 | self.dyn_thresh = instantiate_from_config( 64 | default( 65 | dyn_thresh_config, 66 | {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, 67 | ) 68 | ) 69 | 70 | def __call__(self, x, sigma, step_index, scale=None): 71 | x_u, x_c = x.chunk(2) 72 | scale_value = self.scale_schedule(sigma, step_index.item()) 73 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 74 | return x_pred 75 | 76 | 77 | class IdentityGuider: 78 | def __call__(self, x, sigma): 79 | return x 80 | 81 | def prepare_inputs(self, x, s, c, uc): 82 | c_out = dict() 83 | 84 | for k in c: 85 | c_out[k] = c[k] 86 | 87 | return x, s, c_out 88 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from omegaconf import ListConfig 6 | from ...util import append_dims, instantiate_from_config 7 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 8 | from sat import mpu 9 | 10 | 11 | class StandardDiffusionLoss(nn.Module): 12 | def __init__( 13 | self, 14 | sigma_sampler_config, 15 | type="l2", 16 | offset_noise_level=0.0, 17 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, 18 | ): 19 | super().__init__() 20 | 21 | assert type in ["l2", "l1", "lpips"] 22 | 23 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 24 | 25 | self.type = type 26 | self.offset_noise_level = offset_noise_level 27 | 28 | if type == "lpips": 29 | self.lpips = LPIPS().eval() 30 | 31 | if not batch2model_keys: 32 | batch2model_keys = [] 33 | 34 | if isinstance(batch2model_keys, str): 35 | batch2model_keys = [batch2model_keys] 36 | 37 | self.batch2model_keys = set(batch2model_keys) 38 | 39 | def __call__(self, network, denoiser, conditioner, input, batch): 40 | cond = conditioner(batch) 41 | additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} 42 | 43 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device) 44 | noise = torch.randn_like(input) 45 | if self.offset_noise_level > 0.0: 46 | noise = ( 47 | noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level 48 | ) 49 | noise = noise.to(input.dtype) 50 | noised_input = input.float() + noise * append_dims(sigmas, input.ndim) 51 | model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs) 52 | w = append_dims(denoiser.w(sigmas), input.ndim) 53 | return self.get_loss(model_output, input, w) 54 | 55 | def get_loss(self, model_output, target, w): 56 | if self.type == "l2": 57 | return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1) 58 | elif self.type == "l1": 59 | return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1) 60 | elif self.type == "lpips": 61 | loss = self.lpips(model_output, target).reshape(-1) 62 | return loss 63 | 64 | 65 | class VideoDiffusionLoss(StandardDiffusionLoss): 66 | def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs): 67 | self.fixed_frames = fixed_frames 68 | self.block_scale = block_scale 69 | self.block_size = block_size 70 | self.min_snr_value = min_snr_value 71 | super().__init__(**kwargs) 72 | 73 | def __call__(self, network, denoiser, conditioner, input, batch): 74 | cond = conditioner(batch) 75 | additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} 76 | 77 | alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True) 78 | alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device) 79 | idx = idx.to(input.device) 80 | 81 | noise = torch.randn_like(input) 82 | 83 | # broadcast noise 84 | mp_size = mpu.get_model_parallel_world_size() 85 | global_rank = torch.distributed.get_rank() // mp_size 86 | src = global_rank * mp_size 87 | torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group()) 88 | torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group()) 89 | torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()) 90 | 91 | additional_model_inputs["idx"] = idx 92 | 93 | if 'context_image' in batch.keys(): 94 | additional_model_inputs["context_image"] = batch['context_image'] 95 | 96 | if self.offset_noise_level > 0.0: 97 | noise = ( 98 | noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level 99 | ) 100 | 101 | noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims( 102 | (1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim 103 | ) 104 | 105 | if "concat_images" in batch.keys(): 106 | cond["concat"] = batch["concat_images"] 107 | 108 | # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx']) 109 | model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs) 110 | w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred 111 | 112 | if self.min_snr_value is not None: 113 | w = min(w, self.min_snr_value) 114 | return self.get_loss(model_output, input, w) 115 | 116 | def get_loss(self, model_output, target, w): 117 | if self.type == "l2": 118 | return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1) 119 | elif self.type == "l1": 120 | return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1) 121 | elif self.type == "lpips": 122 | loss = self.lpips(model_output, target).reshape(-1) 123 | return loss 124 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | from einops import rearrange 6 | 7 | 8 | class NoDynamicThresholding: 9 | def __call__(self, uncond, cond, scale): 10 | scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale 11 | return uncond + scale * (cond - uncond) 12 | 13 | 14 | class StaticThresholding: 15 | def __call__(self, uncond, cond, scale): 16 | result = uncond + scale * (cond - uncond) 17 | result = torch.clamp(result, min=-1.0, max=1.0) 18 | return result 19 | 20 | 21 | def dynamic_threshold(x, p=0.95): 22 | N, T, C, H, W = x.shape 23 | x = rearrange(x, "n t c h w -> n c (t h w)") 24 | l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True) 25 | s = torch.maximum(-l, r) 26 | threshold_mask = (s > 1).expand(-1, -1, H * W * T) 27 | if threshold_mask.any(): 28 | x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x) 29 | x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W) 30 | return x 31 | 32 | 33 | def dynamic_thresholding2(x0): 34 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. 35 | origin_dtype = x0.dtype 36 | x0 = x0.to(torch.float32) 37 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 38 | s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) 39 | x0 = torch.clamp(x0, -s, s) # / s 40 | return x0.to(origin_dtype) 41 | 42 | 43 | def latent_dynamic_thresholding(x0): 44 | p = 0.9995 45 | origin_dtype = x0.dtype 46 | x0 = x0.to(torch.float32) 47 | s = torch.quantile(torch.abs(x0), p, dim=2) 48 | s = append_dims(s, x0.dim()) 49 | x0 = torch.clamp(x0, -s, s) / s 50 | return x0.to(origin_dtype) 51 | 52 | 53 | def dynamic_thresholding3(x0): 54 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. 55 | origin_dtype = x0.dtype 56 | x0 = x0.to(torch.float32) 57 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 58 | s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) 59 | x0 = torch.clamp(x0, -s, s) # / s 60 | return x0.to(origin_dtype) 61 | 62 | 63 | class DynamicThresholding: 64 | def __call__(self, uncond, cond, scale): 65 | mean = uncond.mean() 66 | std = uncond.std() 67 | result = uncond + scale * (cond - uncond) 68 | result_mean, result_std = result.mean(), result.std() 69 | result = (result - result_mean) / result_std * std 70 | # result = dynamic_thresholding3(result) 71 | return result 72 | 73 | 74 | class DynamicThresholdingV1: 75 | def __init__(self, scale_factor): 76 | self.scale_factor = scale_factor 77 | 78 | def __call__(self, uncond, cond, scale): 79 | result = uncond + scale * (cond - uncond) 80 | unscaled_result = result / self.scale_factor 81 | B, T, C, H, W = unscaled_result.shape 82 | flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)") 83 | means = flattened.mean(dim=2).unsqueeze(2) 84 | recentered = flattened - means 85 | magnitudes = recentered.abs().max() 86 | normalized = recentered / magnitudes 87 | thresholded = latent_dynamic_thresholding(normalized) 88 | denormalized = thresholded * magnitudes 89 | uncentered = denormalized + means 90 | unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W) 91 | scaled_result = unflattened * self.scale_factor 92 | return scaled_result 93 | 94 | 95 | class DynamicThresholdingV2: 96 | def __call__(self, uncond, cond, scale): 97 | B, T, C, H, W = uncond.shape 98 | diff = cond - uncond 99 | mim_target = uncond + diff * 4.0 100 | cfg_target = uncond + diff * 8.0 101 | 102 | mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)") 103 | cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)") 104 | mim_means = mim_flattened.mean(dim=2).unsqueeze(2) 105 | cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) 106 | mim_centered = mim_flattened - mim_means 107 | cfg_centered = cfg_flattened - cfg_means 108 | 109 | mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) 110 | cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) 111 | 112 | cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref 113 | 114 | result = cfg_renormalized + cfg_means 115 | unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W) 116 | 117 | return unflattened 118 | 119 | 120 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 121 | if order - 1 > i: 122 | raise ValueError(f"Order {order} too high for step {i}") 123 | 124 | def fn(tau): 125 | prod = 1.0 126 | for k in range(order): 127 | if j == k: 128 | continue 129 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 130 | return prod 131 | 132 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 133 | 134 | 135 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 136 | if not eta: 137 | return sigma_to, 0.0 138 | sigma_up = torch.minimum( 139 | sigma_to, 140 | eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 141 | ) 142 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 143 | return sigma_down, sigma_up 144 | 145 | 146 | def to_d(x, sigma, denoised): 147 | return (x - denoised) / append_dims(sigma, x.ndim) 148 | 149 | 150 | def to_neg_log_sigma(sigma): 151 | return sigma.log().neg() 152 | 153 | 154 | def to_sigma(neg_log_sigma): 155 | return neg_log_sigma.neg().exp() 156 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed 3 | 4 | from sat import mpu 5 | 6 | from ...util import default, instantiate_from_config 7 | 8 | 9 | class EDMSampling: 10 | def __init__(self, p_mean=-1.2, p_std=1.2): 11 | self.p_mean = p_mean 12 | self.p_std = p_std 13 | 14 | def __call__(self, n_samples, rand=None): 15 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 16 | return log_sigma.exp() 17 | 18 | 19 | class DiscreteSampling: 20 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): 21 | self.num_idx = num_idx 22 | self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) 23 | world_size = mpu.get_data_parallel_world_size() 24 | self.uniform_sampling = uniform_sampling 25 | if self.uniform_sampling: 26 | i = 1 27 | while True: 28 | if world_size % i != 0 or num_idx % (world_size // i) != 0: 29 | i += 1 30 | else: 31 | self.group_num = world_size // i 32 | break 33 | 34 | assert self.group_num > 0 35 | assert world_size % self.group_num == 0 36 | self.group_width = world_size // self.group_num # the number of rank in one group 37 | self.sigma_interval = self.num_idx // self.group_num 38 | 39 | def idx_to_sigma(self, idx): 40 | return self.sigmas[idx] 41 | 42 | def __call__(self, n_samples, rand=None, return_idx=False): 43 | if self.uniform_sampling: 44 | rank = mpu.get_data_parallel_rank() 45 | group_index = rank // self.group_width 46 | idx = default( 47 | rand, 48 | torch.randint( 49 | group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) 50 | ), 51 | ) 52 | else: 53 | idx = default( 54 | rand, 55 | torch.randint(0, self.num_idx, (n_samples,)), 56 | ) 57 | if return_idx: 58 | return self.idx_to_sigma(idx), idx 59 | else: 60 | return self.idx_to_sigma(idx) 61 | 62 | 63 | class PartialDiscreteSampling: 64 | def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): 65 | self.total_num_idx = total_num_idx 66 | self.partial_num_idx = partial_num_idx 67 | self.sigmas = instantiate_from_config(discretization_config)( 68 | total_num_idx, do_append_zero=do_append_zero, flip=flip 69 | ) 70 | 71 | def idx_to_sigma(self, idx): 72 | return self.sigmas[idx] 73 | 74 | def __call__(self, n_samples, rand=None): 75 | idx = default( 76 | rand, 77 | # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)), 78 | torch.randint(0, self.partial_num_idx, (n_samples,)), 79 | ) 80 | return self.idx_to_sigma(idx) 81 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model 14 | else lambda x: x 15 | ) 16 | self.diffusion_model = compile(diffusion_model) 17 | self.dtype = dtype 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: 25 | for key in c: 26 | c[key] = c[key].to(self.dtype) 27 | 28 | if x.dim() == 4: 29 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 30 | elif x.dim() == 5: 31 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2) 32 | else: 33 | raise ValueError("Input tensor must be 4D or 5D") 34 | 35 | return self.diffusion_model( 36 | x, 37 | timesteps=t, 38 | context=c.get("crossattn", None), 39 | y=c.get("vector", None), 40 | **kwargs, 41 | ) 42 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | # x = self.mean + self.std * torch.randn(self.mean.shape).to( 37 | # device=self.parameters.device 38 | # ) 39 | x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.0]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum( 48 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3], 50 | ) 51 | else: 52 | return 0.5 * torch.sum( 53 | torch.pow(self.mean - other.mean, 2) / other.var 54 | + self.var / other.var 55 | - 1.0 56 | - self.logvar 57 | + other.logvar, 58 | dim=[1, 2, 3], 59 | ) 60 | 61 | def nll(self, sample, dims=[1, 2, 3]): 62 | if self.deterministic: 63 | return torch.Tensor([0.0]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims, 68 | ) 69 | 70 | def mode(self): 71 | return self.mean 72 | 73 | 74 | def normal_kl(mean1, logvar1, mean2, logvar2): 75 | """ 76 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 77 | Compute the KL divergence between two gaussians. 78 | Shapes are automatically broadcasted, so batches can be compared to 79 | scalars, among other use cases. 80 | """ 81 | tensor = None 82 | for obj in (mean1, logvar1, mean2, logvar2): 83 | if isinstance(obj, torch.Tensor): 84 | tensor = obj 85 | break 86 | assert tensor is not None, "at least one argument must be a Tensor" 87 | 88 | # Force variances to be Tensors. Broadcasting helps convert scalars to 89 | # Tensors, but it does not work for torch.exp(). 90 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 91 | 92 | return 0.5 * ( 93 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 94 | ) 95 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), 16 | ) 17 | 18 | for name, p in model.named_parameters(): 19 | if p.requires_grad: 20 | # remove as '.'-character is not allowed in buffers 21 | s_name = name.replace(".", "") 22 | self.m_name2s_name.update({name: s_name}) 23 | self.register_buffer(s_name, p.clone().detach().data) 24 | 25 | self.collected_params = [] 26 | 27 | def reset_num_updates(self): 28 | del self.num_updates 29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/encoders/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /sgm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import nullcontext 3 | from functools import partial 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | import kornia 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange, repeat 11 | from omegaconf import ListConfig 12 | from torch.utils.checkpoint import checkpoint 13 | from transformers import ( 14 | T5EncoderModel, 15 | T5Tokenizer, 16 | ) 17 | 18 | from ...util import ( 19 | append_dims, 20 | autocast, 21 | count_params, 22 | default, 23 | disabled_train, 24 | expand_dims_like, 25 | instantiate_from_config, 26 | ) 27 | 28 | 29 | class AbstractEmbModel(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self._is_trainable = None 33 | self._ucg_rate = None 34 | self._input_key = None 35 | 36 | @property 37 | def is_trainable(self) -> bool: 38 | return self._is_trainable 39 | 40 | @property 41 | def ucg_rate(self) -> Union[float, torch.Tensor]: 42 | return self._ucg_rate 43 | 44 | @property 45 | def input_key(self) -> str: 46 | return self._input_key 47 | 48 | @is_trainable.setter 49 | def is_trainable(self, value: bool): 50 | self._is_trainable = value 51 | 52 | @ucg_rate.setter 53 | def ucg_rate(self, value: Union[float, torch.Tensor]): 54 | self._ucg_rate = value 55 | 56 | @input_key.setter 57 | def input_key(self, value: str): 58 | self._input_key = value 59 | 60 | @is_trainable.deleter 61 | def is_trainable(self): 62 | del self._is_trainable 63 | 64 | @ucg_rate.deleter 65 | def ucg_rate(self): 66 | del self._ucg_rate 67 | 68 | @input_key.deleter 69 | def input_key(self): 70 | del self._input_key 71 | 72 | 73 | class GeneralConditioner(nn.Module): 74 | OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} 75 | KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} 76 | 77 | def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]): 78 | super().__init__() 79 | embedders = [] 80 | for n, embconfig in enumerate(emb_models): 81 | embedder = instantiate_from_config(embconfig) 82 | assert isinstance( 83 | embedder, AbstractEmbModel 84 | ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" 85 | embedder.is_trainable = embconfig.get("is_trainable", False) 86 | embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) 87 | if not embedder.is_trainable: 88 | embedder.train = disabled_train 89 | for param in embedder.parameters(): 90 | param.requires_grad = False 91 | embedder.eval() 92 | print( 93 | f"Initialized embedder #{n}: {embedder.__class__.__name__} " 94 | f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" 95 | ) 96 | 97 | if "input_key" in embconfig: 98 | embedder.input_key = embconfig["input_key"] 99 | elif "input_keys" in embconfig: 100 | embedder.input_keys = embconfig["input_keys"] 101 | else: 102 | raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") 103 | 104 | embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) 105 | if embedder.legacy_ucg_val is not None: 106 | embedder.ucg_prng = np.random.RandomState() 107 | 108 | embedders.append(embedder) 109 | self.embedders = nn.ModuleList(embedders) 110 | 111 | if len(cor_embs) > 0: 112 | assert len(cor_p) == 2 ** len(cor_embs) 113 | self.cor_embs = cor_embs 114 | self.cor_p = cor_p 115 | 116 | def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: 117 | assert embedder.legacy_ucg_val is not None 118 | p = embedder.ucg_rate 119 | val = embedder.legacy_ucg_val 120 | for i in range(len(batch[embedder.input_key])): 121 | if embedder.ucg_prng.choice(2, p=[1 - p, p]): 122 | batch[embedder.input_key][i] = val 123 | return batch 124 | 125 | def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict: 126 | assert embedder.legacy_ucg_val is not None 127 | val = embedder.legacy_ucg_val 128 | for i in range(len(batch[embedder.input_key])): 129 | if cond_or_not[i]: 130 | batch[embedder.input_key][i] = val 131 | return batch 132 | 133 | def get_single_embedding( 134 | self, 135 | embedder, 136 | batch, 137 | output, 138 | cond_or_not: Optional[np.ndarray] = None, 139 | force_zero_embeddings: Optional[List] = None, 140 | ): 141 | embedding_context = nullcontext if embedder.is_trainable else torch.no_grad 142 | with embedding_context(): 143 | if hasattr(embedder, "input_key") and (embedder.input_key is not None): 144 | if embedder.legacy_ucg_val is not None: 145 | if cond_or_not is None: 146 | batch = self.possibly_get_ucg_val(embedder, batch) 147 | else: 148 | batch = self.surely_get_ucg_val(embedder, batch, cond_or_not) 149 | emb_out = embedder(batch[embedder.input_key]) 150 | elif hasattr(embedder, "input_keys"): 151 | emb_out = embedder(*[batch[k] for k in embedder.input_keys]) 152 | assert isinstance( 153 | emb_out, (torch.Tensor, list, tuple) 154 | ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" 155 | if not isinstance(emb_out, (list, tuple)): 156 | emb_out = [emb_out] 157 | for emb in emb_out: 158 | out_key = self.OUTPUT_DIM2KEYS[emb.dim()] 159 | if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: 160 | if cond_or_not is None: 161 | emb = ( 162 | expand_dims_like( 163 | torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), 164 | emb, 165 | ) 166 | * emb 167 | ) 168 | else: 169 | emb = ( 170 | expand_dims_like( 171 | torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), 172 | emb, 173 | ) 174 | * emb 175 | ) 176 | if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: 177 | emb = torch.zeros_like(emb) 178 | if out_key in output: 179 | output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) 180 | else: 181 | output[out_key] = emb 182 | return output 183 | 184 | def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict: 185 | output = dict() 186 | if force_zero_embeddings is None: 187 | force_zero_embeddings = [] 188 | 189 | if len(self.cor_embs) > 0: 190 | batch_size = len(batch[list(batch.keys())[0]]) 191 | rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p) 192 | for emb_idx in self.cor_embs: 193 | cond_or_not = rand_idx % 2 194 | rand_idx //= 2 195 | output = self.get_single_embedding( 196 | self.embedders[emb_idx], 197 | batch, 198 | output=output, 199 | cond_or_not=cond_or_not, 200 | force_zero_embeddings=force_zero_embeddings, 201 | ) 202 | 203 | for i, embedder in enumerate(self.embedders): 204 | if i in self.cor_embs: 205 | continue 206 | output = self.get_single_embedding( 207 | embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings 208 | ) 209 | return output 210 | 211 | def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None): 212 | if force_uc_zero_embeddings is None: 213 | force_uc_zero_embeddings = [] 214 | ucg_rates = list() 215 | for embedder in self.embedders: 216 | ucg_rates.append(embedder.ucg_rate) 217 | embedder.ucg_rate = 0.0 218 | cor_embs = self.cor_embs 219 | cor_p = self.cor_p 220 | self.cor_embs = [] 221 | self.cor_p = [] 222 | 223 | c = self(batch_c) 224 | uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) 225 | 226 | for embedder, rate in zip(self.embedders, ucg_rates): 227 | embedder.ucg_rate = rate 228 | self.cor_embs = cor_embs 229 | self.cor_p = cor_p 230 | 231 | return c, uc 232 | 233 | 234 | class FrozenT5Embedder(AbstractEmbModel): 235 | """Uses the T5 transformer encoder for text""" 236 | 237 | def __init__( 238 | self, 239 | model_dir="google/t5-v1_1-xxl", 240 | device="cuda", 241 | max_length=77, 242 | freeze=True, 243 | cache_dir=None, 244 | ): 245 | super().__init__() 246 | if model_dir != "google/t5-v1_1-xxl": 247 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir) 248 | self.transformer = T5EncoderModel.from_pretrained(model_dir) 249 | else: 250 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir) 251 | self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir) 252 | self.device = device 253 | self.max_length = max_length 254 | if freeze: 255 | self.freeze() 256 | 257 | def freeze(self): 258 | self.transformer = self.transformer.eval() 259 | 260 | for param in self.parameters(): 261 | param.requires_grad = False 262 | 263 | # @autocast 264 | def forward(self, text): 265 | batch_encoding = self.tokenizer( 266 | text, 267 | truncation=True, 268 | max_length=self.max_length, 269 | return_length=True, 270 | return_overflowing_tokens=False, 271 | padding="max_length", 272 | return_tensors="pt", 273 | ) 274 | tokens = batch_encoding["input_ids"].to(self.device) 275 | with torch.autocast("cuda", enabled=False): 276 | outputs = self.transformer(input_ids=tokens) 277 | z = outputs.last_hidden_state 278 | return z 279 | 280 | def encode(self, text): 281 | return self(text) 282 | -------------------------------------------------------------------------------- /sgm/modules/video_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..modules.attention import * 4 | from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding 5 | 6 | 7 | class TimeMixSequential(nn.Sequential): 8 | def forward(self, x, context=None, timesteps=None): 9 | for layer in self: 10 | x = layer(x, context, timesteps) 11 | 12 | return x 13 | 14 | 15 | class VideoTransformerBlock(nn.Module): 16 | ATTENTION_MODES = { 17 | "softmax": CrossAttention, 18 | "softmax-xformers": MemoryEfficientCrossAttention, 19 | } 20 | 21 | def __init__( 22 | self, 23 | dim, 24 | n_heads, 25 | d_head, 26 | dropout=0.0, 27 | context_dim=None, 28 | gated_ff=True, 29 | checkpoint=True, 30 | timesteps=None, 31 | ff_in=False, 32 | inner_dim=None, 33 | attn_mode="softmax", 34 | disable_self_attn=False, 35 | disable_temporal_crossattention=False, 36 | switch_temporal_ca_to_sa=False, 37 | ): 38 | super().__init__() 39 | 40 | attn_cls = self.ATTENTION_MODES[attn_mode] 41 | 42 | self.ff_in = ff_in or inner_dim is not None 43 | if inner_dim is None: 44 | inner_dim = dim 45 | 46 | assert int(n_heads * d_head) == inner_dim 47 | 48 | self.is_res = inner_dim == dim 49 | 50 | if self.ff_in: 51 | self.norm_in = nn.LayerNorm(dim) 52 | self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff) 53 | 54 | self.timesteps = timesteps 55 | self.disable_self_attn = disable_self_attn 56 | if self.disable_self_attn: 57 | self.attn1 = attn_cls( 58 | query_dim=inner_dim, 59 | heads=n_heads, 60 | dim_head=d_head, 61 | context_dim=context_dim, 62 | dropout=dropout, 63 | ) # is a cross-attention 64 | else: 65 | self.attn1 = attn_cls( 66 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 67 | ) # is a self-attention 68 | 69 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) 70 | 71 | if disable_temporal_crossattention: 72 | if switch_temporal_ca_to_sa: 73 | raise ValueError 74 | else: 75 | self.attn2 = None 76 | else: 77 | self.norm2 = nn.LayerNorm(inner_dim) 78 | if switch_temporal_ca_to_sa: 79 | self.attn2 = attn_cls( 80 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 81 | ) # is a self-attention 82 | else: 83 | self.attn2 = attn_cls( 84 | query_dim=inner_dim, 85 | context_dim=context_dim, 86 | heads=n_heads, 87 | dim_head=d_head, 88 | dropout=dropout, 89 | ) # is self-attn if context is none 90 | 91 | self.norm1 = nn.LayerNorm(inner_dim) 92 | self.norm3 = nn.LayerNorm(inner_dim) 93 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 94 | 95 | self.checkpoint = checkpoint 96 | if self.checkpoint: 97 | print(f"{self.__class__.__name__} is using checkpointing") 98 | 99 | def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: 100 | if self.checkpoint: 101 | return checkpoint(self._forward, x, context, timesteps) 102 | else: 103 | return self._forward(x, context, timesteps=timesteps) 104 | 105 | def _forward(self, x, context=None, timesteps=None): 106 | assert self.timesteps or timesteps 107 | assert not (self.timesteps and timesteps) or self.timesteps == timesteps 108 | timesteps = self.timesteps or timesteps 109 | B, S, C = x.shape 110 | x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) 111 | 112 | if self.ff_in: 113 | x_skip = x 114 | x = self.ff_in(self.norm_in(x)) 115 | if self.is_res: 116 | x += x_skip 117 | 118 | if self.disable_self_attn: 119 | x = self.attn1(self.norm1(x), context=context) + x 120 | else: 121 | x = self.attn1(self.norm1(x)) + x 122 | 123 | if self.attn2 is not None: 124 | if self.switch_temporal_ca_to_sa: 125 | x = self.attn2(self.norm2(x)) + x 126 | else: 127 | x = self.attn2(self.norm2(x), context=context) + x 128 | x_skip = x 129 | x = self.ff(self.norm3(x)) 130 | if self.is_res: 131 | x += x_skip 132 | 133 | x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps) 134 | return x 135 | 136 | def get_last_layer(self): 137 | return self.ff.net[-1].weight 138 | 139 | 140 | str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 141 | 142 | 143 | class SpatialVideoTransformer(SpatialTransformer): 144 | def __init__( 145 | self, 146 | in_channels, 147 | n_heads, 148 | d_head, 149 | depth=1, 150 | dropout=0.0, 151 | use_linear=False, 152 | context_dim=None, 153 | use_spatial_context=False, 154 | timesteps=None, 155 | merge_strategy: str = "fixed", 156 | merge_factor: float = 0.5, 157 | time_context_dim=None, 158 | ff_in=False, 159 | checkpoint=False, 160 | time_depth=1, 161 | attn_mode="softmax", 162 | disable_self_attn=False, 163 | disable_temporal_crossattention=False, 164 | max_time_embed_period: int = 10000, 165 | dtype="fp32", 166 | ): 167 | super().__init__( 168 | in_channels, 169 | n_heads, 170 | d_head, 171 | depth=depth, 172 | dropout=dropout, 173 | attn_type=attn_mode, 174 | use_checkpoint=checkpoint, 175 | context_dim=context_dim, 176 | use_linear=use_linear, 177 | disable_self_attn=disable_self_attn, 178 | ) 179 | self.time_depth = time_depth 180 | self.depth = depth 181 | self.max_time_embed_period = max_time_embed_period 182 | 183 | time_mix_d_head = d_head 184 | n_time_mix_heads = n_heads 185 | 186 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) 187 | 188 | inner_dim = n_heads * d_head 189 | if use_spatial_context: 190 | time_context_dim = context_dim 191 | 192 | self.time_stack = nn.ModuleList( 193 | [ 194 | VideoTransformerBlock( 195 | inner_dim, 196 | n_time_mix_heads, 197 | time_mix_d_head, 198 | dropout=dropout, 199 | context_dim=time_context_dim, 200 | timesteps=timesteps, 201 | checkpoint=checkpoint, 202 | ff_in=ff_in, 203 | inner_dim=time_mix_inner_dim, 204 | attn_mode=attn_mode, 205 | disable_self_attn=disable_self_attn, 206 | disable_temporal_crossattention=disable_temporal_crossattention, 207 | ) 208 | for _ in range(self.depth) 209 | ] 210 | ) 211 | 212 | assert len(self.time_stack) == len(self.transformer_blocks) 213 | 214 | self.use_spatial_context = use_spatial_context 215 | self.in_channels = in_channels 216 | 217 | time_embed_dim = self.in_channels * 4 218 | self.time_pos_embed = nn.Sequential( 219 | linear(self.in_channels, time_embed_dim), 220 | nn.SiLU(), 221 | linear(time_embed_dim, self.in_channels), 222 | ) 223 | 224 | self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy) 225 | self.dtype = str_to_dtype[dtype] 226 | 227 | def forward( 228 | self, 229 | x: torch.Tensor, 230 | context: Optional[torch.Tensor] = None, 231 | time_context: Optional[torch.Tensor] = None, 232 | timesteps: Optional[int] = None, 233 | image_only_indicator: Optional[torch.Tensor] = None, 234 | ) -> torch.Tensor: 235 | _, _, h, w = x.shape 236 | x_in = x 237 | spatial_context = None 238 | if exists(context): 239 | spatial_context = context 240 | 241 | if self.use_spatial_context: 242 | assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" 243 | 244 | time_context = context 245 | time_context_first_timestep = time_context[::timesteps] 246 | time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w) 247 | elif time_context is not None and not self.use_spatial_context: 248 | time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) 249 | if time_context.ndim == 2: 250 | time_context = rearrange(time_context, "b c -> b 1 c") 251 | 252 | x = self.norm(x) 253 | if not self.use_linear: 254 | x = self.proj_in(x) 255 | x = rearrange(x, "b c h w -> b (h w) c") 256 | if self.use_linear: 257 | x = self.proj_in(x) 258 | 259 | num_frames = torch.arange(timesteps, device=x.device) 260 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 261 | num_frames = rearrange(num_frames, "b t -> (b t)") 262 | t_emb = timestep_embedding( 263 | num_frames, 264 | self.in_channels, 265 | repeat_only=False, 266 | max_period=self.max_time_embed_period, 267 | dtype=self.dtype, 268 | ) 269 | emb = self.time_pos_embed(t_emb) 270 | emb = emb[:, None, :] 271 | 272 | for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)): 273 | x = block( 274 | x, 275 | context=spatial_context, 276 | ) 277 | 278 | x_mix = x 279 | x_mix = x_mix + emb 280 | 281 | x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) 282 | x = self.time_mixer( 283 | x_spatial=x, 284 | x_temporal=x_mix, 285 | image_only_indicator=image_only_indicator, 286 | ) 287 | if self.use_linear: 288 | x = self.proj_out(x) 289 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 290 | if not self.use_linear: 291 | x = self.proj_out(x) 292 | out = x + x_in 293 | return out 294 | -------------------------------------------------------------------------------- /train_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from functools import partial 4 | import numpy as np 5 | import torch.distributed 6 | from omegaconf import OmegaConf 7 | import imageio 8 | import time 9 | 10 | import torch 11 | 12 | from sat import mpu 13 | from sat.training.deepspeed_training import training_main 14 | 15 | from sgm.util import get_obj_from_str, isheatmap 16 | 17 | from diffusion_video import SATVideoDiffusionEngine 18 | from arguments import get_args 19 | 20 | from einops import rearrange 21 | 22 | try: 23 | import wandb 24 | except ImportError: 25 | print("warning: wandb not installed") 26 | 27 | 28 | def print_debug(args, s): 29 | if args.debug: 30 | s = f"RANK:[{torch.distributed.get_rank()}]:" + s 31 | print(s) 32 | 33 | 34 | def save_texts(texts, save_dir, iterations): 35 | output_path = os.path.join(save_dir, f"{str(iterations).zfill(8)}") 36 | with open(output_path, "w", encoding="utf-8") as f: 37 | for text in texts: 38 | f.write(text + "\n") 39 | 40 | 41 | def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None): 42 | os.makedirs(save_path, exist_ok=True) 43 | 44 | for i, vid in enumerate(video_batch): 45 | gif_frames = [] 46 | for frame in vid: 47 | frame = rearrange(frame, "c h w -> h w c") 48 | frame = (255.0 * frame).cpu().numpy().astype(np.uint8) 49 | gif_frames.append(frame) 50 | now_save_path = os.path.join(save_path, f"{i:06d}.mp4") 51 | with imageio.get_writer(now_save_path, fps=fps) as writer: 52 | for frame in gif_frames: 53 | writer.append_data(frame) 54 | if args is not None and args.wandb: 55 | # wandb.log( 56 | # {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1 57 | # ) 58 | wandb.log( 59 | {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration 60 | ) 61 | 62 | 63 | def log_video(batch, model, args, only_log_video_latents=False): 64 | texts = batch["txt"] 65 | text_save_dir = os.path.join(args.save, "video_texts") 66 | os.makedirs(text_save_dir, exist_ok=True) 67 | save_texts(texts, text_save_dir, args.iteration) 68 | 69 | gpu_autocast_kwargs = { 70 | "enabled": torch.is_autocast_enabled(), 71 | "dtype": torch.get_autocast_gpu_dtype(), 72 | "cache_enabled": torch.is_autocast_cache_enabled(), 73 | } 74 | with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs): 75 | videos = model.log_video(batch, only_log_video_latents=only_log_video_latents) 76 | 77 | if torch.distributed.get_rank() == 0: 78 | root = os.path.join(args.save, "video") 79 | 80 | if only_log_video_latents: 81 | root = os.path.join(root, "latents") 82 | filename = "{}_gs-{:06}".format("latents", args.iteration) 83 | path = os.path.join(root, filename) 84 | os.makedirs(os.path.split(path)[0], exist_ok=True) 85 | os.makedirs(path, exist_ok=True) 86 | torch.save(videos["latents"], os.path.join(path, "latent.pt")) 87 | else: 88 | for k in videos: 89 | N = videos[k].shape[0] 90 | if not isheatmap(videos[k]): 91 | videos[k] = videos[k][:N] 92 | if isinstance(videos[k], torch.Tensor): 93 | videos[k] = videos[k].detach().float().cpu() 94 | if not isheatmap(videos[k]): 95 | videos[k] = torch.clamp(videos[k], -1.0, 1.0) 96 | 97 | num_frames = batch["num_frames"][0] 98 | fps = batch["fps"][0].cpu().item() 99 | if only_log_video_latents: 100 | root = os.path.join(root, "latents") 101 | filename = "{}_gs-{:06}".format("latents", args.iteration) 102 | path = os.path.join(root, filename) 103 | os.makedirs(os.path.split(path)[0], exist_ok=True) 104 | os.makedirs(path, exist_ok=True) 105 | torch.save(videos["latents"], os.path.join(path, "latents.pt")) 106 | else: 107 | for k in videos: 108 | current_time = time.localtime() 109 | formatted_time = time.strftime("%Y-%m-%d-%H:%M:%S", current_time) 110 | samples = (videos[k] + 1.0) / 2.0 111 | filename = "{}_gs-{:06}_{}".format(k, args.iteration, formatted_time) 112 | 113 | path = os.path.join(root, filename) 114 | os.makedirs(os.path.split(path)[0], exist_ok=True) 115 | save_video_as_grid_and_mp4(samples, path, num_frames // fps, fps, args, k) 116 | 117 | 118 | def broad_cast_batch(batch): 119 | mp_size = mpu.get_model_parallel_world_size() 120 | global_rank = torch.distributed.get_rank() // mp_size 121 | src = global_rank * mp_size 122 | 123 | if batch["mp4"] is not None: 124 | broadcast_shape = [batch["mp4"].shape, batch["fps"].shape, batch["num_frames"].shape, batch["face_image"].shape] 125 | else: 126 | broadcast_shape = None 127 | 128 | txt = [batch["txt"], broadcast_shape] 129 | torch.distributed.broadcast_object_list(txt, src=src, group=mpu.get_model_parallel_group()) 130 | batch["txt"] = txt[0] 131 | 132 | mp4_shape = txt[1][0] 133 | fps_shape = txt[1][1] 134 | num_frames_shape = txt[1][2] 135 | face_image_shape = txt[1][3] 136 | 137 | if mpu.get_model_parallel_rank() != 0: 138 | batch["mp4"] = torch.zeros(mp4_shape, device="cuda") 139 | batch["fps"] = torch.zeros(fps_shape, device="cuda", dtype=torch.long) 140 | batch["num_frames"] = torch.zeros(num_frames_shape, device="cuda", dtype=torch.long) 141 | batch["face_image"] = torch.zeros(face_image_shape, device="cuda") 142 | 143 | torch.distributed.broadcast(batch["mp4"], src=src, group=mpu.get_model_parallel_group()) 144 | torch.distributed.broadcast(batch["fps"], src=src, group=mpu.get_model_parallel_group()) 145 | torch.distributed.broadcast(batch["num_frames"], src=src, group=mpu.get_model_parallel_group()) 146 | torch.distributed.broadcast(batch["face_image"], src=src, group=mpu.get_model_parallel_group()) 147 | return batch 148 | 149 | 150 | def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None): 151 | if mpu.get_model_parallel_rank() == 0: 152 | timers("data loader").start() 153 | batch_video = next(data_iterator) 154 | timers("data loader").stop() 155 | 156 | if len(batch_video["mp4"].shape) == 6: 157 | b, v = batch_video["mp4"].shape[:2] 158 | batch_video["mp4"] = batch_video["mp4"].view(-1, *batch_video["mp4"].shape[2:]) 159 | batch_video["face_image"] = batch_video["face_image"].view(-1, *batch_video["face_image"].shape[2:]) 160 | txt = [] 161 | for i in range(b): 162 | for j in range(v): 163 | txt.append(batch_video["txt"][j][i]) 164 | batch_video["txt"] = txt 165 | 166 | for key in batch_video: 167 | if isinstance(batch_video[key], torch.Tensor): 168 | batch_video[key] = batch_video[key].cuda() 169 | else: 170 | batch_video = {"mp4": None, "fps": None, "num_frames": None, "txt": None, "face_image": None} 171 | broad_cast_batch(batch_video) 172 | if mpu.get_data_parallel_rank() == 0: 173 | log_video(batch_video, model, args, only_log_video_latents=only_log_video_latents) 174 | 175 | batch_video["global_step"] = args.iteration 176 | loss, loss_dict = model.shared_step(batch_video) 177 | for k in loss_dict: 178 | if loss_dict[k].dtype == torch.bfloat16: 179 | loss_dict[k] = loss_dict[k].to(torch.float32) 180 | return loss, loss_dict 181 | 182 | 183 | def forward_step(data_iterator, model, args, timers, data_class=None): 184 | if mpu.get_model_parallel_rank() == 0: 185 | timers("data loader").start() 186 | batch = next(data_iterator) 187 | timers("data loader").stop() 188 | for key in batch: 189 | if isinstance(batch[key], torch.Tensor): 190 | batch[key] = batch[key].cuda() 191 | 192 | if torch.distributed.get_rank() == 0: 193 | if not os.path.exists(os.path.join(args.save, "training_config.yaml")): 194 | configs = [OmegaConf.load(cfg) for cfg in args.base] 195 | config = OmegaConf.merge(*configs) 196 | os.makedirs(args.save, exist_ok=True) 197 | OmegaConf.save(config=config, f=os.path.join(args.save, "training_config.yaml")) 198 | else: 199 | batch = {"mp4": None, "fps": None, "num_frames": None, "txt": None, "face_image": None} 200 | 201 | batch["global_step"] = args.iteration 202 | 203 | broad_cast_batch(batch) 204 | 205 | loss, loss_dict = model.shared_step(batch) 206 | 207 | return loss, loss_dict 208 | 209 | 210 | if __name__ == "__main__": 211 | if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: 212 | os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] 213 | os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] 214 | os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] 215 | 216 | py_parser = argparse.ArgumentParser(add_help=False) 217 | known, args_list = py_parser.parse_known_args() 218 | args = get_args(args_list) 219 | # print(f'args: {args_list}') 220 | # xx 221 | args = argparse.Namespace(**vars(args), **vars(known)) 222 | 223 | data_class = get_obj_from_str(args.data_config["target"]) 224 | create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"]) 225 | 226 | import yaml 227 | 228 | configs = [] 229 | for config in args.base: 230 | with open(config, "r") as f: 231 | base_config = yaml.safe_load(f) 232 | configs.append(base_config) 233 | args.log_config = configs 234 | 235 | training_main( 236 | args, 237 | model_cls=SATVideoDiffusionEngine, 238 | forward_step_function=partial(forward_step, data_class=data_class), 239 | forward_step_eval=partial( 240 | forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents 241 | ), 242 | create_dataset_function=create_dataset_function, 243 | ) 244 | -------------------------------------------------------------------------------- /vae_modules/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/vae_modules/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /vae_modules/__pycache__/cp_enc_dec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/vae_modules/__pycache__/cp_enc_dec.cpython-310.pyc -------------------------------------------------------------------------------- /vae_modules/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/vae_modules/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /vae_modules/__pycache__/regularizers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/vae_modules/__pycache__/regularizers.cpython-310.pyc -------------------------------------------------------------------------------- /vae_modules/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/Concat-ID/b15cb95af7efd1c557ee850f5adb116d8feae8b5/vae_modules/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /vae_modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), 16 | ) 17 | 18 | for name, p in model.named_parameters(): 19 | if p.requires_grad: 20 | # remove as '.'-character is not allowed in buffers 21 | s_name = name.replace(".", "") 22 | self.m_name2s_name.update({name: s_name}) 23 | self.register_buffer(s_name, p.clone().detach().data) 24 | 25 | self.collected_params = [] 26 | 27 | def reset_num_updates(self): 28 | del self.num_updates 29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /vae_modules/regularizers.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class DiagonalGaussianDistribution(object): 11 | def __init__(self, parameters, deterministic=False): 12 | self.parameters = parameters 13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 15 | self.deterministic = deterministic 16 | self.std = torch.exp(0.5 * self.logvar) 17 | self.var = torch.exp(self.logvar) 18 | if self.deterministic: 19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 20 | 21 | def sample(self): 22 | # x = self.mean + self.std * torch.randn(self.mean.shape).to( 23 | # device=self.parameters.device 24 | # ) 25 | x = self.mean + self.std * torch.randn_like(self.mean) 26 | return x 27 | 28 | def kl(self, other=None): 29 | if self.deterministic: 30 | return torch.Tensor([0.0]) 31 | else: 32 | if other is None: 33 | return 0.5 * torch.sum( 34 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 35 | dim=[1, 2, 3], 36 | ) 37 | else: 38 | return 0.5 * torch.sum( 39 | torch.pow(self.mean - other.mean, 2) / other.var 40 | + self.var / other.var 41 | - 1.0 42 | - self.logvar 43 | + other.logvar, 44 | dim=[1, 2, 3], 45 | ) 46 | 47 | def nll(self, sample, dims=[1, 2, 3]): 48 | if self.deterministic: 49 | return torch.Tensor([0.0]) 50 | logtwopi = np.log(2.0 * np.pi) 51 | return 0.5 * torch.sum( 52 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 53 | dim=dims, 54 | ) 55 | 56 | def mode(self): 57 | return self.mean 58 | 59 | 60 | class AbstractRegularizer(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | 64 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 65 | raise NotImplementedError() 66 | 67 | @abstractmethod 68 | def get_trainable_parameters(self) -> Any: 69 | raise NotImplementedError() 70 | 71 | 72 | class IdentityRegularizer(AbstractRegularizer): 73 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 74 | return z, dict() 75 | 76 | def get_trainable_parameters(self) -> Any: 77 | yield from () 78 | 79 | 80 | def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: 81 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 82 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 83 | encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 84 | avg_probs = encodings.mean(0) 85 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 86 | cluster_use = torch.sum(avg_probs > 0) 87 | return perplexity, cluster_use 88 | 89 | 90 | class DiagonalGaussianRegularizer(AbstractRegularizer): 91 | def __init__(self, sample: bool = True): 92 | super().__init__() 93 | self.sample = sample 94 | 95 | def get_trainable_parameters(self) -> Any: 96 | yield from () 97 | 98 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 99 | log = dict() 100 | posterior = DiagonalGaussianDistribution(z) 101 | if self.sample: 102 | z = posterior.sample() 103 | else: 104 | z = posterior.mode() 105 | kl_loss = posterior.kl() 106 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 107 | log["kl_loss"] = kl_loss 108 | return z, log 109 | --------------------------------------------------------------------------------