├── README.md ├── assets ├── selection.png ├── teaser1024.png ├── teaser256.png └── teaser512.png ├── benchmark_generation_mamba_simple.py ├── benchmarks └── benchmark_generation_mamba_simple.py ├── configs ├── cifar10_S_DiM.py ├── imagenet256_H_DiM.py ├── imagenet256_L_DiM.py ├── imagenet512_H_DiM_ft.py ├── imagenet512_H_DiM_upsample_3x_test.py └── imagenet512_H_DiM_upsample_test.py ├── csrc └── selective_scan │ ├── reverse_scan.cuh │ ├── selective_scan.cpp │ ├── selective_scan.h │ ├── selective_scan_bwd_bf16_complex.cu │ ├── selective_scan_bwd_bf16_real.cu │ ├── selective_scan_bwd_fp16_complex.cu │ ├── selective_scan_bwd_fp16_real.cu │ ├── selective_scan_bwd_fp32_complex.cu │ ├── selective_scan_bwd_fp32_real.cu │ ├── selective_scan_bwd_kernel.cuh │ ├── selective_scan_common.h │ ├── selective_scan_fwd_bf16.cu │ ├── selective_scan_fwd_fp16.cu │ ├── selective_scan_fwd_fp32.cu │ ├── selective_scan_fwd_kernel.cuh │ ├── static_switch.h │ └── uninitialized_copy.cuh ├── dpm_solver_pp.py ├── dpm_solver_pytorch.py ├── environment.yaml ├── eval.py ├── eval_ldm.py ├── eval_ldm_discrete.py ├── eval_t2i_discrete.py ├── evals └── lm_harness_eval.py ├── libs ├── __init__.py ├── autoencoder.py ├── clip.py ├── timm.py ├── uvit.py └── uvit_t2i.py ├── main.pdf ├── main.png ├── mamba_attn_diff ├── models │ ├── __init__.py │ ├── adapter_attn4mamba.py │ ├── attention.py │ ├── freeu.py │ ├── mamba_2d.py │ ├── normalization.py │ ├── upsample_guidance.py │ └── vim_module.py └── utils │ ├── backup_code.py │ └── init_weights.py ├── mamba_ssm ├── __init__.py ├── models │ ├── __init__.py │ ├── config_mamba.py │ └── mixer_seq_simple.py ├── modules │ ├── __init__.py │ └── mamba_simple.py ├── ops │ ├── __init__.py │ ├── selective_scan_interface.py │ └── triton │ │ ├── __init__.py │ │ ├── layernorm.py │ │ └── selective_state_update.py └── utils │ ├── __init__.py │ ├── generation.py │ └── hf.py ├── sample_t2i_discrete.py ├── scripts ├── extract_empty_feature.py ├── extract_imagenet_feature.py ├── extract_mscoco_feature.py └── extract_test_prompt_feature.py ├── sde.py ├── setup.py ├── tools ├── __init__.py ├── fid_score.py └── inception.py ├── train.py ├── train_ldm.py ├── train_ldm_discrete.py ├── train_t2i_discrete.py ├── utils.py └── uvit_datasets.py /README.md: -------------------------------------------------------------------------------- 1 | # DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis 2 | 3 | The official implementation of our paper [DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis](https://arxiv.org/abs/2405.14224). 4 | 5 | drawing 6 | 7 | drawing 8 | 9 | drawing 10 | 11 | ## Method Overview 12 | 13 | drawing 14 | 15 | ## Acknowledge 16 | 17 | This code is mainly built on [U-ViT](https://github.com/baofff/U-ViT) and [Mamba](https://github.com/state-spaces/mamba). 18 | 19 | Installing Mamba may cost a lot of effort. If you encounter problems, this [issues in Mamba](https://github.com/state-spaces/mamba/issues) may be very helpful. 20 | 21 | ## Installation 22 | 23 | ```bash 24 | # create env: 25 | conda env create -f environment.yaml 26 | 27 | # if you want to update the env `mamba` with the contents in `~/mamba_attn/environment.yaml`: 28 | conda env update --name mamba --file ~/mamba_attn/environment.yaml --prune 29 | 30 | # Switch to the correct environment 31 | conda activate mamba-attn 32 | conda install chardet 33 | 34 | # Compiling Mamba. This step may take a lot of time, please be patient. 35 | # You need to successfully install causal-conv1d first. 36 | CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e . 37 | # If failing to compile, you can copy the files in './build/' from another server which has compiled successfully; Maybe --user is necessary. 38 | 39 | # Optional: if you have only 8 A100 to train Huge model with a batch size of 768, I recommand to install deepspeed to reduce the required GPU memory: 40 | pip install deepspeed 41 | ``` 42 | 43 | **Frequently Asked Questions:** 44 | 45 | - If you encounter errors like `ModuleNotFoundError: No module named 'selective_scan_cuda'`: 46 | 47 | **Answer**: you need to correctly **install and compile** Mamba: 48 | 49 | ```bash 50 | pip install causal-conv1d==1.2.0.post2 # The version maybe different depending on your cuda version 51 | CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e . 52 | ``` 53 | 54 | - failed Compilation: 55 | 56 | - The detected CUDA version mismatches the version that was used to **compile** PyTorch. Please make sure to use the same CUDA versions: 57 | 58 | **Answer**: you need to reinstall Pytorch with the correct version: 59 | 60 | ```bash 61 | # For example, on cuda 11.8: 62 | conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia 63 | # Then, compiling the mamba in our project again: 64 | CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e . 65 | ``` 66 | 67 | ## Preparation Before Training and Evaluation 68 | 69 | Please follow [U-ViT](https://github.com/baofff/U-ViT), the same subtitle. 70 | 71 | ## Checkpoints 72 | 73 | | Model | FID | training iterations | batch size | 74 | | :----------------------------------------------------------: | :------: | :-----------------: | :--------: | 75 | | [ImageNet 256x256 (Huge/2)](https://drive.google.com/drive/folders/1TTEXKKhnJcEV9jeZbZYlXjiPyV87ZhE0?usp=sharing) | 2.40 | 425K | 768 | 76 | | [ImageNet 256x256 (Huge/2)](https://drive.google.com/drive/folders/1ETllUm8Dpd8-vDHefQEXEWF9whdbyhL5?usp=sharing) | **2.21** | 625K | 768 | 77 | | [ImageNet 512x512 (fine-tuned Huge/2)](https://drive.google.com/drive/folders/1lupf4_dj4tWCpycnraGrgqh4P-6yK5Xe?usp=sharing) | 3.94 | Fine-tune | 240 | 78 | 79 | **About the checkpoint files:** 80 | 81 | - **We use `nnet_ema.pth` for evaluation instead of `nnet.pth`.** 82 | 83 | - **`nnet.pth` is the trained model, while `nnet_ema.pth` is the EMA of model weights.** 84 | 85 | ## Evaluation 86 | 87 | **Use `eval_ldm_discrete.py` for evaluation and generating images with CFG** 88 | 89 | ```sh 90 | # ImageNet 256x256 Huge, 425K 91 | # If your model checkpoint path is not 'workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth', you can change the path after '--nnet_path=' 92 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet256_H_DiM.py --nnet_path='workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth' 93 | 94 | # ImageNet 512x512 Huge 95 | # The generated 512x512 images for evaluation cost ~22G. 96 | # So I recommend setting a path to `config.sample.path` in the config `imagenet512_H_DiM_ft` if the space is tight for temporary files. 97 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth' 98 | 99 | # ImageNet 512x512 Huge, upsample 2x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config. 100 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth' 101 | 102 | # ImageNet 512x512 Huge, upsample 3x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config. 103 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_3x_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth' 104 | ``` 105 | 106 | ## Training 107 | 108 | ```sh 109 | # Cifar 32x32 Small 110 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 ./train.py --config=configs/cifar10_S_DiM.py 111 | 112 | # ImageNet 256x256 Large 113 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_L_DiM.py 114 | 115 | # ImageNet 256x256 Huge (Deepspeed Zero-2 for memory-efficient training) 116 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_H_DiM.py 117 | 118 | # ImageNet 512x512 Huge (Deepspeed Zero-2 for memory-efficient training) 119 | # Fine-tuning, and you need to carefully check whether 120 | # the pre-trained weights are in `workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth`. 121 | # This location is set in the config file: `config.nnet.pretrained_path`. 122 | # If there is no such ckpt, no pre-training weight will be loaded. 123 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py 124 | ``` 125 | 126 | # Citation 127 | 128 | ``` 129 | @misc{teng2024dim, 130 | title={DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis}, 131 | author={Yao Teng and Yue Wu and Han Shi and Xuefei Ning and Guohao Dai and Yu Wang and Zhenguo Li and Xihui Liu}, 132 | year={2024}, 133 | eprint={2405.14224}, 134 | archivePrefix={arXiv}, 135 | primaryClass={cs.CV} 136 | } 137 | ``` 138 | 139 | -------------------------------------------------------------------------------- /assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/selection.png -------------------------------------------------------------------------------- /assets/teaser1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/teaser1024.png -------------------------------------------------------------------------------- /assets/teaser256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/teaser256.png -------------------------------------------------------------------------------- /assets/teaser512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/teaser512.png -------------------------------------------------------------------------------- /benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 26 | parser.add_argument("--batch", type=int, default=1) 27 | parser.add_argument("--cache_dir", type=str, default="./checkpoints") 28 | args = parser.parse_args() 29 | 30 | repeats = 3 31 | device = "cuda" 32 | dtype = torch.float16 33 | 34 | cache_dir = args.cache_dir 35 | 36 | print(f"Loading model {args.model_name}") 37 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 38 | if is_mamba: 39 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=cache_dir) 40 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype, cache_dir=cache_dir) 41 | else: 42 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir) 43 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype, cache_dir=cache_dir) 44 | model.eval() 45 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 46 | 47 | torch.random.manual_seed(0) 48 | if args.prompt is None: 49 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 50 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 51 | else: 52 | tokens = tokenizer(args.prompt, return_tensors="pt") 53 | input_ids = tokens.input_ids.to(device=device) 54 | attn_mask = tokens.attention_mask.to(device=device) 55 | max_length = input_ids.shape[1] + args.genlen 56 | 57 | # for i, n in model.named_parameters(): 58 | # print(i, n.shape, n.dtype) 59 | print(model) 60 | assert False 61 | 62 | if is_mamba: 63 | fn = lambda: model.generate( 64 | input_ids=input_ids, 65 | max_length=max_length, 66 | cg=True, 67 | return_dict_in_generate=True, 68 | output_scores=True, 69 | enable_timing=False, 70 | temperature=args.temperature, 71 | top_k=args.topk, 72 | top_p=args.topp, 73 | repetition_penalty=args.repetition_penalty, 74 | ) 75 | else: 76 | fn = lambda: model.generate( 77 | input_ids=input_ids, 78 | attention_mask=attn_mask, 79 | max_length=max_length, 80 | return_dict_in_generate=True, 81 | pad_token_id=tokenizer.eos_token_id, 82 | do_sample=True, 83 | temperature=args.temperature, 84 | top_k=args.topk, 85 | top_p=args.topp, 86 | repetition_penalty=args.repetition_penalty, 87 | ) 88 | out = fn() 89 | if args.prompt is not None: 90 | print(tokenizer.batch_decode(out.sequences.tolist())) 91 | 92 | torch.cuda.synchronize() 93 | start = time.time() 94 | for _ in range(repeats): 95 | fn() 96 | torch.cuda.synchronize() 97 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 98 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 99 | 100 | ''' 101 | # 130mb 102 | MambaLMHeadModel( 103 | (backbone): MixerModel( 104 | (embedding): Embedding(50280, 768) 105 | (layers): ModuleList( 106 | (0-23): 24 x Block( 107 | (mixer): Mamba( 108 | (in_proj): Linear(in_features=768, out_features=3072, bias=False) 109 | (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536) 110 | (act): SiLU() 111 | (x_proj): Linear(in_features=1536, out_features=80, bias=False) 112 | (dt_proj): Linear(in_features=48, out_features=1536, bias=True) 113 | (out_proj): Linear(in_features=1536, out_features=768, bias=False) 114 | ) 115 | (norm): RMSNorm() 116 | ) 117 | ) 118 | (norm_f): RMSNorm() 119 | ) 120 | (lm_head): Linear(in_features=768, out_features=50280, bias=False) 121 | ) 122 | 123 | ''' -------------------------------------------------------------------------------- /benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--minp", type=float, default=0.0) 26 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 27 | parser.add_argument("--batch", type=int, default=1) 28 | args = parser.parse_args() 29 | 30 | repeats = 3 31 | device = "cuda" 32 | dtype = torch.float16 33 | 34 | print(f"Loading model {args.model_name}") 35 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 36 | if is_mamba: 37 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 38 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 39 | else: 40 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 41 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 42 | model.eval() 43 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 44 | 45 | torch.random.manual_seed(0) 46 | if args.prompt is None: 47 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 48 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 49 | else: 50 | tokens = tokenizer(args.prompt, return_tensors="pt") 51 | input_ids = tokens.input_ids.to(device=device) 52 | attn_mask = tokens.attention_mask.to(device=device) 53 | max_length = input_ids.shape[1] + args.genlen 54 | 55 | if is_mamba: 56 | fn = lambda: model.generate( 57 | input_ids=input_ids, 58 | max_length=max_length, 59 | cg=True, 60 | return_dict_in_generate=True, 61 | output_scores=True, 62 | enable_timing=False, 63 | temperature=args.temperature, 64 | top_k=args.topk, 65 | top_p=args.topp, 66 | min_p=args.minp, 67 | repetition_penalty=args.repetition_penalty, 68 | ) 69 | else: 70 | fn = lambda: model.generate( 71 | input_ids=input_ids, 72 | attention_mask=attn_mask, 73 | max_length=max_length, 74 | return_dict_in_generate=True, 75 | pad_token_id=tokenizer.eos_token_id, 76 | do_sample=True, 77 | temperature=args.temperature, 78 | top_k=args.topk, 79 | top_p=args.topp, 80 | repetition_penalty=args.repetition_penalty, 81 | ) 82 | out = fn() 83 | if args.prompt is not None: 84 | print(tokenizer.batch_decode(out.sequences.tolist())) 85 | 86 | torch.cuda.synchronize() 87 | start = time.time() 88 | for _ in range(repeats): 89 | fn() 90 | torch.cuda.synchronize() 91 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 92 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 93 | -------------------------------------------------------------------------------- /configs/cifar10_S_DiM.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | 15 | n_steps = 500000 16 | config.train = d( 17 | n_steps=n_steps, 18 | batch_size=128, 19 | mode='uncond', 20 | log_interval=10, 21 | eval_interval=5000, 22 | save_interval=50000, 23 | ) 24 | 25 | lr = 0.0002 26 | config.optimizer = d( 27 | name='adamw', 28 | lr=lr, 29 | weight_decay=0.03, 30 | betas=(0.99, 0.999), 31 | ) 32 | 33 | config.lr_scheduler = d( 34 | name='customized', 35 | warmup_steps=2500, 36 | ) 37 | 38 | learned_sigma = False 39 | latent_size = 32 40 | in_channels = 3 41 | config.nnet = d( 42 | name='Mamba_DiT_S_2', 43 | attention_head_dim=512//1, num_attention_heads=1, num_layers=25, 44 | in_channels=in_channels, 45 | num_embeds_ada_norm=10, 46 | sample_size=latent_size, 47 | activation_fn="gelu-approximate", 48 | attention_bias=True, 49 | norm_elementwise_affine=False, 50 | norm_type="ada_norm_single", #"layer_norm" 51 | out_channels=in_channels*2 if learned_sigma else in_channels, 52 | patch_size=2, 53 | mamba_d_state=16, 54 | mamba_d_conv=3, 55 | mamba_expand=2, 56 | use_bidirectional_rnn=False, 57 | mamba_type='enc', 58 | nested_order=0, 59 | is_uconnect=True, 60 | no_ff=True, 61 | use_conv1d=True, 62 | is_extra_tokens=True, 63 | rms=True, 64 | use_pad_token=True, 65 | use_a4m_adapter=True, 66 | drop_path_rate=0.0, 67 | encoder_start_blk_id=1, #0 68 | kv_as_one_token_idx=-1, 69 | num_2d_enc_dec_layers=6, 70 | pad_token_schedules=['dec_split', 'lateral'], 71 | is_absorb=False, 72 | use_adapter_modules=True, 73 | sequence_schedule='dilated', 74 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'], 75 | pos_encoding_type='learnable', 76 | scan_pattern_len=3, 77 | is_align_exchange_q_kv=False, 78 | ) 79 | 80 | config.dataset = d( 81 | name='cifar10', 82 | path='assets/datasets/cifar10', 83 | random_flip=True, 84 | ) 85 | 86 | config.sample = d( 87 | sample_steps=1000, 88 | n_samples=50000, 89 | mini_batch_size=500, 90 | algorithm='euler_maruyama_sde', 91 | path='' 92 | ) 93 | 94 | return config 95 | -------------------------------------------------------------------------------- /configs/imagenet256_H_DiM.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | # config.gradient_accumulation_steps=2 # 1 21 | config.max_grad_norm = 1.0 22 | 23 | config.train = d( 24 | n_steps=750000, # 300000 25 | batch_size=768, 26 | mode='cond', 27 | log_interval=10, 28 | eval_interval=5000, 29 | save_interval=25000, # 50000 30 | ) 31 | 32 | config.optimizer = d( 33 | name='adamw', 34 | lr=0.0002, 35 | weight_decay=0.03, 36 | betas=(0.99, 0.99), 37 | eps=1e-15, 38 | ) 39 | 40 | config.lr_scheduler = d( 41 | name='customized', 42 | warmup_steps=5000, 43 | ) 44 | 45 | learned_sigma = False 46 | latent_size = 32 47 | in_channels = 4 # 3 48 | config.nnet = d( 49 | name='Mamba_DiT_H_2', 50 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49, 51 | in_channels=in_channels, 52 | num_embeds_ada_norm=1000, 53 | sample_size=latent_size, 54 | activation_fn="gelu-approximate", #"gelu-approximate", 55 | attention_bias=True, 56 | norm_elementwise_affine=False, 57 | norm_type="ada_norm_single", #"layer_norm", 58 | out_channels=in_channels*2 if learned_sigma else in_channels, 59 | patch_size=2, 60 | mamba_d_state=16, 61 | mamba_d_conv=3, 62 | mamba_expand=2, 63 | use_bidirectional_rnn=False, 64 | mamba_type='enc', 65 | nested_order=0, 66 | is_uconnect=True, 67 | no_ff=True, 68 | use_conv1d=True, 69 | is_extra_tokens=True, 70 | rms=True, 71 | use_pad_token=True, 72 | use_a4m_adapter=True, 73 | drop_path_rate=0.0, 74 | encoder_start_blk_id=1, 75 | kv_as_one_token_idx=-1, 76 | num_2d_enc_dec_layers=6, 77 | pad_token_schedules=['dec_split', 'lateral'], 78 | is_absorb=False, 79 | use_adapter_modules=True, 80 | sequence_schedule='dilated', 81 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'], 82 | pos_encoding_type='learnable', 83 | scan_pattern_len=4 -1, 84 | is_align_exchange_q_kv=False, 85 | is_random_patterns=False, 86 | ) 87 | config.gradient_checkpointing = False 88 | 89 | config.dataset = d( 90 | name='imagenet', 91 | path='assets/datasets/ImageNet', 92 | resolution=256, 93 | cfg=True, 94 | p_uncond=0.1, 95 | ) 96 | 97 | config.sample = d( 98 | sample_steps=50, 99 | n_samples=50000, 100 | mini_batch_size=25, # the decoder is large 101 | algorithm='dpm_solver', 102 | cfg=True, 103 | scale=0.4, 104 | path='' 105 | ) 106 | 107 | return config 108 | -------------------------------------------------------------------------------- /configs/imagenet256_L_DiM.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 32, 32) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | # config.gradient_accumulation_steps=2 # 1 21 | config.max_grad_norm = 1.0 22 | 23 | config.train = d( 24 | n_steps=300000, # 300000 25 | batch_size=1024, 26 | mode='cond', 27 | log_interval=10, 28 | eval_interval=5000, 29 | save_interval=25000, # 50000 30 | ) 31 | 32 | config.optimizer = d( 33 | name='adamw', 34 | lr=0.0002, 35 | weight_decay=0.03, 36 | betas=(0.99, 0.99), 37 | eps=1e-15, 38 | ) 39 | 40 | config.lr_scheduler = d( 41 | name='customized', 42 | warmup_steps=5000, 43 | ) 44 | 45 | learned_sigma = False 46 | latent_size = 32 47 | in_channels = 4 # 3 48 | config.nnet = d( 49 | name='Mamba_DiT_H_2', 50 | attention_head_dim=1024//16, num_attention_heads=16, num_layers=49, 51 | in_channels=in_channels, 52 | num_embeds_ada_norm=1000, 53 | sample_size=latent_size, 54 | activation_fn="gelu-approximate", #"gelu-approximate", 55 | attention_bias=True, 56 | norm_elementwise_affine=False, 57 | norm_type="ada_norm_single", #"layer_norm", 58 | out_channels=in_channels*2 if learned_sigma else in_channels, 59 | patch_size=2, 60 | mamba_d_state=16, 61 | mamba_d_conv=3, 62 | mamba_expand=2, 63 | use_bidirectional_rnn=False, 64 | mamba_type='enc', 65 | nested_order=0, 66 | is_uconnect=True, 67 | no_ff=True, 68 | use_conv1d=True, 69 | is_extra_tokens=True, 70 | rms=True, 71 | use_pad_token=True, 72 | use_a4m_adapter=True, 73 | drop_path_rate=0.0, 74 | encoder_start_blk_id=1, 75 | kv_as_one_token_idx=-1, 76 | num_2d_enc_dec_layers=6, 77 | pad_token_schedules=['dec_split', 'lateral'], 78 | is_absorb=False, 79 | use_adapter_modules=True, 80 | sequence_schedule='dilated', 81 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'], 82 | pos_encoding_type='learnable', 83 | scan_pattern_len=4 -1, 84 | is_align_exchange_q_kv=False, 85 | is_random_patterns=False, 86 | ) 87 | config.gradient_checkpointing = False 88 | 89 | config.dataset = d( 90 | name='imagenet', 91 | path='assets/datasets/ImageNet', 92 | resolution=256, 93 | cfg=True, 94 | p_uncond=0.15, # aligned with u-vit 95 | ) 96 | 97 | config.sample = d( 98 | sample_steps=50, 99 | n_samples=50000, 100 | mini_batch_size=25, # the decoder is large 101 | algorithm='dpm_solver', 102 | cfg=True, 103 | scale=0.4, 104 | path='' 105 | ) 106 | 107 | return config 108 | -------------------------------------------------------------------------------- /configs/imagenet512_H_DiM_ft.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.gradient_accumulation_steps=2 # 1 21 | config.max_grad_norm = 1.0 22 | 23 | config.train = d( 24 | n_steps=64000, #300000, 25 | batch_size=240, #1024, 26 | mode='cond', 27 | log_interval=10, 28 | eval_interval=1000, # 5000, 29 | save_interval=10000, # 50000 30 | ) 31 | 32 | config.optimizer = d( 33 | name='adamw', 34 | lr=0.0001 , #0.0002, 35 | weight_decay=0, #0.03, 36 | betas=(0.99, 0.99), 37 | eps=1e-15, 38 | ) 39 | 40 | config.lr_scheduler = d( 41 | name='customized', 42 | warmup_steps=2500, # 1, #5000, 43 | ) 44 | 45 | learned_sigma = False 46 | latent_size = 64 #32 47 | in_channels = 4 # 3 48 | config.nnet = d( 49 | name='Mamba_DiT_H_2', 50 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49, 51 | in_channels=in_channels, 52 | num_embeds_ada_norm=1000, 53 | sample_size=latent_size, 54 | activation_fn="gelu-approximate", 55 | attention_bias=True, 56 | norm_elementwise_affine=False, 57 | norm_type="ada_norm_single", #"layer_norm", 58 | out_channels=in_channels*2 if learned_sigma else in_channels, 59 | patch_size=2, 60 | mamba_d_state=16, 61 | mamba_d_conv=3, 62 | mamba_expand=2, 63 | use_bidirectional_rnn=False, 64 | mamba_type='enc', 65 | nested_order=0, 66 | is_uconnect=True, 67 | no_ff=True, 68 | use_conv1d=True, 69 | is_extra_tokens=True, 70 | rms=True, 71 | use_pad_token=True, 72 | use_a4m_adapter=True, 73 | drop_path_rate=0.0, 74 | encoder_start_blk_id=1, 75 | kv_as_one_token_idx=-1, 76 | num_2d_enc_dec_layers=6, 77 | pad_token_schedules=['dec_split', 'rho_pad'], #['dec_split', 'lateral'], 78 | is_absorb=False, 79 | use_adapter_modules=True, 80 | sequence_schedule='dilated', 81 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'], 82 | pos_encoding_type='learnable', 83 | scan_pattern_len=4 -1, 84 | is_align_exchange_q_kv=False, 85 | is_random_patterns=False, 86 | pretrained_path = 'workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth', 87 | multi_times=2, 88 | pattern_type='base', 89 | ) 90 | config.gradient_checkpointing = False 91 | 92 | config.dataset = d( 93 | name='imagenet512_features', 94 | path='assets/datasets/imagenet512_features', 95 | cfg=True, 96 | p_uncond=0.15 97 | ) 98 | 99 | config.sample = d( 100 | sample_steps=50, 101 | n_samples=50000, 102 | mini_batch_size=25, # the decoder is large 103 | algorithm='dpm_solver', 104 | cfg=True, 105 | scale=0.7, 106 | path='' 107 | ) 108 | 109 | return config 110 | -------------------------------------------------------------------------------- /configs/imagenet512_H_DiM_upsample_3x_test.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import os 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | # config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.gradient_accumulation_steps=2 # 1 21 | config.max_grad_norm = 1.0 22 | 23 | config.train = d( 24 | n_steps=300000, 25 | batch_size=128, 26 | mode='cond', 27 | log_interval=10, 28 | eval_interval=1000, 29 | save_interval=5000, 30 | ) 31 | 32 | config.optimizer = d( 33 | name='adamw', 34 | lr=0.0001, 35 | weight_decay=0.03, 36 | betas=(0.99, 0.99), 37 | eps=1e-15, 38 | ) 39 | 40 | config.lr_scheduler = d( 41 | name='customized', 42 | warmup_steps=1, 43 | ) 44 | 45 | config.ug_theta = 1 46 | config.ug_eta = 0.3 47 | config.ug_T = 1 48 | 49 | base_resolution = 512 50 | patch_size = 2 51 | multi_times = 3 #2 52 | resolution = int(base_resolution * multi_times) 53 | coco_multi_scale = [ resolution, resolution //2 , base_resolution, ] 54 | 55 | learned_sigma = False 56 | latent_size = resolution // 8 57 | config.z_shape = (4, latent_size, latent_size) 58 | in_channels = 4 59 | config.nnet = d( 60 | name='Mamba_DiT_H_2', 61 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49, 62 | in_channels=in_channels, 63 | num_embeds_ada_norm=1000, 64 | sample_size=latent_size, 65 | activation_fn="gelu-approximate", 66 | attention_bias=True, 67 | norm_elementwise_affine=False, 68 | norm_type="ada_norm_single", #"layer_norm", 69 | out_channels=in_channels*2 if learned_sigma else in_channels, 70 | patch_size=patch_size, 71 | mamba_d_state=16, 72 | mamba_d_conv=3, 73 | mamba_expand=2, 74 | use_bidirectional_rnn=False, 75 | mamba_type='enc', 76 | nested_order=0, 77 | is_uconnect=True, 78 | no_ff=True, 79 | use_conv1d=True, 80 | is_extra_tokens=True, 81 | rms=True, 82 | use_pad_token=True, 83 | use_a4m_adapter=True, 84 | drop_path_rate=0.0, 85 | encoder_start_blk_id=1, 86 | kv_as_one_token_idx=-1, 87 | num_2d_enc_dec_layers=6, 88 | pad_token_schedules=['dec_split', 'rho_pad'], 89 | is_absorb=False, 90 | use_adapter_modules=True, 91 | sequence_schedule='dilated', 92 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'], 93 | pos_encoding_type='learnable', 94 | scan_pattern_len=4 -1, 95 | is_align_exchange_q_kv=False, 96 | is_random_patterns=False, 97 | pretrained_path = 'workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth', 98 | multi_times=multi_times, 99 | pattern_type='base', 100 | is_freeu=False, freeu_param=(0.15, 0.1, 1.1, 1.2), # (0.3, 0.2, 1.1, 1.2) 101 | num_patches = [ (i //8 //patch_size)**2 for i in coco_multi_scale], 102 | is_skip_tune=True, skip_tune_param = (0.82, 1.0), #(0.82, 1.0), 103 | ) 104 | config.gradient_checkpointing = False 105 | 106 | config.dataset = d( 107 | name='imagenet512_features', 108 | path='assets/datasets/imagenet512_features', 109 | cfg=True, 110 | p_uncond=0.15 111 | ) 112 | vis_label = None 113 | config.sample = d( 114 | sample_steps=50, 115 | n_samples=50000, 116 | mini_batch_size=1, # the decoder is large 117 | algorithm='dpm_solver_upsample_g', 118 | cfg=True, 119 | scale=3.0, 120 | path='workdir/imagenet512_H_DiM_ft/test_tmp/' + ( 121 | str(vis_label) if vis_label is not None else 'all_class' 122 | ) + '_' + str( resolution ), 123 | vis_label=vis_label, 124 | ) 125 | 126 | if not os.path.exists(config.sample.path): 127 | os.makedirs(config.sample.path) 128 | 129 | return config 130 | -------------------------------------------------------------------------------- /configs/imagenet512_H_DiM_upsample_test.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import os 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.pred = 'noise_pred' 14 | # config.z_shape = (4, 64, 64) 15 | 16 | config.autoencoder = d( 17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth' 18 | ) 19 | 20 | config.gradient_accumulation_steps=2 # 1 21 | config.max_grad_norm = 1.0 22 | 23 | config.train = d( 24 | n_steps=300000, 25 | batch_size=128, 26 | mode='cond', 27 | log_interval=10, 28 | eval_interval=1000, 29 | save_interval=5000, 30 | ) 31 | 32 | config.optimizer = d( 33 | name='adamw', 34 | lr=0.0001, 35 | weight_decay=0.03, 36 | betas=(0.99, 0.99), 37 | eps=1e-15, 38 | ) 39 | 40 | config.lr_scheduler = d( 41 | name='customized', 42 | warmup_steps=1, 43 | ) 44 | 45 | config.ug_theta = 1 46 | config.ug_eta = 0.3 47 | config.ug_T = 1 48 | 49 | base_resolution = 512 50 | patch_size = 2 51 | multi_times = 2 52 | resolution = int(base_resolution * multi_times) 53 | coco_multi_scale = [ resolution, resolution //2 , base_resolution, ] 54 | 55 | learned_sigma = False 56 | latent_size = resolution // 8 57 | config.z_shape = (4, latent_size, latent_size) 58 | in_channels = 4 59 | config.nnet = d( 60 | name='Mamba_DiT_H_2', 61 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49, 62 | in_channels=in_channels, 63 | num_embeds_ada_norm=1000, 64 | sample_size=latent_size, 65 | activation_fn="gelu-approximate", 66 | attention_bias=True, 67 | norm_elementwise_affine=False, 68 | norm_type="ada_norm_single", #"layer_norm", 69 | out_channels=in_channels*2 if learned_sigma else in_channels, 70 | patch_size=patch_size, 71 | mamba_d_state=16, 72 | mamba_d_conv=3, 73 | mamba_expand=2, 74 | use_bidirectional_rnn=False, 75 | mamba_type='enc', 76 | nested_order=0, 77 | is_uconnect=True, 78 | no_ff=True, 79 | use_conv1d=True, 80 | is_extra_tokens=True, 81 | rms=True, 82 | use_pad_token=True, 83 | use_a4m_adapter=True, 84 | drop_path_rate=0.0, 85 | encoder_start_blk_id=1, 86 | kv_as_one_token_idx=-1, 87 | num_2d_enc_dec_layers=6, 88 | pad_token_schedules=['dec_split', 'rho_pad'], 89 | is_absorb=False, 90 | use_adapter_modules=True, 91 | sequence_schedule='dilated', 92 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'], 93 | pos_encoding_type='learnable', 94 | scan_pattern_len=4 -1, 95 | is_align_exchange_q_kv=False, 96 | is_random_patterns=False, 97 | pretrained_path = 'workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth', 98 | multi_times=multi_times, 99 | pattern_type='base', 100 | is_freeu=False, freeu_param=(0.15, 0.1, 1.1, 1.2), # (0.3, 0.2, 1.1, 1.2) 101 | num_patches = [ (i //8 //patch_size)**2 for i in coco_multi_scale], 102 | is_skip_tune=True, skip_tune_param = (0.82, 1.0), #(0.82, 1.0), 103 | ) 104 | config.gradient_checkpointing = False 105 | 106 | config.dataset = d( 107 | name='imagenet512_features', 108 | path='assets/datasets/imagenet512_features', 109 | cfg=True, 110 | p_uncond=0.15 111 | ) 112 | vis_label = None 113 | # rabit 330 114 | # panda 388 115 | # lion 291 116 | # cat 283 117 | config.sample = d( 118 | sample_steps=50, 119 | n_samples=50000, 120 | mini_batch_size=1, # the decoder is large 121 | algorithm='dpm_solver_upsample_g', 122 | cfg=True, 123 | scale=3.0, 124 | path='workdir/imagenet512_H_DiM_ft/test_tmp/' + ( 125 | str(vis_label) if vis_label is not None else 'all_class' 126 | ) + '_' + str( resolution ), 127 | vis_label=vis_label, 128 | ) 129 | 130 | if not os.path.exists(config.sample.path): 131 | os.makedirs(config.sample.path) 132 | 133 | return config 134 | -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | using complex_t = c10::complex; 14 | 15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 16 | return {a.x + b.x, a.y + b.y}; 17 | } 18 | 19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 20 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 21 | } 22 | 23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | template struct BytesToType {}; 30 | 31 | template<> struct BytesToType<16> { 32 | using Type = uint4; 33 | static_assert(sizeof(Type) == 16); 34 | }; 35 | 36 | template<> struct BytesToType<8> { 37 | using Type = uint64_t; 38 | static_assert(sizeof(Type) == 8); 39 | }; 40 | 41 | template<> struct BytesToType<4> { 42 | using Type = uint32_t; 43 | static_assert(sizeof(Type) == 4); 44 | }; 45 | 46 | template<> struct BytesToType<2> { 47 | using Type = uint16_t; 48 | static_assert(sizeof(Type) == 2); 49 | }; 50 | 51 | template<> struct BytesToType<1> { 52 | using Type = uint8_t; 53 | static_assert(sizeof(Type) == 1); 54 | }; 55 | 56 | //////////////////////////////////////////////////////////////////////////////////////////////////// 57 | 58 | template 59 | struct Converter{ 60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 61 | #pragma unroll 62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 63 | } 64 | }; 65 | 66 | template 67 | struct Converter{ 68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 69 | static_assert(N % 2 == 0); 70 | auto &src2 = reinterpret_cast(src); 71 | auto &dst2 = reinterpret_cast(dst); 72 | #pragma unroll 73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 74 | } 75 | }; 76 | 77 | #if __CUDA_ARCH__ >= 800 78 | template 79 | struct Converter{ 80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 81 | static_assert(N % 2 == 0); 82 | auto &src2 = reinterpret_cast(src); 83 | auto &dst2 = reinterpret_cast(dst); 84 | #pragma unroll 85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 86 | } 87 | }; 88 | #endif 89 | 90 | //////////////////////////////////////////////////////////////////////////////////////////////////// 91 | 92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp 93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) { 95 | float t = exp2f(z.real_); 96 | float c, s; 97 | sincosf(z.imag_, &s, &c); 98 | return complex_t(c * t, s * t); 99 | } 100 | 101 | __device__ __forceinline__ complex_t cexpf(complex_t z) { 102 | float t = expf(z.real_); 103 | float c, s; 104 | sincosf(z.imag_, &s, &c); 105 | return complex_t(c * t, s * t); 106 | } 107 | 108 | template struct SSMScanOp; 109 | 110 | template<> 111 | struct SSMScanOp { 112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 114 | } 115 | }; 116 | 117 | template<> 118 | struct SSMScanOp { 119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { 120 | complex_t a0 = complex_t(ab0.x, ab0.y); 121 | complex_t b0 = complex_t(ab0.z, ab0.w); 122 | complex_t a1 = complex_t(ab1.x, ab1.y); 123 | complex_t b1 = complex_t(ab1.z, ab1.w); 124 | complex_t out_a = a1 * a0; 125 | complex_t out_b = a1 * b0 + b1; 126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); 127 | } 128 | }; 129 | 130 | // A stateful callback functor that maintains a running prefix to be applied 131 | // during consecutive scan operations. 132 | template struct SSMScanPrefixCallbackOp { 133 | using scan_t = std::conditional_t, float2, float4>; 134 | scan_t running_prefix; 135 | // Constructor 136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 137 | // Callback operator to be entered by the first warp of threads in the block. 138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 139 | __device__ scan_t operator()(scan_t block_aggregate) { 140 | scan_t old_prefix = running_prefix; 141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 142 | return old_prefix; 143 | } 144 | }; 145 | 146 | //////////////////////////////////////////////////////////////////////////////////////////////////// 147 | 148 | template 149 | inline __device__ void load_input(typename Ktraits::input_t *u, 150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 151 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 152 | int seqlen) { 153 | if constexpr (Ktraits::kIsEvenLen) { 154 | auto& smem_load_vec = reinterpret_cast(smem_load); 155 | using vec_t = typename Ktraits::vec_t; 156 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 157 | reinterpret_cast(u), 158 | reinterpret_cast(u_vals) 159 | ); 160 | } else { 161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 162 | } 163 | } 164 | 165 | template 166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 169 | int seqlen) { 170 | constexpr int kNItems = Ktraits::kNItems; 171 | if constexpr (!Ktraits::kIsComplex) { 172 | typename Ktraits::input_t B_vals_load[kNItems]; 173 | if constexpr (Ktraits::kIsEvenLen) { 174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 175 | using vec_t = typename Ktraits::vec_t; 176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 177 | reinterpret_cast(Bvar), 178 | reinterpret_cast(B_vals_load) 179 | ); 180 | } else { 181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 182 | } 183 | // #pragma unroll 184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 185 | Converter::to_float(B_vals_load, B_vals); 186 | } else { 187 | typename Ktraits::input_t B_vals_load[kNItems * 2]; 188 | if constexpr (Ktraits::kIsEvenLen) { 189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 190 | using vec_t = typename Ktraits::vec_t; 191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 192 | reinterpret_cast(Bvar), 193 | reinterpret_cast(B_vals_load) 194 | ); 195 | } else { 196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 197 | } 198 | #pragma unroll 199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } 200 | } 201 | } 202 | 203 | template 204 | inline __device__ void store_output(typename Ktraits::input_t *out, 205 | const float (&out_vals)[Ktraits::kNItems], 206 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 207 | int seqlen) { 208 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 209 | #pragma unroll 210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 211 | if constexpr (Ktraits::kIsEvenLen) { 212 | auto& smem_store_vec = reinterpret_cast(smem_store); 213 | using vec_t = typename Ktraits::vec_t; 214 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 215 | reinterpret_cast(out), 216 | reinterpret_cast(write_vals) 217 | ); 218 | } else { 219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mamba-attn 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - pip=22.3.1 8 | - cudatoolkit=11.8 9 | - pip: 10 | - torch==2.1.1 11 | - torchvision 12 | - packaging 13 | - tb-nightly 14 | - gradio==3.33.1 15 | - albumentations==1.3.0 16 | - opencv-contrib-python 17 | - imageio==2.9.0 18 | - imageio-ffmpeg==0.4.2 19 | - pytorch-lightning==1.5.0 20 | - omegaconf==2.3.0 21 | - test-tube>=0.7.5 22 | - streamlit==1.12.1 23 | - einops==0.6.0 24 | - transformers==4.36.2 25 | - webdataset==0.2.5 26 | - kornia==0.6 27 | - open_clip_torch==2.16.0 28 | - invisible-watermark>=0.1.5 29 | - streamlit-drawable-canvas==0.8.0 30 | - torchmetrics==0.6.0 31 | - timm==0.6.12 32 | - addict==2.4.0 33 | - yapf==0.32.0 34 | - prettytable==3.6.0 35 | - safetensors==0.3.1 36 | - basicsr==1.4.2 37 | - accelerate==0.17.0 38 | - decord==0.6.0 39 | - diffusers==0.25 40 | - moviepy==1.0.3 41 | - opencv_python==4.7.0.68 42 | - Pillow==9.4.0 43 | - scikit_image==0.19.3 44 | - scipy==1.10.1 45 | - tensorboardX==2.6 46 | - tqdm==4.64.1 47 | - numpy 48 | - datasets 49 | - ipython 50 | - seaborn 51 | - pycocotools 52 | - ipdb 53 | - matplotlib 54 | - flow_vis 55 | - charset-normalizer 56 | - ml_collections 57 | - wandb 58 | - causal-conv1d==1.2.0.post2 59 | - galore-torch -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import multiprocessing as mp 7 | import accelerate 8 | import utils 9 | import sde 10 | from uvit_datasets import get_dataset 11 | import tempfile 12 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 13 | from absl import logging 14 | import builtins 15 | 16 | from mamba_attn_diff.models.upsample_guidance import make_ufg_nnet 17 | 18 | def evaluate(config): 19 | if config.get('benchmark', False): 20 | torch.backends.cudnn.benchmark = True 21 | torch.backends.cudnn.deterministic = False 22 | 23 | mp.set_start_method('spawn') 24 | accelerator = accelerate.Accelerator() 25 | device = accelerator.device 26 | accelerate.utils.set_seed(config.seed, device_specific=True) 27 | logging.info(f'Process {accelerator.process_index} using device: {device}') 28 | 29 | config.mixed_precision = accelerator.mixed_precision 30 | config = ml_collections.FrozenConfigDict(config) 31 | if accelerator.is_main_process: 32 | utils.set_logger(log_level='info', fname=config.output_path) 33 | else: 34 | utils.set_logger(log_level='error') 35 | builtins.print = lambda *args: None 36 | 37 | dataset = get_dataset(**config.dataset) 38 | 39 | nnet = utils.get_nnet(**config.nnet) 40 | nnet = accelerator.prepare(nnet) 41 | logging.info(f'load nnet from {config.nnet_path}') 42 | nnet = accelerator.unwrap_model(nnet) 43 | nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 44 | nnet.eval() 45 | print(nnet, (config.sample.algorithm if config.get('scheduler', False) else 'dpm_solver')) 46 | 47 | def cfg_nnet(x, timestep, y, **kwargs): 48 | _cond = nnet(x, timestep, y=y, **kwargs) 49 | _uncond = nnet(x, timestep, y=torch.tensor([dataset.K] * x.size(0), device=device), **kwargs) 50 | _cond = _cond.sample if not isinstance(_cond, torch.Tensor) else _cond 51 | _uncond = _uncond.sample if not isinstance(_uncond, torch.Tensor) else _uncond 52 | return _cond + config.sample.scale * (_cond - _uncond) 53 | 54 | def uncfg_nnet(x, timestep, y=None, **kwargs): 55 | _uncfg = nnet(x, timestep, **kwargs) 56 | _uncfg = _uncfg.sample if not isinstance(_uncfg, torch.Tensor) else _uncfg 57 | return _uncfg 58 | 59 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 60 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 61 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) 62 | else: 63 | score_model = sde.ScoreModel(uncfg_nnet, pred=config.pred, sde=sde.VPSDE()) 64 | 65 | 66 | logging.info(config.sample) 67 | assert os.path.exists(dataset.fid_stat) 68 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 69 | 70 | def sample_fn(_n_samples): 71 | if config.sample.algorithm == 'dpm_solver_upsample_g': 72 | m = 2 73 | data_shape = tuple([dataset.data_shape[0]] + [ i*m for i in dataset.data_shape[1:] ]) 74 | x_init = torch.randn(_n_samples, *data_shape, device=device) 75 | else: 76 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device) 77 | 78 | if config.train.mode == 'uncond': 79 | kwargs = dict() 80 | elif config.train.mode == 'cond': 81 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 82 | else: 83 | raise NotImplementedError 84 | 85 | if config.sample.algorithm == 'euler_maruyama_sde': 86 | rsde = sde.ReverseSDE(score_model) 87 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 88 | elif config.sample.algorithm == 'euler_maruyama_ode': 89 | rsde = sde.ODE(score_model) 90 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 91 | elif config.sample.algorithm == 'dpm_solver_upsample_g': 92 | noise_schedule = NoiseScheduleVP(schedule='linear') 93 | sde_entity = sde.VPSDE() 94 | 95 | normed_timesteps = torch.arange(1000, dtype=x_init.dtype, device=device).flip(0) / 999 96 | normed_timesteps[-1] = 1e-5 97 | model_fn = make_ufg_nnet( 98 | cfg_nnet, 99 | uncfg_nnet, 100 | normed_timesteps, 101 | sde_entity.cum_alpha, 102 | sde_entity.cum_beta, 103 | sde_entity.snr, 104 | m=2, 105 | **kwargs, 106 | ) 107 | 108 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 109 | return dpm_solver.sample( 110 | x_init, 111 | steps=config.sample.sample_steps, 112 | eps=1e-4, 113 | adaptive_step_size=False, 114 | fast_version=True, 115 | ) 116 | elif config.sample.algorithm == 'dpm_solver': 117 | noise_schedule = NoiseScheduleVP(schedule='linear') 118 | model_fn = model_wrapper( 119 | score_model.noise_pred, 120 | noise_schedule, 121 | time_input_type='0', 122 | model_kwargs=kwargs 123 | ) 124 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 125 | return dpm_solver.sample( 126 | x_init, 127 | steps=config.sample.sample_steps, 128 | eps=1e-4, 129 | adaptive_step_size=False, 130 | fast_version=True, 131 | ) 132 | else: 133 | raise NotImplementedError 134 | 135 | with tempfile.TemporaryDirectory() as temp_path: 136 | path = config.sample.path or temp_path 137 | if accelerator.is_main_process: 138 | os.makedirs(path, exist_ok=True) 139 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 140 | if accelerator.is_main_process: 141 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 142 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 143 | 144 | 145 | from absl import flags 146 | from absl import app 147 | from ml_collections import config_flags 148 | import os 149 | 150 | 151 | FLAGS = flags.FLAGS 152 | config_flags.DEFINE_config_file( 153 | "config", None, "Training configuration.", lock_config=False) 154 | flags.mark_flags_as_required(["config"]) 155 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 156 | flags.DEFINE_string("output_path", None, "The path to output log.") 157 | 158 | 159 | def main(argv): 160 | config = FLAGS.config 161 | config.nnet_path = FLAGS.nnet_path 162 | config.output_path = FLAGS.output_path 163 | evaluate(config) 164 | 165 | 166 | if __name__ == "__main__": 167 | app.run(main) 168 | -------------------------------------------------------------------------------- /eval_ldm.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | import utils 7 | import sde 8 | from uvit_datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import libs.autoencoder 14 | 15 | from mamba_attn_diff.models.upsample_guidance import make_ufg_nnet 16 | 17 | def evaluate(config): 18 | if config.get('benchmark', False): 19 | torch.backends.cudnn.benchmark = True 20 | torch.backends.cudnn.deterministic = False 21 | 22 | mp.set_start_method('spawn') 23 | accelerator = accelerate.Accelerator() 24 | device = accelerator.device 25 | accelerate.utils.set_seed(config.seed, device_specific=True) 26 | logging.info(f'Process {accelerator.process_index} using device: {device}') 27 | 28 | config.mixed_precision = accelerator.mixed_precision 29 | config = ml_collections.FrozenConfigDict(config) 30 | if accelerator.is_main_process: 31 | utils.set_logger(log_level='info', fname=config.output_path) 32 | else: 33 | utils.set_logger(log_level='error') 34 | builtins.print = lambda *args: None 35 | 36 | dataset = get_dataset(**config.dataset) 37 | 38 | nnet = utils.get_nnet(**config.nnet) 39 | nnet = accelerator.prepare(nnet) 40 | logging.info(f'load nnet from {config.nnet_path}') 41 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 42 | nnet.eval() 43 | 44 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 45 | autoencoder.to(device) 46 | 47 | @torch.cuda.amp.autocast() 48 | def encode(_batch): 49 | return autoencoder.encode(_batch) 50 | 51 | @torch.cuda.amp.autocast() 52 | def decode(_batch): 53 | return autoencoder.decode(_batch) 54 | 55 | def decode_large_batch(_batch): 56 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 57 | xs = [] 58 | pt = 0 59 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 60 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 61 | pt += _decode_mini_batch_size 62 | xs.append(x) 63 | xs = torch.concat(xs, dim=0) 64 | assert xs.size(0) == _batch.size(0) 65 | return xs 66 | 67 | def cfg_nnet(x, timestep, y, **kwargs): 68 | _cond = nnet(x, timestep, y=y, **kwargs) 69 | _uncond = nnet(x, timestep, y=torch.tensor([dataset.K] * x.size(0), device=device), **kwargs) 70 | _cond = _cond.sample if not isinstance(_cond, torch.Tensor) else _cond 71 | _uncond = _uncond.sample if not isinstance(_uncond, torch.Tensor) else _uncond 72 | return _cond + config.sample.scale * (_cond - _uncond) 73 | 74 | def uncfg_nnet(x, timestep, y=None, **kwargs): 75 | _uncfg = nnet(x, timestep, **kwargs) 76 | _uncfg = _uncfg.sample if not isinstance(_uncfg, torch.Tensor) else _uncfg 77 | return _uncfg 78 | 79 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance 80 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}') 81 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) # 82 | else: 83 | score_model = sde.ScoreModel(uncfg_nnet, pred=config.pred, sde=sde.VPSDE()) 84 | 85 | logging.info(config.sample) 86 | assert os.path.exists(dataset.fid_stat) 87 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}') 88 | 89 | def sample_fn(_n_samples): 90 | if config.sample.algorithm == 'dpm_solver_upsample_g': 91 | m = 2 92 | data_shape = tuple([config.z_shape[0]] + [ int(i*m) for i in config.z_shape[1:] ]) # config.z_shape[1:] 93 | _z_init = torch.randn(_n_samples, *data_shape, device=device) 94 | else: 95 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 96 | 97 | if config.train.mode == 'uncond': 98 | kwargs = dict() 99 | elif config.train.mode == 'cond': 100 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 101 | else: 102 | raise NotImplementedError 103 | 104 | if config.sample.algorithm == 'euler_maruyama_sde': 105 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 106 | elif config.sample.algorithm == 'euler_maruyama_ode': 107 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs) 108 | elif config.sample.algorithm == 'dpm_solver_upsample_g': 109 | noise_schedule = NoiseScheduleVP(schedule='linear') 110 | sde_entity = sde.VPSDE() 111 | 112 | normed_timesteps = torch.arange(1000, dtype=_z_init.dtype, device=device).flip(0) / 999 113 | normed_timesteps[-1] = 1e-5 114 | model_fn = make_ufg_nnet( 115 | cfg_nnet, 116 | uncfg_nnet, 117 | normed_timesteps, 118 | sde_entity.cum_alpha, 119 | sde_entity.cum_beta, 120 | sde_entity.snr, 121 | m=m, 122 | **kwargs, 123 | ) 124 | 125 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 126 | _z = dpm_solver.sample( 127 | _z_init, 128 | steps=config.sample.sample_steps, 129 | eps=1e-4, 130 | adaptive_step_size=False, 131 | fast_version=True, 132 | ) 133 | elif config.sample.algorithm == 'dpm_solver': 134 | noise_schedule = NoiseScheduleVP(schedule='linear') 135 | model_fn = model_wrapper( 136 | score_model.noise_pred, # 137 | noise_schedule, 138 | time_input_type='0', 139 | model_kwargs=kwargs 140 | ) 141 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 142 | _z = dpm_solver.sample( 143 | _z_init, 144 | steps=config.sample.sample_steps, 145 | eps=1e-4, 146 | adaptive_step_size=False, 147 | fast_version=True, 148 | ) 149 | else: 150 | raise NotImplementedError 151 | return decode_large_batch(_z) 152 | 153 | with tempfile.TemporaryDirectory() as temp_path: 154 | path = config.sample.path or temp_path 155 | if accelerator.is_main_process: 156 | os.makedirs(path, exist_ok=True) 157 | logging.info(f'Samples are saved in {path}') 158 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 159 | if accelerator.is_main_process: 160 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 161 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 162 | 163 | 164 | from absl import flags 165 | from absl import app 166 | from ml_collections import config_flags 167 | import os 168 | 169 | 170 | FLAGS = flags.FLAGS 171 | config_flags.DEFINE_config_file( 172 | "config", None, "Training configuration.", lock_config=False) 173 | flags.mark_flags_as_required(["config"]) 174 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 175 | flags.DEFINE_string("output_path", None, "The path to output log.") 176 | 177 | 178 | def main(argv): 179 | config = FLAGS.config 180 | config.nnet_path = FLAGS.nnet_path 181 | config.output_path = FLAGS.output_path 182 | evaluate(config) 183 | 184 | 185 | if __name__ == "__main__": 186 | app.run(main) 187 | -------------------------------------------------------------------------------- /eval_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | from tools.fid_score import calculate_fid_given_paths 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | import accelerate 6 | from torch.utils.data import DataLoader 7 | import utils 8 | from datasets import get_dataset 9 | import tempfile 10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 11 | from absl import logging 12 | import builtins 13 | import einops 14 | import libs.autoencoder 15 | 16 | 17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 18 | _betas = ( 19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 20 | ) 21 | return _betas.numpy() 22 | 23 | 24 | def evaluate(config): 25 | if config.get('benchmark', False): 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.deterministic = False 28 | 29 | mp.set_start_method('spawn') 30 | accelerator = accelerate.Accelerator() 31 | device = accelerator.device 32 | accelerate.utils.set_seed(config.seed, device_specific=True) 33 | logging.info(f'Process {accelerator.process_index} using device: {device}') 34 | 35 | config.mixed_precision = accelerator.mixed_precision 36 | config = ml_collections.FrozenConfigDict(config) 37 | if accelerator.is_main_process: 38 | utils.set_logger(log_level='info', fname=config.output_path) 39 | else: 40 | utils.set_logger(log_level='error') 41 | builtins.print = lambda *args: None 42 | 43 | dataset = get_dataset(**config.dataset) 44 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling 45 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, 46 | drop_last=True, num_workers=8, pin_memory=True, persistent_workers=True) 47 | 48 | nnet = utils.get_nnet(**config.nnet) 49 | nnet, test_dataset_loader = accelerator.prepare(nnet, test_dataset_loader) 50 | logging.info(f'load nnet from {config.nnet_path}') 51 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 52 | nnet.eval() 53 | 54 | def cfg_nnet(x, timesteps, context): 55 | _cond = nnet(x, timesteps, context=context) 56 | if config.sample.scale == 0: 57 | return _cond 58 | _empty_context = torch.tensor(dataset.empty_context, device=device) 59 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 60 | _uncond = nnet(x, timesteps, context=_empty_context) 61 | return _cond + config.sample.scale * (_cond - _uncond) 62 | 63 | autoencoder = libs.autoencoder.get_model(**config.autoencoder) 64 | autoencoder.to(device) 65 | 66 | @torch.cuda.amp.autocast() 67 | def encode(_batch): 68 | return autoencoder.encode(_batch) 69 | 70 | @torch.cuda.amp.autocast() 71 | def decode(_batch): 72 | return autoencoder.decode(_batch) 73 | 74 | def decode_large_batch(_batch): 75 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large 76 | xs = [] 77 | pt = 0 78 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size): 79 | x = decode(_batch[pt: pt + _decode_mini_batch_size]) 80 | pt += _decode_mini_batch_size 81 | xs.append(x) 82 | xs = torch.concat(xs, dim=0) 83 | assert xs.size(0) == _batch.size(0) 84 | return xs 85 | 86 | def get_context_generator(): 87 | while True: 88 | for data in test_dataset_loader: 89 | _, _context = data 90 | yield _context 91 | 92 | context_generator = get_context_generator() 93 | 94 | _betas = stable_diffusion_beta_schedule() 95 | N = len(_betas) 96 | 97 | logging.info(config.sample) 98 | assert os.path.exists(dataset.fid_stat) 99 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}') 100 | 101 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs): 102 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device) 103 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 104 | 105 | def model_fn(x, t_continuous): 106 | t = t_continuous * N 107 | return cfg_nnet(x, t, **kwargs) 108 | 109 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 110 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.) 111 | return decode_large_batch(_z) 112 | 113 | def sample_fn(_n_samples): 114 | _context = next(context_generator) 115 | assert _context.size(0) == _n_samples 116 | return dpm_solver_sample(_n_samples, config.sample.sample_steps, context=_context) 117 | 118 | with tempfile.TemporaryDirectory() as temp_path: 119 | path = config.sample.path or temp_path 120 | if accelerator.is_main_process: 121 | os.makedirs(path, exist_ok=True) 122 | logging.info(f'Samples are saved in {path}') 123 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 124 | if accelerator.is_main_process: 125 | fid = calculate_fid_given_paths((dataset.fid_stat, path)) 126 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}') 127 | 128 | 129 | from absl import flags 130 | from absl import app 131 | from ml_collections import config_flags 132 | import os 133 | 134 | 135 | FLAGS = flags.FLAGS 136 | config_flags.DEFINE_config_file( 137 | "config", None, "Training configuration.", lock_config=False) 138 | flags.mark_flags_as_required(["config"]) 139 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 140 | flags.DEFINE_string("output_path", None, "The path to output log.") 141 | 142 | 143 | def main(argv): 144 | config = FLAGS.config 145 | config.nnet_path = FLAGS.nnet_path 146 | config.output_path = FLAGS.output_path 147 | evaluate(config) 148 | 149 | 150 | if __name__ == "__main__": 151 | app.run(main) 152 | -------------------------------------------------------------------------------- /evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = int(batch_size) if batch_size is not None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # codes from third party 2 | -------------------------------------------------------------------------------- /libs/clip.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import CLIPTokenizer, CLIPTextModel 3 | 4 | 5 | class AbstractEncoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def encode(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | 13 | class FrozenCLIPEmbedder(AbstractEncoder): 14 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 16 | super().__init__() 17 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 18 | self.transformer = CLIPTextModel.from_pretrained(version) 19 | self.device = device 20 | self.max_length = max_length 21 | self.freeze() 22 | 23 | def freeze(self): 24 | self.transformer = self.transformer.eval() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def forward(self, text): 29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 31 | tokens = batch_encoding["input_ids"].to(self.device) 32 | outputs = self.transformer(input_ids=tokens) 33 | 34 | z = outputs.last_hidden_state 35 | return z 36 | 37 | def encode(self, text): 38 | return self(text) 39 | -------------------------------------------------------------------------------- /libs/timm.py: -------------------------------------------------------------------------------- 1 | # code from timm 0.3.2 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import warnings 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 67 | 68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 72 | 'survival rate' as the argument. 73 | 74 | """ 75 | if drop_prob == 0. or not training: 76 | return x 77 | keep_prob = 1 - drop_prob 78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 80 | random_tensor.floor_() # binarize 81 | output = x.div(keep_prob) * random_tensor 82 | return output 83 | 84 | 85 | class DropPath(nn.Module): 86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 87 | """ 88 | def __init__(self, drop_prob=None): 89 | super(DropPath, self).__init__() 90 | self.drop_prob = drop_prob 91 | 92 | def forward(self, x): 93 | return drop_path(x, self.drop_prob, self.training) 94 | 95 | 96 | class Mlp(nn.Module): 97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.fc2 = nn.Linear(hidden_features, out_features) 104 | self.drop = nn.Dropout(drop) 105 | 106 | def forward(self, x): 107 | x = self.fc1(x) 108 | x = self.act(x) 109 | x = self.drop(x) 110 | x = self.fc2(x) 111 | x = self.drop(x) 112 | return x 113 | -------------------------------------------------------------------------------- /libs/uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | from mamba_attn_diff.utils.init_weights import _init_weights_mamba, pos_embed_inteplot 9 | 10 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 11 | ATTENTION_MODE = 'flash' 12 | else: 13 | try: 14 | import xformers 15 | import xformers.ops 16 | ATTENTION_MODE = 'xformers' 17 | except: 18 | ATTENTION_MODE = 'math' 19 | print(f'attention mode is {ATTENTION_MODE}') 20 | # ATTENTION_MODE = 'math' 21 | 22 | def timestep_embedding(timesteps, dim, max_period=10000): 23 | """ 24 | Create sinusoidal timestep embeddings. 25 | 26 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 27 | These may be fractional. 28 | :param dim: the dimension of the output. 29 | :param max_period: controls the minimum frequency of the embeddings. 30 | :return: an [N x dim] Tensor of positional embeddings. 31 | """ 32 | half = dim // 2 33 | freqs = torch.exp( 34 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 35 | ).to(device=timesteps.device) 36 | args = timesteps[:, None].float() * freqs[None] 37 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 38 | if dim % 2: 39 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 40 | return embedding 41 | 42 | 43 | def patchify(imgs, patch_size): 44 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 45 | return x 46 | 47 | 48 | def unpatchify(x, channels=3): 49 | patch_size = int((x.shape[2] // channels) ** 0.5) 50 | h = w = int(x.shape[1] ** .5) 51 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 52 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 53 | return x 54 | 55 | 56 | class Attention(nn.Module): 57 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 58 | super().__init__() 59 | self.num_heads = num_heads 60 | head_dim = dim // num_heads 61 | self.scale = qk_scale or head_dim ** -0.5 62 | 63 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | def forward(self, x): 69 | B, L, C = x.shape 70 | 71 | qkv = self.qkv(x) 72 | if ATTENTION_MODE == 'flash': 73 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 74 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 75 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 76 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 77 | elif ATTENTION_MODE == 'xformers': 78 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 79 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 80 | x = xformers.ops.memory_efficient_attention(q, k, v) 81 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 82 | elif ATTENTION_MODE == 'math': 83 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 84 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 85 | attn = (q @ k.transpose(-2, -1)) * self.scale 86 | attn = attn.softmax(dim=-1) 87 | attn = self.attn_drop(attn) 88 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 89 | else: 90 | raise NotImplemented 91 | 92 | x = self.proj(x) 93 | x = self.proj_drop(x) 94 | return x 95 | 96 | 97 | class Block(nn.Module): 98 | 99 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 100 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 101 | super().__init__() 102 | self.norm1 = norm_layer(dim) 103 | self.attn = Attention( 104 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 105 | self.norm2 = norm_layer(dim) 106 | mlp_hidden_dim = int(dim * mlp_ratio) 107 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 108 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 109 | self.use_checkpoint = use_checkpoint 110 | 111 | def forward(self, x, skip=None): 112 | if self.use_checkpoint: 113 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 114 | else: 115 | return self._forward(x, skip) 116 | 117 | def _forward(self, x, skip=None): 118 | if self.skip_linear is not None: 119 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 120 | x = x + self.attn(self.norm1(x)) 121 | x = x + self.mlp(self.norm2(x)) 122 | return x 123 | 124 | 125 | class PatchEmbed(nn.Module): 126 | """ Image to Patch Embedding 127 | """ 128 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 129 | super().__init__() 130 | self.patch_size = patch_size 131 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 132 | 133 | def forward(self, x): 134 | B, C, H, W = x.shape 135 | assert H % self.patch_size == 0 and W % self.patch_size == 0 136 | x = self.proj(x).flatten(2).transpose(1, 2) 137 | return x 138 | 139 | 140 | class UViT(nn.Module): 141 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 142 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1, 143 | use_checkpoint=False, conv=True, skip=True, is_last_double_channel=False, 144 | hook_intermediate_out=False, **kwargs): 145 | super().__init__() 146 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 147 | self.num_classes = num_classes 148 | self.in_chans = in_chans 149 | self.hook_intermediate_out = hook_intermediate_out 150 | out_chans = in_chans * 3 if is_last_double_channel else in_chans 151 | 152 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 153 | num_patches = (img_size // patch_size) ** 2 154 | self.num_patches = num_patches 155 | 156 | self.time_embed = nn.Sequential( 157 | nn.Linear(embed_dim, 4 * embed_dim), 158 | nn.SiLU(), 159 | nn.Linear(4 * embed_dim, embed_dim), 160 | ) if mlp_time_embed else nn.Identity() 161 | 162 | if self.num_classes > 0: 163 | self.label_emb = nn.Embedding(self.num_classes, embed_dim) 164 | self.extras = 2 165 | else: 166 | self.extras = 1 167 | 168 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 169 | 170 | self.in_blocks = nn.ModuleList([ 171 | Block( 172 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 173 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 174 | for _ in range(depth // 2)]) 175 | 176 | self.mid_block = Block( 177 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 178 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 179 | 180 | self.out_blocks = nn.ModuleList([ 181 | Block( 182 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 183 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 184 | for _ in range(depth // 2)]) 185 | 186 | self.norm = norm_layer(embed_dim) 187 | self.patch_dim = patch_size ** 2 * in_chans 188 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 189 | self.final_layer = nn.Conv2d(self.in_chans, out_chans, 3, padding=1) if conv else nn.Identity() 190 | 191 | trunc_normal_(self.pos_embed, std=.02) 192 | self.apply(self._init_weights) 193 | 194 | def _init_weights(self, m): 195 | if isinstance(m, nn.Linear): 196 | trunc_normal_(m.weight, std=.02) 197 | if isinstance(m, nn.Linear) and m.bias is not None: 198 | nn.init.constant_(m.bias, 0) 199 | elif isinstance(m, nn.LayerNorm): 200 | nn.init.constant_(m.bias, 0) 201 | nn.init.constant_(m.weight, 1.0) 202 | 203 | @torch.jit.ignore 204 | def no_weight_decay(self): 205 | return {'pos_embed'} 206 | 207 | def forward(self, x, timesteps, y=None, use_pos_inteplot=False): 208 | x = self.patch_embed(x) 209 | B, L, D = x.shape 210 | 211 | height, width = int(L**0.5), int(L**0.5) 212 | 213 | if self.hook_intermediate_out and self.training: 214 | intermediate_out = dict() 215 | intermediate_out['intermediate_outputs'] = [] 216 | layer_count = 0 217 | 218 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 219 | time_token = time_token.unsqueeze(dim=1) 220 | x = torch.cat((time_token, x), dim=1) 221 | 222 | if y is not None: 223 | label_emb = self.label_emb(y) 224 | label_emb = label_emb.unsqueeze(dim=1) 225 | x = torch.cat((label_emb, x), dim=1) 226 | 227 | pos_embed = self.pos_embed 228 | if use_pos_inteplot: 229 | pos_embed = pos_embed_inteplot( 230 | cur_pos_embed=None, 231 | pretrained_pos_embed=pos_embed, 232 | extra_len=self.extras, 233 | cur_size=(height, width), 234 | ) 235 | 236 | x = x + pos_embed 237 | 238 | if self.hook_intermediate_out and self.training: 239 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L)) 240 | layer_count += 1 241 | 242 | skips = [] 243 | for blk in self.in_blocks: 244 | x = blk(x) 245 | skips.append(x) 246 | 247 | if self.hook_intermediate_out and self.training: 248 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L)) 249 | layer_count += 1 250 | 251 | x = self.mid_block(x) 252 | 253 | if self.hook_intermediate_out and self.training: 254 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L)) 255 | layer_count += 1 256 | 257 | for idx, blk in enumerate(self.out_blocks): 258 | x = blk(x, skips.pop()) 259 | 260 | if self.hook_intermediate_out and self.training and idx < len(self.out_blocks) - 1: 261 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L)) 262 | layer_count += 1 263 | 264 | x = self.predict_head(x, L) 265 | 266 | if self.hook_intermediate_out and self.training: 267 | intermediate_out['num_pred_layer'] = layer_count 268 | return x, intermediate_out 269 | else: 270 | return x 271 | 272 | def predict_head(self, x, L): 273 | x = self.norm(x) 274 | x = self.decoder_pred(x) 275 | x = x[:, -L:, :] 276 | x = unpatchify(x, self.in_chans) 277 | x = self.final_layer(x) 278 | return x 279 | -------------------------------------------------------------------------------- /libs/uvit_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .timm import trunc_normal_, Mlp 5 | import einops 6 | import torch.utils.checkpoint 7 | 8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): 9 | ATTENTION_MODE = 'flash' 10 | else: 11 | try: 12 | import xformers 13 | import xformers.ops 14 | ATTENTION_MODE = 'xformers' 15 | except: 16 | ATTENTION_MODE = 'math' 17 | print(f'attention mode is {ATTENTION_MODE}') 18 | 19 | 20 | def timestep_embedding(timesteps, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | 24 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 25 | These may be fractional. 26 | :param dim: the dimension of the output. 27 | :param max_period: controls the minimum frequency of the embeddings. 28 | :return: an [N x dim] Tensor of positional embeddings. 29 | """ 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 33 | ).to(device=timesteps.device) 34 | args = timesteps[:, None].float() * freqs[None] 35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 36 | if dim % 2: 37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 38 | return embedding 39 | 40 | 41 | def patchify(imgs, patch_size): 42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) 43 | return x 44 | 45 | 46 | def unpatchify(x, channels=3): 47 | patch_size = int((x.shape[2] // channels) ** 0.5) 48 | h = w = int(x.shape[1] ** .5) 49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] 50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 56 | super().__init__() 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, L, C = x.shape 68 | 69 | qkv = self.qkv(x) 70 | if ATTENTION_MODE == 'flash': 71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() 72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 74 | x = einops.rearrange(x, 'B H L D -> B L (H D)') 75 | elif ATTENTION_MODE == 'xformers': 76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) 77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D 78 | x = xformers.ops.memory_efficient_attention(q, k, v) 79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) 80 | elif ATTENTION_MODE == 'math': 81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) 82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C) 87 | else: 88 | raise NotImplemented 89 | 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None 107 | self.use_checkpoint = use_checkpoint 108 | 109 | def forward(self, x, skip=None): 110 | if self.use_checkpoint: 111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip) 112 | else: 113 | return self._forward(x, skip) 114 | 115 | def _forward(self, x, skip=None): 116 | if self.skip_linear is not None: 117 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 118 | x = x + self.attn(self.norm1(x)) 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, patch_size, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | self.patch_size = patch_size 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | assert H % self.patch_size == 0 and W % self.patch_size == 0 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class UViT(nn.Module): 139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False, 141 | clip_dim=768, num_clip_token=77, conv=True, skip=True): 142 | super().__init__() 143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 144 | self.in_chans = in_chans 145 | 146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | num_patches = (img_size // patch_size) ** 2 148 | 149 | self.time_embed = nn.Sequential( 150 | nn.Linear(embed_dim, 4 * embed_dim), 151 | nn.SiLU(), 152 | nn.Linear(4 * embed_dim, embed_dim), 153 | ) if mlp_time_embed else nn.Identity() 154 | 155 | self.context_embed = nn.Linear(clip_dim, embed_dim) 156 | 157 | self.extras = 1 + num_clip_token 158 | 159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim)) 160 | 161 | self.in_blocks = nn.ModuleList([ 162 | Block( 163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 165 | for _ in range(depth // 2)]) 166 | 167 | self.mid_block = Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint) 170 | 171 | self.out_blocks = nn.ModuleList([ 172 | Block( 173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) 175 | for _ in range(depth // 2)]) 176 | 177 | self.norm = norm_layer(embed_dim) 178 | self.patch_dim = patch_size ** 2 * in_chans 179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) 180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity() 181 | 182 | trunc_normal_(self.pos_embed, std=.02) 183 | self.apply(self._init_weights) 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | trunc_normal_(m.weight, std=.02) 188 | if isinstance(m, nn.Linear) and m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.LayerNorm): 191 | nn.init.constant_(m.bias, 0) 192 | nn.init.constant_(m.weight, 1.0) 193 | 194 | @torch.jit.ignore 195 | def no_weight_decay(self): 196 | return {'pos_embed'} 197 | 198 | def forward(self, x, timesteps, context): 199 | x = self.patch_embed(x) 200 | B, L, D = x.shape 201 | 202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) 203 | time_token = time_token.unsqueeze(dim=1) 204 | context_token = self.context_embed(context) 205 | x = torch.cat((time_token, context_token, x), dim=1) 206 | x = x + self.pos_embed 207 | 208 | skips = [] 209 | for blk in self.in_blocks: 210 | x = blk(x) 211 | skips.append(x) 212 | 213 | x = self.mid_block(x) 214 | 215 | for blk in self.out_blocks: 216 | x = blk(x, skips.pop()) 217 | 218 | x = self.norm(x) 219 | x = self.decoder_pred(x) 220 | assert x.size(1) == self.extras + L 221 | x = x[:, self.extras:, :] 222 | x = unpatchify(x, self.in_chans) 223 | x = self.final_layer(x) 224 | return x 225 | -------------------------------------------------------------------------------- /main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/main.pdf -------------------------------------------------------------------------------- /main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/main.png -------------------------------------------------------------------------------- /mamba_attn_diff/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_attn_diff/models/__init__.py -------------------------------------------------------------------------------- /mamba_attn_diff/models/freeu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fftn, fftshift, ifftn, ifftshift 3 | 4 | 5 | def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int): 6 | """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). 7 | 8 | This version of the method comes from here: 9 | https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 10 | """ 11 | x = x_in 12 | B, L, C = x.shape 13 | 14 | # Non-power of 2 images must be float32 15 | if (L & (L - 1)) != 0 : 16 | x = x.to(dtype=torch.float32) 17 | 18 | # FFT 19 | x_freq = fftn(x, dim=1) 20 | x_freq = fftshift(x_freq, dim=1) 21 | 22 | B, L, C = x_freq.shape 23 | mask = torch.ones((B, L, C), device=x.device) 24 | 25 | crow = L // 2 26 | mask[..., crow - threshold : crow + threshold, :] = scale 27 | x_freq = x_freq * mask 28 | 29 | # IFFT 30 | x_freq = ifftshift(x_freq, dim=1) 31 | x_filtered = ifftn(x_freq, dim=1).real 32 | 33 | return x_filtered.to(dtype=x_in.dtype) 34 | 35 | 36 | def apply_freeu( 37 | resolution_idx, 38 | hidden_states, res_hidden_states, 39 | s1=0.6, s2=0.4, b1=1.1, b2=1.2, 40 | encoder_start_blk_id=1, 41 | num_layers=49, 42 | extra_len=2, 43 | ): 44 | """Applies the FreeU mechanism as introduced in https: 45 | //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. 46 | 47 | Args: 48 | resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. 49 | hidden_states (`torch.Tensor`): Inputs to the underlying block. 50 | res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. 51 | s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. 52 | s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. 53 | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. 54 | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. 55 | 56 | pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) 57 | 58 | for i, up_block_type in enumerate(up_block_types): 59 | resolution_idx=i, 60 | """ 61 | # if resolution_idx == encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 0: 62 | # # print(resolution_idx) 63 | # num_half_channels = hidden_states.shape[-1] // 2 64 | # hidden_states[..., :num_half_channels] = hidden_states[..., :num_half_channels] * b1 65 | # res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=s1, ) 66 | # elif resolution_idx == encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 1: 67 | # # print(resolution_idx) 68 | # num_half_channels = hidden_states.shape[-1] // 2 69 | # hidden_states[..., :num_half_channels] = hidden_states[..., :num_half_channels] * b2 70 | # res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=s2, ) 71 | 72 | if resolution_idx == encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 0: 73 | s = s1 74 | b = b1 75 | elif resolution_idx <= encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 1: 76 | s = s2 77 | b = b2 78 | 79 | if resolution_idx >= encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 0 and \ 80 | resolution_idx <= encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 1 : 81 | 82 | hidden_mean = hidden_states[:, extra_len:, :].mean(-1).unsqueeze(-1) 83 | B = hidden_mean.shape[0] 84 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1) 85 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1) 86 | hidden_max = hidden_max.unsqueeze(-1).unsqueeze(-1) 87 | hidden_min = hidden_min.unsqueeze(-1).unsqueeze(-1) 88 | hidden_mean = (hidden_mean - hidden_min) / (hidden_max - hidden_min) 89 | 90 | 91 | hidden_states[:, extra_len:, :] = hidden_states[:, extra_len:, :] * ((b - 1 ) * hidden_mean + 1) 92 | res_hidden_states[:, extra_len:, :] = fourier_filter(res_hidden_states[:, extra_len:, :], threshold=1, scale=s, ) 93 | 94 | 95 | 96 | return hidden_states, res_hidden_states -------------------------------------------------------------------------------- /mamba_attn_diff/models/upsample_guidance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange 6 | 7 | # Upsample Guidance: Scale Up Diffusion Models without Training: https://arxiv.org/abs/2404.01709 8 | 9 | def model_wrapper(model, noise_schedule=None, is_cond_classifier=False, classifier_fn=None, classifier_scale=1., 10 | time_input_type='1', total_N=1000, model_kwargs={}): 11 | 12 | def get_model_input_time(t_continuous): 13 | """ 14 | Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. 15 | """ 16 | if time_input_type == '0': 17 | # discrete_type == '0' means that the model is continuous-time model. 18 | # For continuous-time DPMs, the continuous time equals to the discrete time. 19 | return t_continuous 20 | elif time_input_type == '1': 21 | # Type-1 discrete label, as detailed in the Appendix of DPM-Solver. 22 | return 1000. * torch.max(t_continuous - 1. / total_N, torch.zeros_like(t_continuous).to(t_continuous)) 23 | elif time_input_type == '2': 24 | # Type-2 discrete label, as detailed in the Appendix of DPM-Solver. 25 | max_N = (total_N - 1) / total_N * 1000. 26 | return max_N * t_continuous 27 | else: 28 | raise ValueError("Unsupported time input type {}, must be '0' or '1' or '2'".format(time_input_type)) 29 | 30 | def cond_fn(x, t_discrete, y): 31 | """ 32 | Compute the gradient of the classifier, multiplied with the sclae of the classifier guidance. 33 | """ 34 | assert y is not None 35 | with torch.enable_grad(): 36 | x_in = x.detach().requires_grad_(True) 37 | logits = classifier_fn(x_in, t_discrete) 38 | log_probs = F.log_softmax(logits, dim=-1) 39 | selected = log_probs[range(len(logits)), y.view(-1)] 40 | return classifier_scale * torch.autograd.grad(selected.sum(), x_in)[0] 41 | 42 | def model_fn(x, t_continuous): 43 | """ 44 | The noise predicition model function that is used for DPM-Solver. 45 | """ 46 | if is_cond_classifier: 47 | y = model_kwargs.get("y", None) 48 | if y is None: 49 | raise ValueError("For classifier guidance, the label y has to be in the input.") 50 | t_discrete = get_model_input_time(t_continuous) 51 | noise_uncond = model(x, t_discrete, **model_kwargs) 52 | noise_uncond = noise_uncond.sample if not isinstance(noise_uncond, torch.Tensor) else noise_uncond 53 | cond_grad = cond_fn(x, t_discrete, y) 54 | sigma_t = noise_schedule.marginal_std(t_continuous) 55 | dims = len(cond_grad.shape) - 1 56 | return noise_uncond - sigma_t[(...,) + (None,) * dims] * cond_grad 57 | else: 58 | t_discrete = get_model_input_time(t_continuous) 59 | model_output = model(x, t_discrete, **model_kwargs) 60 | model_output = model_output.sample if not isinstance(model_output, torch.Tensor) else model_output 61 | return model_output 62 | 63 | return model_fn 64 | 65 | def _init_taus(snrs, timesteps, m=2): 66 | scaled_snrs = snrs * m 67 | approx_snr, taus = torch.min( 68 | (scaled_snrs.reshape(-1, 1) < snrs.reshape(1, -1)) * snrs.reshape(1, -1), 69 | dim=-1, 70 | ) 71 | taus[approx_snr == 0] = 0 72 | taus = timesteps[taus] 73 | return taus 74 | 75 | def get_tau(preset_snrs, snr_func, t_continuous, m=2, return_indices=False): 76 | cur_scaled_snr = snr_func(t_continuous) * (m**2) 77 | cur_scaled_snr = cur_scaled_snr.reshape(-1, 1) 78 | preset_snrs = preset_snrs.reshape(1, -1) 79 | tau = (cur_scaled_snr - preset_snrs).abs().min(dim=-1).indices 80 | if return_indices: 81 | return tau 82 | 83 | tau = 1 - tau / (preset_snrs.shape[-1] - 1 ) 84 | tau = tau.clamp(max=t_continuous[0]) 85 | return tau 86 | 87 | def _init_timestep_idx_map(timesteps): 88 | timestep_idx_map = { t.item(): idx for idx, t in enumerate(timesteps) } 89 | return timestep_idx_map 90 | 91 | def make_ufg_nnet( 92 | cfg_model_func, uncfg_model_func, timesteps, alpha_t_func, beta_t_func, snr_func, 93 | N=999, m=2, 94 | ug_theta=1, ug_eta = 0.3, ug_T = 1, 95 | **kwargs, 96 | ): 97 | # print(timesteps, 'timesteps') 98 | preset_snrs = snr_func(timesteps) 99 | 100 | def ufg_nnet(x, t_continuous): 101 | 102 | tau_continuous = get_tau(preset_snrs, snr_func, t_continuous, m=m) 103 | t = t_continuous * N 104 | tau = tau_continuous * N 105 | 106 | cfg_pred = cfg_model_func(x, timestep=t, use_pos_inteplot=True, **kwargs) 107 | cfg_pred = cfg_pred.sample if not isinstance(cfg_pred, torch.Tensor) else cfg_pred 108 | 109 | w_t = append_dims(get_wt(t_continuous, theta=ug_theta, eta=ug_eta, T=ug_T), len(x.shape)) 110 | 111 | if (w_t > 0).any(): 112 | print([tau[0], t[0], alpha_t_func( t_continuous[0] ), beta_t_func( t_continuous[0] ), t_continuous], ug_theta, ug_eta, ug_T ) 113 | 114 | resized_pred_noise = get_resized_input_predicted_noise( 115 | cfg_model_func, x, 116 | alpha_t_func( t_continuous ), 117 | beta_t_func( t_continuous ), 118 | tau, 119 | m=m, 120 | **kwargs, 121 | ) 122 | 123 | g_pred = cfg_pred + w_t * upsample_guidance( 124 | resized_pred_noise, 125 | cfg_pred, #uncon_pred, 126 | m=m, 127 | ) 128 | else: 129 | print(t_continuous ) 130 | g_pred = cfg_pred 131 | 132 | return g_pred 133 | 134 | return ufg_nnet 135 | 136 | def slide_denoise(model_func, x, t, m=2, y=None, **kwargs): 137 | x = rearrange(x, 'b c (p h) (q w) -> (b p q) c h w ', p=m, q=m) 138 | t = t.reshape(-1, 1).expand(-1, m*m).reshape(-1) 139 | y = y.reshape(-1, 1).expand(-1, m*m).reshape(-1) if y is not None else None 140 | model_output = model_func(x, timestep=t, y=y, **kwargs) 141 | model_output = model_output.sample if not isinstance(model_output, torch.Tensor) else model_output 142 | model_output = rearrange(model_output, '(b p q) c h w -> b c (p h) (q w)', p=m, q=m) 143 | return model_output 144 | 145 | def upsample_guidance(resized_input_predicted_noise, predicted_noise, m=2): 146 | g_pred = resize_func( 147 | resized_input_predicted_noise / m - resize_func( 148 | predicted_noise, scale=1/m, 149 | ), 150 | scale = m, 151 | ) 152 | return g_pred 153 | 154 | def get_resized_input_predicted_noise(model_func, x, alpha_t, beta_t, tau, m=2, **kwargs): 155 | p = get_P(alpha_t, beta_t, m=m) 156 | p = append_dims(p, len(x.shape)) 157 | x = resize_func(x, scale=1/m) / ( 158 | p ** 0.5 159 | ) 160 | model_output = model_func(x, timestep=tau, use_pos_inteplot=True, **kwargs) 161 | return model_output.sample if not isinstance(model_output, torch.Tensor) else model_output 162 | 163 | def get_wt(t, theta=1, eta=0.6, T=1): 164 | def h(x): 165 | return 1*(x >= 0) 166 | 167 | w_t = theta * h(t - (1 - eta)*T) 168 | return w_t 169 | 170 | def get_P(alpha_t, beta_t, m=2, ): 171 | return alpha_t + beta_t / (m**2) 172 | 173 | def resize_func(x, scale): 174 | if scale > 1: 175 | return F.interpolate( 176 | x, 177 | scale_factor=scale, 178 | mode='nearest', #align_corners=False, 179 | ) 180 | elif scale < 1: 181 | rscale = int(1 / scale) 182 | x = rearrange(x, 'b c (h p) (w q) -> b c h w (p q)', p=rscale, q=rscale ) 183 | x = x.mean(-1) 184 | return x 185 | else: 186 | return x 187 | 188 | def append_dims(x, new_shape_len): 189 | return x.reshape(*x.shape, *([1]*(new_shape_len - len(x.shape)))) -------------------------------------------------------------------------------- /mamba_attn_diff/utils/backup_code.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import torch 5 | import shutil 6 | import logging 7 | import datetime 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | UNSAVED_DIRS = ['outputs', 'checkpoint', 'checkpoints', 'workdir', 'build', '.git', '__pycache__', 'assets', 'samples'] 11 | 12 | def backup_code(work_dir, verbose=False): 13 | base_dir = './' #os.path.dirname(os.path.abspath(__file__)) 14 | 15 | dir_list = ["*.py", ] 16 | for file in os.listdir(base_dir): 17 | sub_dir = os.path.join(base_dir, file) 18 | if os.path.isdir(sub_dir): 19 | if file in UNSAVED_DIRS: 20 | continue 21 | 22 | for root, dirs, files in os.walk(sub_dir): 23 | for dir_name in dirs: 24 | dir_list.append(os.path.join(root, dir_name)+"/*.py") 25 | 26 | elif file.split('.')[-1] == 'py': 27 | pass 28 | 29 | # print(dir_list) 30 | 31 | for pattern in dir_list: 32 | for file in glob.glob(pattern): 33 | src = os.path.join(base_dir, file) 34 | dst = os.path.join(work_dir, 'backup', os.path.dirname(file)) 35 | # print(base_dir, src, dst) 36 | 37 | if verbose: 38 | logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst))) 39 | 40 | os.makedirs(dst, exist_ok=True) 41 | shutil.copy2(src, dst) 42 | 43 | 44 | -------------------------------------------------------------------------------- /mamba_attn_diff/utils/init_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | def pos_embed_into_2d(pos_embed, extra_len): 10 | extra_tokens = pos_embed[:, :extra_len] 11 | pos_embed = pos_embed[:, extra_len:] 12 | num_patches = pos_embed.shape[-2] 13 | H = int(num_patches ** 0.5) 14 | W = num_patches // H 15 | assert H * W == num_patches 16 | pos_embed = pos_embed.reshape( 17 | pos_embed.shape[0], H, W, -1) 18 | return pos_embed, extra_tokens 19 | 20 | def pos_embed_inteplot(cur_pos_embed, pretrained_pos_embed, extra_len, cur_size=None): 21 | if cur_pos_embed is not None: 22 | cur_pos_embed, _ = pos_embed_into_2d(cur_pos_embed, extra_len) 23 | cur_size = cur_pos_embed.shape[-3:-1] 24 | 25 | ori_pretrained_pos_embed = pretrained_pos_embed 26 | pretrained_pos_embed, extra_tokens = pos_embed_into_2d(pretrained_pos_embed, extra_len) 27 | if pretrained_pos_embed.shape[-3:-1] == cur_size: 28 | return ori_pretrained_pos_embed 29 | 30 | pretrained_pos_embed = F.interpolate( 31 | rearrange(pretrained_pos_embed, 'b h w d -> b d h w'), 32 | size= cur_size, 33 | mode='bilinear', align_corners=False, 34 | ) 35 | pretrained_pos_embed = rearrange(pretrained_pos_embed, 'b d h w -> b (h w) d') 36 | pretrained_pos_embed = torch.cat([extra_tokens, pretrained_pos_embed], dim=-2) 37 | return pretrained_pos_embed 38 | 39 | def _init_weight_norm_fc_conv(model): 40 | for name, module in model.named_modules(): 41 | if isinstance_str(module, "Conv2d") and ("adapter" in name): 42 | module.__class__ = make_weight_norm_conv2d_nobias(module.__class__) 43 | module._init_scale() 44 | 45 | if isinstance_str(module, "Linear") and ("adapter" in name): 46 | module.__class__ = make_weight_norm_fc_nobias(module.__class__) 47 | module._init_scale() 48 | 49 | return model 50 | 51 | def make_weight_norm_conv2d_nobias(block_class): 52 | class WNConv2d(nn.Conv2d): 53 | def _init_scale(self): 54 | self.scale = nn.Parameter(self.weight.new_ones(*self.weight.shape[1:])) 55 | 56 | def forward(self, x): 57 | 58 | fan_in = self.weight[0].numel() 59 | weight = weight_normalize(self.weight) / np.sqrt(fan_in) * self.scale 60 | 61 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 62 | 63 | return WNConv2d 64 | 65 | def make_weight_norm_fc_nobias(block_class): 66 | class WNLinear(nn.Linear): 67 | def _init_scale(self): 68 | self.scale = nn.Parameter(self.weight.new_ones(*self.weight.shape[1:])) 69 | 70 | def forward(self, x): 71 | 72 | fan_in = self.weight[0].numel() 73 | weight = weight_normalize(self.weight) / np.sqrt(fan_in) * self.scale 74 | 75 | return F.linear(x, weight, self.bias) 76 | 77 | return WNLinear 78 | 79 | def weight_normalize(x, eps=1e-4): 80 | dim = list(range(1, x.ndim)) 81 | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) 82 | alpha = np.sqrt(n.numel() / x.numel()) 83 | return x / torch.add(eps, n, alpha=alpha) 84 | 85 | def _init_weights_mamba( 86 | module, 87 | n_layer, 88 | initializer_range=0.02, # Now only used for embedding layer. 89 | rescale_prenorm_residual=True, 90 | n_residuals_per_layer=1, # Change to 2 if we have MLP 91 | ): 92 | if isinstance(module, nn.Linear): 93 | if module.bias is not None: 94 | if not getattr(module.bias, "_no_reinit", False): 95 | nn.init.zeros_(module.bias) 96 | elif isinstance(module, nn.Embedding): 97 | nn.init.normal_(module.weight, std=initializer_range) 98 | 99 | if rescale_prenorm_residual: 100 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 101 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 102 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 103 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 104 | # 105 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 106 | for name, p in module.named_parameters(): 107 | if name in ["out_proj.weight", "fc2.weight"]: 108 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 109 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 110 | # We need to reinit p since this code could be called multiple times 111 | # Having just p *= scale would repeatedly scale it down 112 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 113 | with torch.no_grad(): 114 | p /= math.sqrt(n_residuals_per_layer * n_layer) 115 | 116 | def init_embedders(model): 117 | 118 | for name, module in model.named_modules(): 119 | if 'class_embedder' in name.lower(): 120 | if isinstance(module, nn.Embedding): 121 | nn.init.normal_(module.weight, std=0.02) 122 | # print('class_embedder', module.weight.shape) 123 | elif 'timestep_embedder' in name.lower(): 124 | if isinstance(module, nn.Linear) : 125 | nn.init.normal_(module.weight, std=0.02) 126 | 127 | def init_adaLN_modulation_layers(model): 128 | 129 | for name, module in model.named_modules(): 130 | 131 | if 'blocks' in name.lower() and ('norm' in name.lower() or "adaln_single" in name.lower()): 132 | if isinstance(module, nn.Linear) : 133 | nn.init.constant_(module.weight, 0) 134 | nn.init.constant_(module.bias, 0) 135 | 136 | def initialize_weights(model): 137 | # Initialize transformer layers: 138 | for name, module in model.named_modules(): 139 | if 'mamba' in name.lower(): 140 | continue 141 | 142 | if 'embed' in name.lower(): 143 | continue 144 | 145 | if isinstance(module, nn.Linear): 146 | torch.nn.init.xavier_uniform_(module.weight) 147 | if module.bias is not None: 148 | nn.init.constant_(module.bias, 0) 149 | 150 | for name, module in model.named_modules(): 151 | if 'pos_embed.proj' in name.lower() or "proj_in" in name.lower(): 152 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 153 | w = module.weight.data 154 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 155 | nn.init.constant_(module.bias, 0) 156 | 157 | init_embedders(model) 158 | init_adaLN_modulation_layers(model) 159 | 160 | for name, module in model.named_modules(): 161 | if 'out' in name.lower(): 162 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 163 | nn.init.constant_(module.weight, 0) 164 | if module.bias is not None: 165 | nn.init.constant_(module.bias, 0) 166 | 167 | def isinstance_str(x: object, cls_name: str): 168 | """ 169 | Checks whether x has any class *named* cls_name in its ancestry. 170 | Doesn't require access to the class's implementation. 171 | 172 | Useful for patching! 173 | """ 174 | 175 | for _cls in x.__class__.__mro__: 176 | if _cls.__name__ == cls_name: 177 | return True 178 | 179 | return False -------------------------------------------------------------------------------- /mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.2.0.post1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/models/__init__.py -------------------------------------------------------------------------------- /mamba_ssm/models/config_mamba.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MambaConfig: 6 | 7 | d_model: int = 2560 8 | n_layer: int = 64 9 | vocab_size: int = 50277 10 | ssm_cfg: dict = field(default_factory=dict) 11 | rms_norm: bool = True 12 | residual_in_fp32: bool = True 13 | fused_add_norm: bool = True 14 | pad_vocab_size_multiple: int = 8 15 | tie_embeddings: bool = True 16 | -------------------------------------------------------------------------------- /mamba_ssm/models/mixer_seq_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | 3 | import math 4 | from functools import partial 5 | import json 6 | import os 7 | 8 | from collections import namedtuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from mamba_ssm.models.config_mamba import MambaConfig 14 | from mamba_ssm.modules.mamba_simple import Mamba, Block 15 | from mamba_ssm.utils.generation import GenerationMixin 16 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 17 | 18 | try: 19 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 20 | except ImportError: 21 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 22 | 23 | 24 | def create_block( 25 | d_model, 26 | ssm_cfg=None, 27 | norm_epsilon=1e-5, 28 | rms_norm=False, 29 | residual_in_fp32=False, 30 | fused_add_norm=False, 31 | layer_idx=None, 32 | device=None, 33 | dtype=None, 34 | ): 35 | if ssm_cfg is None: 36 | ssm_cfg = {} 37 | factory_kwargs = {"device": device, "dtype": dtype} 38 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 39 | norm_cls = partial( 40 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 41 | ) 42 | block = Block( 43 | d_model, 44 | mixer_cls, 45 | norm_cls=norm_cls, 46 | fused_add_norm=fused_add_norm, 47 | residual_in_fp32=residual_in_fp32, 48 | ) 49 | block.layer_idx = layer_idx 50 | return block 51 | 52 | 53 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 54 | def _init_weights( 55 | module, 56 | n_layer, 57 | initializer_range=0.02, # Now only used for embedding layer. 58 | rescale_prenorm_residual=True, 59 | n_residuals_per_layer=1, # Change to 2 if we have MLP 60 | ): 61 | if isinstance(module, nn.Linear): 62 | if module.bias is not None: 63 | if not getattr(module.bias, "_no_reinit", False): 64 | nn.init.zeros_(module.bias) 65 | elif isinstance(module, nn.Embedding): 66 | nn.init.normal_(module.weight, std=initializer_range) 67 | 68 | if rescale_prenorm_residual: 69 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 70 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 71 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 72 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 73 | # 74 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 75 | for name, p in module.named_parameters(): 76 | if name in ["out_proj.weight", "fc2.weight"]: 77 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 78 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 79 | # We need to reinit p since this code could be called multiple times 80 | # Having just p *= scale would repeatedly scale it down 81 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 82 | with torch.no_grad(): 83 | p /= math.sqrt(n_residuals_per_layer * n_layer) 84 | 85 | 86 | class MixerModel(nn.Module): 87 | def __init__( 88 | self, 89 | d_model: int, 90 | n_layer: int, 91 | vocab_size: int, 92 | ssm_cfg=None, 93 | norm_epsilon: float = 1e-5, 94 | rms_norm: bool = False, 95 | initializer_cfg=None, 96 | fused_add_norm=False, 97 | residual_in_fp32=False, 98 | device=None, 99 | dtype=None, 100 | ) -> None: 101 | factory_kwargs = {"device": device, "dtype": dtype} 102 | super().__init__() 103 | self.residual_in_fp32 = residual_in_fp32 104 | 105 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) 106 | 107 | # We change the order of residual and layer norm: 108 | # Instead of LN -> Attn / MLP -> Add, we do: 109 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 110 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 111 | # This is for performance reason: we can fuse add + layer_norm. 112 | self.fused_add_norm = fused_add_norm 113 | if self.fused_add_norm: 114 | if layer_norm_fn is None or rms_norm_fn is None: 115 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 116 | 117 | self.layers = nn.ModuleList( 118 | [ 119 | create_block( 120 | d_model, 121 | ssm_cfg=ssm_cfg, 122 | norm_epsilon=norm_epsilon, 123 | rms_norm=rms_norm, 124 | residual_in_fp32=residual_in_fp32, 125 | fused_add_norm=fused_add_norm, 126 | layer_idx=i, 127 | **factory_kwargs, 128 | ) 129 | for i in range(n_layer) 130 | ] 131 | ) 132 | 133 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 134 | d_model, eps=norm_epsilon, **factory_kwargs 135 | ) 136 | 137 | self.apply( 138 | partial( 139 | _init_weights, 140 | n_layer=n_layer, 141 | **(initializer_cfg if initializer_cfg is not None else {}), 142 | ) 143 | ) 144 | 145 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 146 | return { 147 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 148 | for i, layer in enumerate(self.layers) 149 | } 150 | 151 | def forward(self, input_ids, inference_params=None): 152 | hidden_states = self.embedding(input_ids) 153 | residual = None 154 | for layer in self.layers: 155 | hidden_states, residual = layer( 156 | hidden_states, residual, inference_params=inference_params 157 | ) 158 | if not self.fused_add_norm: 159 | residual = (hidden_states + residual) if residual is not None else hidden_states 160 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 161 | else: 162 | # Set prenorm=False here since we don't need the residual 163 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 164 | hidden_states = fused_add_norm_fn( 165 | hidden_states, 166 | self.norm_f.weight, 167 | self.norm_f.bias, 168 | eps=self.norm_f.eps, 169 | residual=residual, 170 | prenorm=False, 171 | residual_in_fp32=self.residual_in_fp32, 172 | ) 173 | return hidden_states 174 | 175 | 176 | class MambaLMHeadModel(nn.Module, GenerationMixin): 177 | 178 | def __init__( 179 | self, 180 | config: MambaConfig, 181 | initializer_cfg=None, 182 | device=None, 183 | dtype=None, 184 | ) -> None: 185 | self.config = config 186 | d_model = config.d_model 187 | n_layer = config.n_layer 188 | vocab_size = config.vocab_size 189 | ssm_cfg = config.ssm_cfg 190 | rms_norm = config.rms_norm 191 | residual_in_fp32 = config.residual_in_fp32 192 | fused_add_norm = config.fused_add_norm 193 | pad_vocab_size_multiple = config.pad_vocab_size_multiple 194 | factory_kwargs = {"device": device, "dtype": dtype} 195 | 196 | super().__init__() 197 | if vocab_size % pad_vocab_size_multiple != 0: 198 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) 199 | self.backbone = MixerModel( 200 | d_model=d_model, 201 | n_layer=n_layer, 202 | vocab_size=vocab_size, 203 | ssm_cfg=ssm_cfg, 204 | rms_norm=rms_norm, 205 | initializer_cfg=initializer_cfg, 206 | fused_add_norm=fused_add_norm, 207 | residual_in_fp32=residual_in_fp32, 208 | **factory_kwargs, 209 | ) 210 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) 211 | 212 | # Initialize weights and apply final processing 213 | self.apply( 214 | partial( 215 | _init_weights, 216 | n_layer=n_layer, 217 | **(initializer_cfg if initializer_cfg is not None else {}), 218 | ) 219 | ) 220 | self.tie_weights() 221 | 222 | def tie_weights(self): 223 | if self.config.tie_embeddings: 224 | self.lm_head.weight = self.backbone.embedding.weight 225 | 226 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 227 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 228 | 229 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): 230 | """ 231 | "position_ids" is just to be compatible with Transformer generation. We don't use it. 232 | num_last_tokens: if > 0, only return the logits for the last n tokens 233 | """ 234 | hidden_states = self.backbone(input_ids, inference_params=inference_params) 235 | if num_last_tokens > 0: 236 | hidden_states = hidden_states[:, -num_last_tokens:] 237 | lm_logits = self.lm_head(hidden_states) 238 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) 239 | return CausalLMOutput(logits=lm_logits) 240 | 241 | @classmethod 242 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): 243 | config_data = load_config_hf(pretrained_model_name) 244 | config = MambaConfig(**config_data) 245 | model = cls(config, device=device, dtype=dtype, **kwargs) 246 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) 247 | return model 248 | 249 | def save_pretrained(self, save_directory): 250 | """ 251 | Minimal implementation of save_pretrained for MambaLMHeadModel. 252 | Save the model and its configuration file to a directory. 253 | """ 254 | # Ensure save_directory exists 255 | os.makedirs(save_directory, exist_ok=True) 256 | 257 | # Save the model's state_dict 258 | model_path = os.path.join(save_directory, 'pytorch_model.bin') 259 | torch.save(self.state_dict(), model_path) 260 | 261 | # Save the configuration of the model 262 | config_path = os.path.join(save_directory, 'config.json') 263 | with open(config_path, 'w') as f: 264 | json.dump(self.config.__dict__, f) 265 | -------------------------------------------------------------------------------- /mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/modules/__init__.py -------------------------------------------------------------------------------- /mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/ops/__init__.py -------------------------------------------------------------------------------- /mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/ops/triton/__init__.py -------------------------------------------------------------------------------- /mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/utils/__init__.py -------------------------------------------------------------------------------- /mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /sample_t2i_discrete.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | from torch import multiprocessing as mp 4 | import accelerate 5 | import utils 6 | from datasets import get_dataset 7 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver 8 | from absl import logging 9 | import builtins 10 | import einops 11 | import libs.autoencoder 12 | import libs.clip 13 | from torchvision.utils import save_image 14 | 15 | 16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): 17 | _betas = ( 18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 19 | ) 20 | return _betas.numpy() 21 | 22 | 23 | def evaluate(config): 24 | if config.get('benchmark', False): 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = False 27 | 28 | mp.set_start_method('spawn') 29 | accelerator = accelerate.Accelerator() 30 | device = accelerator.device 31 | accelerate.utils.set_seed(config.seed, device_specific=True) 32 | logging.info(f'Process {accelerator.process_index} using device: {device}') 33 | 34 | config.mixed_precision = accelerator.mixed_precision 35 | config = ml_collections.FrozenConfigDict(config) 36 | if accelerator.is_main_process: 37 | utils.set_logger(log_level='info') 38 | else: 39 | utils.set_logger(log_level='error') 40 | builtins.print = lambda *args: None 41 | 42 | dataset = get_dataset(**config.dataset) 43 | 44 | with open(config.input_path, 'r') as f: 45 | prompts = f.read().strip().split('\n') 46 | 47 | print(prompts) 48 | 49 | clip = libs.clip.FrozenCLIPEmbedder() 50 | clip.eval() 51 | clip.to(device) 52 | 53 | contexts = clip.encode(prompts) 54 | 55 | nnet = utils.get_nnet(**config.nnet) 56 | nnet = accelerator.prepare(nnet) 57 | logging.info(f'load nnet from {config.nnet_path}') 58 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu')) 59 | nnet.eval() 60 | 61 | def cfg_nnet(x, timesteps, context): 62 | _cond = nnet(x, timesteps, context=context) 63 | if config.sample.scale == 0: 64 | return _cond 65 | _empty_context = torch.tensor(dataset.empty_context, device=device) 66 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) 67 | _uncond = nnet(x, timesteps, context=_empty_context) 68 | return _cond + config.sample.scale * (_cond - _uncond) 69 | 70 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path) 71 | autoencoder.to(device) 72 | 73 | @torch.cuda.amp.autocast() 74 | def encode(_batch): 75 | return autoencoder.encode(_batch) 76 | 77 | @torch.cuda.amp.autocast() 78 | def decode(_batch): 79 | return autoencoder.decode(_batch) 80 | 81 | _betas = stable_diffusion_beta_schedule() 82 | N = len(_betas) 83 | 84 | logging.info(config.sample) 85 | logging.info(f'mixed_precision={config.mixed_precision}') 86 | logging.info(f'N={N}') 87 | 88 | z_init = torch.randn(contexts.size(0), *config.z_shape, device=device) 89 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) 90 | 91 | def model_fn(x, t_continuous): 92 | t = t_continuous * N 93 | return cfg_nnet(x, t, context=contexts) 94 | 95 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) 96 | z = dpm_solver.sample(z_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) 97 | samples = dataset.unpreprocess(decode(z)) 98 | 99 | os.makedirs(config.output_path, exist_ok=True) 100 | for sample, prompt in zip(samples, prompts): 101 | save_image(sample, os.path.join(config.output_path, f"{prompt}.png")) 102 | 103 | 104 | 105 | from absl import flags 106 | from absl import app 107 | from ml_collections import config_flags 108 | import os 109 | 110 | 111 | FLAGS = flags.FLAGS 112 | config_flags.DEFINE_config_file( 113 | "config", None, "Training configuration.", lock_config=False) 114 | flags.mark_flags_as_required(["config"]) 115 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.") 116 | flags.DEFINE_string("output_path", None, "The path to output images.") 117 | flags.DEFINE_string("input_path", None, "The path to input texts.") 118 | 119 | 120 | def main(argv): 121 | config = FLAGS.config 122 | config.nnet_path = FLAGS.nnet_path 123 | config.output_path = FLAGS.output_path 124 | config.input_path = FLAGS.input_path 125 | evaluate(config) 126 | 127 | 128 | if __name__ == "__main__": 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /scripts/extract_empty_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | '', 14 | ] 15 | 16 | device = 'cuda' 17 | clip = libs.clip.FrozenCLIPEmbedder() 18 | clip.eval() 19 | clip.to(device) 20 | 21 | save_dir = f'assets/datasets/coco256_features' 22 | latent = clip.encode(prompts) 23 | print(latent.shape) 24 | c = latent[0].detach().cpu().numpy() 25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /scripts/extract_imagenet_feature.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | from uvit_datasets import ImageNet 5 | from torch.utils.data import DataLoader 6 | from libs.autoencoder import get_model 7 | import argparse 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | 12 | 13 | def main(resolution=256): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('path') 16 | args = parser.parse_args() 17 | 18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False) 19 | train_dataset = dataset.get_split(split='train', labeled=True) 20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False, 21 | num_workers=8, pin_memory=True, persistent_workers=True) 22 | 23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth') 24 | model = nn.DataParallel(model) 25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 26 | model.to(device) 27 | 28 | # features = [] 29 | # labels = [] 30 | 31 | idx = 0 32 | for batch in tqdm(train_dataset_loader): 33 | img, label = batch 34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 35 | img = img.to(device) 36 | moments = model(img, fn='encode_moments') 37 | moments = moments.detach().cpu().numpy() 38 | 39 | label = torch.cat([label, label], dim=0) 40 | label = label.detach().cpu().numpy() 41 | 42 | for moment, lb in zip(moments, label): 43 | # np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb)) 44 | np.savez(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', z=moment, label=lb) 45 | idx += 1 46 | 47 | print(f'save {idx} files') 48 | 49 | # features = np.concatenate(features, axis=0) 50 | # labels = np.concatenate(labels, axis=0) 51 | # print(f'features.shape={features.shape}') 52 | # print(f'labels.shape={labels.shape}') 53 | # np.save(f'imagenet{resolution}_features.npy', features) 54 | # np.save(f'imagenet{resolution}_labels.npy', labels) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/extract_mscoco_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(resolution=256): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--split', default='train') 14 | args = parser.parse_args() 15 | print(args) 16 | 17 | 18 | if args.split == "train": 19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014', 20 | annFile='assets/datasets/coco/annotations/captions_train2014.json', 21 | size=resolution) 22 | save_dir = f'assets/datasets/coco{resolution}_features/train' 23 | elif args.split == "val": 24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014', 25 | annFile='assets/datasets/coco/annotations/captions_val2014.json', 26 | size=resolution) 27 | save_dir = f'assets/datasets/coco{resolution}_features/val' 28 | else: 29 | raise NotImplementedError("ERROR!") 30 | 31 | device = "cuda" 32 | os.makedirs(save_dir) 33 | 34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth') 35 | autoencoder.to(device) 36 | clip = libs.clip.FrozenCLIPEmbedder() 37 | clip.eval() 38 | clip.to(device) 39 | 40 | with torch.no_grad(): 41 | for idx, data in tqdm(enumerate(datas)): 42 | x, captions = data 43 | 44 | if len(x.shape) == 3: 45 | x = x[None, ...] 46 | x = torch.tensor(x, device=device) 47 | moments = autoencoder(x, fn='encode_moments').squeeze(0) 48 | moments = moments.detach().cpu().numpy() 49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments) 50 | 51 | latent = clip.encode(captions) 52 | for i in range(len(latent)): 53 | c = latent[i].detach().cpu().numpy() 54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/extract_test_prompt_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import libs.autoencoder 5 | import libs.clip 6 | from datasets import MSCOCODatabase 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | 11 | def main(): 12 | prompts = [ 13 | 'A green train is coming down the tracks.', 14 | 'A group of skiers are preparing to ski down a mountain.', 15 | 'A small kitchen with a low ceiling.', 16 | 'A group of elephants walking in muddy water.', 17 | 'A living area with a television and a table.', 18 | 'A road with traffic lights, street lights and cars.', 19 | 'A bus driving in a city area with traffic signs.', 20 | 'A bus pulls over to the curb close to an intersection.', 21 | 'A group of people are walking and one is holding an umbrella.', 22 | 'A baseball player taking a swing at an incoming ball.', 23 | 'A city street line with brick buildings and trees.', 24 | 'A close up of a plate of broccoli and sauce.', 25 | ] 26 | 27 | device = 'cuda' 28 | clip = libs.clip.FrozenCLIPEmbedder() 29 | clip.eval() 30 | clip.to(device) 31 | 32 | save_dir = f'assets/datasets/coco256_features/run_vis' 33 | latent = clip.encode(prompts) 34 | for i in range(len(latent)): 35 | c = latent[i].detach().cpu().numpy() 36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | import shutil 11 | 12 | from setuptools import setup, find_packages 13 | import subprocess 14 | 15 | import urllib.request 16 | import urllib.error 17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 18 | 19 | import torch 20 | from torch.utils.cpp_extension import ( 21 | BuildExtension, 22 | CppExtension, 23 | CUDAExtension, 24 | CUDA_HOME, 25 | ) 26 | 27 | 28 | with open("README.md", "r", encoding="utf-8") as fh: 29 | long_description = fh.read() 30 | 31 | 32 | # ninja build does not work unless include_dirs are abs path 33 | this_dir = os.path.dirname(os.path.abspath(__file__)) 34 | 35 | PACKAGE_NAME = "mamba_ssm" 36 | 37 | BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" 38 | 39 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 40 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 41 | FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" 42 | SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 43 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 44 | FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" 45 | 46 | 47 | def get_platform(): 48 | """ 49 | Returns the platform name as used in wheel filenames. 50 | """ 51 | if sys.platform.startswith("linux"): 52 | return "linux_x86_64" 53 | elif sys.platform == "darwin": 54 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 55 | return f"macosx_{mac_version}_x86_64" 56 | elif sys.platform == "win32": 57 | return "win_amd64" 58 | else: 59 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 60 | 61 | 62 | def get_cuda_bare_metal_version(cuda_dir): 63 | raw_output = subprocess.check_output( 64 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 65 | ) 66 | output = raw_output.split() 67 | release_idx = output.index("release") + 1 68 | bare_metal_version = parse(output[release_idx].split(",")[0]) 69 | 70 | return raw_output, bare_metal_version 71 | 72 | 73 | def check_if_cuda_home_none(global_option: str) -> None: 74 | if CUDA_HOME is not None: 75 | return 76 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 77 | # in that case. 78 | warnings.warn( 79 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 80 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 81 | "only images whose names contain 'devel' will provide nvcc." 82 | ) 83 | 84 | 85 | def append_nvcc_threads(nvcc_extra_args): 86 | return nvcc_extra_args + ["--threads", "4"] 87 | 88 | 89 | cmdclass = {} 90 | ext_modules = [] 91 | 92 | if not SKIP_CUDA_BUILD: 93 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 94 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 95 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 96 | 97 | check_if_cuda_home_none(PACKAGE_NAME) 98 | # Check, if CUDA11 is installed for compute capability 8.0 99 | cc_flag = [] 100 | if CUDA_HOME is not None: 101 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 102 | if bare_metal_version < Version("11.6"): 103 | raise RuntimeError( 104 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " 105 | "Note: make sure nvcc has a supported version by running nvcc -V." 106 | ) 107 | 108 | cc_flag.append("-gencode") 109 | cc_flag.append("arch=compute_53,code=sm_53") 110 | cc_flag.append("-gencode") 111 | cc_flag.append("arch=compute_62,code=sm_62") 112 | cc_flag.append("-gencode") 113 | cc_flag.append("arch=compute_70,code=sm_70") 114 | cc_flag.append("-gencode") 115 | cc_flag.append("arch=compute_72,code=sm_72") 116 | cc_flag.append("-gencode") 117 | cc_flag.append("arch=compute_80,code=sm_80") 118 | cc_flag.append("-gencode") 119 | cc_flag.append("arch=compute_87,code=sm_87") 120 | if bare_metal_version >= Version("11.8"): 121 | cc_flag.append("-gencode") 122 | cc_flag.append("arch=compute_90,code=sm_90") 123 | 124 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 125 | # torch._C._GLIBCXX_USE_CXX11_ABI 126 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 127 | if FORCE_CXX11_ABI: 128 | torch._C._GLIBCXX_USE_CXX11_ABI = True 129 | 130 | ext_modules.append( 131 | CUDAExtension( 132 | name="selective_scan_cuda", 133 | sources=[ 134 | "csrc/selective_scan/selective_scan.cpp", 135 | "csrc/selective_scan/selective_scan_fwd_fp32.cu", 136 | "csrc/selective_scan/selective_scan_fwd_fp16.cu", 137 | "csrc/selective_scan/selective_scan_fwd_bf16.cu", 138 | "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", 139 | "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", 140 | "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", 141 | "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", 142 | "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", 143 | "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", 144 | ], 145 | extra_compile_args={ 146 | "cxx": ["-O3", "-std=c++17"], 147 | "nvcc": append_nvcc_threads( 148 | [ 149 | "-O3", 150 | "-std=c++17", 151 | "-U__CUDA_NO_HALF_OPERATORS__", 152 | "-U__CUDA_NO_HALF_CONVERSIONS__", 153 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 154 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 155 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 156 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 157 | "--expt-relaxed-constexpr", 158 | "--expt-extended-lambda", 159 | "--use_fast_math", 160 | "--ptxas-options=-v", 161 | "-lineinfo", 162 | ] 163 | + cc_flag 164 | ), 165 | }, 166 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], 167 | ) 168 | ) 169 | 170 | 171 | def get_package_version(): 172 | with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: 173 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 174 | public_version = ast.literal_eval(version_match.group(1)) 175 | local_version = os.environ.get("MAMBA_LOCAL_VERSION") 176 | if local_version: 177 | return f"{public_version}+{local_version}" 178 | else: 179 | return str(public_version) 180 | 181 | 182 | def get_wheel_url(): 183 | # Determine the version numbers that will be used to determine the correct wheel 184 | # We're using the CUDA version used to build torch, not the one currently installed 185 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 186 | torch_cuda_version = parse(torch.version.cuda) 187 | torch_version_raw = parse(torch.__version__) 188 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 189 | # to save CI time. Minor versions should be compatible. 190 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 191 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 192 | platform_name = get_platform() 193 | mamba_ssm_version = get_package_version() 194 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 195 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 196 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 197 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 198 | 199 | # Determine wheel URL based on CUDA version, torch version, python version and OS 200 | wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 201 | wheel_url = BASE_WHEEL_URL.format( 202 | tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename 203 | ) 204 | return wheel_url, wheel_filename 205 | 206 | 207 | class CachedWheelsCommand(_bdist_wheel): 208 | """ 209 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 210 | find an existing wheel (which is currently the case for all installs). We use 211 | the environment parameters to detect whether there is already a pre-built version of a compatible 212 | wheel available and short-circuits the standard full build pipeline. 213 | """ 214 | 215 | def run(self): 216 | if FORCE_BUILD: 217 | return super().run() 218 | 219 | wheel_url, wheel_filename = get_wheel_url() 220 | print("Guessing wheel URL: ", wheel_url) 221 | try: 222 | urllib.request.urlretrieve(wheel_url, wheel_filename) 223 | 224 | # Make the archive 225 | # Lifted from the root wheel processing command 226 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 227 | if not os.path.exists(self.dist_dir): 228 | os.makedirs(self.dist_dir) 229 | 230 | impl_tag, abi_tag, plat_tag = self.get_tag() 231 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 232 | 233 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 234 | print("Raw wheel path", wheel_path) 235 | shutil.move(wheel_filename, wheel_path) 236 | except urllib.error.HTTPError: 237 | print("Precompiled wheel not found. Building from source...") 238 | # If the wheel could not be downloaded, build from source 239 | super().run() 240 | 241 | 242 | setup( 243 | name=PACKAGE_NAME, 244 | version=get_package_version(), 245 | packages=find_packages( 246 | exclude=( 247 | "build", 248 | "csrc", 249 | "include", 250 | "tests", 251 | "dist", 252 | "docs", 253 | "benchmarks", 254 | "mamba_ssm.egg-info", 255 | ) 256 | ), 257 | author="Tri Dao, Albert Gu", 258 | author_email="tri@tridao.me, agu@cs.cmu.edu", 259 | description="Mamba state-space model", 260 | long_description=long_description, 261 | long_description_content_type="text/markdown", 262 | url="https://github.com/state-spaces/mamba", 263 | classifiers=[ 264 | "Programming Language :: Python :: 3", 265 | "License :: OSI Approved :: BSD License", 266 | "Operating System :: Unix", 267 | ], 268 | ext_modules=ext_modules, 269 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 270 | if ext_modules 271 | else { 272 | "bdist_wheel": CachedWheelsCommand, 273 | }, 274 | python_requires=">=3.7", 275 | install_requires=[ 276 | "torch", 277 | "packaging", 278 | "ninja", 279 | "einops", 280 | "triton", 281 | "transformers", 282 | # "causal_conv1d>=1.2.0", 283 | ], 284 | ) 285 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/tools/__init__.py -------------------------------------------------------------------------------- /tools/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms as TF 40 | from PIL import Image 41 | from scipy import linalg 42 | from torch.nn.functional import adaptive_avg_pool2d 43 | 44 | try: 45 | from tqdm import tqdm 46 | except ImportError: 47 | # If tqdm is not available, provide a mock version of it 48 | def tqdm(x): 49 | return x 50 | 51 | from .inception import InceptionV3 52 | 53 | 54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 55 | 'tif', 'tiff', 'webp'} 56 | 57 | 58 | class ImagePathDataset(torch.utils.data.Dataset): 59 | def __init__(self, files, transforms=None): 60 | self.files = files 61 | self.transforms = transforms 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | def __getitem__(self, i): 67 | path = self.files[i] 68 | img = Image.open(path).convert('RGB') 69 | if self.transforms is not None: 70 | img = self.transforms(img) 71 | return img 72 | 73 | 74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 75 | """Calculates the activations of the pool_3 layer for all images. 76 | 77 | Params: 78 | -- files : List of image files paths 79 | -- model : Instance of inception model 80 | -- batch_size : Batch size of images for the model to process at once. 81 | Make sure that the number of samples is a multiple of 82 | the batch size, otherwise some samples are ignored. This 83 | behavior is retained to match the original FID score 84 | implementation. 85 | -- dims : Dimensionality of features returned by Inception 86 | -- device : Device to run calculations 87 | -- num_workers : Number of parallel dataloader workers 88 | 89 | Returns: 90 | -- A numpy array of dimension (num images, dims) that contains the 91 | activations of the given tensor when feeding inception with the 92 | query tensor. 93 | """ 94 | model.eval() 95 | 96 | if batch_size > len(files): 97 | print(('Warning: batch size is bigger than the data size. ' 98 | 'Setting batch size to data size')) 99 | batch_size = len(files) 100 | 101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 102 | dataloader = torch.utils.data.DataLoader(dataset, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | drop_last=False, 106 | num_workers=num_workers) 107 | 108 | pred_arr = np.empty((len(files), dims)) 109 | 110 | start_idx = 0 111 | 112 | for batch in tqdm(dataloader): 113 | batch = batch.to(device) 114 | 115 | with torch.no_grad(): 116 | pred = model(batch)[0] 117 | 118 | # If model output is not scalar, apply global spatial average pooling. 119 | # This happens if you choose a dimensionality not equal 2048. 120 | if pred.size(2) != 1 or pred.size(3) != 1: 121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 122 | 123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 124 | 125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 126 | 127 | start_idx = start_idx + pred.shape[0] 128 | 129 | return pred_arr 130 | 131 | 132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 133 | """Numpy implementation of the Frechet Distance. 134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 135 | and X_2 ~ N(mu_2, C_2) is 136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 137 | 138 | Stable version by Dougal J. Sutherland. 139 | 140 | Params: 141 | -- mu1 : Numpy array containing the activations of a layer of the 142 | inception net (like returned by the function 'get_predictions') 143 | for generated samples. 144 | -- mu2 : The sample mean over activations, precalculated on an 145 | representative data set. 146 | -- sigma1: The covariance matrix over activations for generated samples. 147 | -- sigma2: The covariance matrix over activations, precalculated on an 148 | representative data set. 149 | 150 | Returns: 151 | -- : The Frechet Distance. 152 | """ 153 | 154 | mu1 = np.atleast_1d(mu1) 155 | mu2 = np.atleast_1d(mu2) 156 | 157 | sigma1 = np.atleast_2d(sigma1) 158 | sigma2 = np.atleast_2d(sigma2) 159 | 160 | assert mu1.shape == mu2.shape, \ 161 | 'Training and test mean vectors have different lengths' 162 | assert sigma1.shape == sigma2.shape, \ 163 | 'Training and test covariances have different dimensions' 164 | 165 | diff = mu1 - mu2 166 | 167 | # Product might be almost singular 168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 169 | if not np.isfinite(covmean).all(): 170 | msg = ('fid calculation produces singular product; ' 171 | 'adding %s to diagonal of cov estimates') % eps 172 | print(msg) 173 | offset = np.eye(sigma1.shape[0]) * eps 174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 175 | 176 | # Numerical error might give slight imaginary component 177 | if np.iscomplexobj(covmean): 178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 179 | m = np.max(np.abs(covmean.imag)) 180 | raise ValueError('Imaginary component {}'.format(m)) 181 | covmean = covmean.real 182 | 183 | tr_covmean = np.trace(covmean) 184 | 185 | return (diff.dot(diff) + np.trace(sigma1) 186 | + np.trace(sigma2) - 2 * tr_covmean) 187 | 188 | 189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 190 | device='cpu', num_workers=8): 191 | """Calculation of the statistics used by the FID. 192 | Params: 193 | -- files : List of image files paths 194 | -- model : Instance of inception model 195 | -- batch_size : The images numpy array is split into batches with 196 | batch size batch_size. A reasonable batch size 197 | depends on the hardware. 198 | -- dims : Dimensionality of features returned by Inception 199 | -- device : Device to run calculations 200 | -- num_workers : Number of parallel dataloader workers 201 | 202 | Returns: 203 | -- mu : The mean over samples of the activations of the pool_3 layer of 204 | the inception model. 205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 206 | the inception model. 207 | """ 208 | act = get_activations(files, model, batch_size, dims, device, num_workers) 209 | mu = np.mean(act, axis=0) 210 | sigma = np.cov(act, rowvar=False) 211 | return mu, sigma 212 | 213 | 214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 215 | if path.endswith('.npz'): 216 | with np.load(path) as f: 217 | m, s = f['mu'][:], f['sigma'][:] 218 | else: 219 | path = pathlib.Path(path) 220 | files = sorted([file for ext in IMAGE_EXTENSIONS 221 | for file in path.glob('*.{}'.format(ext))]) 222 | m, s = calculate_activation_statistics(files, model, batch_size, 223 | dims, device, num_workers) 224 | 225 | return m, s 226 | 227 | 228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8): 229 | if device is None: 230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 231 | else: 232 | device = torch.device(device) 233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 234 | model = InceptionV3([block_idx]).to(device) 235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers) 236 | np.savez(out_path, mu=m1, sigma=s1) 237 | 238 | 239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8): 240 | """Calculates the FID of two paths""" 241 | if device is None: 242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 243 | else: 244 | device = torch.device(device) 245 | 246 | for p in paths: 247 | if not os.path.exists(p): 248 | raise RuntimeError('Invalid path: %s' % p) 249 | 250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 251 | 252 | model = InceptionV3([block_idx]).to(device) 253 | 254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 255 | dims, device, num_workers) 256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 257 | dims, device, num_workers) 258 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 259 | 260 | return fid_value 261 | -------------------------------------------------------------------------------- /tools/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sde 2 | import ml_collections 3 | import torch 4 | from torch import multiprocessing as mp 5 | from uvit_datasets import get_dataset 6 | from torchvision.utils import make_grid, save_image 7 | import utils 8 | import einops 9 | from torch.utils._pytree import tree_map 10 | import accelerate 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 14 | import tempfile 15 | from tools.fid_score import calculate_fid_given_paths 16 | from absl import logging 17 | import builtins 18 | import os 19 | import wandb 20 | 21 | import time 22 | 23 | from accelerate import DistributedDataParallelKwargs 24 | 25 | from sde import get_sde 26 | from mamba_attn_diff.utils.backup_code import backup_code 27 | 28 | def train(config): 29 | if config.get('benchmark', False): 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.deterministic = False 32 | 33 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 34 | mp.set_start_method('spawn') 35 | accelerator = accelerate.Accelerator() # kwargs_handlers=[ddp_kwargs] 36 | device = accelerator.device 37 | accelerate.utils.set_seed(config.seed, device_specific=True) 38 | logging.info(f'Process {accelerator.process_index} using device: {device}') 39 | 40 | config.mixed_precision = accelerator.mixed_precision 41 | config = ml_collections.FrozenConfigDict(config) 42 | 43 | assert config.train.batch_size % accelerator.num_processes == 0 44 | mini_batch_size = config.train.batch_size // accelerator.num_processes 45 | 46 | if accelerator.is_main_process: 47 | os.makedirs(config.ckpt_root, exist_ok=True) 48 | os.makedirs(config.sample_dir, exist_ok=True) 49 | accelerator.wait_for_everyone() 50 | if accelerator.is_main_process: 51 | workdir = config.workdir 52 | 53 | backup_code(os.path.abspath(workdir)) 54 | 55 | wandb.init(dir=os.path.abspath(workdir), 56 | project=f'uvit_{config.dataset.name}', config=config.to_dict(), 57 | name=config.hparams, job_type='train', mode='offline') 58 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) 59 | logging.info(config) 60 | else: 61 | pass 62 | 63 | dataset = get_dataset(**config.dataset) 64 | assert os.path.exists(dataset.fid_stat) 65 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond') 66 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, 67 | num_workers=8, pin_memory=True, persistent_workers=True) 68 | 69 | train_state = utils.initialize_train_state(config, device) 70 | if hasattr(train_state.nnet, 'enable_gradient_checkpointing') and config.get('gradient_checkpointing', False): 71 | train_state.nnet.enable_gradient_checkpointing() 72 | 73 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare( 74 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader) 75 | lr_scheduler = train_state.lr_scheduler 76 | train_state.resume(config.ckpt_root) 77 | 78 | 79 | def get_data_generator(): 80 | while True: 81 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): 82 | yield data 83 | 84 | data_generator = get_data_generator() 85 | 86 | 87 | # set the score_model to train 88 | scheduler_config = dict(name='vpsde') if not config.get('scheduler', False) else config.scheduler 89 | scheduler = get_sde(device=device, **scheduler_config) # sde.VPSDE() 90 | scheduler_ema = get_sde(device=device, **scheduler_config) 91 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=scheduler) 92 | score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=scheduler_ema) 93 | 94 | 95 | def train_step(_batch, iter_rate): 96 | _metrics = dict() 97 | optimizer.zero_grad() 98 | if config.train.mode == 'uncond': 99 | loss = sde.LSimple(score_model, _batch, pred=config.pred, iter_rate=iter_rate) 100 | elif config.train.mode == 'cond': 101 | loss = sde.LSimple(score_model, _batch[0], pred=config.pred, y=_batch[1], iter_rate=iter_rate) 102 | else: 103 | raise NotImplementedError(config.train.mode) 104 | _metrics['loss'] = accelerator.gather(loss.detach()).mean() 105 | accelerator.backward(loss.mean()) 106 | if 'grad_clip' in config and config.grad_clip > 0: 107 | accelerator.clip_grad_norm_(nnet.parameters(), max_norm=config.grad_clip) 108 | optimizer.step() 109 | lr_scheduler.step() 110 | train_state.ema_update(config.get('ema_rate', 0.9999)) 111 | train_state.step += 1 112 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) 113 | 114 | 115 | def eval_step(n_samples, sample_steps, algorithm): 116 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, ' 117 | f'mini_batch_size={config.sample.mini_batch_size}') 118 | 119 | def sample_fn(_n_samples): 120 | _x_init = torch.randn(_n_samples, *dataset.data_shape, device=device) 121 | if config.train.mode == 'uncond': 122 | kwargs = dict() 123 | elif config.train.mode == 'cond': 124 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device)) 125 | else: 126 | raise NotImplementedError 127 | 128 | if algorithm == 'euler_maruyama_sde': 129 | return sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _x_init, sample_steps, **kwargs) 130 | elif algorithm == 'euler_maruyama_ode': 131 | return sde.euler_maruyama(sde.ODE(score_model_ema), _x_init, sample_steps, **kwargs) 132 | elif algorithm == 'dpm_solver': 133 | noise_schedule = NoiseScheduleVP(schedule='linear') 134 | model_fn = model_wrapper( 135 | score_model_ema.noise_pred, 136 | noise_schedule, 137 | time_input_type='0', 138 | model_kwargs=kwargs 139 | ) 140 | dpm_solver = DPM_Solver(model_fn, noise_schedule) 141 | return dpm_solver.sample( 142 | _x_init, 143 | steps=sample_steps, 144 | eps=1e-4, 145 | adaptive_step_size=False, 146 | fast_version=True, 147 | ) 148 | elif algorithm in ['edm', 'ddim', 'ddpm']: 149 | return sde.diffusers_denoising( 150 | score_model_ema, scheduler.noise_scheduler, _x_init, 151 | config.sample.sample_steps, 152 | device=device, **kwargs) 153 | else: 154 | raise NotImplementedError 155 | 156 | with tempfile.TemporaryDirectory() as temp_path: 157 | path = config.sample.path or temp_path 158 | if accelerator.is_main_process: 159 | os.makedirs(path, exist_ok=True) 160 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess) 161 | 162 | _fid = 0 163 | if accelerator.is_main_process: 164 | _fid = calculate_fid_given_paths((dataset.fid_stat, path)) 165 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}') 166 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: 167 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f) 168 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) 169 | _fid = torch.tensor(_fid, device=device) 170 | _fid = accelerator.reduce(_fid, reduction='sum') 171 | 172 | return _fid.item() 173 | 174 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') 175 | 176 | step_fid = [] 177 | while train_state.step < config.train.n_steps: 178 | nnet.train() 179 | batch = tree_map(lambda x: x.to(device), next(data_generator)) 180 | metrics = train_step(batch, train_state.step*1. / config.train.n_steps) 181 | 182 | nnet.eval() 183 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: 184 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) 185 | logging.info(config.workdir) 186 | wandb.log(metrics, step=train_state.step) 187 | 188 | if train_state.step % config.train.eval_interval == 0: 189 | # torch.cuda.empty_cache() 190 | if accelerator.is_main_process: 191 | logging.info('Save a grid of images...') 192 | 193 | with torch.no_grad(): 194 | x_init = torch.randn(100, *dataset.data_shape, device=device) 195 | if config.train.mode == 'cond': 196 | y = einops.repeat(torch.arange(10, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10) 197 | 198 | if config.sample.algorithm in ['edm', 'ddim', 'ddpm']: 199 | if config.train.mode == 'uncond': 200 | samples = sde.diffusers_denoising( 201 | score_model_ema, scheduler.noise_scheduler, x_init=x_init, 202 | sample_steps=config.sample.sample_steps, device=device,) 203 | elif config.train.mode == 'cond': 204 | samples = sde.diffusers_denoising( 205 | score_model_ema, scheduler.noise_scheduler, x_init=x_init, 206 | sample_steps=config.sample.sample_steps, device=device, y=y, 207 | do_classifier_free_guidance=True, cfg_weight=1.5) 208 | else: 209 | if config.train.mode == 'uncond': 210 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50) 211 | elif config.train.mode == 'cond': 212 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50, y=y) 213 | else: 214 | raise NotImplementedError 215 | 216 | if accelerator.is_main_process: 217 | samples = make_grid(dataset.unpreprocess(samples), 10) 218 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png')) 219 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step) 220 | else: 221 | pass 222 | 223 | accelerator.wait_for_everyone() 224 | 225 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: 226 | logging.info(f'Save and eval checkpoint {train_state.step}...') 227 | if accelerator.local_process_index == 0: 228 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) 229 | accelerator.wait_for_everyone() 230 | fid = eval_step(n_samples=10000, sample_steps=50, 231 | algorithm=(config.sample.algorithm if config.get('scheduler', False) else 'dpm_solver') ) # calculate fid of the saved checkpoint 232 | step_fid.append((train_state.step, fid)) 233 | 234 | accelerator.wait_for_everyone() 235 | 236 | logging.info(f'Finish fitting, step={train_state.step}') 237 | logging.info(f'step_fid: {step_fid}') 238 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0] 239 | logging.info(f'step_best: {step_best}') 240 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) 241 | del metrics 242 | accelerator.wait_for_everyone() 243 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm) 244 | 245 | 246 | 247 | from absl import flags 248 | from absl import app 249 | from ml_collections import config_flags 250 | import sys 251 | from pathlib import Path 252 | 253 | 254 | FLAGS = flags.FLAGS 255 | config_flags.DEFINE_config_file( 256 | "config", None, "Training configuration.", lock_config=False) 257 | flags.mark_flags_as_required(["config"]) 258 | flags.DEFINE_string("workdir", None, "Work unit directory.") 259 | 260 | 261 | def get_config_name(): 262 | argv = sys.argv 263 | for i in range(1, len(argv)): 264 | if argv[i].startswith('--config='): 265 | return Path(argv[i].split('=')[-1]).stem 266 | 267 | 268 | def get_hparams(): 269 | argv = sys.argv 270 | lst = [] 271 | for i in range(1, len(argv)): 272 | assert '=' in argv[i] 273 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): 274 | hparam, val = argv[i].split('=') 275 | hparam = hparam.split('.')[-1] 276 | if hparam.endswith('path'): 277 | val = Path(val).stem 278 | lst.append(f'{hparam}={val}') 279 | hparams = '-'.join(lst) 280 | if hparams == '': 281 | hparams = 'default' 282 | return hparams 283 | 284 | 285 | def main(argv): 286 | config = FLAGS.config 287 | config.config_name = get_config_name() 288 | config.hparams = get_hparams() 289 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) 290 | config.ckpt_root = os.path.join(config.workdir, 'ckpts') 291 | config.sample_dir = os.path.join(config.workdir, 'samples') 292 | train(config) 293 | 294 | 295 | if __name__ == "__main__": 296 | app.run(main) 297 | --------------------------------------------------------------------------------