├── README.md
├── assets
├── selection.png
├── teaser1024.png
├── teaser256.png
└── teaser512.png
├── benchmark_generation_mamba_simple.py
├── benchmarks
└── benchmark_generation_mamba_simple.py
├── configs
├── cifar10_S_DiM.py
├── imagenet256_H_DiM.py
├── imagenet256_L_DiM.py
├── imagenet512_H_DiM_ft.py
├── imagenet512_H_DiM_upsample_3x_test.py
└── imagenet512_H_DiM_upsample_test.py
├── csrc
└── selective_scan
│ ├── reverse_scan.cuh
│ ├── selective_scan.cpp
│ ├── selective_scan.h
│ ├── selective_scan_bwd_bf16_complex.cu
│ ├── selective_scan_bwd_bf16_real.cu
│ ├── selective_scan_bwd_fp16_complex.cu
│ ├── selective_scan_bwd_fp16_real.cu
│ ├── selective_scan_bwd_fp32_complex.cu
│ ├── selective_scan_bwd_fp32_real.cu
│ ├── selective_scan_bwd_kernel.cuh
│ ├── selective_scan_common.h
│ ├── selective_scan_fwd_bf16.cu
│ ├── selective_scan_fwd_fp16.cu
│ ├── selective_scan_fwd_fp32.cu
│ ├── selective_scan_fwd_kernel.cuh
│ ├── static_switch.h
│ └── uninitialized_copy.cuh
├── dpm_solver_pp.py
├── dpm_solver_pytorch.py
├── environment.yaml
├── eval.py
├── eval_ldm.py
├── eval_ldm_discrete.py
├── eval_t2i_discrete.py
├── evals
└── lm_harness_eval.py
├── libs
├── __init__.py
├── autoencoder.py
├── clip.py
├── timm.py
├── uvit.py
└── uvit_t2i.py
├── main.pdf
├── main.png
├── mamba_attn_diff
├── models
│ ├── __init__.py
│ ├── adapter_attn4mamba.py
│ ├── attention.py
│ ├── freeu.py
│ ├── mamba_2d.py
│ ├── normalization.py
│ ├── upsample_guidance.py
│ └── vim_module.py
└── utils
│ ├── backup_code.py
│ └── init_weights.py
├── mamba_ssm
├── __init__.py
├── models
│ ├── __init__.py
│ ├── config_mamba.py
│ └── mixer_seq_simple.py
├── modules
│ ├── __init__.py
│ └── mamba_simple.py
├── ops
│ ├── __init__.py
│ ├── selective_scan_interface.py
│ └── triton
│ │ ├── __init__.py
│ │ ├── layernorm.py
│ │ └── selective_state_update.py
└── utils
│ ├── __init__.py
│ ├── generation.py
│ └── hf.py
├── sample_t2i_discrete.py
├── scripts
├── extract_empty_feature.py
├── extract_imagenet_feature.py
├── extract_mscoco_feature.py
└── extract_test_prompt_feature.py
├── sde.py
├── setup.py
├── tools
├── __init__.py
├── fid_score.py
└── inception.py
├── train.py
├── train_ldm.py
├── train_ldm_discrete.py
├── train_t2i_discrete.py
├── utils.py
└── uvit_datasets.py
/README.md:
--------------------------------------------------------------------------------
1 | # DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis
2 |
3 | The official implementation of our paper [DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis](https://arxiv.org/abs/2405.14224).
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Method Overview
12 |
13 |
14 |
15 | ## Acknowledge
16 |
17 | This code is mainly built on [U-ViT](https://github.com/baofff/U-ViT) and [Mamba](https://github.com/state-spaces/mamba).
18 |
19 | Installing Mamba may cost a lot of effort. If you encounter problems, this [issues in Mamba](https://github.com/state-spaces/mamba/issues) may be very helpful.
20 |
21 | ## Installation
22 |
23 | ```bash
24 | # create env:
25 | conda env create -f environment.yaml
26 |
27 | # if you want to update the env `mamba` with the contents in `~/mamba_attn/environment.yaml`:
28 | conda env update --name mamba --file ~/mamba_attn/environment.yaml --prune
29 |
30 | # Switch to the correct environment
31 | conda activate mamba-attn
32 | conda install chardet
33 |
34 | # Compiling Mamba. This step may take a lot of time, please be patient.
35 | # You need to successfully install causal-conv1d first.
36 | CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
37 | # If failing to compile, you can copy the files in './build/' from another server which has compiled successfully; Maybe --user is necessary.
38 |
39 | # Optional: if you have only 8 A100 to train Huge model with a batch size of 768, I recommand to install deepspeed to reduce the required GPU memory:
40 | pip install deepspeed
41 | ```
42 |
43 | **Frequently Asked Questions:**
44 |
45 | - If you encounter errors like `ModuleNotFoundError: No module named 'selective_scan_cuda'`:
46 |
47 | **Answer**: you need to correctly **install and compile** Mamba:
48 |
49 | ```bash
50 | pip install causal-conv1d==1.2.0.post2 # The version maybe different depending on your cuda version
51 | CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
52 | ```
53 |
54 | - failed Compilation:
55 |
56 | - The detected CUDA version mismatches the version that was used to **compile** PyTorch. Please make sure to use the same CUDA versions:
57 |
58 | **Answer**: you need to reinstall Pytorch with the correct version:
59 |
60 | ```bash
61 | # For example, on cuda 11.8:
62 | conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia
63 | # Then, compiling the mamba in our project again:
64 | CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
65 | ```
66 |
67 | ## Preparation Before Training and Evaluation
68 |
69 | Please follow [U-ViT](https://github.com/baofff/U-ViT), the same subtitle.
70 |
71 | ## Checkpoints
72 |
73 | | Model | FID | training iterations | batch size |
74 | | :----------------------------------------------------------: | :------: | :-----------------: | :--------: |
75 | | [ImageNet 256x256 (Huge/2)](https://drive.google.com/drive/folders/1TTEXKKhnJcEV9jeZbZYlXjiPyV87ZhE0?usp=sharing) | 2.40 | 425K | 768 |
76 | | [ImageNet 256x256 (Huge/2)](https://drive.google.com/drive/folders/1ETllUm8Dpd8-vDHefQEXEWF9whdbyhL5?usp=sharing) | **2.21** | 625K | 768 |
77 | | [ImageNet 512x512 (fine-tuned Huge/2)](https://drive.google.com/drive/folders/1lupf4_dj4tWCpycnraGrgqh4P-6yK5Xe?usp=sharing) | 3.94 | Fine-tune | 240 |
78 |
79 | **About the checkpoint files:**
80 |
81 | - **We use `nnet_ema.pth` for evaluation instead of `nnet.pth`.**
82 |
83 | - **`nnet.pth` is the trained model, while `nnet_ema.pth` is the EMA of model weights.**
84 |
85 | ## Evaluation
86 |
87 | **Use `eval_ldm_discrete.py` for evaluation and generating images with CFG**
88 |
89 | ```sh
90 | # ImageNet 256x256 Huge, 425K
91 | # If your model checkpoint path is not 'workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth', you can change the path after '--nnet_path='
92 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet256_H_DiM.py --nnet_path='workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth'
93 |
94 | # ImageNet 512x512 Huge
95 | # The generated 512x512 images for evaluation cost ~22G.
96 | # So I recommend setting a path to `config.sample.path` in the config `imagenet512_H_DiM_ft` if the space is tight for temporary files.
97 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'
98 |
99 | # ImageNet 512x512 Huge, upsample 2x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config.
100 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'
101 |
102 | # ImageNet 512x512 Huge, upsample 3x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config.
103 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_3x_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'
104 | ```
105 |
106 | ## Training
107 |
108 | ```sh
109 | # Cifar 32x32 Small
110 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 ./train.py --config=configs/cifar10_S_DiM.py
111 |
112 | # ImageNet 256x256 Large
113 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_L_DiM.py
114 |
115 | # ImageNet 256x256 Huge (Deepspeed Zero-2 for memory-efficient training)
116 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_H_DiM.py
117 |
118 | # ImageNet 512x512 Huge (Deepspeed Zero-2 for memory-efficient training)
119 | # Fine-tuning, and you need to carefully check whether
120 | # the pre-trained weights are in `workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth`.
121 | # This location is set in the config file: `config.nnet.pretrained_path`.
122 | # If there is no such ckpt, no pre-training weight will be loaded.
123 | accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py
124 | ```
125 |
126 | # Citation
127 |
128 | ```
129 | @misc{teng2024dim,
130 | title={DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis},
131 | author={Yao Teng and Yue Wu and Han Shi and Xuefei Ning and Guohao Dai and Yu Wang and Zhenguo Li and Xihui Liu},
132 | year={2024},
133 | eprint={2405.14224},
134 | archivePrefix={arXiv},
135 | primaryClass={cs.CV}
136 | }
137 | ```
138 |
139 |
--------------------------------------------------------------------------------
/assets/selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/selection.png
--------------------------------------------------------------------------------
/assets/teaser1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/teaser1024.png
--------------------------------------------------------------------------------
/assets/teaser256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/teaser256.png
--------------------------------------------------------------------------------
/assets/teaser512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/assets/teaser512.png
--------------------------------------------------------------------------------
/benchmark_generation_mamba_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao, Albert Gu.
2 |
3 | import argparse
4 | import time
5 | import json
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from einops import rearrange
11 |
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 |
14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Generation benchmarking")
18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19 | parser.add_argument("--prompt", type=str, default=None)
20 | parser.add_argument("--promptlen", type=int, default=100)
21 | parser.add_argument("--genlen", type=int, default=100)
22 | parser.add_argument("--temperature", type=float, default=1.0)
23 | parser.add_argument("--topk", type=int, default=1)
24 | parser.add_argument("--topp", type=float, default=1.0)
25 | parser.add_argument("--repetition-penalty", type=float, default=1.0)
26 | parser.add_argument("--batch", type=int, default=1)
27 | parser.add_argument("--cache_dir", type=str, default="./checkpoints")
28 | args = parser.parse_args()
29 |
30 | repeats = 3
31 | device = "cuda"
32 | dtype = torch.float16
33 |
34 | cache_dir = args.cache_dir
35 |
36 | print(f"Loading model {args.model_name}")
37 | is_mamba = args.model_name.startswith("state-spaces/mamba-")
38 | if is_mamba:
39 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=cache_dir)
40 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype, cache_dir=cache_dir)
41 | else:
42 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir)
43 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype, cache_dir=cache_dir)
44 | model.eval()
45 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
46 |
47 | torch.random.manual_seed(0)
48 | if args.prompt is None:
49 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
50 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
51 | else:
52 | tokens = tokenizer(args.prompt, return_tensors="pt")
53 | input_ids = tokens.input_ids.to(device=device)
54 | attn_mask = tokens.attention_mask.to(device=device)
55 | max_length = input_ids.shape[1] + args.genlen
56 |
57 | # for i, n in model.named_parameters():
58 | # print(i, n.shape, n.dtype)
59 | print(model)
60 | assert False
61 |
62 | if is_mamba:
63 | fn = lambda: model.generate(
64 | input_ids=input_ids,
65 | max_length=max_length,
66 | cg=True,
67 | return_dict_in_generate=True,
68 | output_scores=True,
69 | enable_timing=False,
70 | temperature=args.temperature,
71 | top_k=args.topk,
72 | top_p=args.topp,
73 | repetition_penalty=args.repetition_penalty,
74 | )
75 | else:
76 | fn = lambda: model.generate(
77 | input_ids=input_ids,
78 | attention_mask=attn_mask,
79 | max_length=max_length,
80 | return_dict_in_generate=True,
81 | pad_token_id=tokenizer.eos_token_id,
82 | do_sample=True,
83 | temperature=args.temperature,
84 | top_k=args.topk,
85 | top_p=args.topp,
86 | repetition_penalty=args.repetition_penalty,
87 | )
88 | out = fn()
89 | if args.prompt is not None:
90 | print(tokenizer.batch_decode(out.sequences.tolist()))
91 |
92 | torch.cuda.synchronize()
93 | start = time.time()
94 | for _ in range(repeats):
95 | fn()
96 | torch.cuda.synchronize()
97 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
98 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
99 |
100 | '''
101 | # 130mb
102 | MambaLMHeadModel(
103 | (backbone): MixerModel(
104 | (embedding): Embedding(50280, 768)
105 | (layers): ModuleList(
106 | (0-23): 24 x Block(
107 | (mixer): Mamba(
108 | (in_proj): Linear(in_features=768, out_features=3072, bias=False)
109 | (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
110 | (act): SiLU()
111 | (x_proj): Linear(in_features=1536, out_features=80, bias=False)
112 | (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
113 | (out_proj): Linear(in_features=1536, out_features=768, bias=False)
114 | )
115 | (norm): RMSNorm()
116 | )
117 | )
118 | (norm_f): RMSNorm()
119 | )
120 | (lm_head): Linear(in_features=768, out_features=50280, bias=False)
121 | )
122 |
123 | '''
--------------------------------------------------------------------------------
/benchmarks/benchmark_generation_mamba_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao, Albert Gu.
2 |
3 | import argparse
4 | import time
5 | import json
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from einops import rearrange
11 |
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 |
14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Generation benchmarking")
18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19 | parser.add_argument("--prompt", type=str, default=None)
20 | parser.add_argument("--promptlen", type=int, default=100)
21 | parser.add_argument("--genlen", type=int, default=100)
22 | parser.add_argument("--temperature", type=float, default=1.0)
23 | parser.add_argument("--topk", type=int, default=1)
24 | parser.add_argument("--topp", type=float, default=1.0)
25 | parser.add_argument("--minp", type=float, default=0.0)
26 | parser.add_argument("--repetition-penalty", type=float, default=1.0)
27 | parser.add_argument("--batch", type=int, default=1)
28 | args = parser.parse_args()
29 |
30 | repeats = 3
31 | device = "cuda"
32 | dtype = torch.float16
33 |
34 | print(f"Loading model {args.model_name}")
35 | is_mamba = args.model_name.startswith("state-spaces/mamba-")
36 | if is_mamba:
37 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
38 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
39 | else:
40 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
41 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
42 | model.eval()
43 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
44 |
45 | torch.random.manual_seed(0)
46 | if args.prompt is None:
47 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
48 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
49 | else:
50 | tokens = tokenizer(args.prompt, return_tensors="pt")
51 | input_ids = tokens.input_ids.to(device=device)
52 | attn_mask = tokens.attention_mask.to(device=device)
53 | max_length = input_ids.shape[1] + args.genlen
54 |
55 | if is_mamba:
56 | fn = lambda: model.generate(
57 | input_ids=input_ids,
58 | max_length=max_length,
59 | cg=True,
60 | return_dict_in_generate=True,
61 | output_scores=True,
62 | enable_timing=False,
63 | temperature=args.temperature,
64 | top_k=args.topk,
65 | top_p=args.topp,
66 | min_p=args.minp,
67 | repetition_penalty=args.repetition_penalty,
68 | )
69 | else:
70 | fn = lambda: model.generate(
71 | input_ids=input_ids,
72 | attention_mask=attn_mask,
73 | max_length=max_length,
74 | return_dict_in_generate=True,
75 | pad_token_id=tokenizer.eos_token_id,
76 | do_sample=True,
77 | temperature=args.temperature,
78 | top_k=args.topk,
79 | top_p=args.topp,
80 | repetition_penalty=args.repetition_penalty,
81 | )
82 | out = fn()
83 | if args.prompt is not None:
84 | print(tokenizer.batch_decode(out.sequences.tolist()))
85 |
86 | torch.cuda.synchronize()
87 | start = time.time()
88 | for _ in range(repeats):
89 | fn()
90 | torch.cuda.synchronize()
91 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
92 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
93 |
--------------------------------------------------------------------------------
/configs/cifar10_S_DiM.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 |
15 | n_steps = 500000
16 | config.train = d(
17 | n_steps=n_steps,
18 | batch_size=128,
19 | mode='uncond',
20 | log_interval=10,
21 | eval_interval=5000,
22 | save_interval=50000,
23 | )
24 |
25 | lr = 0.0002
26 | config.optimizer = d(
27 | name='adamw',
28 | lr=lr,
29 | weight_decay=0.03,
30 | betas=(0.99, 0.999),
31 | )
32 |
33 | config.lr_scheduler = d(
34 | name='customized',
35 | warmup_steps=2500,
36 | )
37 |
38 | learned_sigma = False
39 | latent_size = 32
40 | in_channels = 3
41 | config.nnet = d(
42 | name='Mamba_DiT_S_2',
43 | attention_head_dim=512//1, num_attention_heads=1, num_layers=25,
44 | in_channels=in_channels,
45 | num_embeds_ada_norm=10,
46 | sample_size=latent_size,
47 | activation_fn="gelu-approximate",
48 | attention_bias=True,
49 | norm_elementwise_affine=False,
50 | norm_type="ada_norm_single", #"layer_norm"
51 | out_channels=in_channels*2 if learned_sigma else in_channels,
52 | patch_size=2,
53 | mamba_d_state=16,
54 | mamba_d_conv=3,
55 | mamba_expand=2,
56 | use_bidirectional_rnn=False,
57 | mamba_type='enc',
58 | nested_order=0,
59 | is_uconnect=True,
60 | no_ff=True,
61 | use_conv1d=True,
62 | is_extra_tokens=True,
63 | rms=True,
64 | use_pad_token=True,
65 | use_a4m_adapter=True,
66 | drop_path_rate=0.0,
67 | encoder_start_blk_id=1, #0
68 | kv_as_one_token_idx=-1,
69 | num_2d_enc_dec_layers=6,
70 | pad_token_schedules=['dec_split', 'lateral'],
71 | is_absorb=False,
72 | use_adapter_modules=True,
73 | sequence_schedule='dilated',
74 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'],
75 | pos_encoding_type='learnable',
76 | scan_pattern_len=3,
77 | is_align_exchange_q_kv=False,
78 | )
79 |
80 | config.dataset = d(
81 | name='cifar10',
82 | path='assets/datasets/cifar10',
83 | random_flip=True,
84 | )
85 |
86 | config.sample = d(
87 | sample_steps=1000,
88 | n_samples=50000,
89 | mini_batch_size=500,
90 | algorithm='euler_maruyama_sde',
91 | path=''
92 | )
93 |
94 | return config
95 |
--------------------------------------------------------------------------------
/configs/imagenet256_H_DiM.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | # config.gradient_accumulation_steps=2 # 1
21 | config.max_grad_norm = 1.0
22 |
23 | config.train = d(
24 | n_steps=750000, # 300000
25 | batch_size=768,
26 | mode='cond',
27 | log_interval=10,
28 | eval_interval=5000,
29 | save_interval=25000, # 50000
30 | )
31 |
32 | config.optimizer = d(
33 | name='adamw',
34 | lr=0.0002,
35 | weight_decay=0.03,
36 | betas=(0.99, 0.99),
37 | eps=1e-15,
38 | )
39 |
40 | config.lr_scheduler = d(
41 | name='customized',
42 | warmup_steps=5000,
43 | )
44 |
45 | learned_sigma = False
46 | latent_size = 32
47 | in_channels = 4 # 3
48 | config.nnet = d(
49 | name='Mamba_DiT_H_2',
50 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49,
51 | in_channels=in_channels,
52 | num_embeds_ada_norm=1000,
53 | sample_size=latent_size,
54 | activation_fn="gelu-approximate", #"gelu-approximate",
55 | attention_bias=True,
56 | norm_elementwise_affine=False,
57 | norm_type="ada_norm_single", #"layer_norm",
58 | out_channels=in_channels*2 if learned_sigma else in_channels,
59 | patch_size=2,
60 | mamba_d_state=16,
61 | mamba_d_conv=3,
62 | mamba_expand=2,
63 | use_bidirectional_rnn=False,
64 | mamba_type='enc',
65 | nested_order=0,
66 | is_uconnect=True,
67 | no_ff=True,
68 | use_conv1d=True,
69 | is_extra_tokens=True,
70 | rms=True,
71 | use_pad_token=True,
72 | use_a4m_adapter=True,
73 | drop_path_rate=0.0,
74 | encoder_start_blk_id=1,
75 | kv_as_one_token_idx=-1,
76 | num_2d_enc_dec_layers=6,
77 | pad_token_schedules=['dec_split', 'lateral'],
78 | is_absorb=False,
79 | use_adapter_modules=True,
80 | sequence_schedule='dilated',
81 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'],
82 | pos_encoding_type='learnable',
83 | scan_pattern_len=4 -1,
84 | is_align_exchange_q_kv=False,
85 | is_random_patterns=False,
86 | )
87 | config.gradient_checkpointing = False
88 |
89 | config.dataset = d(
90 | name='imagenet',
91 | path='assets/datasets/ImageNet',
92 | resolution=256,
93 | cfg=True,
94 | p_uncond=0.1,
95 | )
96 |
97 | config.sample = d(
98 | sample_steps=50,
99 | n_samples=50000,
100 | mini_batch_size=25, # the decoder is large
101 | algorithm='dpm_solver',
102 | cfg=True,
103 | scale=0.4,
104 | path=''
105 | )
106 |
107 | return config
108 |
--------------------------------------------------------------------------------
/configs/imagenet256_L_DiM.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 32, 32)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | # config.gradient_accumulation_steps=2 # 1
21 | config.max_grad_norm = 1.0
22 |
23 | config.train = d(
24 | n_steps=300000, # 300000
25 | batch_size=1024,
26 | mode='cond',
27 | log_interval=10,
28 | eval_interval=5000,
29 | save_interval=25000, # 50000
30 | )
31 |
32 | config.optimizer = d(
33 | name='adamw',
34 | lr=0.0002,
35 | weight_decay=0.03,
36 | betas=(0.99, 0.99),
37 | eps=1e-15,
38 | )
39 |
40 | config.lr_scheduler = d(
41 | name='customized',
42 | warmup_steps=5000,
43 | )
44 |
45 | learned_sigma = False
46 | latent_size = 32
47 | in_channels = 4 # 3
48 | config.nnet = d(
49 | name='Mamba_DiT_H_2',
50 | attention_head_dim=1024//16, num_attention_heads=16, num_layers=49,
51 | in_channels=in_channels,
52 | num_embeds_ada_norm=1000,
53 | sample_size=latent_size,
54 | activation_fn="gelu-approximate", #"gelu-approximate",
55 | attention_bias=True,
56 | norm_elementwise_affine=False,
57 | norm_type="ada_norm_single", #"layer_norm",
58 | out_channels=in_channels*2 if learned_sigma else in_channels,
59 | patch_size=2,
60 | mamba_d_state=16,
61 | mamba_d_conv=3,
62 | mamba_expand=2,
63 | use_bidirectional_rnn=False,
64 | mamba_type='enc',
65 | nested_order=0,
66 | is_uconnect=True,
67 | no_ff=True,
68 | use_conv1d=True,
69 | is_extra_tokens=True,
70 | rms=True,
71 | use_pad_token=True,
72 | use_a4m_adapter=True,
73 | drop_path_rate=0.0,
74 | encoder_start_blk_id=1,
75 | kv_as_one_token_idx=-1,
76 | num_2d_enc_dec_layers=6,
77 | pad_token_schedules=['dec_split', 'lateral'],
78 | is_absorb=False,
79 | use_adapter_modules=True,
80 | sequence_schedule='dilated',
81 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'],
82 | pos_encoding_type='learnable',
83 | scan_pattern_len=4 -1,
84 | is_align_exchange_q_kv=False,
85 | is_random_patterns=False,
86 | )
87 | config.gradient_checkpointing = False
88 |
89 | config.dataset = d(
90 | name='imagenet',
91 | path='assets/datasets/ImageNet',
92 | resolution=256,
93 | cfg=True,
94 | p_uncond=0.15, # aligned with u-vit
95 | )
96 |
97 | config.sample = d(
98 | sample_steps=50,
99 | n_samples=50000,
100 | mini_batch_size=25, # the decoder is large
101 | algorithm='dpm_solver',
102 | cfg=True,
103 | scale=0.4,
104 | path=''
105 | )
106 |
107 | return config
108 |
--------------------------------------------------------------------------------
/configs/imagenet512_H_DiM_ft.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | config.z_shape = (4, 64, 64)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.gradient_accumulation_steps=2 # 1
21 | config.max_grad_norm = 1.0
22 |
23 | config.train = d(
24 | n_steps=64000, #300000,
25 | batch_size=240, #1024,
26 | mode='cond',
27 | log_interval=10,
28 | eval_interval=1000, # 5000,
29 | save_interval=10000, # 50000
30 | )
31 |
32 | config.optimizer = d(
33 | name='adamw',
34 | lr=0.0001 , #0.0002,
35 | weight_decay=0, #0.03,
36 | betas=(0.99, 0.99),
37 | eps=1e-15,
38 | )
39 |
40 | config.lr_scheduler = d(
41 | name='customized',
42 | warmup_steps=2500, # 1, #5000,
43 | )
44 |
45 | learned_sigma = False
46 | latent_size = 64 #32
47 | in_channels = 4 # 3
48 | config.nnet = d(
49 | name='Mamba_DiT_H_2',
50 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49,
51 | in_channels=in_channels,
52 | num_embeds_ada_norm=1000,
53 | sample_size=latent_size,
54 | activation_fn="gelu-approximate",
55 | attention_bias=True,
56 | norm_elementwise_affine=False,
57 | norm_type="ada_norm_single", #"layer_norm",
58 | out_channels=in_channels*2 if learned_sigma else in_channels,
59 | patch_size=2,
60 | mamba_d_state=16,
61 | mamba_d_conv=3,
62 | mamba_expand=2,
63 | use_bidirectional_rnn=False,
64 | mamba_type='enc',
65 | nested_order=0,
66 | is_uconnect=True,
67 | no_ff=True,
68 | use_conv1d=True,
69 | is_extra_tokens=True,
70 | rms=True,
71 | use_pad_token=True,
72 | use_a4m_adapter=True,
73 | drop_path_rate=0.0,
74 | encoder_start_blk_id=1,
75 | kv_as_one_token_idx=-1,
76 | num_2d_enc_dec_layers=6,
77 | pad_token_schedules=['dec_split', 'rho_pad'], #['dec_split', 'lateral'],
78 | is_absorb=False,
79 | use_adapter_modules=True,
80 | sequence_schedule='dilated',
81 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'],
82 | pos_encoding_type='learnable',
83 | scan_pattern_len=4 -1,
84 | is_align_exchange_q_kv=False,
85 | is_random_patterns=False,
86 | pretrained_path = 'workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth',
87 | multi_times=2,
88 | pattern_type='base',
89 | )
90 | config.gradient_checkpointing = False
91 |
92 | config.dataset = d(
93 | name='imagenet512_features',
94 | path='assets/datasets/imagenet512_features',
95 | cfg=True,
96 | p_uncond=0.15
97 | )
98 |
99 | config.sample = d(
100 | sample_steps=50,
101 | n_samples=50000,
102 | mini_batch_size=25, # the decoder is large
103 | algorithm='dpm_solver',
104 | cfg=True,
105 | scale=0.7,
106 | path=''
107 | )
108 |
109 | return config
110 |
--------------------------------------------------------------------------------
/configs/imagenet512_H_DiM_upsample_3x_test.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import os
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | # config.z_shape = (4, 64, 64)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.gradient_accumulation_steps=2 # 1
21 | config.max_grad_norm = 1.0
22 |
23 | config.train = d(
24 | n_steps=300000,
25 | batch_size=128,
26 | mode='cond',
27 | log_interval=10,
28 | eval_interval=1000,
29 | save_interval=5000,
30 | )
31 |
32 | config.optimizer = d(
33 | name='adamw',
34 | lr=0.0001,
35 | weight_decay=0.03,
36 | betas=(0.99, 0.99),
37 | eps=1e-15,
38 | )
39 |
40 | config.lr_scheduler = d(
41 | name='customized',
42 | warmup_steps=1,
43 | )
44 |
45 | config.ug_theta = 1
46 | config.ug_eta = 0.3
47 | config.ug_T = 1
48 |
49 | base_resolution = 512
50 | patch_size = 2
51 | multi_times = 3 #2
52 | resolution = int(base_resolution * multi_times)
53 | coco_multi_scale = [ resolution, resolution //2 , base_resolution, ]
54 |
55 | learned_sigma = False
56 | latent_size = resolution // 8
57 | config.z_shape = (4, latent_size, latent_size)
58 | in_channels = 4
59 | config.nnet = d(
60 | name='Mamba_DiT_H_2',
61 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49,
62 | in_channels=in_channels,
63 | num_embeds_ada_norm=1000,
64 | sample_size=latent_size,
65 | activation_fn="gelu-approximate",
66 | attention_bias=True,
67 | norm_elementwise_affine=False,
68 | norm_type="ada_norm_single", #"layer_norm",
69 | out_channels=in_channels*2 if learned_sigma else in_channels,
70 | patch_size=patch_size,
71 | mamba_d_state=16,
72 | mamba_d_conv=3,
73 | mamba_expand=2,
74 | use_bidirectional_rnn=False,
75 | mamba_type='enc',
76 | nested_order=0,
77 | is_uconnect=True,
78 | no_ff=True,
79 | use_conv1d=True,
80 | is_extra_tokens=True,
81 | rms=True,
82 | use_pad_token=True,
83 | use_a4m_adapter=True,
84 | drop_path_rate=0.0,
85 | encoder_start_blk_id=1,
86 | kv_as_one_token_idx=-1,
87 | num_2d_enc_dec_layers=6,
88 | pad_token_schedules=['dec_split', 'rho_pad'],
89 | is_absorb=False,
90 | use_adapter_modules=True,
91 | sequence_schedule='dilated',
92 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'],
93 | pos_encoding_type='learnable',
94 | scan_pattern_len=4 -1,
95 | is_align_exchange_q_kv=False,
96 | is_random_patterns=False,
97 | pretrained_path = 'workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth',
98 | multi_times=multi_times,
99 | pattern_type='base',
100 | is_freeu=False, freeu_param=(0.15, 0.1, 1.1, 1.2), # (0.3, 0.2, 1.1, 1.2)
101 | num_patches = [ (i //8 //patch_size)**2 for i in coco_multi_scale],
102 | is_skip_tune=True, skip_tune_param = (0.82, 1.0), #(0.82, 1.0),
103 | )
104 | config.gradient_checkpointing = False
105 |
106 | config.dataset = d(
107 | name='imagenet512_features',
108 | path='assets/datasets/imagenet512_features',
109 | cfg=True,
110 | p_uncond=0.15
111 | )
112 | vis_label = None
113 | config.sample = d(
114 | sample_steps=50,
115 | n_samples=50000,
116 | mini_batch_size=1, # the decoder is large
117 | algorithm='dpm_solver_upsample_g',
118 | cfg=True,
119 | scale=3.0,
120 | path='workdir/imagenet512_H_DiM_ft/test_tmp/' + (
121 | str(vis_label) if vis_label is not None else 'all_class'
122 | ) + '_' + str( resolution ),
123 | vis_label=vis_label,
124 | )
125 |
126 | if not os.path.exists(config.sample.path):
127 | os.makedirs(config.sample.path)
128 |
129 | return config
130 |
--------------------------------------------------------------------------------
/configs/imagenet512_H_DiM_upsample_test.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import os
3 |
4 | def d(**kwargs):
5 | """Helper of creating a config dict."""
6 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
7 |
8 |
9 | def get_config():
10 | config = ml_collections.ConfigDict()
11 |
12 | config.seed = 1234
13 | config.pred = 'noise_pred'
14 | # config.z_shape = (4, 64, 64)
15 |
16 | config.autoencoder = d(
17 | pretrained_path='assets/stable-diffusion/autoencoder_kl_ema.pth'
18 | )
19 |
20 | config.gradient_accumulation_steps=2 # 1
21 | config.max_grad_norm = 1.0
22 |
23 | config.train = d(
24 | n_steps=300000,
25 | batch_size=128,
26 | mode='cond',
27 | log_interval=10,
28 | eval_interval=1000,
29 | save_interval=5000,
30 | )
31 |
32 | config.optimizer = d(
33 | name='adamw',
34 | lr=0.0001,
35 | weight_decay=0.03,
36 | betas=(0.99, 0.99),
37 | eps=1e-15,
38 | )
39 |
40 | config.lr_scheduler = d(
41 | name='customized',
42 | warmup_steps=1,
43 | )
44 |
45 | config.ug_theta = 1
46 | config.ug_eta = 0.3
47 | config.ug_T = 1
48 |
49 | base_resolution = 512
50 | patch_size = 2
51 | multi_times = 2
52 | resolution = int(base_resolution * multi_times)
53 | coco_multi_scale = [ resolution, resolution //2 , base_resolution, ]
54 |
55 | learned_sigma = False
56 | latent_size = resolution // 8
57 | config.z_shape = (4, latent_size, latent_size)
58 | in_channels = 4
59 | config.nnet = d(
60 | name='Mamba_DiT_H_2',
61 | attention_head_dim=1536//1, num_attention_heads=1, num_layers=49,
62 | in_channels=in_channels,
63 | num_embeds_ada_norm=1000,
64 | sample_size=latent_size,
65 | activation_fn="gelu-approximate",
66 | attention_bias=True,
67 | norm_elementwise_affine=False,
68 | norm_type="ada_norm_single", #"layer_norm",
69 | out_channels=in_channels*2 if learned_sigma else in_channels,
70 | patch_size=patch_size,
71 | mamba_d_state=16,
72 | mamba_d_conv=3,
73 | mamba_expand=2,
74 | use_bidirectional_rnn=False,
75 | mamba_type='enc',
76 | nested_order=0,
77 | is_uconnect=True,
78 | no_ff=True,
79 | use_conv1d=True,
80 | is_extra_tokens=True,
81 | rms=True,
82 | use_pad_token=True,
83 | use_a4m_adapter=True,
84 | drop_path_rate=0.0,
85 | encoder_start_blk_id=1,
86 | kv_as_one_token_idx=-1,
87 | num_2d_enc_dec_layers=6,
88 | pad_token_schedules=['dec_split', 'rho_pad'],
89 | is_absorb=False,
90 | use_adapter_modules=True,
91 | sequence_schedule='dilated',
92 | sub_sequence_schedule=['reverse_single', 'layerwise_cross'],
93 | pos_encoding_type='learnable',
94 | scan_pattern_len=4 -1,
95 | is_align_exchange_q_kv=False,
96 | is_random_patterns=False,
97 | pretrained_path = 'workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth',
98 | multi_times=multi_times,
99 | pattern_type='base',
100 | is_freeu=False, freeu_param=(0.15, 0.1, 1.1, 1.2), # (0.3, 0.2, 1.1, 1.2)
101 | num_patches = [ (i //8 //patch_size)**2 for i in coco_multi_scale],
102 | is_skip_tune=True, skip_tune_param = (0.82, 1.0), #(0.82, 1.0),
103 | )
104 | config.gradient_checkpointing = False
105 |
106 | config.dataset = d(
107 | name='imagenet512_features',
108 | path='assets/datasets/imagenet512_features',
109 | cfg=True,
110 | p_uncond=0.15
111 | )
112 | vis_label = None
113 | # rabit 330
114 | # panda 388
115 | # lion 291
116 | # cat 283
117 | config.sample = d(
118 | sample_steps=50,
119 | n_samples=50000,
120 | mini_batch_size=1, # the decoder is large
121 | algorithm='dpm_solver_upsample_g',
122 | cfg=True,
123 | scale=3.0,
124 | path='workdir/imagenet512_H_DiM_ft/test_tmp/' + (
125 | str(vis_label) if vis_label is not None else 'all_class'
126 | ) + '_' + str( resolution ),
127 | vis_label=vis_label,
128 | )
129 |
130 | if not os.path.exists(config.sample.path):
131 | os.makedirs(config.sample.path)
132 |
133 | return config
134 |
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct SSMScanParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, seqlen, n_chunks;
13 | index_t a_batch_stride;
14 | index_t b_batch_stride;
15 | index_t out_batch_stride;
16 |
17 | // Common data pointers.
18 | void *__restrict__ a_ptr;
19 | void *__restrict__ b_ptr;
20 | void *__restrict__ out_ptr;
21 | void *__restrict__ x_ptr;
22 | };
23 |
24 | ////////////////////////////////////////////////////////////////////////////////////////////////////
25 |
26 | struct SSMParamsBase {
27 | using index_t = uint32_t;
28 |
29 | int batch, dim, seqlen, dstate, n_groups, n_chunks;
30 | int dim_ngroups_ratio;
31 | bool is_variable_B;
32 | bool is_variable_C;
33 |
34 | bool delta_softplus;
35 |
36 | index_t A_d_stride;
37 | index_t A_dstate_stride;
38 | index_t B_batch_stride;
39 | index_t B_d_stride;
40 | index_t B_dstate_stride;
41 | index_t B_group_stride;
42 | index_t C_batch_stride;
43 | index_t C_d_stride;
44 | index_t C_dstate_stride;
45 | index_t C_group_stride;
46 | index_t u_batch_stride;
47 | index_t u_d_stride;
48 | index_t delta_batch_stride;
49 | index_t delta_d_stride;
50 | index_t z_batch_stride;
51 | index_t z_d_stride;
52 | index_t out_batch_stride;
53 | index_t out_d_stride;
54 | index_t out_z_batch_stride;
55 | index_t out_z_d_stride;
56 |
57 | // Common data pointers.
58 | void *__restrict__ A_ptr;
59 | void *__restrict__ B_ptr;
60 | void *__restrict__ C_ptr;
61 | void *__restrict__ D_ptr;
62 | void *__restrict__ u_ptr;
63 | void *__restrict__ delta_ptr;
64 | void *__restrict__ delta_bias_ptr;
65 | void *__restrict__ out_ptr;
66 | void *__restrict__ x_ptr;
67 | void *__restrict__ z_ptr;
68 | void *__restrict__ out_z_ptr;
69 | };
70 |
71 | struct SSMParamsBwd: public SSMParamsBase {
72 | index_t dout_batch_stride;
73 | index_t dout_d_stride;
74 | index_t dA_d_stride;
75 | index_t dA_dstate_stride;
76 | index_t dB_batch_stride;
77 | index_t dB_group_stride;
78 | index_t dB_d_stride;
79 | index_t dB_dstate_stride;
80 | index_t dC_batch_stride;
81 | index_t dC_group_stride;
82 | index_t dC_d_stride;
83 | index_t dC_dstate_stride;
84 | index_t du_batch_stride;
85 | index_t du_d_stride;
86 | index_t dz_batch_stride;
87 | index_t dz_d_stride;
88 | index_t ddelta_batch_stride;
89 | index_t ddelta_d_stride;
90 |
91 | // Common data pointers.
92 | void *__restrict__ dout_ptr;
93 | void *__restrict__ dA_ptr;
94 | void *__restrict__ dB_ptr;
95 | void *__restrict__ dC_ptr;
96 | void *__restrict__ dD_ptr;
97 | void *__restrict__ du_ptr;
98 | void *__restrict__ dz_ptr;
99 | void *__restrict__ ddelta_ptr;
100 | void *__restrict__ ddelta_bias_ptr;
101 | };
102 |
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_bwd_bf16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_bwd_fp16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_bwd_fp32_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_common.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 | #include // For scalar_value_type
10 |
11 | #define MAX_DSTATE 256
12 |
13 | using complex_t = c10::complex;
14 |
15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){
16 | return {a.x + b.x, a.y + b.y};
17 | }
18 |
19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) {
20 | return {a.x + b.x, a.y + b.y, a.z + b.z};
21 | }
22 |
23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){
24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
25 | }
26 |
27 | ////////////////////////////////////////////////////////////////////////////////////////////////////
28 |
29 | template struct BytesToType {};
30 |
31 | template<> struct BytesToType<16> {
32 | using Type = uint4;
33 | static_assert(sizeof(Type) == 16);
34 | };
35 |
36 | template<> struct BytesToType<8> {
37 | using Type = uint64_t;
38 | static_assert(sizeof(Type) == 8);
39 | };
40 |
41 | template<> struct BytesToType<4> {
42 | using Type = uint32_t;
43 | static_assert(sizeof(Type) == 4);
44 | };
45 |
46 | template<> struct BytesToType<2> {
47 | using Type = uint16_t;
48 | static_assert(sizeof(Type) == 2);
49 | };
50 |
51 | template<> struct BytesToType<1> {
52 | using Type = uint8_t;
53 | static_assert(sizeof(Type) == 1);
54 | };
55 |
56 | ////////////////////////////////////////////////////////////////////////////////////////////////////
57 |
58 | template
59 | struct Converter{
60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
61 | #pragma unroll
62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
63 | }
64 | };
65 |
66 | template
67 | struct Converter{
68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
69 | static_assert(N % 2 == 0);
70 | auto &src2 = reinterpret_cast(src);
71 | auto &dst2 = reinterpret_cast(dst);
72 | #pragma unroll
73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
74 | }
75 | };
76 |
77 | #if __CUDA_ARCH__ >= 800
78 | template
79 | struct Converter{
80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
81 | static_assert(N % 2 == 0);
82 | auto &src2 = reinterpret_cast(src);
83 | auto &dst2 = reinterpret_cast(dst);
84 | #pragma unroll
85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
86 | }
87 | };
88 | #endif
89 |
90 | ////////////////////////////////////////////////////////////////////////////////////////////////////
91 |
92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) {
95 | float t = exp2f(z.real_);
96 | float c, s;
97 | sincosf(z.imag_, &s, &c);
98 | return complex_t(c * t, s * t);
99 | }
100 |
101 | __device__ __forceinline__ complex_t cexpf(complex_t z) {
102 | float t = expf(z.real_);
103 | float c, s;
104 | sincosf(z.imag_, &s, &c);
105 | return complex_t(c * t, s * t);
106 | }
107 |
108 | template struct SSMScanOp;
109 |
110 | template<>
111 | struct SSMScanOp {
112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
114 | }
115 | };
116 |
117 | template<>
118 | struct SSMScanOp {
119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
120 | complex_t a0 = complex_t(ab0.x, ab0.y);
121 | complex_t b0 = complex_t(ab0.z, ab0.w);
122 | complex_t a1 = complex_t(ab1.x, ab1.y);
123 | complex_t b1 = complex_t(ab1.z, ab1.w);
124 | complex_t out_a = a1 * a0;
125 | complex_t out_b = a1 * b0 + b1;
126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
127 | }
128 | };
129 |
130 | // A stateful callback functor that maintains a running prefix to be applied
131 | // during consecutive scan operations.
132 | template struct SSMScanPrefixCallbackOp {
133 | using scan_t = std::conditional_t, float2, float4>;
134 | scan_t running_prefix;
135 | // Constructor
136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
137 | // Callback operator to be entered by the first warp of threads in the block.
138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan.
139 | __device__ scan_t operator()(scan_t block_aggregate) {
140 | scan_t old_prefix = running_prefix;
141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate);
142 | return old_prefix;
143 | }
144 | };
145 |
146 | ////////////////////////////////////////////////////////////////////////////////////////////////////
147 |
148 | template
149 | inline __device__ void load_input(typename Ktraits::input_t *u,
150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
151 | typename Ktraits::BlockLoadT::TempStorage &smem_load,
152 | int seqlen) {
153 | if constexpr (Ktraits::kIsEvenLen) {
154 | auto& smem_load_vec = reinterpret_cast(smem_load);
155 | using vec_t = typename Ktraits::vec_t;
156 | Ktraits::BlockLoadVecT(smem_load_vec).Load(
157 | reinterpret_cast(u),
158 | reinterpret_cast(u_vals)
159 | );
160 | } else {
161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
162 | }
163 | }
164 |
165 | template
166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
169 | int seqlen) {
170 | constexpr int kNItems = Ktraits::kNItems;
171 | if constexpr (!Ktraits::kIsComplex) {
172 | typename Ktraits::input_t B_vals_load[kNItems];
173 | if constexpr (Ktraits::kIsEvenLen) {
174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight);
175 | using vec_t = typename Ktraits::vec_t;
176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
177 | reinterpret_cast(Bvar),
178 | reinterpret_cast(B_vals_load)
179 | );
180 | } else {
181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
182 | }
183 | // #pragma unroll
184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
185 | Converter::to_float(B_vals_load, B_vals);
186 | } else {
187 | typename Ktraits::input_t B_vals_load[kNItems * 2];
188 | if constexpr (Ktraits::kIsEvenLen) {
189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight);
190 | using vec_t = typename Ktraits::vec_t;
191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
192 | reinterpret_cast(Bvar),
193 | reinterpret_cast(B_vals_load)
194 | );
195 | } else {
196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
197 | }
198 | #pragma unroll
199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
200 | }
201 | }
202 |
203 | template
204 | inline __device__ void store_output(typename Ktraits::input_t *out,
205 | const float (&out_vals)[Ktraits::kNItems],
206 | typename Ktraits::BlockStoreT::TempStorage &smem_store,
207 | int seqlen) {
208 | typename Ktraits::input_t write_vals[Ktraits::kNItems];
209 | #pragma unroll
210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
211 | if constexpr (Ktraits::kIsEvenLen) {
212 | auto& smem_store_vec = reinterpret_cast(smem_store);
213 | using vec_t = typename Ktraits::vec_t;
214 | Ktraits::BlockStoreVecT(smem_store_vec).Store(
215 | reinterpret_cast(out),
216 | reinterpret_cast(write_vals)
217 | );
218 | } else {
219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
220 | }
221 | }
222 |
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_fwd_bf16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_fwd_fp16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/selective_scan_fwd_fp32.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/csrc/selective_scan/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/csrc/selective_scan/uninitialized_copy.cuh:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Redistribution and use in source and binary forms, with or without
5 | * modification, are permitted provided that the following conditions are met:
6 | * * Redistributions of source code must retain the above copyright
7 | * notice, this list of conditions and the following disclaimer.
8 | * * Redistributions in binary form must reproduce the above copyright
9 | * notice, this list of conditions and the following disclaimer in the
10 | * documentation and/or other materials provided with the distribution.
11 | * * Neither the name of the NVIDIA CORPORATION nor the
12 | * names of its contributors may be used to endorse or promote products
13 | * derived from this software without specific prior written permission.
14 | *
15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | *
26 | ******************************************************************************/
27 |
28 | #pragma once
29 |
30 | #include
31 |
32 | #include
33 |
34 |
35 | namespace detail
36 | {
37 |
38 | #if defined(_NVHPC_CUDA)
39 | template
40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41 | {
42 | // NVBug 3384810
43 | new (ptr) T(::cuda::std::forward(val));
44 | }
45 | #else
46 | template ::value,
50 | int
51 | >::type = 0>
52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53 | {
54 | *ptr = ::cuda::std::forward(val);
55 | }
56 |
57 | template ::value,
61 | int
62 | >::type = 0>
63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64 | {
65 | new (ptr) T(::cuda::std::forward(val));
66 | }
67 | #endif
68 |
69 | } // namespace detail
70 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: mamba-attn
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.9
7 | - pip=22.3.1
8 | - cudatoolkit=11.8
9 | - pip:
10 | - torch==2.1.1
11 | - torchvision
12 | - packaging
13 | - tb-nightly
14 | - gradio==3.33.1
15 | - albumentations==1.3.0
16 | - opencv-contrib-python
17 | - imageio==2.9.0
18 | - imageio-ffmpeg==0.4.2
19 | - pytorch-lightning==1.5.0
20 | - omegaconf==2.3.0
21 | - test-tube>=0.7.5
22 | - streamlit==1.12.1
23 | - einops==0.6.0
24 | - transformers==4.36.2
25 | - webdataset==0.2.5
26 | - kornia==0.6
27 | - open_clip_torch==2.16.0
28 | - invisible-watermark>=0.1.5
29 | - streamlit-drawable-canvas==0.8.0
30 | - torchmetrics==0.6.0
31 | - timm==0.6.12
32 | - addict==2.4.0
33 | - yapf==0.32.0
34 | - prettytable==3.6.0
35 | - safetensors==0.3.1
36 | - basicsr==1.4.2
37 | - accelerate==0.17.0
38 | - decord==0.6.0
39 | - diffusers==0.25
40 | - moviepy==1.0.3
41 | - opencv_python==4.7.0.68
42 | - Pillow==9.4.0
43 | - scikit_image==0.19.3
44 | - scipy==1.10.1
45 | - tensorboardX==2.6
46 | - tqdm==4.64.1
47 | - numpy
48 | - datasets
49 | - ipython
50 | - seaborn
51 | - pycocotools
52 | - ipdb
53 | - matplotlib
54 | - flow_vis
55 | - charset-normalizer
56 | - ml_collections
57 | - wandb
58 | - causal-conv1d==1.2.0.post2
59 | - galore-torch
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import multiprocessing as mp
7 | import accelerate
8 | import utils
9 | import sde
10 | from uvit_datasets import get_dataset
11 | import tempfile
12 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
13 | from absl import logging
14 | import builtins
15 |
16 | from mamba_attn_diff.models.upsample_guidance import make_ufg_nnet
17 |
18 | def evaluate(config):
19 | if config.get('benchmark', False):
20 | torch.backends.cudnn.benchmark = True
21 | torch.backends.cudnn.deterministic = False
22 |
23 | mp.set_start_method('spawn')
24 | accelerator = accelerate.Accelerator()
25 | device = accelerator.device
26 | accelerate.utils.set_seed(config.seed, device_specific=True)
27 | logging.info(f'Process {accelerator.process_index} using device: {device}')
28 |
29 | config.mixed_precision = accelerator.mixed_precision
30 | config = ml_collections.FrozenConfigDict(config)
31 | if accelerator.is_main_process:
32 | utils.set_logger(log_level='info', fname=config.output_path)
33 | else:
34 | utils.set_logger(log_level='error')
35 | builtins.print = lambda *args: None
36 |
37 | dataset = get_dataset(**config.dataset)
38 |
39 | nnet = utils.get_nnet(**config.nnet)
40 | nnet = accelerator.prepare(nnet)
41 | logging.info(f'load nnet from {config.nnet_path}')
42 | nnet = accelerator.unwrap_model(nnet)
43 | nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
44 | nnet.eval()
45 | print(nnet, (config.sample.algorithm if config.get('scheduler', False) else 'dpm_solver'))
46 |
47 | def cfg_nnet(x, timestep, y, **kwargs):
48 | _cond = nnet(x, timestep, y=y, **kwargs)
49 | _uncond = nnet(x, timestep, y=torch.tensor([dataset.K] * x.size(0), device=device), **kwargs)
50 | _cond = _cond.sample if not isinstance(_cond, torch.Tensor) else _cond
51 | _uncond = _uncond.sample if not isinstance(_uncond, torch.Tensor) else _uncond
52 | return _cond + config.sample.scale * (_cond - _uncond)
53 |
54 | def uncfg_nnet(x, timestep, y=None, **kwargs):
55 | _uncfg = nnet(x, timestep, **kwargs)
56 | _uncfg = _uncfg.sample if not isinstance(_uncfg, torch.Tensor) else _uncfg
57 | return _uncfg
58 |
59 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
60 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
61 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
62 | else:
63 | score_model = sde.ScoreModel(uncfg_nnet, pred=config.pred, sde=sde.VPSDE())
64 |
65 |
66 | logging.info(config.sample)
67 | assert os.path.exists(dataset.fid_stat)
68 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
69 |
70 | def sample_fn(_n_samples):
71 | if config.sample.algorithm == 'dpm_solver_upsample_g':
72 | m = 2
73 | data_shape = tuple([dataset.data_shape[0]] + [ i*m for i in dataset.data_shape[1:] ])
74 | x_init = torch.randn(_n_samples, *data_shape, device=device)
75 | else:
76 | x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
77 |
78 | if config.train.mode == 'uncond':
79 | kwargs = dict()
80 | elif config.train.mode == 'cond':
81 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
82 | else:
83 | raise NotImplementedError
84 |
85 | if config.sample.algorithm == 'euler_maruyama_sde':
86 | rsde = sde.ReverseSDE(score_model)
87 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
88 | elif config.sample.algorithm == 'euler_maruyama_ode':
89 | rsde = sde.ODE(score_model)
90 | return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
91 | elif config.sample.algorithm == 'dpm_solver_upsample_g':
92 | noise_schedule = NoiseScheduleVP(schedule='linear')
93 | sde_entity = sde.VPSDE()
94 |
95 | normed_timesteps = torch.arange(1000, dtype=x_init.dtype, device=device).flip(0) / 999
96 | normed_timesteps[-1] = 1e-5
97 | model_fn = make_ufg_nnet(
98 | cfg_nnet,
99 | uncfg_nnet,
100 | normed_timesteps,
101 | sde_entity.cum_alpha,
102 | sde_entity.cum_beta,
103 | sde_entity.snr,
104 | m=2,
105 | **kwargs,
106 | )
107 |
108 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
109 | return dpm_solver.sample(
110 | x_init,
111 | steps=config.sample.sample_steps,
112 | eps=1e-4,
113 | adaptive_step_size=False,
114 | fast_version=True,
115 | )
116 | elif config.sample.algorithm == 'dpm_solver':
117 | noise_schedule = NoiseScheduleVP(schedule='linear')
118 | model_fn = model_wrapper(
119 | score_model.noise_pred,
120 | noise_schedule,
121 | time_input_type='0',
122 | model_kwargs=kwargs
123 | )
124 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
125 | return dpm_solver.sample(
126 | x_init,
127 | steps=config.sample.sample_steps,
128 | eps=1e-4,
129 | adaptive_step_size=False,
130 | fast_version=True,
131 | )
132 | else:
133 | raise NotImplementedError
134 |
135 | with tempfile.TemporaryDirectory() as temp_path:
136 | path = config.sample.path or temp_path
137 | if accelerator.is_main_process:
138 | os.makedirs(path, exist_ok=True)
139 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
140 | if accelerator.is_main_process:
141 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
142 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
143 |
144 |
145 | from absl import flags
146 | from absl import app
147 | from ml_collections import config_flags
148 | import os
149 |
150 |
151 | FLAGS = flags.FLAGS
152 | config_flags.DEFINE_config_file(
153 | "config", None, "Training configuration.", lock_config=False)
154 | flags.mark_flags_as_required(["config"])
155 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
156 | flags.DEFINE_string("output_path", None, "The path to output log.")
157 |
158 |
159 | def main(argv):
160 | config = FLAGS.config
161 | config.nnet_path = FLAGS.nnet_path
162 | config.output_path = FLAGS.output_path
163 | evaluate(config)
164 |
165 |
166 | if __name__ == "__main__":
167 | app.run(main)
168 |
--------------------------------------------------------------------------------
/eval_ldm.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | import utils
7 | import sde
8 | from uvit_datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
11 | from absl import logging
12 | import builtins
13 | import libs.autoencoder
14 |
15 | from mamba_attn_diff.models.upsample_guidance import make_ufg_nnet
16 |
17 | def evaluate(config):
18 | if config.get('benchmark', False):
19 | torch.backends.cudnn.benchmark = True
20 | torch.backends.cudnn.deterministic = False
21 |
22 | mp.set_start_method('spawn')
23 | accelerator = accelerate.Accelerator()
24 | device = accelerator.device
25 | accelerate.utils.set_seed(config.seed, device_specific=True)
26 | logging.info(f'Process {accelerator.process_index} using device: {device}')
27 |
28 | config.mixed_precision = accelerator.mixed_precision
29 | config = ml_collections.FrozenConfigDict(config)
30 | if accelerator.is_main_process:
31 | utils.set_logger(log_level='info', fname=config.output_path)
32 | else:
33 | utils.set_logger(log_level='error')
34 | builtins.print = lambda *args: None
35 |
36 | dataset = get_dataset(**config.dataset)
37 |
38 | nnet = utils.get_nnet(**config.nnet)
39 | nnet = accelerator.prepare(nnet)
40 | logging.info(f'load nnet from {config.nnet_path}')
41 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
42 | nnet.eval()
43 |
44 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
45 | autoencoder.to(device)
46 |
47 | @torch.cuda.amp.autocast()
48 | def encode(_batch):
49 | return autoencoder.encode(_batch)
50 |
51 | @torch.cuda.amp.autocast()
52 | def decode(_batch):
53 | return autoencoder.decode(_batch)
54 |
55 | def decode_large_batch(_batch):
56 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
57 | xs = []
58 | pt = 0
59 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
60 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
61 | pt += _decode_mini_batch_size
62 | xs.append(x)
63 | xs = torch.concat(xs, dim=0)
64 | assert xs.size(0) == _batch.size(0)
65 | return xs
66 |
67 | def cfg_nnet(x, timestep, y, **kwargs):
68 | _cond = nnet(x, timestep, y=y, **kwargs)
69 | _uncond = nnet(x, timestep, y=torch.tensor([dataset.K] * x.size(0), device=device), **kwargs)
70 | _cond = _cond.sample if not isinstance(_cond, torch.Tensor) else _cond
71 | _uncond = _uncond.sample if not isinstance(_uncond, torch.Tensor) else _uncond
72 | return _cond + config.sample.scale * (_cond - _uncond)
73 |
74 | def uncfg_nnet(x, timestep, y=None, **kwargs):
75 | _uncfg = nnet(x, timestep, **kwargs)
76 | _uncfg = _uncfg.sample if not isinstance(_uncfg, torch.Tensor) else _uncfg
77 | return _uncfg
78 |
79 | if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
80 | logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
81 | score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE()) #
82 | else:
83 | score_model = sde.ScoreModel(uncfg_nnet, pred=config.pred, sde=sde.VPSDE())
84 |
85 | logging.info(config.sample)
86 | assert os.path.exists(dataset.fid_stat)
87 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
88 |
89 | def sample_fn(_n_samples):
90 | if config.sample.algorithm == 'dpm_solver_upsample_g':
91 | m = 2
92 | data_shape = tuple([config.z_shape[0]] + [ int(i*m) for i in config.z_shape[1:] ]) # config.z_shape[1:]
93 | _z_init = torch.randn(_n_samples, *data_shape, device=device)
94 | else:
95 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
96 |
97 | if config.train.mode == 'uncond':
98 | kwargs = dict()
99 | elif config.train.mode == 'cond':
100 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
101 | else:
102 | raise NotImplementedError
103 |
104 | if config.sample.algorithm == 'euler_maruyama_sde':
105 | _z = sde.euler_maruyama(sde.ReverseSDE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
106 | elif config.sample.algorithm == 'euler_maruyama_ode':
107 | _z = sde.euler_maruyama(sde.ODE(score_model), _z_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
108 | elif config.sample.algorithm == 'dpm_solver_upsample_g':
109 | noise_schedule = NoiseScheduleVP(schedule='linear')
110 | sde_entity = sde.VPSDE()
111 |
112 | normed_timesteps = torch.arange(1000, dtype=_z_init.dtype, device=device).flip(0) / 999
113 | normed_timesteps[-1] = 1e-5
114 | model_fn = make_ufg_nnet(
115 | cfg_nnet,
116 | uncfg_nnet,
117 | normed_timesteps,
118 | sde_entity.cum_alpha,
119 | sde_entity.cum_beta,
120 | sde_entity.snr,
121 | m=m,
122 | **kwargs,
123 | )
124 |
125 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
126 | _z = dpm_solver.sample(
127 | _z_init,
128 | steps=config.sample.sample_steps,
129 | eps=1e-4,
130 | adaptive_step_size=False,
131 | fast_version=True,
132 | )
133 | elif config.sample.algorithm == 'dpm_solver':
134 | noise_schedule = NoiseScheduleVP(schedule='linear')
135 | model_fn = model_wrapper(
136 | score_model.noise_pred, #
137 | noise_schedule,
138 | time_input_type='0',
139 | model_kwargs=kwargs
140 | )
141 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
142 | _z = dpm_solver.sample(
143 | _z_init,
144 | steps=config.sample.sample_steps,
145 | eps=1e-4,
146 | adaptive_step_size=False,
147 | fast_version=True,
148 | )
149 | else:
150 | raise NotImplementedError
151 | return decode_large_batch(_z)
152 |
153 | with tempfile.TemporaryDirectory() as temp_path:
154 | path = config.sample.path or temp_path
155 | if accelerator.is_main_process:
156 | os.makedirs(path, exist_ok=True)
157 | logging.info(f'Samples are saved in {path}')
158 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
159 | if accelerator.is_main_process:
160 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
161 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
162 |
163 |
164 | from absl import flags
165 | from absl import app
166 | from ml_collections import config_flags
167 | import os
168 |
169 |
170 | FLAGS = flags.FLAGS
171 | config_flags.DEFINE_config_file(
172 | "config", None, "Training configuration.", lock_config=False)
173 | flags.mark_flags_as_required(["config"])
174 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
175 | flags.DEFINE_string("output_path", None, "The path to output log.")
176 |
177 |
178 | def main(argv):
179 | config = FLAGS.config
180 | config.nnet_path = FLAGS.nnet_path
181 | config.output_path = FLAGS.output_path
182 | evaluate(config)
183 |
184 |
185 | if __name__ == "__main__":
186 | app.run(main)
187 |
--------------------------------------------------------------------------------
/eval_t2i_discrete.py:
--------------------------------------------------------------------------------
1 | from tools.fid_score import calculate_fid_given_paths
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | import accelerate
6 | from torch.utils.data import DataLoader
7 | import utils
8 | from datasets import get_dataset
9 | import tempfile
10 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
11 | from absl import logging
12 | import builtins
13 | import einops
14 | import libs.autoencoder
15 |
16 |
17 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
18 | _betas = (
19 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
20 | )
21 | return _betas.numpy()
22 |
23 |
24 | def evaluate(config):
25 | if config.get('benchmark', False):
26 | torch.backends.cudnn.benchmark = True
27 | torch.backends.cudnn.deterministic = False
28 |
29 | mp.set_start_method('spawn')
30 | accelerator = accelerate.Accelerator()
31 | device = accelerator.device
32 | accelerate.utils.set_seed(config.seed, device_specific=True)
33 | logging.info(f'Process {accelerator.process_index} using device: {device}')
34 |
35 | config.mixed_precision = accelerator.mixed_precision
36 | config = ml_collections.FrozenConfigDict(config)
37 | if accelerator.is_main_process:
38 | utils.set_logger(log_level='info', fname=config.output_path)
39 | else:
40 | utils.set_logger(log_level='error')
41 | builtins.print = lambda *args: None
42 |
43 | dataset = get_dataset(**config.dataset)
44 | test_dataset = dataset.get_split(split='test', labeled=True) # for sampling
45 | test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True,
46 | drop_last=True, num_workers=8, pin_memory=True, persistent_workers=True)
47 |
48 | nnet = utils.get_nnet(**config.nnet)
49 | nnet, test_dataset_loader = accelerator.prepare(nnet, test_dataset_loader)
50 | logging.info(f'load nnet from {config.nnet_path}')
51 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
52 | nnet.eval()
53 |
54 | def cfg_nnet(x, timesteps, context):
55 | _cond = nnet(x, timesteps, context=context)
56 | if config.sample.scale == 0:
57 | return _cond
58 | _empty_context = torch.tensor(dataset.empty_context, device=device)
59 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
60 | _uncond = nnet(x, timesteps, context=_empty_context)
61 | return _cond + config.sample.scale * (_cond - _uncond)
62 |
63 | autoencoder = libs.autoencoder.get_model(**config.autoencoder)
64 | autoencoder.to(device)
65 |
66 | @torch.cuda.amp.autocast()
67 | def encode(_batch):
68 | return autoencoder.encode(_batch)
69 |
70 | @torch.cuda.amp.autocast()
71 | def decode(_batch):
72 | return autoencoder.decode(_batch)
73 |
74 | def decode_large_batch(_batch):
75 | decode_mini_batch_size = 50 # use a small batch size since the decoder is large
76 | xs = []
77 | pt = 0
78 | for _decode_mini_batch_size in utils.amortize(_batch.size(0), decode_mini_batch_size):
79 | x = decode(_batch[pt: pt + _decode_mini_batch_size])
80 | pt += _decode_mini_batch_size
81 | xs.append(x)
82 | xs = torch.concat(xs, dim=0)
83 | assert xs.size(0) == _batch.size(0)
84 | return xs
85 |
86 | def get_context_generator():
87 | while True:
88 | for data in test_dataset_loader:
89 | _, _context = data
90 | yield _context
91 |
92 | context_generator = get_context_generator()
93 |
94 | _betas = stable_diffusion_beta_schedule()
95 | N = len(_betas)
96 |
97 | logging.info(config.sample)
98 | assert os.path.exists(dataset.fid_stat)
99 | logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
100 |
101 | def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
102 | _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
103 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
104 |
105 | def model_fn(x, t_continuous):
106 | t = t_continuous * N
107 | return cfg_nnet(x, t, **kwargs)
108 |
109 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
110 | _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / N, T=1.)
111 | return decode_large_batch(_z)
112 |
113 | def sample_fn(_n_samples):
114 | _context = next(context_generator)
115 | assert _context.size(0) == _n_samples
116 | return dpm_solver_sample(_n_samples, config.sample.sample_steps, context=_context)
117 |
118 | with tempfile.TemporaryDirectory() as temp_path:
119 | path = config.sample.path or temp_path
120 | if accelerator.is_main_process:
121 | os.makedirs(path, exist_ok=True)
122 | logging.info(f'Samples are saved in {path}')
123 | utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
124 | if accelerator.is_main_process:
125 | fid = calculate_fid_given_paths((dataset.fid_stat, path))
126 | logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
127 |
128 |
129 | from absl import flags
130 | from absl import app
131 | from ml_collections import config_flags
132 | import os
133 |
134 |
135 | FLAGS = flags.FLAGS
136 | config_flags.DEFINE_config_file(
137 | "config", None, "Training configuration.", lock_config=False)
138 | flags.mark_flags_as_required(["config"])
139 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
140 | flags.DEFINE_string("output_path", None, "The path to output log.")
141 |
142 |
143 | def main(argv):
144 | config = FLAGS.config
145 | config.nnet_path = FLAGS.nnet_path
146 | config.output_path = FLAGS.output_path
147 | evaluate(config)
148 |
149 |
150 | if __name__ == "__main__":
151 | app.run(main)
152 |
--------------------------------------------------------------------------------
/evals/lm_harness_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import transformers
4 | from transformers import AutoTokenizer
5 |
6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
7 |
8 | from lm_eval.api.model import LM
9 | from lm_eval.models.huggingface import HFLM
10 | from lm_eval.api.registry import register_model
11 | from lm_eval.__main__ import cli_evaluate
12 |
13 |
14 | @register_model("mamba")
15 | class MambaEvalWrapper(HFLM):
16 |
17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
18 |
19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
20 | dtype=torch.float16):
21 | LM.__init__(self)
22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
25 | self.vocab_size = self.tokenizer.vocab_size
26 | self._batch_size = int(batch_size) if batch_size is not None else 64
27 | self._max_length = max_length
28 | self._device = torch.device(device)
29 |
30 | @property
31 | def batch_size(self):
32 | return self._batch_size
33 |
34 | def _model_generate(self, context, max_length, stop, **generation_kwargs):
35 | raise NotImplementedError()
36 |
37 |
38 | if __name__ == "__main__":
39 | cli_evaluate()
40 |
--------------------------------------------------------------------------------
/libs/__init__.py:
--------------------------------------------------------------------------------
1 | # codes from third party
2 |
--------------------------------------------------------------------------------
/libs/clip.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import CLIPTokenizer, CLIPTextModel
3 |
4 |
5 | class AbstractEncoder(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def encode(self, *args, **kwargs):
10 | raise NotImplementedError
11 |
12 |
13 | class FrozenCLIPEmbedder(AbstractEncoder):
14 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
15 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
16 | super().__init__()
17 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
18 | self.transformer = CLIPTextModel.from_pretrained(version)
19 | self.device = device
20 | self.max_length = max_length
21 | self.freeze()
22 |
23 | def freeze(self):
24 | self.transformer = self.transformer.eval()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def forward(self, text):
29 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
30 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
31 | tokens = batch_encoding["input_ids"].to(self.device)
32 | outputs = self.transformer(input_ids=tokens)
33 |
34 | z = outputs.last_hidden_state
35 | return z
36 |
37 | def encode(self, text):
38 | return self(text)
39 |
--------------------------------------------------------------------------------
/libs/timm.py:
--------------------------------------------------------------------------------
1 | # code from timm 0.3.2
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | import warnings
6 |
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 |
65 | def drop_path(x, drop_prob: float = 0., training: bool = False):
66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67 |
68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72 | 'survival rate' as the argument.
73 |
74 | """
75 | if drop_prob == 0. or not training:
76 | return x
77 | keep_prob = 1 - drop_prob
78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80 | random_tensor.floor_() # binarize
81 | output = x.div(keep_prob) * random_tensor
82 | return output
83 |
84 |
85 | class DropPath(nn.Module):
86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87 | """
88 | def __init__(self, drop_prob=None):
89 | super(DropPath, self).__init__()
90 | self.drop_prob = drop_prob
91 |
92 | def forward(self, x):
93 | return drop_path(x, self.drop_prob, self.training)
94 |
95 |
96 | class Mlp(nn.Module):
97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98 | super().__init__()
99 | out_features = out_features or in_features
100 | hidden_features = hidden_features or in_features
101 | self.fc1 = nn.Linear(in_features, hidden_features)
102 | self.act = act_layer()
103 | self.fc2 = nn.Linear(hidden_features, out_features)
104 | self.drop = nn.Dropout(drop)
105 |
106 | def forward(self, x):
107 | x = self.fc1(x)
108 | x = self.act(x)
109 | x = self.drop(x)
110 | x = self.fc2(x)
111 | x = self.drop(x)
112 | return x
113 |
--------------------------------------------------------------------------------
/libs/uvit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | from mamba_attn_diff.utils.init_weights import _init_weights_mamba, pos_embed_inteplot
9 |
10 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
11 | ATTENTION_MODE = 'flash'
12 | else:
13 | try:
14 | import xformers
15 | import xformers.ops
16 | ATTENTION_MODE = 'xformers'
17 | except:
18 | ATTENTION_MODE = 'math'
19 | print(f'attention mode is {ATTENTION_MODE}')
20 | # ATTENTION_MODE = 'math'
21 |
22 | def timestep_embedding(timesteps, dim, max_period=10000):
23 | """
24 | Create sinusoidal timestep embeddings.
25 |
26 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
27 | These may be fractional.
28 | :param dim: the dimension of the output.
29 | :param max_period: controls the minimum frequency of the embeddings.
30 | :return: an [N x dim] Tensor of positional embeddings.
31 | """
32 | half = dim // 2
33 | freqs = torch.exp(
34 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
35 | ).to(device=timesteps.device)
36 | args = timesteps[:, None].float() * freqs[None]
37 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
38 | if dim % 2:
39 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
40 | return embedding
41 |
42 |
43 | def patchify(imgs, patch_size):
44 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
45 | return x
46 |
47 |
48 | def unpatchify(x, channels=3):
49 | patch_size = int((x.shape[2] // channels) ** 0.5)
50 | h = w = int(x.shape[1] ** .5)
51 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
52 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
53 | return x
54 |
55 |
56 | class Attention(nn.Module):
57 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
58 | super().__init__()
59 | self.num_heads = num_heads
60 | head_dim = dim // num_heads
61 | self.scale = qk_scale or head_dim ** -0.5
62 |
63 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
64 | self.attn_drop = nn.Dropout(attn_drop)
65 | self.proj = nn.Linear(dim, dim)
66 | self.proj_drop = nn.Dropout(proj_drop)
67 |
68 | def forward(self, x):
69 | B, L, C = x.shape
70 |
71 | qkv = self.qkv(x)
72 | if ATTENTION_MODE == 'flash':
73 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
74 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
75 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
76 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
77 | elif ATTENTION_MODE == 'xformers':
78 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
79 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
80 | x = xformers.ops.memory_efficient_attention(q, k, v)
81 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
82 | elif ATTENTION_MODE == 'math':
83 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
84 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
85 | attn = (q @ k.transpose(-2, -1)) * self.scale
86 | attn = attn.softmax(dim=-1)
87 | attn = self.attn_drop(attn)
88 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
89 | else:
90 | raise NotImplemented
91 |
92 | x = self.proj(x)
93 | x = self.proj_drop(x)
94 | return x
95 |
96 |
97 | class Block(nn.Module):
98 |
99 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
100 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
101 | super().__init__()
102 | self.norm1 = norm_layer(dim)
103 | self.attn = Attention(
104 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
105 | self.norm2 = norm_layer(dim)
106 | mlp_hidden_dim = int(dim * mlp_ratio)
107 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
108 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
109 | self.use_checkpoint = use_checkpoint
110 |
111 | def forward(self, x, skip=None):
112 | if self.use_checkpoint:
113 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
114 | else:
115 | return self._forward(x, skip)
116 |
117 | def _forward(self, x, skip=None):
118 | if self.skip_linear is not None:
119 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
120 | x = x + self.attn(self.norm1(x))
121 | x = x + self.mlp(self.norm2(x))
122 | return x
123 |
124 |
125 | class PatchEmbed(nn.Module):
126 | """ Image to Patch Embedding
127 | """
128 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
129 | super().__init__()
130 | self.patch_size = patch_size
131 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
132 |
133 | def forward(self, x):
134 | B, C, H, W = x.shape
135 | assert H % self.patch_size == 0 and W % self.patch_size == 0
136 | x = self.proj(x).flatten(2).transpose(1, 2)
137 | return x
138 |
139 |
140 | class UViT(nn.Module):
141 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
142 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
143 | use_checkpoint=False, conv=True, skip=True, is_last_double_channel=False,
144 | hook_intermediate_out=False, **kwargs):
145 | super().__init__()
146 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
147 | self.num_classes = num_classes
148 | self.in_chans = in_chans
149 | self.hook_intermediate_out = hook_intermediate_out
150 | out_chans = in_chans * 3 if is_last_double_channel else in_chans
151 |
152 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
153 | num_patches = (img_size // patch_size) ** 2
154 | self.num_patches = num_patches
155 |
156 | self.time_embed = nn.Sequential(
157 | nn.Linear(embed_dim, 4 * embed_dim),
158 | nn.SiLU(),
159 | nn.Linear(4 * embed_dim, embed_dim),
160 | ) if mlp_time_embed else nn.Identity()
161 |
162 | if self.num_classes > 0:
163 | self.label_emb = nn.Embedding(self.num_classes, embed_dim)
164 | self.extras = 2
165 | else:
166 | self.extras = 1
167 |
168 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
169 |
170 | self.in_blocks = nn.ModuleList([
171 | Block(
172 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
173 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
174 | for _ in range(depth // 2)])
175 |
176 | self.mid_block = Block(
177 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
178 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
179 |
180 | self.out_blocks = nn.ModuleList([
181 | Block(
182 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
183 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
184 | for _ in range(depth // 2)])
185 |
186 | self.norm = norm_layer(embed_dim)
187 | self.patch_dim = patch_size ** 2 * in_chans
188 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
189 | self.final_layer = nn.Conv2d(self.in_chans, out_chans, 3, padding=1) if conv else nn.Identity()
190 |
191 | trunc_normal_(self.pos_embed, std=.02)
192 | self.apply(self._init_weights)
193 |
194 | def _init_weights(self, m):
195 | if isinstance(m, nn.Linear):
196 | trunc_normal_(m.weight, std=.02)
197 | if isinstance(m, nn.Linear) and m.bias is not None:
198 | nn.init.constant_(m.bias, 0)
199 | elif isinstance(m, nn.LayerNorm):
200 | nn.init.constant_(m.bias, 0)
201 | nn.init.constant_(m.weight, 1.0)
202 |
203 | @torch.jit.ignore
204 | def no_weight_decay(self):
205 | return {'pos_embed'}
206 |
207 | def forward(self, x, timesteps, y=None, use_pos_inteplot=False):
208 | x = self.patch_embed(x)
209 | B, L, D = x.shape
210 |
211 | height, width = int(L**0.5), int(L**0.5)
212 |
213 | if self.hook_intermediate_out and self.training:
214 | intermediate_out = dict()
215 | intermediate_out['intermediate_outputs'] = []
216 | layer_count = 0
217 |
218 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
219 | time_token = time_token.unsqueeze(dim=1)
220 | x = torch.cat((time_token, x), dim=1)
221 |
222 | if y is not None:
223 | label_emb = self.label_emb(y)
224 | label_emb = label_emb.unsqueeze(dim=1)
225 | x = torch.cat((label_emb, x), dim=1)
226 |
227 | pos_embed = self.pos_embed
228 | if use_pos_inteplot:
229 | pos_embed = pos_embed_inteplot(
230 | cur_pos_embed=None,
231 | pretrained_pos_embed=pos_embed,
232 | extra_len=self.extras,
233 | cur_size=(height, width),
234 | )
235 |
236 | x = x + pos_embed
237 |
238 | if self.hook_intermediate_out and self.training:
239 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L))
240 | layer_count += 1
241 |
242 | skips = []
243 | for blk in self.in_blocks:
244 | x = blk(x)
245 | skips.append(x)
246 |
247 | if self.hook_intermediate_out and self.training:
248 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L))
249 | layer_count += 1
250 |
251 | x = self.mid_block(x)
252 |
253 | if self.hook_intermediate_out and self.training:
254 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L))
255 | layer_count += 1
256 |
257 | for idx, blk in enumerate(self.out_blocks):
258 | x = blk(x, skips.pop())
259 |
260 | if self.hook_intermediate_out and self.training and idx < len(self.out_blocks) - 1:
261 | intermediate_out['intermediate_outputs'].append(self.predict_head(x, L))
262 | layer_count += 1
263 |
264 | x = self.predict_head(x, L)
265 |
266 | if self.hook_intermediate_out and self.training:
267 | intermediate_out['num_pred_layer'] = layer_count
268 | return x, intermediate_out
269 | else:
270 | return x
271 |
272 | def predict_head(self, x, L):
273 | x = self.norm(x)
274 | x = self.decoder_pred(x)
275 | x = x[:, -L:, :]
276 | x = unpatchify(x, self.in_chans)
277 | x = self.final_layer(x)
278 | return x
279 |
--------------------------------------------------------------------------------
/libs/uvit_t2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from .timm import trunc_normal_, Mlp
5 | import einops
6 | import torch.utils.checkpoint
7 |
8 | if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
9 | ATTENTION_MODE = 'flash'
10 | else:
11 | try:
12 | import xformers
13 | import xformers.ops
14 | ATTENTION_MODE = 'xformers'
15 | except:
16 | ATTENTION_MODE = 'math'
17 | print(f'attention mode is {ATTENTION_MODE}')
18 |
19 |
20 | def timestep_embedding(timesteps, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 |
24 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
25 | These may be fractional.
26 | :param dim: the dimension of the output.
27 | :param max_period: controls the minimum frequency of the embeddings.
28 | :return: an [N x dim] Tensor of positional embeddings.
29 | """
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33 | ).to(device=timesteps.device)
34 | args = timesteps[:, None].float() * freqs[None]
35 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36 | if dim % 2:
37 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38 | return embedding
39 |
40 |
41 | def patchify(imgs, patch_size):
42 | x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
43 | return x
44 |
45 |
46 | def unpatchify(x, channels=3):
47 | patch_size = int((x.shape[2] // channels) ** 0.5)
48 | h = w = int(x.shape[1] ** .5)
49 | assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
50 | x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
56 | super().__init__()
57 | self.num_heads = num_heads
58 | head_dim = dim // num_heads
59 | self.scale = qk_scale or head_dim ** -0.5
60 |
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.attn_drop = nn.Dropout(attn_drop)
63 | self.proj = nn.Linear(dim, dim)
64 | self.proj_drop = nn.Dropout(proj_drop)
65 |
66 | def forward(self, x):
67 | B, L, C = x.shape
68 |
69 | qkv = self.qkv(x)
70 | if ATTENTION_MODE == 'flash':
71 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
72 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
73 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
74 | x = einops.rearrange(x, 'B H L D -> B L (H D)')
75 | elif ATTENTION_MODE == 'xformers':
76 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
77 | q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
78 | x = xformers.ops.memory_efficient_attention(q, k, v)
79 | x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
80 | elif ATTENTION_MODE == 'math':
81 | qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
82 | q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 | x = (attn @ v).transpose(1, 2).reshape(B, L, C)
87 | else:
88 | raise NotImplemented
89 |
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
98 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
106 | self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
107 | self.use_checkpoint = use_checkpoint
108 |
109 | def forward(self, x, skip=None):
110 | if self.use_checkpoint:
111 | return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
112 | else:
113 | return self._forward(x, skip)
114 |
115 | def _forward(self, x, skip=None):
116 | if self.skip_linear is not None:
117 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
118 | x = x + self.attn(self.norm1(x))
119 | x = x + self.mlp(self.norm2(x))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, patch_size, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | self.patch_size = patch_size
129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
130 |
131 | def forward(self, x):
132 | B, C, H, W = x.shape
133 | assert H % self.patch_size == 0 and W % self.patch_size == 0
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class UViT(nn.Module):
139 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
140 | qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
141 | clip_dim=768, num_clip_token=77, conv=True, skip=True):
142 | super().__init__()
143 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
144 | self.in_chans = in_chans
145 |
146 | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
147 | num_patches = (img_size // patch_size) ** 2
148 |
149 | self.time_embed = nn.Sequential(
150 | nn.Linear(embed_dim, 4 * embed_dim),
151 | nn.SiLU(),
152 | nn.Linear(4 * embed_dim, embed_dim),
153 | ) if mlp_time_embed else nn.Identity()
154 |
155 | self.context_embed = nn.Linear(clip_dim, embed_dim)
156 |
157 | self.extras = 1 + num_clip_token
158 |
159 | self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
160 |
161 | self.in_blocks = nn.ModuleList([
162 | Block(
163 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
164 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
165 | for _ in range(depth // 2)])
166 |
167 | self.mid_block = Block(
168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
169 | norm_layer=norm_layer, use_checkpoint=use_checkpoint)
170 |
171 | self.out_blocks = nn.ModuleList([
172 | Block(
173 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
174 | norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
175 | for _ in range(depth // 2)])
176 |
177 | self.norm = norm_layer(embed_dim)
178 | self.patch_dim = patch_size ** 2 * in_chans
179 | self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
180 | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
181 |
182 | trunc_normal_(self.pos_embed, std=.02)
183 | self.apply(self._init_weights)
184 |
185 | def _init_weights(self, m):
186 | if isinstance(m, nn.Linear):
187 | trunc_normal_(m.weight, std=.02)
188 | if isinstance(m, nn.Linear) and m.bias is not None:
189 | nn.init.constant_(m.bias, 0)
190 | elif isinstance(m, nn.LayerNorm):
191 | nn.init.constant_(m.bias, 0)
192 | nn.init.constant_(m.weight, 1.0)
193 |
194 | @torch.jit.ignore
195 | def no_weight_decay(self):
196 | return {'pos_embed'}
197 |
198 | def forward(self, x, timesteps, context):
199 | x = self.patch_embed(x)
200 | B, L, D = x.shape
201 |
202 | time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
203 | time_token = time_token.unsqueeze(dim=1)
204 | context_token = self.context_embed(context)
205 | x = torch.cat((time_token, context_token, x), dim=1)
206 | x = x + self.pos_embed
207 |
208 | skips = []
209 | for blk in self.in_blocks:
210 | x = blk(x)
211 | skips.append(x)
212 |
213 | x = self.mid_block(x)
214 |
215 | for blk in self.out_blocks:
216 | x = blk(x, skips.pop())
217 |
218 | x = self.norm(x)
219 | x = self.decoder_pred(x)
220 | assert x.size(1) == self.extras + L
221 | x = x[:, self.extras:, :]
222 | x = unpatchify(x, self.in_chans)
223 | x = self.final_layer(x)
224 | return x
225 |
--------------------------------------------------------------------------------
/main.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/main.pdf
--------------------------------------------------------------------------------
/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/main.png
--------------------------------------------------------------------------------
/mamba_attn_diff/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_attn_diff/models/__init__.py
--------------------------------------------------------------------------------
/mamba_attn_diff/models/freeu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.fft import fftn, fftshift, ifftn, ifftshift
3 |
4 |
5 | def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int):
6 | """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
7 |
8 | This version of the method comes from here:
9 | https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
10 | """
11 | x = x_in
12 | B, L, C = x.shape
13 |
14 | # Non-power of 2 images must be float32
15 | if (L & (L - 1)) != 0 :
16 | x = x.to(dtype=torch.float32)
17 |
18 | # FFT
19 | x_freq = fftn(x, dim=1)
20 | x_freq = fftshift(x_freq, dim=1)
21 |
22 | B, L, C = x_freq.shape
23 | mask = torch.ones((B, L, C), device=x.device)
24 |
25 | crow = L // 2
26 | mask[..., crow - threshold : crow + threshold, :] = scale
27 | x_freq = x_freq * mask
28 |
29 | # IFFT
30 | x_freq = ifftshift(x_freq, dim=1)
31 | x_filtered = ifftn(x_freq, dim=1).real
32 |
33 | return x_filtered.to(dtype=x_in.dtype)
34 |
35 |
36 | def apply_freeu(
37 | resolution_idx,
38 | hidden_states, res_hidden_states,
39 | s1=0.6, s2=0.4, b1=1.1, b2=1.2,
40 | encoder_start_blk_id=1,
41 | num_layers=49,
42 | extra_len=2,
43 | ):
44 | """Applies the FreeU mechanism as introduced in https:
45 | //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
46 |
47 | Args:
48 | resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
49 | hidden_states (`torch.Tensor`): Inputs to the underlying block.
50 | res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
51 | s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
52 | s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
53 | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
54 | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
55 |
56 | pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
57 |
58 | for i, up_block_type in enumerate(up_block_types):
59 | resolution_idx=i,
60 | """
61 | # if resolution_idx == encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 0:
62 | # # print(resolution_idx)
63 | # num_half_channels = hidden_states.shape[-1] // 2
64 | # hidden_states[..., :num_half_channels] = hidden_states[..., :num_half_channels] * b1
65 | # res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=s1, )
66 | # elif resolution_idx == encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 1:
67 | # # print(resolution_idx)
68 | # num_half_channels = hidden_states.shape[-1] // 2
69 | # hidden_states[..., :num_half_channels] = hidden_states[..., :num_half_channels] * b2
70 | # res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=s2, )
71 |
72 | if resolution_idx == encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 0:
73 | s = s1
74 | b = b1
75 | elif resolution_idx <= encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 1:
76 | s = s2
77 | b = b2
78 |
79 | if resolution_idx >= encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 0 and \
80 | resolution_idx <= encoder_start_blk_id + (num_layers - encoder_start_blk_id)//2 + 1 :
81 |
82 | hidden_mean = hidden_states[:, extra_len:, :].mean(-1).unsqueeze(-1)
83 | B = hidden_mean.shape[0]
84 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1)
85 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1)
86 | hidden_max = hidden_max.unsqueeze(-1).unsqueeze(-1)
87 | hidden_min = hidden_min.unsqueeze(-1).unsqueeze(-1)
88 | hidden_mean = (hidden_mean - hidden_min) / (hidden_max - hidden_min)
89 |
90 |
91 | hidden_states[:, extra_len:, :] = hidden_states[:, extra_len:, :] * ((b - 1 ) * hidden_mean + 1)
92 | res_hidden_states[:, extra_len:, :] = fourier_filter(res_hidden_states[:, extra_len:, :], threshold=1, scale=s, )
93 |
94 |
95 |
96 | return hidden_states, res_hidden_states
--------------------------------------------------------------------------------
/mamba_attn_diff/models/upsample_guidance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange
6 |
7 | # Upsample Guidance: Scale Up Diffusion Models without Training: https://arxiv.org/abs/2404.01709
8 |
9 | def model_wrapper(model, noise_schedule=None, is_cond_classifier=False, classifier_fn=None, classifier_scale=1.,
10 | time_input_type='1', total_N=1000, model_kwargs={}):
11 |
12 | def get_model_input_time(t_continuous):
13 | """
14 | Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
15 | """
16 | if time_input_type == '0':
17 | # discrete_type == '0' means that the model is continuous-time model.
18 | # For continuous-time DPMs, the continuous time equals to the discrete time.
19 | return t_continuous
20 | elif time_input_type == '1':
21 | # Type-1 discrete label, as detailed in the Appendix of DPM-Solver.
22 | return 1000. * torch.max(t_continuous - 1. / total_N, torch.zeros_like(t_continuous).to(t_continuous))
23 | elif time_input_type == '2':
24 | # Type-2 discrete label, as detailed in the Appendix of DPM-Solver.
25 | max_N = (total_N - 1) / total_N * 1000.
26 | return max_N * t_continuous
27 | else:
28 | raise ValueError("Unsupported time input type {}, must be '0' or '1' or '2'".format(time_input_type))
29 |
30 | def cond_fn(x, t_discrete, y):
31 | """
32 | Compute the gradient of the classifier, multiplied with the sclae of the classifier guidance.
33 | """
34 | assert y is not None
35 | with torch.enable_grad():
36 | x_in = x.detach().requires_grad_(True)
37 | logits = classifier_fn(x_in, t_discrete)
38 | log_probs = F.log_softmax(logits, dim=-1)
39 | selected = log_probs[range(len(logits)), y.view(-1)]
40 | return classifier_scale * torch.autograd.grad(selected.sum(), x_in)[0]
41 |
42 | def model_fn(x, t_continuous):
43 | """
44 | The noise predicition model function that is used for DPM-Solver.
45 | """
46 | if is_cond_classifier:
47 | y = model_kwargs.get("y", None)
48 | if y is None:
49 | raise ValueError("For classifier guidance, the label y has to be in the input.")
50 | t_discrete = get_model_input_time(t_continuous)
51 | noise_uncond = model(x, t_discrete, **model_kwargs)
52 | noise_uncond = noise_uncond.sample if not isinstance(noise_uncond, torch.Tensor) else noise_uncond
53 | cond_grad = cond_fn(x, t_discrete, y)
54 | sigma_t = noise_schedule.marginal_std(t_continuous)
55 | dims = len(cond_grad.shape) - 1
56 | return noise_uncond - sigma_t[(...,) + (None,) * dims] * cond_grad
57 | else:
58 | t_discrete = get_model_input_time(t_continuous)
59 | model_output = model(x, t_discrete, **model_kwargs)
60 | model_output = model_output.sample if not isinstance(model_output, torch.Tensor) else model_output
61 | return model_output
62 |
63 | return model_fn
64 |
65 | def _init_taus(snrs, timesteps, m=2):
66 | scaled_snrs = snrs * m
67 | approx_snr, taus = torch.min(
68 | (scaled_snrs.reshape(-1, 1) < snrs.reshape(1, -1)) * snrs.reshape(1, -1),
69 | dim=-1,
70 | )
71 | taus[approx_snr == 0] = 0
72 | taus = timesteps[taus]
73 | return taus
74 |
75 | def get_tau(preset_snrs, snr_func, t_continuous, m=2, return_indices=False):
76 | cur_scaled_snr = snr_func(t_continuous) * (m**2)
77 | cur_scaled_snr = cur_scaled_snr.reshape(-1, 1)
78 | preset_snrs = preset_snrs.reshape(1, -1)
79 | tau = (cur_scaled_snr - preset_snrs).abs().min(dim=-1).indices
80 | if return_indices:
81 | return tau
82 |
83 | tau = 1 - tau / (preset_snrs.shape[-1] - 1 )
84 | tau = tau.clamp(max=t_continuous[0])
85 | return tau
86 |
87 | def _init_timestep_idx_map(timesteps):
88 | timestep_idx_map = { t.item(): idx for idx, t in enumerate(timesteps) }
89 | return timestep_idx_map
90 |
91 | def make_ufg_nnet(
92 | cfg_model_func, uncfg_model_func, timesteps, alpha_t_func, beta_t_func, snr_func,
93 | N=999, m=2,
94 | ug_theta=1, ug_eta = 0.3, ug_T = 1,
95 | **kwargs,
96 | ):
97 | # print(timesteps, 'timesteps')
98 | preset_snrs = snr_func(timesteps)
99 |
100 | def ufg_nnet(x, t_continuous):
101 |
102 | tau_continuous = get_tau(preset_snrs, snr_func, t_continuous, m=m)
103 | t = t_continuous * N
104 | tau = tau_continuous * N
105 |
106 | cfg_pred = cfg_model_func(x, timestep=t, use_pos_inteplot=True, **kwargs)
107 | cfg_pred = cfg_pred.sample if not isinstance(cfg_pred, torch.Tensor) else cfg_pred
108 |
109 | w_t = append_dims(get_wt(t_continuous, theta=ug_theta, eta=ug_eta, T=ug_T), len(x.shape))
110 |
111 | if (w_t > 0).any():
112 | print([tau[0], t[0], alpha_t_func( t_continuous[0] ), beta_t_func( t_continuous[0] ), t_continuous], ug_theta, ug_eta, ug_T )
113 |
114 | resized_pred_noise = get_resized_input_predicted_noise(
115 | cfg_model_func, x,
116 | alpha_t_func( t_continuous ),
117 | beta_t_func( t_continuous ),
118 | tau,
119 | m=m,
120 | **kwargs,
121 | )
122 |
123 | g_pred = cfg_pred + w_t * upsample_guidance(
124 | resized_pred_noise,
125 | cfg_pred, #uncon_pred,
126 | m=m,
127 | )
128 | else:
129 | print(t_continuous )
130 | g_pred = cfg_pred
131 |
132 | return g_pred
133 |
134 | return ufg_nnet
135 |
136 | def slide_denoise(model_func, x, t, m=2, y=None, **kwargs):
137 | x = rearrange(x, 'b c (p h) (q w) -> (b p q) c h w ', p=m, q=m)
138 | t = t.reshape(-1, 1).expand(-1, m*m).reshape(-1)
139 | y = y.reshape(-1, 1).expand(-1, m*m).reshape(-1) if y is not None else None
140 | model_output = model_func(x, timestep=t, y=y, **kwargs)
141 | model_output = model_output.sample if not isinstance(model_output, torch.Tensor) else model_output
142 | model_output = rearrange(model_output, '(b p q) c h w -> b c (p h) (q w)', p=m, q=m)
143 | return model_output
144 |
145 | def upsample_guidance(resized_input_predicted_noise, predicted_noise, m=2):
146 | g_pred = resize_func(
147 | resized_input_predicted_noise / m - resize_func(
148 | predicted_noise, scale=1/m,
149 | ),
150 | scale = m,
151 | )
152 | return g_pred
153 |
154 | def get_resized_input_predicted_noise(model_func, x, alpha_t, beta_t, tau, m=2, **kwargs):
155 | p = get_P(alpha_t, beta_t, m=m)
156 | p = append_dims(p, len(x.shape))
157 | x = resize_func(x, scale=1/m) / (
158 | p ** 0.5
159 | )
160 | model_output = model_func(x, timestep=tau, use_pos_inteplot=True, **kwargs)
161 | return model_output.sample if not isinstance(model_output, torch.Tensor) else model_output
162 |
163 | def get_wt(t, theta=1, eta=0.6, T=1):
164 | def h(x):
165 | return 1*(x >= 0)
166 |
167 | w_t = theta * h(t - (1 - eta)*T)
168 | return w_t
169 |
170 | def get_P(alpha_t, beta_t, m=2, ):
171 | return alpha_t + beta_t / (m**2)
172 |
173 | def resize_func(x, scale):
174 | if scale > 1:
175 | return F.interpolate(
176 | x,
177 | scale_factor=scale,
178 | mode='nearest', #align_corners=False,
179 | )
180 | elif scale < 1:
181 | rscale = int(1 / scale)
182 | x = rearrange(x, 'b c (h p) (w q) -> b c h w (p q)', p=rscale, q=rscale )
183 | x = x.mean(-1)
184 | return x
185 | else:
186 | return x
187 |
188 | def append_dims(x, new_shape_len):
189 | return x.reshape(*x.shape, *([1]*(new_shape_len - len(x.shape))))
--------------------------------------------------------------------------------
/mamba_attn_diff/utils/backup_code.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import torch
5 | import shutil
6 | import logging
7 | import datetime
8 | from torch.utils.tensorboard import SummaryWriter
9 |
10 | UNSAVED_DIRS = ['outputs', 'checkpoint', 'checkpoints', 'workdir', 'build', '.git', '__pycache__', 'assets', 'samples']
11 |
12 | def backup_code(work_dir, verbose=False):
13 | base_dir = './' #os.path.dirname(os.path.abspath(__file__))
14 |
15 | dir_list = ["*.py", ]
16 | for file in os.listdir(base_dir):
17 | sub_dir = os.path.join(base_dir, file)
18 | if os.path.isdir(sub_dir):
19 | if file in UNSAVED_DIRS:
20 | continue
21 |
22 | for root, dirs, files in os.walk(sub_dir):
23 | for dir_name in dirs:
24 | dir_list.append(os.path.join(root, dir_name)+"/*.py")
25 |
26 | elif file.split('.')[-1] == 'py':
27 | pass
28 |
29 | # print(dir_list)
30 |
31 | for pattern in dir_list:
32 | for file in glob.glob(pattern):
33 | src = os.path.join(base_dir, file)
34 | dst = os.path.join(work_dir, 'backup', os.path.dirname(file))
35 | # print(base_dir, src, dst)
36 |
37 | if verbose:
38 | logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst)))
39 |
40 | os.makedirs(dst, exist_ok=True)
41 | shutil.copy2(src, dst)
42 |
43 |
44 |
--------------------------------------------------------------------------------
/mamba_attn_diff/utils/init_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import numpy as np
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 | def pos_embed_into_2d(pos_embed, extra_len):
10 | extra_tokens = pos_embed[:, :extra_len]
11 | pos_embed = pos_embed[:, extra_len:]
12 | num_patches = pos_embed.shape[-2]
13 | H = int(num_patches ** 0.5)
14 | W = num_patches // H
15 | assert H * W == num_patches
16 | pos_embed = pos_embed.reshape(
17 | pos_embed.shape[0], H, W, -1)
18 | return pos_embed, extra_tokens
19 |
20 | def pos_embed_inteplot(cur_pos_embed, pretrained_pos_embed, extra_len, cur_size=None):
21 | if cur_pos_embed is not None:
22 | cur_pos_embed, _ = pos_embed_into_2d(cur_pos_embed, extra_len)
23 | cur_size = cur_pos_embed.shape[-3:-1]
24 |
25 | ori_pretrained_pos_embed = pretrained_pos_embed
26 | pretrained_pos_embed, extra_tokens = pos_embed_into_2d(pretrained_pos_embed, extra_len)
27 | if pretrained_pos_embed.shape[-3:-1] == cur_size:
28 | return ori_pretrained_pos_embed
29 |
30 | pretrained_pos_embed = F.interpolate(
31 | rearrange(pretrained_pos_embed, 'b h w d -> b d h w'),
32 | size= cur_size,
33 | mode='bilinear', align_corners=False,
34 | )
35 | pretrained_pos_embed = rearrange(pretrained_pos_embed, 'b d h w -> b (h w) d')
36 | pretrained_pos_embed = torch.cat([extra_tokens, pretrained_pos_embed], dim=-2)
37 | return pretrained_pos_embed
38 |
39 | def _init_weight_norm_fc_conv(model):
40 | for name, module in model.named_modules():
41 | if isinstance_str(module, "Conv2d") and ("adapter" in name):
42 | module.__class__ = make_weight_norm_conv2d_nobias(module.__class__)
43 | module._init_scale()
44 |
45 | if isinstance_str(module, "Linear") and ("adapter" in name):
46 | module.__class__ = make_weight_norm_fc_nobias(module.__class__)
47 | module._init_scale()
48 |
49 | return model
50 |
51 | def make_weight_norm_conv2d_nobias(block_class):
52 | class WNConv2d(nn.Conv2d):
53 | def _init_scale(self):
54 | self.scale = nn.Parameter(self.weight.new_ones(*self.weight.shape[1:]))
55 |
56 | def forward(self, x):
57 |
58 | fan_in = self.weight[0].numel()
59 | weight = weight_normalize(self.weight) / np.sqrt(fan_in) * self.scale
60 |
61 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
62 |
63 | return WNConv2d
64 |
65 | def make_weight_norm_fc_nobias(block_class):
66 | class WNLinear(nn.Linear):
67 | def _init_scale(self):
68 | self.scale = nn.Parameter(self.weight.new_ones(*self.weight.shape[1:]))
69 |
70 | def forward(self, x):
71 |
72 | fan_in = self.weight[0].numel()
73 | weight = weight_normalize(self.weight) / np.sqrt(fan_in) * self.scale
74 |
75 | return F.linear(x, weight, self.bias)
76 |
77 | return WNLinear
78 |
79 | def weight_normalize(x, eps=1e-4):
80 | dim = list(range(1, x.ndim))
81 | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
82 | alpha = np.sqrt(n.numel() / x.numel())
83 | return x / torch.add(eps, n, alpha=alpha)
84 |
85 | def _init_weights_mamba(
86 | module,
87 | n_layer,
88 | initializer_range=0.02, # Now only used for embedding layer.
89 | rescale_prenorm_residual=True,
90 | n_residuals_per_layer=1, # Change to 2 if we have MLP
91 | ):
92 | if isinstance(module, nn.Linear):
93 | if module.bias is not None:
94 | if not getattr(module.bias, "_no_reinit", False):
95 | nn.init.zeros_(module.bias)
96 | elif isinstance(module, nn.Embedding):
97 | nn.init.normal_(module.weight, std=initializer_range)
98 |
99 | if rescale_prenorm_residual:
100 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
101 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
102 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
103 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/
104 | #
105 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
106 | for name, p in module.named_parameters():
107 | if name in ["out_proj.weight", "fc2.weight"]:
108 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
109 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
110 | # We need to reinit p since this code could be called multiple times
111 | # Having just p *= scale would repeatedly scale it down
112 | nn.init.kaiming_uniform_(p, a=math.sqrt(5))
113 | with torch.no_grad():
114 | p /= math.sqrt(n_residuals_per_layer * n_layer)
115 |
116 | def init_embedders(model):
117 |
118 | for name, module in model.named_modules():
119 | if 'class_embedder' in name.lower():
120 | if isinstance(module, nn.Embedding):
121 | nn.init.normal_(module.weight, std=0.02)
122 | # print('class_embedder', module.weight.shape)
123 | elif 'timestep_embedder' in name.lower():
124 | if isinstance(module, nn.Linear) :
125 | nn.init.normal_(module.weight, std=0.02)
126 |
127 | def init_adaLN_modulation_layers(model):
128 |
129 | for name, module in model.named_modules():
130 |
131 | if 'blocks' in name.lower() and ('norm' in name.lower() or "adaln_single" in name.lower()):
132 | if isinstance(module, nn.Linear) :
133 | nn.init.constant_(module.weight, 0)
134 | nn.init.constant_(module.bias, 0)
135 |
136 | def initialize_weights(model):
137 | # Initialize transformer layers:
138 | for name, module in model.named_modules():
139 | if 'mamba' in name.lower():
140 | continue
141 |
142 | if 'embed' in name.lower():
143 | continue
144 |
145 | if isinstance(module, nn.Linear):
146 | torch.nn.init.xavier_uniform_(module.weight)
147 | if module.bias is not None:
148 | nn.init.constant_(module.bias, 0)
149 |
150 | for name, module in model.named_modules():
151 | if 'pos_embed.proj' in name.lower() or "proj_in" in name.lower():
152 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
153 | w = module.weight.data
154 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
155 | nn.init.constant_(module.bias, 0)
156 |
157 | init_embedders(model)
158 | init_adaLN_modulation_layers(model)
159 |
160 | for name, module in model.named_modules():
161 | if 'out' in name.lower():
162 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
163 | nn.init.constant_(module.weight, 0)
164 | if module.bias is not None:
165 | nn.init.constant_(module.bias, 0)
166 |
167 | def isinstance_str(x: object, cls_name: str):
168 | """
169 | Checks whether x has any class *named* cls_name in its ancestry.
170 | Doesn't require access to the class's implementation.
171 |
172 | Useful for patching!
173 | """
174 |
175 | for _cls in x.__class__.__mro__:
176 | if _cls.__name__ == cls_name:
177 | return True
178 |
179 | return False
--------------------------------------------------------------------------------
/mamba_ssm/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.2.0.post1"
2 |
3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4 | from mamba_ssm.modules.mamba_simple import Mamba
5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
6 |
--------------------------------------------------------------------------------
/mamba_ssm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/models/__init__.py
--------------------------------------------------------------------------------
/mamba_ssm/models/config_mamba.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 |
4 | @dataclass
5 | class MambaConfig:
6 |
7 | d_model: int = 2560
8 | n_layer: int = 64
9 | vocab_size: int = 50277
10 | ssm_cfg: dict = field(default_factory=dict)
11 | rms_norm: bool = True
12 | residual_in_fp32: bool = True
13 | fused_add_norm: bool = True
14 | pad_vocab_size_multiple: int = 8
15 | tie_embeddings: bool = True
16 |
--------------------------------------------------------------------------------
/mamba_ssm/models/mixer_seq_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Albert Gu, Tri Dao.
2 |
3 | import math
4 | from functools import partial
5 | import json
6 | import os
7 |
8 | from collections import namedtuple
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from mamba_ssm.models.config_mamba import MambaConfig
14 | from mamba_ssm.modules.mamba_simple import Mamba, Block
15 | from mamba_ssm.utils.generation import GenerationMixin
16 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
17 |
18 | try:
19 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
20 | except ImportError:
21 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
22 |
23 |
24 | def create_block(
25 | d_model,
26 | ssm_cfg=None,
27 | norm_epsilon=1e-5,
28 | rms_norm=False,
29 | residual_in_fp32=False,
30 | fused_add_norm=False,
31 | layer_idx=None,
32 | device=None,
33 | dtype=None,
34 | ):
35 | if ssm_cfg is None:
36 | ssm_cfg = {}
37 | factory_kwargs = {"device": device, "dtype": dtype}
38 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
39 | norm_cls = partial(
40 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
41 | )
42 | block = Block(
43 | d_model,
44 | mixer_cls,
45 | norm_cls=norm_cls,
46 | fused_add_norm=fused_add_norm,
47 | residual_in_fp32=residual_in_fp32,
48 | )
49 | block.layer_idx = layer_idx
50 | return block
51 |
52 |
53 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
54 | def _init_weights(
55 | module,
56 | n_layer,
57 | initializer_range=0.02, # Now only used for embedding layer.
58 | rescale_prenorm_residual=True,
59 | n_residuals_per_layer=1, # Change to 2 if we have MLP
60 | ):
61 | if isinstance(module, nn.Linear):
62 | if module.bias is not None:
63 | if not getattr(module.bias, "_no_reinit", False):
64 | nn.init.zeros_(module.bias)
65 | elif isinstance(module, nn.Embedding):
66 | nn.init.normal_(module.weight, std=initializer_range)
67 |
68 | if rescale_prenorm_residual:
69 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
70 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
71 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
72 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/
73 | #
74 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
75 | for name, p in module.named_parameters():
76 | if name in ["out_proj.weight", "fc2.weight"]:
77 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
78 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
79 | # We need to reinit p since this code could be called multiple times
80 | # Having just p *= scale would repeatedly scale it down
81 | nn.init.kaiming_uniform_(p, a=math.sqrt(5))
82 | with torch.no_grad():
83 | p /= math.sqrt(n_residuals_per_layer * n_layer)
84 |
85 |
86 | class MixerModel(nn.Module):
87 | def __init__(
88 | self,
89 | d_model: int,
90 | n_layer: int,
91 | vocab_size: int,
92 | ssm_cfg=None,
93 | norm_epsilon: float = 1e-5,
94 | rms_norm: bool = False,
95 | initializer_cfg=None,
96 | fused_add_norm=False,
97 | residual_in_fp32=False,
98 | device=None,
99 | dtype=None,
100 | ) -> None:
101 | factory_kwargs = {"device": device, "dtype": dtype}
102 | super().__init__()
103 | self.residual_in_fp32 = residual_in_fp32
104 |
105 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
106 |
107 | # We change the order of residual and layer norm:
108 | # Instead of LN -> Attn / MLP -> Add, we do:
109 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
110 | # the main branch (output of MLP / Mixer). The model definition is unchanged.
111 | # This is for performance reason: we can fuse add + layer_norm.
112 | self.fused_add_norm = fused_add_norm
113 | if self.fused_add_norm:
114 | if layer_norm_fn is None or rms_norm_fn is None:
115 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
116 |
117 | self.layers = nn.ModuleList(
118 | [
119 | create_block(
120 | d_model,
121 | ssm_cfg=ssm_cfg,
122 | norm_epsilon=norm_epsilon,
123 | rms_norm=rms_norm,
124 | residual_in_fp32=residual_in_fp32,
125 | fused_add_norm=fused_add_norm,
126 | layer_idx=i,
127 | **factory_kwargs,
128 | )
129 | for i in range(n_layer)
130 | ]
131 | )
132 |
133 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
134 | d_model, eps=norm_epsilon, **factory_kwargs
135 | )
136 |
137 | self.apply(
138 | partial(
139 | _init_weights,
140 | n_layer=n_layer,
141 | **(initializer_cfg if initializer_cfg is not None else {}),
142 | )
143 | )
144 |
145 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
146 | return {
147 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
148 | for i, layer in enumerate(self.layers)
149 | }
150 |
151 | def forward(self, input_ids, inference_params=None):
152 | hidden_states = self.embedding(input_ids)
153 | residual = None
154 | for layer in self.layers:
155 | hidden_states, residual = layer(
156 | hidden_states, residual, inference_params=inference_params
157 | )
158 | if not self.fused_add_norm:
159 | residual = (hidden_states + residual) if residual is not None else hidden_states
160 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
161 | else:
162 | # Set prenorm=False here since we don't need the residual
163 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
164 | hidden_states = fused_add_norm_fn(
165 | hidden_states,
166 | self.norm_f.weight,
167 | self.norm_f.bias,
168 | eps=self.norm_f.eps,
169 | residual=residual,
170 | prenorm=False,
171 | residual_in_fp32=self.residual_in_fp32,
172 | )
173 | return hidden_states
174 |
175 |
176 | class MambaLMHeadModel(nn.Module, GenerationMixin):
177 |
178 | def __init__(
179 | self,
180 | config: MambaConfig,
181 | initializer_cfg=None,
182 | device=None,
183 | dtype=None,
184 | ) -> None:
185 | self.config = config
186 | d_model = config.d_model
187 | n_layer = config.n_layer
188 | vocab_size = config.vocab_size
189 | ssm_cfg = config.ssm_cfg
190 | rms_norm = config.rms_norm
191 | residual_in_fp32 = config.residual_in_fp32
192 | fused_add_norm = config.fused_add_norm
193 | pad_vocab_size_multiple = config.pad_vocab_size_multiple
194 | factory_kwargs = {"device": device, "dtype": dtype}
195 |
196 | super().__init__()
197 | if vocab_size % pad_vocab_size_multiple != 0:
198 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
199 | self.backbone = MixerModel(
200 | d_model=d_model,
201 | n_layer=n_layer,
202 | vocab_size=vocab_size,
203 | ssm_cfg=ssm_cfg,
204 | rms_norm=rms_norm,
205 | initializer_cfg=initializer_cfg,
206 | fused_add_norm=fused_add_norm,
207 | residual_in_fp32=residual_in_fp32,
208 | **factory_kwargs,
209 | )
210 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
211 |
212 | # Initialize weights and apply final processing
213 | self.apply(
214 | partial(
215 | _init_weights,
216 | n_layer=n_layer,
217 | **(initializer_cfg if initializer_cfg is not None else {}),
218 | )
219 | )
220 | self.tie_weights()
221 |
222 | def tie_weights(self):
223 | if self.config.tie_embeddings:
224 | self.lm_head.weight = self.backbone.embedding.weight
225 |
226 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
227 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
228 |
229 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
230 | """
231 | "position_ids" is just to be compatible with Transformer generation. We don't use it.
232 | num_last_tokens: if > 0, only return the logits for the last n tokens
233 | """
234 | hidden_states = self.backbone(input_ids, inference_params=inference_params)
235 | if num_last_tokens > 0:
236 | hidden_states = hidden_states[:, -num_last_tokens:]
237 | lm_logits = self.lm_head(hidden_states)
238 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
239 | return CausalLMOutput(logits=lm_logits)
240 |
241 | @classmethod
242 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
243 | config_data = load_config_hf(pretrained_model_name)
244 | config = MambaConfig(**config_data)
245 | model = cls(config, device=device, dtype=dtype, **kwargs)
246 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
247 | return model
248 |
249 | def save_pretrained(self, save_directory):
250 | """
251 | Minimal implementation of save_pretrained for MambaLMHeadModel.
252 | Save the model and its configuration file to a directory.
253 | """
254 | # Ensure save_directory exists
255 | os.makedirs(save_directory, exist_ok=True)
256 |
257 | # Save the model's state_dict
258 | model_path = os.path.join(save_directory, 'pytorch_model.bin')
259 | torch.save(self.state_dict(), model_path)
260 |
261 | # Save the configuration of the model
262 | config_path = os.path.join(save_directory, 'config.json')
263 | with open(config_path, 'w') as f:
264 | json.dump(self.config.__dict__, f)
265 |
--------------------------------------------------------------------------------
/mamba_ssm/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/modules/__init__.py
--------------------------------------------------------------------------------
/mamba_ssm/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/ops/__init__.py
--------------------------------------------------------------------------------
/mamba_ssm/ops/triton/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/ops/triton/__init__.py
--------------------------------------------------------------------------------
/mamba_ssm/ops/triton/selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | """We want triton==2.1.0 for this
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import triton
11 | import triton.language as tl
12 |
13 | from einops import rearrange, repeat
14 |
15 |
16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20 | @triton.jit
21 | def _selective_scan_update_kernel(
22 | # Pointers to matrices
23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24 | # Matrix dimensions
25 | batch, dim, dstate,
26 | # Strides
27 | stride_state_batch, stride_state_dim, stride_state_dstate,
28 | stride_x_batch, stride_x_dim,
29 | stride_dt_batch, stride_dt_dim,
30 | stride_dt_bias_dim,
31 | stride_A_dim, stride_A_dstate,
32 | stride_B_batch, stride_B_dstate,
33 | stride_C_batch, stride_C_dstate,
34 | stride_D_dim,
35 | stride_z_batch, stride_z_dim,
36 | stride_out_batch, stride_out_dim,
37 | # Meta-parameters
38 | DT_SOFTPLUS: tl.constexpr,
39 | BLOCK_SIZE_M: tl.constexpr,
40 | HAS_DT_BIAS: tl.constexpr,
41 | HAS_D: tl.constexpr,
42 | HAS_Z: tl.constexpr,
43 | BLOCK_SIZE_DSTATE: tl.constexpr,
44 | ):
45 | pid_m = tl.program_id(axis=0)
46 | pid_b = tl.program_id(axis=1)
47 | state_ptr += pid_b * stride_state_batch
48 | x_ptr += pid_b * stride_x_batch
49 | dt_ptr += pid_b * stride_dt_batch
50 | B_ptr += pid_b * stride_B_batch
51 | C_ptr += pid_b * stride_C_batch
52 | if HAS_Z:
53 | z_ptr += pid_b * stride_z_batch
54 | out_ptr += pid_b * stride_out_batch
55 |
56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
59 | x_ptrs = x_ptr + offs_m * stride_x_dim
60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim
61 | if HAS_DT_BIAS:
62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
64 | B_ptrs = B_ptr + offs_n * stride_B_dstate
65 | C_ptrs = C_ptr + offs_n * stride_C_dstate
66 | if HAS_D:
67 | D_ptrs = D_ptr + offs_m * stride_D_dim
68 | if HAS_Z:
69 | z_ptrs = z_ptr + offs_m * stride_z_dim
70 | out_ptrs = out_ptr + offs_m * stride_out_dim
71 |
72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
75 | if HAS_DT_BIAS:
76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
77 | if DT_SOFTPLUS:
78 | dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
80 | dA = tl.exp(A * dt[:, None])
81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
83 | if HAS_D:
84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85 | if HAS_Z:
86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87 |
88 | dB = B[None, :] * dt[:, None]
89 | state = state * dA + dB * x[:, None]
90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
91 | out = tl.sum(state * C[None, :], axis=1)
92 | if HAS_D:
93 | out += x * D
94 | if HAS_Z:
95 | out *= z * tl.sigmoid(z)
96 | tl.store(out_ptrs, out, mask=offs_m < dim)
97 |
98 |
99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
100 | """
101 | Argument:
102 | state: (batch, dim, dstate)
103 | x: (batch, dim)
104 | dt: (batch, dim)
105 | A: (dim, dstate)
106 | B: (batch, dstate)
107 | C: (batch, dstate)
108 | D: (dim,)
109 | z: (batch, dim)
110 | dt_bias: (dim,)
111 | Return:
112 | out: (batch, dim)
113 | """
114 | batch, dim, dstate = state.shape
115 | assert x.shape == (batch, dim)
116 | assert dt.shape == x.shape
117 | assert A.shape == (dim, dstate)
118 | assert B.shape == (batch, dstate)
119 | assert C.shape == B.shape
120 | if D is not None:
121 | assert D.shape == (dim,)
122 | if z is not None:
123 | assert z.shape == x.shape
124 | if dt_bias is not None:
125 | assert dt_bias.shape == (dim,)
126 | out = torch.empty_like(x)
127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
129 | # We don't want autotune since it will overwrite the state
130 | # We instead tune by hand.
131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
132 | else ((16, 4) if dstate <= 32 else
133 | ((8, 4) if dstate <= 64 else
134 | ((4, 4) if dstate <= 128 else
135 | ((4, 8))))))
136 | with torch.cuda.device(x.device.index):
137 | _selective_scan_update_kernel[grid](
138 | state, x, dt, dt_bias, A, B, C, D, z, out,
139 | batch, dim, dstate,
140 | state.stride(0), state.stride(1), state.stride(2),
141 | x.stride(0), x.stride(1),
142 | dt.stride(0), dt.stride(1),
143 | dt_bias.stride(0) if dt_bias is not None else 0,
144 | A.stride(0), A.stride(1),
145 | B.stride(0), B.stride(1),
146 | C.stride(0), C.stride(1),
147 | D.stride(0) if D is not None else 0,
148 | z_strides[0], z_strides[1],
149 | out.stride(0), out.stride(1),
150 | dt_softplus,
151 | BLOCK_SIZE_M,
152 | num_warps=num_warps,
153 | )
154 | return out
155 |
156 |
157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
158 | """
159 | Argument:
160 | state: (batch, dim, dstate)
161 | x: (batch, dim)
162 | dt: (batch, dim)
163 | A: (dim, dstate)
164 | B: (batch, dstate)
165 | C: (batch, dstate)
166 | D: (dim,)
167 | z: (batch, dim)
168 | dt_bias: (dim,)
169 | Return:
170 | out: (batch, dim)
171 | """
172 | batch, dim, dstate = state.shape
173 | assert x.shape == (batch, dim)
174 | assert dt.shape == x.shape
175 | assert A.shape == (dim, dstate)
176 | assert B.shape == (batch, dstate)
177 | assert C.shape == B.shape
178 | if D is not None:
179 | assert D.shape == (dim,)
180 | if z is not None:
181 | assert z.shape == x.shape
182 | if dt_bias is not None:
183 | assert dt_bias.shape == (dim,)
184 | dt = dt + dt_bias
185 | dt = F.softplus(dt) if dt_softplus else dt
186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
190 | if D is not None:
191 | out += (x * D).to(out.dtype)
192 | return (out if z is None else out * F.silu(z)).to(x.dtype)
193 |
--------------------------------------------------------------------------------
/mamba_ssm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/mamba_ssm/utils/__init__.py
--------------------------------------------------------------------------------
/mamba_ssm/utils/hf.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 |
5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6 | from transformers.utils.hub import cached_file
7 |
8 |
9 | def load_config_hf(model_name):
10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11 | return json.load(open(resolved_archive_file))
12 |
13 |
14 | def load_state_dict_hf(model_name, device=None, dtype=None):
15 | # If not fp32, then we don't want to load directly to the GPU
16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18 | return torch.load(resolved_archive_file, map_location=mapped_device)
19 | # Convert dtype before moving to GPU to save memory
20 | if dtype is not None:
21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23 | return state_dict
24 |
--------------------------------------------------------------------------------
/sample_t2i_discrete.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 | import torch
3 | from torch import multiprocessing as mp
4 | import accelerate
5 | import utils
6 | from datasets import get_dataset
7 | from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
8 | from absl import logging
9 | import builtins
10 | import einops
11 | import libs.autoencoder
12 | import libs.clip
13 | from torchvision.utils import save_image
14 |
15 |
16 | def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
17 | _betas = (
18 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
19 | )
20 | return _betas.numpy()
21 |
22 |
23 | def evaluate(config):
24 | if config.get('benchmark', False):
25 | torch.backends.cudnn.benchmark = True
26 | torch.backends.cudnn.deterministic = False
27 |
28 | mp.set_start_method('spawn')
29 | accelerator = accelerate.Accelerator()
30 | device = accelerator.device
31 | accelerate.utils.set_seed(config.seed, device_specific=True)
32 | logging.info(f'Process {accelerator.process_index} using device: {device}')
33 |
34 | config.mixed_precision = accelerator.mixed_precision
35 | config = ml_collections.FrozenConfigDict(config)
36 | if accelerator.is_main_process:
37 | utils.set_logger(log_level='info')
38 | else:
39 | utils.set_logger(log_level='error')
40 | builtins.print = lambda *args: None
41 |
42 | dataset = get_dataset(**config.dataset)
43 |
44 | with open(config.input_path, 'r') as f:
45 | prompts = f.read().strip().split('\n')
46 |
47 | print(prompts)
48 |
49 | clip = libs.clip.FrozenCLIPEmbedder()
50 | clip.eval()
51 | clip.to(device)
52 |
53 | contexts = clip.encode(prompts)
54 |
55 | nnet = utils.get_nnet(**config.nnet)
56 | nnet = accelerator.prepare(nnet)
57 | logging.info(f'load nnet from {config.nnet_path}')
58 | accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
59 | nnet.eval()
60 |
61 | def cfg_nnet(x, timesteps, context):
62 | _cond = nnet(x, timesteps, context=context)
63 | if config.sample.scale == 0:
64 | return _cond
65 | _empty_context = torch.tensor(dataset.empty_context, device=device)
66 | _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
67 | _uncond = nnet(x, timesteps, context=_empty_context)
68 | return _cond + config.sample.scale * (_cond - _uncond)
69 |
70 | autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
71 | autoencoder.to(device)
72 |
73 | @torch.cuda.amp.autocast()
74 | def encode(_batch):
75 | return autoencoder.encode(_batch)
76 |
77 | @torch.cuda.amp.autocast()
78 | def decode(_batch):
79 | return autoencoder.decode(_batch)
80 |
81 | _betas = stable_diffusion_beta_schedule()
82 | N = len(_betas)
83 |
84 | logging.info(config.sample)
85 | logging.info(f'mixed_precision={config.mixed_precision}')
86 | logging.info(f'N={N}')
87 |
88 | z_init = torch.randn(contexts.size(0), *config.z_shape, device=device)
89 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
90 |
91 | def model_fn(x, t_continuous):
92 | t = t_continuous * N
93 | return cfg_nnet(x, t, context=contexts)
94 |
95 | dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
96 | z = dpm_solver.sample(z_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
97 | samples = dataset.unpreprocess(decode(z))
98 |
99 | os.makedirs(config.output_path, exist_ok=True)
100 | for sample, prompt in zip(samples, prompts):
101 | save_image(sample, os.path.join(config.output_path, f"{prompt}.png"))
102 |
103 |
104 |
105 | from absl import flags
106 | from absl import app
107 | from ml_collections import config_flags
108 | import os
109 |
110 |
111 | FLAGS = flags.FLAGS
112 | config_flags.DEFINE_config_file(
113 | "config", None, "Training configuration.", lock_config=False)
114 | flags.mark_flags_as_required(["config"])
115 | flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
116 | flags.DEFINE_string("output_path", None, "The path to output images.")
117 | flags.DEFINE_string("input_path", None, "The path to input texts.")
118 |
119 |
120 | def main(argv):
121 | config = FLAGS.config
122 | config.nnet_path = FLAGS.nnet_path
123 | config.output_path = FLAGS.output_path
124 | config.input_path = FLAGS.input_path
125 | evaluate(config)
126 |
127 |
128 | if __name__ == "__main__":
129 | app.run(main)
130 |
--------------------------------------------------------------------------------
/scripts/extract_empty_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main():
12 | prompts = [
13 | '',
14 | ]
15 |
16 | device = 'cuda'
17 | clip = libs.clip.FrozenCLIPEmbedder()
18 | clip.eval()
19 | clip.to(device)
20 |
21 | save_dir = f'assets/datasets/coco256_features'
22 | latent = clip.encode(prompts)
23 | print(latent.shape)
24 | c = latent[0].detach().cpu().numpy()
25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c)
26 |
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/scripts/extract_imagenet_feature.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | from uvit_datasets import ImageNet
5 | from torch.utils.data import DataLoader
6 | from libs.autoencoder import get_model
7 | import argparse
8 | from tqdm import tqdm
9 | torch.manual_seed(0)
10 | np.random.seed(0)
11 |
12 |
13 | def main(resolution=256):
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('path')
16 | args = parser.parse_args()
17 |
18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False)
19 | train_dataset = dataset.get_split(split='train', labeled=True)
20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False,
21 | num_workers=8, pin_memory=True, persistent_workers=True)
22 |
23 | model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
24 | model = nn.DataParallel(model)
25 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26 | model.to(device)
27 |
28 | # features = []
29 | # labels = []
30 |
31 | idx = 0
32 | for batch in tqdm(train_dataset_loader):
33 | img, label = batch
34 | img = torch.cat([img, img.flip(dims=[-1])], dim=0)
35 | img = img.to(device)
36 | moments = model(img, fn='encode_moments')
37 | moments = moments.detach().cpu().numpy()
38 |
39 | label = torch.cat([label, label], dim=0)
40 | label = label.detach().cpu().numpy()
41 |
42 | for moment, lb in zip(moments, label):
43 | # np.save(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', (moment, lb))
44 | np.savez(f'assets/datasets/imagenet{resolution}_features/{idx}.npy', z=moment, label=lb)
45 | idx += 1
46 |
47 | print(f'save {idx} files')
48 |
49 | # features = np.concatenate(features, axis=0)
50 | # labels = np.concatenate(labels, axis=0)
51 | # print(f'features.shape={features.shape}')
52 | # print(f'labels.shape={labels.shape}')
53 | # np.save(f'imagenet{resolution}_features.npy', features)
54 | # np.save(f'imagenet{resolution}_labels.npy', labels)
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/scripts/extract_mscoco_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main(resolution=256):
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--split', default='train')
14 | args = parser.parse_args()
15 | print(args)
16 |
17 |
18 | if args.split == "train":
19 | datas = MSCOCODatabase(root='assets/datasets/coco/train2014',
20 | annFile='assets/datasets/coco/annotations/captions_train2014.json',
21 | size=resolution)
22 | save_dir = f'assets/datasets/coco{resolution}_features/train'
23 | elif args.split == "val":
24 | datas = MSCOCODatabase(root='assets/datasets/coco/val2014',
25 | annFile='assets/datasets/coco/annotations/captions_val2014.json',
26 | size=resolution)
27 | save_dir = f'assets/datasets/coco{resolution}_features/val'
28 | else:
29 | raise NotImplementedError("ERROR!")
30 |
31 | device = "cuda"
32 | os.makedirs(save_dir)
33 |
34 | autoencoder = libs.autoencoder.get_model('assets/stable-diffusion/autoencoder_kl.pth')
35 | autoencoder.to(device)
36 | clip = libs.clip.FrozenCLIPEmbedder()
37 | clip.eval()
38 | clip.to(device)
39 |
40 | with torch.no_grad():
41 | for idx, data in tqdm(enumerate(datas)):
42 | x, captions = data
43 |
44 | if len(x.shape) == 3:
45 | x = x[None, ...]
46 | x = torch.tensor(x, device=device)
47 | moments = autoencoder(x, fn='encode_moments').squeeze(0)
48 | moments = moments.detach().cpu().numpy()
49 | np.save(os.path.join(save_dir, f'{idx}.npy'), moments)
50 |
51 | latent = clip.encode(captions)
52 | for i in range(len(latent)):
53 | c = latent[i].detach().cpu().numpy()
54 | np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/scripts/extract_test_prompt_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import libs.autoencoder
5 | import libs.clip
6 | from datasets import MSCOCODatabase
7 | import argparse
8 | from tqdm import tqdm
9 |
10 |
11 | def main():
12 | prompts = [
13 | 'A green train is coming down the tracks.',
14 | 'A group of skiers are preparing to ski down a mountain.',
15 | 'A small kitchen with a low ceiling.',
16 | 'A group of elephants walking in muddy water.',
17 | 'A living area with a television and a table.',
18 | 'A road with traffic lights, street lights and cars.',
19 | 'A bus driving in a city area with traffic signs.',
20 | 'A bus pulls over to the curb close to an intersection.',
21 | 'A group of people are walking and one is holding an umbrella.',
22 | 'A baseball player taking a swing at an incoming ball.',
23 | 'A city street line with brick buildings and trees.',
24 | 'A close up of a plate of broccoli and sauce.',
25 | ]
26 |
27 | device = 'cuda'
28 | clip = libs.clip.FrozenCLIPEmbedder()
29 | clip.eval()
30 | clip.to(device)
31 |
32 | save_dir = f'assets/datasets/coco256_features/run_vis'
33 | latent = clip.encode(prompts)
34 | for i in range(len(latent)):
35 | c = latent[i].detach().cpu().numpy()
36 | np.save(os.path.join(save_dir, f'{i}.npy'), (prompts[i], c))
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Albert Gu, Tri Dao.
2 | import sys
3 | import warnings
4 | import os
5 | import re
6 | import ast
7 | from pathlib import Path
8 | from packaging.version import parse, Version
9 | import platform
10 | import shutil
11 |
12 | from setuptools import setup, find_packages
13 | import subprocess
14 |
15 | import urllib.request
16 | import urllib.error
17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
18 |
19 | import torch
20 | from torch.utils.cpp_extension import (
21 | BuildExtension,
22 | CppExtension,
23 | CUDAExtension,
24 | CUDA_HOME,
25 | )
26 |
27 |
28 | with open("README.md", "r", encoding="utf-8") as fh:
29 | long_description = fh.read()
30 |
31 |
32 | # ninja build does not work unless include_dirs are abs path
33 | this_dir = os.path.dirname(os.path.abspath(__file__))
34 |
35 | PACKAGE_NAME = "mamba_ssm"
36 |
37 | BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"
38 |
39 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
40 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
41 | FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
42 | SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
43 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
44 | FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
45 |
46 |
47 | def get_platform():
48 | """
49 | Returns the platform name as used in wheel filenames.
50 | """
51 | if sys.platform.startswith("linux"):
52 | return "linux_x86_64"
53 | elif sys.platform == "darwin":
54 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
55 | return f"macosx_{mac_version}_x86_64"
56 | elif sys.platform == "win32":
57 | return "win_amd64"
58 | else:
59 | raise ValueError("Unsupported platform: {}".format(sys.platform))
60 |
61 |
62 | def get_cuda_bare_metal_version(cuda_dir):
63 | raw_output = subprocess.check_output(
64 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
65 | )
66 | output = raw_output.split()
67 | release_idx = output.index("release") + 1
68 | bare_metal_version = parse(output[release_idx].split(",")[0])
69 |
70 | return raw_output, bare_metal_version
71 |
72 |
73 | def check_if_cuda_home_none(global_option: str) -> None:
74 | if CUDA_HOME is not None:
75 | return
76 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
77 | # in that case.
78 | warnings.warn(
79 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
80 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
81 | "only images whose names contain 'devel' will provide nvcc."
82 | )
83 |
84 |
85 | def append_nvcc_threads(nvcc_extra_args):
86 | return nvcc_extra_args + ["--threads", "4"]
87 |
88 |
89 | cmdclass = {}
90 | ext_modules = []
91 |
92 | if not SKIP_CUDA_BUILD:
93 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
94 | TORCH_MAJOR = int(torch.__version__.split(".")[0])
95 | TORCH_MINOR = int(torch.__version__.split(".")[1])
96 |
97 | check_if_cuda_home_none(PACKAGE_NAME)
98 | # Check, if CUDA11 is installed for compute capability 8.0
99 | cc_flag = []
100 | if CUDA_HOME is not None:
101 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
102 | if bare_metal_version < Version("11.6"):
103 | raise RuntimeError(
104 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
105 | "Note: make sure nvcc has a supported version by running nvcc -V."
106 | )
107 |
108 | cc_flag.append("-gencode")
109 | cc_flag.append("arch=compute_53,code=sm_53")
110 | cc_flag.append("-gencode")
111 | cc_flag.append("arch=compute_62,code=sm_62")
112 | cc_flag.append("-gencode")
113 | cc_flag.append("arch=compute_70,code=sm_70")
114 | cc_flag.append("-gencode")
115 | cc_flag.append("arch=compute_72,code=sm_72")
116 | cc_flag.append("-gencode")
117 | cc_flag.append("arch=compute_80,code=sm_80")
118 | cc_flag.append("-gencode")
119 | cc_flag.append("arch=compute_87,code=sm_87")
120 | if bare_metal_version >= Version("11.8"):
121 | cc_flag.append("-gencode")
122 | cc_flag.append("arch=compute_90,code=sm_90")
123 |
124 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
125 | # torch._C._GLIBCXX_USE_CXX11_ABI
126 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
127 | if FORCE_CXX11_ABI:
128 | torch._C._GLIBCXX_USE_CXX11_ABI = True
129 |
130 | ext_modules.append(
131 | CUDAExtension(
132 | name="selective_scan_cuda",
133 | sources=[
134 | "csrc/selective_scan/selective_scan.cpp",
135 | "csrc/selective_scan/selective_scan_fwd_fp32.cu",
136 | "csrc/selective_scan/selective_scan_fwd_fp16.cu",
137 | "csrc/selective_scan/selective_scan_fwd_bf16.cu",
138 | "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
139 | "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
140 | "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
141 | "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
142 | "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
143 | "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
144 | ],
145 | extra_compile_args={
146 | "cxx": ["-O3", "-std=c++17"],
147 | "nvcc": append_nvcc_threads(
148 | [
149 | "-O3",
150 | "-std=c++17",
151 | "-U__CUDA_NO_HALF_OPERATORS__",
152 | "-U__CUDA_NO_HALF_CONVERSIONS__",
153 | "-U__CUDA_NO_BFLOAT16_OPERATORS__",
154 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
155 | "-U__CUDA_NO_BFLOAT162_OPERATORS__",
156 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
157 | "--expt-relaxed-constexpr",
158 | "--expt-extended-lambda",
159 | "--use_fast_math",
160 | "--ptxas-options=-v",
161 | "-lineinfo",
162 | ]
163 | + cc_flag
164 | ),
165 | },
166 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
167 | )
168 | )
169 |
170 |
171 | def get_package_version():
172 | with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
173 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
174 | public_version = ast.literal_eval(version_match.group(1))
175 | local_version = os.environ.get("MAMBA_LOCAL_VERSION")
176 | if local_version:
177 | return f"{public_version}+{local_version}"
178 | else:
179 | return str(public_version)
180 |
181 |
182 | def get_wheel_url():
183 | # Determine the version numbers that will be used to determine the correct wheel
184 | # We're using the CUDA version used to build torch, not the one currently installed
185 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
186 | torch_cuda_version = parse(torch.version.cuda)
187 | torch_version_raw = parse(torch.__version__)
188 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
189 | # to save CI time. Minor versions should be compatible.
190 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
191 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
192 | platform_name = get_platform()
193 | mamba_ssm_version = get_package_version()
194 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
195 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
196 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
197 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
198 |
199 | # Determine wheel URL based on CUDA version, torch version, python version and OS
200 | wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
201 | wheel_url = BASE_WHEEL_URL.format(
202 | tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
203 | )
204 | return wheel_url, wheel_filename
205 |
206 |
207 | class CachedWheelsCommand(_bdist_wheel):
208 | """
209 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
210 | find an existing wheel (which is currently the case for all installs). We use
211 | the environment parameters to detect whether there is already a pre-built version of a compatible
212 | wheel available and short-circuits the standard full build pipeline.
213 | """
214 |
215 | def run(self):
216 | if FORCE_BUILD:
217 | return super().run()
218 |
219 | wheel_url, wheel_filename = get_wheel_url()
220 | print("Guessing wheel URL: ", wheel_url)
221 | try:
222 | urllib.request.urlretrieve(wheel_url, wheel_filename)
223 |
224 | # Make the archive
225 | # Lifted from the root wheel processing command
226 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
227 | if not os.path.exists(self.dist_dir):
228 | os.makedirs(self.dist_dir)
229 |
230 | impl_tag, abi_tag, plat_tag = self.get_tag()
231 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
232 |
233 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
234 | print("Raw wheel path", wheel_path)
235 | shutil.move(wheel_filename, wheel_path)
236 | except urllib.error.HTTPError:
237 | print("Precompiled wheel not found. Building from source...")
238 | # If the wheel could not be downloaded, build from source
239 | super().run()
240 |
241 |
242 | setup(
243 | name=PACKAGE_NAME,
244 | version=get_package_version(),
245 | packages=find_packages(
246 | exclude=(
247 | "build",
248 | "csrc",
249 | "include",
250 | "tests",
251 | "dist",
252 | "docs",
253 | "benchmarks",
254 | "mamba_ssm.egg-info",
255 | )
256 | ),
257 | author="Tri Dao, Albert Gu",
258 | author_email="tri@tridao.me, agu@cs.cmu.edu",
259 | description="Mamba state-space model",
260 | long_description=long_description,
261 | long_description_content_type="text/markdown",
262 | url="https://github.com/state-spaces/mamba",
263 | classifiers=[
264 | "Programming Language :: Python :: 3",
265 | "License :: OSI Approved :: BSD License",
266 | "Operating System :: Unix",
267 | ],
268 | ext_modules=ext_modules,
269 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
270 | if ext_modules
271 | else {
272 | "bdist_wheel": CachedWheelsCommand,
273 | },
274 | python_requires=">=3.7",
275 | install_requires=[
276 | "torch",
277 | "packaging",
278 | "ninja",
279 | "einops",
280 | "triton",
281 | "transformers",
282 | # "causal_conv1d>=1.2.0",
283 | ],
284 | )
285 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tyshiwo1/DiM-DiffusionMamba/292b2285de2f979cf5ac84ab90cbee20cdf1f2f3/tools/__init__.py
--------------------------------------------------------------------------------
/tools/fid_score.py:
--------------------------------------------------------------------------------
1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2 |
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 |
7 | When run as a stand-alone program, it compares the distribution of
8 | images that are stored as PNG/JPEG at a specified location with a
9 | distribution given by summary statistics (in pickle format).
10 |
11 | The FID is calculated by assuming that X_1 and X_2 are the activations of
12 | the pool_3 layer of the inception net for generated samples and real world
13 | samples respectively.
14 |
15 | See --help to see further details.
16 |
17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18 | of Tensorflow
19 |
20 | Copyright 2018 Institute of Bioinformatics, JKU Linz
21 |
22 | Licensed under the Apache License, Version 2.0 (the "License");
23 | you may not use this file except in compliance with the License.
24 | You may obtain a copy of the License at
25 |
26 | http://www.apache.org/licenses/LICENSE-2.0
27 |
28 | Unless required by applicable law or agreed to in writing, software
29 | distributed under the License is distributed on an "AS IS" BASIS,
30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 | See the License for the specific language governing permissions and
32 | limitations under the License.
33 | """
34 | import os
35 | import pathlib
36 |
37 | import numpy as np
38 | import torch
39 | import torchvision.transforms as TF
40 | from PIL import Image
41 | from scipy import linalg
42 | from torch.nn.functional import adaptive_avg_pool2d
43 |
44 | try:
45 | from tqdm import tqdm
46 | except ImportError:
47 | # If tqdm is not available, provide a mock version of it
48 | def tqdm(x):
49 | return x
50 |
51 | from .inception import InceptionV3
52 |
53 |
54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
55 | 'tif', 'tiff', 'webp'}
56 |
57 |
58 | class ImagePathDataset(torch.utils.data.Dataset):
59 | def __init__(self, files, transforms=None):
60 | self.files = files
61 | self.transforms = transforms
62 |
63 | def __len__(self):
64 | return len(self.files)
65 |
66 | def __getitem__(self, i):
67 | path = self.files[i]
68 | img = Image.open(path).convert('RGB')
69 | if self.transforms is not None:
70 | img = self.transforms(img)
71 | return img
72 |
73 |
74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
75 | """Calculates the activations of the pool_3 layer for all images.
76 |
77 | Params:
78 | -- files : List of image files paths
79 | -- model : Instance of inception model
80 | -- batch_size : Batch size of images for the model to process at once.
81 | Make sure that the number of samples is a multiple of
82 | the batch size, otherwise some samples are ignored. This
83 | behavior is retained to match the original FID score
84 | implementation.
85 | -- dims : Dimensionality of features returned by Inception
86 | -- device : Device to run calculations
87 | -- num_workers : Number of parallel dataloader workers
88 |
89 | Returns:
90 | -- A numpy array of dimension (num images, dims) that contains the
91 | activations of the given tensor when feeding inception with the
92 | query tensor.
93 | """
94 | model.eval()
95 |
96 | if batch_size > len(files):
97 | print(('Warning: batch size is bigger than the data size. '
98 | 'Setting batch size to data size'))
99 | batch_size = len(files)
100 |
101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor())
102 | dataloader = torch.utils.data.DataLoader(dataset,
103 | batch_size=batch_size,
104 | shuffle=False,
105 | drop_last=False,
106 | num_workers=num_workers)
107 |
108 | pred_arr = np.empty((len(files), dims))
109 |
110 | start_idx = 0
111 |
112 | for batch in tqdm(dataloader):
113 | batch = batch.to(device)
114 |
115 | with torch.no_grad():
116 | pred = model(batch)[0]
117 |
118 | # If model output is not scalar, apply global spatial average pooling.
119 | # This happens if you choose a dimensionality not equal 2048.
120 | if pred.size(2) != 1 or pred.size(3) != 1:
121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
122 |
123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
124 |
125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
126 |
127 | start_idx = start_idx + pred.shape[0]
128 |
129 | return pred_arr
130 |
131 |
132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
133 | """Numpy implementation of the Frechet Distance.
134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
135 | and X_2 ~ N(mu_2, C_2) is
136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
137 |
138 | Stable version by Dougal J. Sutherland.
139 |
140 | Params:
141 | -- mu1 : Numpy array containing the activations of a layer of the
142 | inception net (like returned by the function 'get_predictions')
143 | for generated samples.
144 | -- mu2 : The sample mean over activations, precalculated on an
145 | representative data set.
146 | -- sigma1: The covariance matrix over activations for generated samples.
147 | -- sigma2: The covariance matrix over activations, precalculated on an
148 | representative data set.
149 |
150 | Returns:
151 | -- : The Frechet Distance.
152 | """
153 |
154 | mu1 = np.atleast_1d(mu1)
155 | mu2 = np.atleast_1d(mu2)
156 |
157 | sigma1 = np.atleast_2d(sigma1)
158 | sigma2 = np.atleast_2d(sigma2)
159 |
160 | assert mu1.shape == mu2.shape, \
161 | 'Training and test mean vectors have different lengths'
162 | assert sigma1.shape == sigma2.shape, \
163 | 'Training and test covariances have different dimensions'
164 |
165 | diff = mu1 - mu2
166 |
167 | # Product might be almost singular
168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
169 | if not np.isfinite(covmean).all():
170 | msg = ('fid calculation produces singular product; '
171 | 'adding %s to diagonal of cov estimates') % eps
172 | print(msg)
173 | offset = np.eye(sigma1.shape[0]) * eps
174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
175 |
176 | # Numerical error might give slight imaginary component
177 | if np.iscomplexobj(covmean):
178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
179 | m = np.max(np.abs(covmean.imag))
180 | raise ValueError('Imaginary component {}'.format(m))
181 | covmean = covmean.real
182 |
183 | tr_covmean = np.trace(covmean)
184 |
185 | return (diff.dot(diff) + np.trace(sigma1)
186 | + np.trace(sigma2) - 2 * tr_covmean)
187 |
188 |
189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
190 | device='cpu', num_workers=8):
191 | """Calculation of the statistics used by the FID.
192 | Params:
193 | -- files : List of image files paths
194 | -- model : Instance of inception model
195 | -- batch_size : The images numpy array is split into batches with
196 | batch size batch_size. A reasonable batch size
197 | depends on the hardware.
198 | -- dims : Dimensionality of features returned by Inception
199 | -- device : Device to run calculations
200 | -- num_workers : Number of parallel dataloader workers
201 |
202 | Returns:
203 | -- mu : The mean over samples of the activations of the pool_3 layer of
204 | the inception model.
205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
206 | the inception model.
207 | """
208 | act = get_activations(files, model, batch_size, dims, device, num_workers)
209 | mu = np.mean(act, axis=0)
210 | sigma = np.cov(act, rowvar=False)
211 | return mu, sigma
212 |
213 |
214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
215 | if path.endswith('.npz'):
216 | with np.load(path) as f:
217 | m, s = f['mu'][:], f['sigma'][:]
218 | else:
219 | path = pathlib.Path(path)
220 | files = sorted([file for ext in IMAGE_EXTENSIONS
221 | for file in path.glob('*.{}'.format(ext))])
222 | m, s = calculate_activation_statistics(files, model, batch_size,
223 | dims, device, num_workers)
224 |
225 | return m, s
226 |
227 |
228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
229 | if device is None:
230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
231 | else:
232 | device = torch.device(device)
233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
234 | model = InceptionV3([block_idx]).to(device)
235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
236 | np.savez(out_path, mu=m1, sigma=s1)
237 |
238 |
239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
240 | """Calculates the FID of two paths"""
241 | if device is None:
242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
243 | else:
244 | device = torch.device(device)
245 |
246 | for p in paths:
247 | if not os.path.exists(p):
248 | raise RuntimeError('Invalid path: %s' % p)
249 |
250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
251 |
252 | model = InceptionV3([block_idx]).to(device)
253 |
254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
255 | dims, device, num_workers)
256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
257 | dims, device, num_workers)
258 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
259 |
260 | return fid_value
261 |
--------------------------------------------------------------------------------
/tools/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=(DEFAULT_BLOCK_INDEX,),
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 |
39 | Parameters
40 | ----------
41 | output_blocks : list of int
42 | Indices of blocks to return features of. Possible values are:
43 | - 0: corresponds to output of first max pooling
44 | - 1: corresponds to output of second max pooling
45 | - 2: corresponds to output which is fed to aux classifier
46 | - 3: corresponds to output of final average pooling
47 | resize_input : bool
48 | If true, bilinearly resizes input to width and height 299 before
49 | feeding input to model. As the network without fully connected
50 | layers is fully convolutional, it should be able to handle inputs
51 | of arbitrary size, so resizing might not be strictly needed
52 | normalize_input : bool
53 | If true, scales the input from range (0, 1) to the range the
54 | pretrained Inception network expects, namely (-1, 1)
55 | requires_grad : bool
56 | If true, parameters of the model require gradients. Possibly useful
57 | for finetuning the network
58 | use_fid_inception : bool
59 | If true, uses the pretrained Inception model used in Tensorflow's
60 | FID implementation. If false, uses the pretrained Inception model
61 | available in torchvision. The FID Inception model has different
62 | weights and a slightly different structure from torchvision's
63 | Inception model. If you want to compute FID scores, you are
64 | strongly advised to set this parameter to true to get comparable
65 | results.
66 | """
67 | super(InceptionV3, self).__init__()
68 |
69 | self.resize_input = resize_input
70 | self.normalize_input = normalize_input
71 | self.output_blocks = sorted(output_blocks)
72 | self.last_needed_block = max(output_blocks)
73 |
74 | assert self.last_needed_block <= 3, \
75 | 'Last possible output block index is 3'
76 |
77 | self.blocks = nn.ModuleList()
78 |
79 | if use_fid_inception:
80 | inception = fid_inception_v3()
81 | else:
82 | inception = _inception_v3(pretrained=True)
83 |
84 | # Block 0: input to maxpool1
85 | block0 = [
86 | inception.Conv2d_1a_3x3,
87 | inception.Conv2d_2a_3x3,
88 | inception.Conv2d_2b_3x3,
89 | nn.MaxPool2d(kernel_size=3, stride=2)
90 | ]
91 | self.blocks.append(nn.Sequential(*block0))
92 |
93 | # Block 1: maxpool1 to maxpool2
94 | if self.last_needed_block >= 1:
95 | block1 = [
96 | inception.Conv2d_3b_1x1,
97 | inception.Conv2d_4a_3x3,
98 | nn.MaxPool2d(kernel_size=3, stride=2)
99 | ]
100 | self.blocks.append(nn.Sequential(*block1))
101 |
102 | # Block 2: maxpool2 to aux classifier
103 | if self.last_needed_block >= 2:
104 | block2 = [
105 | inception.Mixed_5b,
106 | inception.Mixed_5c,
107 | inception.Mixed_5d,
108 | inception.Mixed_6a,
109 | inception.Mixed_6b,
110 | inception.Mixed_6c,
111 | inception.Mixed_6d,
112 | inception.Mixed_6e,
113 | ]
114 | self.blocks.append(nn.Sequential(*block2))
115 |
116 | # Block 3: aux classifier to final avgpool
117 | if self.last_needed_block >= 3:
118 | block3 = [
119 | inception.Mixed_7a,
120 | inception.Mixed_7b,
121 | inception.Mixed_7c,
122 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
123 | ]
124 | self.blocks.append(nn.Sequential(*block3))
125 |
126 | for param in self.parameters():
127 | param.requires_grad = requires_grad
128 |
129 | def forward(self, inp):
130 | """Get Inception feature maps
131 |
132 | Parameters
133 | ----------
134 | inp : torch.autograd.Variable
135 | Input tensor of shape Bx3xHxW. Values are expected to be in
136 | range (0, 1)
137 |
138 | Returns
139 | -------
140 | List of torch.autograd.Variable, corresponding to the selected output
141 | block, sorted ascending by index
142 | """
143 | outp = []
144 | x = inp
145 |
146 | if self.resize_input:
147 | x = F.interpolate(x,
148 | size=(299, 299),
149 | mode='bilinear',
150 | align_corners=False)
151 |
152 | if self.normalize_input:
153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154 |
155 | for idx, block in enumerate(self.blocks):
156 | x = block(x)
157 | if idx in self.output_blocks:
158 | outp.append(x)
159 |
160 | if idx == self.last_needed_block:
161 | break
162 |
163 | return outp
164 |
165 |
166 | def _inception_v3(*args, **kwargs):
167 | """Wraps `torchvision.models.inception_v3`
168 |
169 | Skips default weight inititialization if supported by torchvision version.
170 | See https://github.com/mseitzer/pytorch-fid/issues/28.
171 | """
172 | try:
173 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174 | except ValueError:
175 | # Just a caution against weird version strings
176 | version = (0,)
177 |
178 | if version >= (0, 6):
179 | kwargs['init_weights'] = False
180 |
181 | return torchvision.models.inception_v3(*args, **kwargs)
182 |
183 |
184 | def fid_inception_v3():
185 | """Build pretrained Inception model for FID computation
186 |
187 | The Inception model for FID computation uses a different set of weights
188 | and has a slightly different structure than torchvision's Inception.
189 |
190 | This method first constructs torchvision's Inception and then patches the
191 | necessary parts that are different in the FID Inception model.
192 | """
193 | inception = _inception_v3(num_classes=1008,
194 | aux_logits=False,
195 | pretrained=False)
196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203 | inception.Mixed_7b = FIDInceptionE_1(1280)
204 | inception.Mixed_7c = FIDInceptionE_2(2048)
205 |
206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
207 | inception.load_state_dict(state_dict)
208 | return inception
209 |
210 |
211 | class FIDInceptionA(torchvision.models.inception.InceptionA):
212 | """InceptionA block patched for FID computation"""
213 | def __init__(self, in_channels, pool_features):
214 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
215 |
216 | def forward(self, x):
217 | branch1x1 = self.branch1x1(x)
218 |
219 | branch5x5 = self.branch5x5_1(x)
220 | branch5x5 = self.branch5x5_2(branch5x5)
221 |
222 | branch3x3dbl = self.branch3x3dbl_1(x)
223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225 |
226 | # Patch: Tensorflow's average pool does not use the padded zero's in
227 | # its average calculation
228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229 | count_include_pad=False)
230 | branch_pool = self.branch_pool(branch_pool)
231 |
232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233 | return torch.cat(outputs, 1)
234 |
235 |
236 | class FIDInceptionC(torchvision.models.inception.InceptionC):
237 | """InceptionC block patched for FID computation"""
238 | def __init__(self, in_channels, channels_7x7):
239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240 |
241 | def forward(self, x):
242 | branch1x1 = self.branch1x1(x)
243 |
244 | branch7x7 = self.branch7x7_1(x)
245 | branch7x7 = self.branch7x7_2(branch7x7)
246 | branch7x7 = self.branch7x7_3(branch7x7)
247 |
248 | branch7x7dbl = self.branch7x7dbl_1(x)
249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253 |
254 | # Patch: Tensorflow's average pool does not use the padded zero's in
255 | # its average calculation
256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257 | count_include_pad=False)
258 | branch_pool = self.branch_pool(branch_pool)
259 |
260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261 | return torch.cat(outputs, 1)
262 |
263 |
264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265 | """First InceptionE block patched for FID computation"""
266 | def __init__(self, in_channels):
267 | super(FIDInceptionE_1, self).__init__(in_channels)
268 |
269 | def forward(self, x):
270 | branch1x1 = self.branch1x1(x)
271 |
272 | branch3x3 = self.branch3x3_1(x)
273 | branch3x3 = [
274 | self.branch3x3_2a(branch3x3),
275 | self.branch3x3_2b(branch3x3),
276 | ]
277 | branch3x3 = torch.cat(branch3x3, 1)
278 |
279 | branch3x3dbl = self.branch3x3dbl_1(x)
280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281 | branch3x3dbl = [
282 | self.branch3x3dbl_3a(branch3x3dbl),
283 | self.branch3x3dbl_3b(branch3x3dbl),
284 | ]
285 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
286 |
287 | # Patch: Tensorflow's average pool does not use the padded zero's in
288 | # its average calculation
289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290 | count_include_pad=False)
291 | branch_pool = self.branch_pool(branch_pool)
292 |
293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294 | return torch.cat(outputs, 1)
295 |
296 |
297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298 | """Second InceptionE block patched for FID computation"""
299 | def __init__(self, in_channels):
300 | super(FIDInceptionE_2, self).__init__(in_channels)
301 |
302 | def forward(self, x):
303 | branch1x1 = self.branch1x1(x)
304 |
305 | branch3x3 = self.branch3x3_1(x)
306 | branch3x3 = [
307 | self.branch3x3_2a(branch3x3),
308 | self.branch3x3_2b(branch3x3),
309 | ]
310 | branch3x3 = torch.cat(branch3x3, 1)
311 |
312 | branch3x3dbl = self.branch3x3dbl_1(x)
313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314 | branch3x3dbl = [
315 | self.branch3x3dbl_3a(branch3x3dbl),
316 | self.branch3x3dbl_3b(branch3x3dbl),
317 | ]
318 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
319 |
320 | # Patch: The FID Inception model uses max pooling instead of average
321 | # pooling. This is likely an error in this specific Inception
322 | # implementation, as other Inception models use average pooling here
323 | # (which matches the description in the paper).
324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325 | branch_pool = self.branch_pool(branch_pool)
326 |
327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328 | return torch.cat(outputs, 1)
329 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sde
2 | import ml_collections
3 | import torch
4 | from torch import multiprocessing as mp
5 | from uvit_datasets import get_dataset
6 | from torchvision.utils import make_grid, save_image
7 | import utils
8 | import einops
9 | from torch.utils._pytree import tree_map
10 | import accelerate
11 | from torch.utils.data import DataLoader
12 | from tqdm.auto import tqdm
13 | from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
14 | import tempfile
15 | from tools.fid_score import calculate_fid_given_paths
16 | from absl import logging
17 | import builtins
18 | import os
19 | import wandb
20 |
21 | import time
22 |
23 | from accelerate import DistributedDataParallelKwargs
24 |
25 | from sde import get_sde
26 | from mamba_attn_diff.utils.backup_code import backup_code
27 |
28 | def train(config):
29 | if config.get('benchmark', False):
30 | torch.backends.cudnn.benchmark = True
31 | torch.backends.cudnn.deterministic = False
32 |
33 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
34 | mp.set_start_method('spawn')
35 | accelerator = accelerate.Accelerator() # kwargs_handlers=[ddp_kwargs]
36 | device = accelerator.device
37 | accelerate.utils.set_seed(config.seed, device_specific=True)
38 | logging.info(f'Process {accelerator.process_index} using device: {device}')
39 |
40 | config.mixed_precision = accelerator.mixed_precision
41 | config = ml_collections.FrozenConfigDict(config)
42 |
43 | assert config.train.batch_size % accelerator.num_processes == 0
44 | mini_batch_size = config.train.batch_size // accelerator.num_processes
45 |
46 | if accelerator.is_main_process:
47 | os.makedirs(config.ckpt_root, exist_ok=True)
48 | os.makedirs(config.sample_dir, exist_ok=True)
49 | accelerator.wait_for_everyone()
50 | if accelerator.is_main_process:
51 | workdir = config.workdir
52 |
53 | backup_code(os.path.abspath(workdir))
54 |
55 | wandb.init(dir=os.path.abspath(workdir),
56 | project=f'uvit_{config.dataset.name}', config=config.to_dict(),
57 | name=config.hparams, job_type='train', mode='offline')
58 | utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
59 | logging.info(config)
60 | else:
61 | pass
62 |
63 | dataset = get_dataset(**config.dataset)
64 | assert os.path.exists(dataset.fid_stat)
65 | train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
66 | train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
67 | num_workers=8, pin_memory=True, persistent_workers=True)
68 |
69 | train_state = utils.initialize_train_state(config, device)
70 | if hasattr(train_state.nnet, 'enable_gradient_checkpointing') and config.get('gradient_checkpointing', False):
71 | train_state.nnet.enable_gradient_checkpointing()
72 |
73 | nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
74 | train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
75 | lr_scheduler = train_state.lr_scheduler
76 | train_state.resume(config.ckpt_root)
77 |
78 |
79 | def get_data_generator():
80 | while True:
81 | for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
82 | yield data
83 |
84 | data_generator = get_data_generator()
85 |
86 |
87 | # set the score_model to train
88 | scheduler_config = dict(name='vpsde') if not config.get('scheduler', False) else config.scheduler
89 | scheduler = get_sde(device=device, **scheduler_config) # sde.VPSDE()
90 | scheduler_ema = get_sde(device=device, **scheduler_config)
91 | score_model = sde.ScoreModel(nnet, pred=config.pred, sde=scheduler)
92 | score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=scheduler_ema)
93 |
94 |
95 | def train_step(_batch, iter_rate):
96 | _metrics = dict()
97 | optimizer.zero_grad()
98 | if config.train.mode == 'uncond':
99 | loss = sde.LSimple(score_model, _batch, pred=config.pred, iter_rate=iter_rate)
100 | elif config.train.mode == 'cond':
101 | loss = sde.LSimple(score_model, _batch[0], pred=config.pred, y=_batch[1], iter_rate=iter_rate)
102 | else:
103 | raise NotImplementedError(config.train.mode)
104 | _metrics['loss'] = accelerator.gather(loss.detach()).mean()
105 | accelerator.backward(loss.mean())
106 | if 'grad_clip' in config and config.grad_clip > 0:
107 | accelerator.clip_grad_norm_(nnet.parameters(), max_norm=config.grad_clip)
108 | optimizer.step()
109 | lr_scheduler.step()
110 | train_state.ema_update(config.get('ema_rate', 0.9999))
111 | train_state.step += 1
112 | return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
113 |
114 |
115 | def eval_step(n_samples, sample_steps, algorithm):
116 | logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, '
117 | f'mini_batch_size={config.sample.mini_batch_size}')
118 |
119 | def sample_fn(_n_samples):
120 | _x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
121 | if config.train.mode == 'uncond':
122 | kwargs = dict()
123 | elif config.train.mode == 'cond':
124 | kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
125 | else:
126 | raise NotImplementedError
127 |
128 | if algorithm == 'euler_maruyama_sde':
129 | return sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _x_init, sample_steps, **kwargs)
130 | elif algorithm == 'euler_maruyama_ode':
131 | return sde.euler_maruyama(sde.ODE(score_model_ema), _x_init, sample_steps, **kwargs)
132 | elif algorithm == 'dpm_solver':
133 | noise_schedule = NoiseScheduleVP(schedule='linear')
134 | model_fn = model_wrapper(
135 | score_model_ema.noise_pred,
136 | noise_schedule,
137 | time_input_type='0',
138 | model_kwargs=kwargs
139 | )
140 | dpm_solver = DPM_Solver(model_fn, noise_schedule)
141 | return dpm_solver.sample(
142 | _x_init,
143 | steps=sample_steps,
144 | eps=1e-4,
145 | adaptive_step_size=False,
146 | fast_version=True,
147 | )
148 | elif algorithm in ['edm', 'ddim', 'ddpm']:
149 | return sde.diffusers_denoising(
150 | score_model_ema, scheduler.noise_scheduler, _x_init,
151 | config.sample.sample_steps,
152 | device=device, **kwargs)
153 | else:
154 | raise NotImplementedError
155 |
156 | with tempfile.TemporaryDirectory() as temp_path:
157 | path = config.sample.path or temp_path
158 | if accelerator.is_main_process:
159 | os.makedirs(path, exist_ok=True)
160 | utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
161 |
162 | _fid = 0
163 | if accelerator.is_main_process:
164 | _fid = calculate_fid_given_paths((dataset.fid_stat, path))
165 | logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
166 | with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
167 | print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
168 | wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
169 | _fid = torch.tensor(_fid, device=device)
170 | _fid = accelerator.reduce(_fid, reduction='sum')
171 |
172 | return _fid.item()
173 |
174 | logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
175 |
176 | step_fid = []
177 | while train_state.step < config.train.n_steps:
178 | nnet.train()
179 | batch = tree_map(lambda x: x.to(device), next(data_generator))
180 | metrics = train_step(batch, train_state.step*1. / config.train.n_steps)
181 |
182 | nnet.eval()
183 | if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
184 | logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
185 | logging.info(config.workdir)
186 | wandb.log(metrics, step=train_state.step)
187 |
188 | if train_state.step % config.train.eval_interval == 0:
189 | # torch.cuda.empty_cache()
190 | if accelerator.is_main_process:
191 | logging.info('Save a grid of images...')
192 |
193 | with torch.no_grad():
194 | x_init = torch.randn(100, *dataset.data_shape, device=device)
195 | if config.train.mode == 'cond':
196 | y = einops.repeat(torch.arange(10, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
197 |
198 | if config.sample.algorithm in ['edm', 'ddim', 'ddpm']:
199 | if config.train.mode == 'uncond':
200 | samples = sde.diffusers_denoising(
201 | score_model_ema, scheduler.noise_scheduler, x_init=x_init,
202 | sample_steps=config.sample.sample_steps, device=device,)
203 | elif config.train.mode == 'cond':
204 | samples = sde.diffusers_denoising(
205 | score_model_ema, scheduler.noise_scheduler, x_init=x_init,
206 | sample_steps=config.sample.sample_steps, device=device, y=y,
207 | do_classifier_free_guidance=True, cfg_weight=1.5)
208 | else:
209 | if config.train.mode == 'uncond':
210 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50)
211 | elif config.train.mode == 'cond':
212 | samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50, y=y)
213 | else:
214 | raise NotImplementedError
215 |
216 | if accelerator.is_main_process:
217 | samples = make_grid(dataset.unpreprocess(samples), 10)
218 | save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
219 | wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
220 | else:
221 | pass
222 |
223 | accelerator.wait_for_everyone()
224 |
225 | if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
226 | logging.info(f'Save and eval checkpoint {train_state.step}...')
227 | if accelerator.local_process_index == 0:
228 | train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
229 | accelerator.wait_for_everyone()
230 | fid = eval_step(n_samples=10000, sample_steps=50,
231 | algorithm=(config.sample.algorithm if config.get('scheduler', False) else 'dpm_solver') ) # calculate fid of the saved checkpoint
232 | step_fid.append((train_state.step, fid))
233 |
234 | accelerator.wait_for_everyone()
235 |
236 | logging.info(f'Finish fitting, step={train_state.step}')
237 | logging.info(f'step_fid: {step_fid}')
238 | step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
239 | logging.info(f'step_best: {step_best}')
240 | train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
241 | del metrics
242 | accelerator.wait_for_everyone()
243 | eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm)
244 |
245 |
246 |
247 | from absl import flags
248 | from absl import app
249 | from ml_collections import config_flags
250 | import sys
251 | from pathlib import Path
252 |
253 |
254 | FLAGS = flags.FLAGS
255 | config_flags.DEFINE_config_file(
256 | "config", None, "Training configuration.", lock_config=False)
257 | flags.mark_flags_as_required(["config"])
258 | flags.DEFINE_string("workdir", None, "Work unit directory.")
259 |
260 |
261 | def get_config_name():
262 | argv = sys.argv
263 | for i in range(1, len(argv)):
264 | if argv[i].startswith('--config='):
265 | return Path(argv[i].split('=')[-1]).stem
266 |
267 |
268 | def get_hparams():
269 | argv = sys.argv
270 | lst = []
271 | for i in range(1, len(argv)):
272 | assert '=' in argv[i]
273 | if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
274 | hparam, val = argv[i].split('=')
275 | hparam = hparam.split('.')[-1]
276 | if hparam.endswith('path'):
277 | val = Path(val).stem
278 | lst.append(f'{hparam}={val}')
279 | hparams = '-'.join(lst)
280 | if hparams == '':
281 | hparams = 'default'
282 | return hparams
283 |
284 |
285 | def main(argv):
286 | config = FLAGS.config
287 | config.config_name = get_config_name()
288 | config.hparams = get_hparams()
289 | config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
290 | config.ckpt_root = os.path.join(config.workdir, 'ckpts')
291 | config.sample_dir = os.path.join(config.workdir, 'samples')
292 | train(config)
293 |
294 |
295 | if __name__ == "__main__":
296 | app.run(main)
297 |
--------------------------------------------------------------------------------