├── mamba ├── mamba_ssm │ ├── ops │ │ ├── __init__.py │ │ └── triton │ │ │ ├── __init__.py │ │ │ └── selective_state_update.py │ ├── utils │ │ ├── __init__.py │ │ └── hf.py │ ├── models │ │ ├── __init__.py │ │ └── mixer_seq_simple.py │ ├── modules │ │ └── __init__.py │ └── __init__.py ├── AUTHORS ├── assets │ └── selection.png ├── csrc │ └── selective_scan │ │ ├── selective_scan_bwd_fp32_real.cu │ │ ├── selective_scan_bwd_fp16_real.cu │ │ ├── selective_scan_bwd_bf16_real.cu │ │ ├── selective_scan_bwd_fp16_complex.cu │ │ ├── selective_scan_bwd_fp32_complex.cu │ │ ├── selective_scan_bwd_bf16_complex.cu │ │ ├── selective_scan_fwd_fp32.cu │ │ ├── selective_scan_fwd_fp16.cu │ │ ├── selective_scan_fwd_bf16.cu │ │ ├── static_switch.h │ │ ├── uninitialized_copy.cuh │ │ ├── selective_scan.h │ │ └── selective_scan_common.h ├── test_mamba_module.py ├── evals │ └── lm_harness_eval.py ├── tests │ └── ops │ │ └── triton │ │ └── test_selective_state_update.py ├── benchmarks │ └── benchmark_generation_mamba_simple.py ├── README.md ├── setup.py └── LICENSE ├── causal-conv1d ├── AUTHORS ├── README.md ├── causal_conv1d │ ├── __init__.py │ └── causal_conv1d_interface.py ├── csrc │ ├── static_switch.h │ ├── causal_conv1d.h │ ├── causal_conv1d_common.h │ ├── causal_conv1d_update.cu │ └── causal_conv1d.cpp ├── LICENSE ├── tests │ └── test_causal_conv1d.py └── setup.py ├── tools ├── __pycache__ │ ├── constants.cpython-310.pyc │ ├── dataset.cpython-310.pyc │ ├── fid_score.cpython-310.pyc │ ├── inception.cpython-310.pyc │ └── webdataset.cpython-310.pyc ├── webdataset.py ├── dataset.py ├── fid_score.py └── inception.py ├── diffusion ├── __pycache__ │ ├── respace.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── diffusion_utils.cpython-310.pyc │ └── gaussian_diffusion.cpython-310.pyc ├── __init__.py ├── diffusion_utils.py ├── respace.py └── timestep_sampler.py ├── scripts ├── train_celeba_middle.sh ├── train_wds_huge_256.sh └── train_imagenet_huge_256.sh ├── parallel_train.sh ├── requirements.txt ├── clip.py ├── README.md ├── sample.py ├── t5.py └── test.py /mamba/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /causal-conv1d/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /mamba/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /causal-conv1d/README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | -------------------------------------------------------------------------------- /mamba/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/mamba/assets/selection.png -------------------------------------------------------------------------------- /tools/__pycache__/constants.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/tools/__pycache__/constants.cpython-310.pyc -------------------------------------------------------------------------------- /tools/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/tools/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /tools/__pycache__/fid_score.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/tools/__pycache__/fid_score.cpython-310.pyc -------------------------------------------------------------------------------- /tools/__pycache__/inception.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/tools/__pycache__/inception.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/respace.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/diffusion/__pycache__/respace.cpython-310.pyc -------------------------------------------------------------------------------- /tools/__pycache__/webdataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/tools/__pycache__/webdataset.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/diffusion_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/diffusion/__pycache__/diffusion_utils.cpython-310.pyc -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Vespa/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /mamba/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/test_mamba_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mamba_ssm import Mamba 3 | 4 | batch, length, dim = 2, 64, 768 5 | x = torch.randn(batch, length, dim).to("cuda") 6 | model = Mamba( 7 | # This module uses roughly 3 * expand * d_model^2 parameters 8 | d_model=dim, # Model dimension d_model 9 | d_state=16, # SSM state expansion factor # 64 10 | d_conv=4, # Local convolution width 11 | expand=2, # Block expansion factor 12 | use_fast_path=False, 13 | ).to("cuda") 14 | y = model(x) 15 | assert y.shape == x.shape 16 | -------------------------------------------------------------------------------- /scripts/train_celeba_middle.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=1 --nproc_per_node=8 train.py \ 2 | --model DRWKV-L/2 \ 3 | --dataset-type celeba \ 4 | --data-path /TrainData/Multimodal/zhengcong.fei/dis/data/CelebA \ 5 | --image-size 64 \ 6 | --resize-only True \ 7 | --global-batch-size 128 \ 8 | --accum_iter 8 \ 9 | --epochs 30 \ 10 | --lr 1e-5 \ 11 | --warmup_epochs 0 \ 12 | --eval_steps 2000 \ 13 | --ckpt-every 4000 \ 14 | --global-seed 12433 \ 15 | --resume /TrainData/Multimodal/zhengcong.fei/diff-rwkv/results/DRWKV-M-2-celeba-uncond-64/checkpoints/0044000.pt -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /parallel_train.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=2 --nproc_per_node=8 --node_rank $RANK --rdzv-id=vespa \ 2 | --rdzv-endpoint=10.0.13.24:12345 \ 3 | accelerate_train.py \ 4 | --model VeSpa-H/2 \ 5 | --dataset-type wds \ 6 | --image-only True \ 7 | --data-path /maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2 \ 8 | --anna-path /maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2 \ 9 | --image-size 512 \ 10 | --text_encoder_type t5 \ 11 | --global-batch-size 16 \ 12 | --epochs 20 \ 13 | --warmup_epochs 0 \ 14 | --accum_iter 8 \ 15 | --eval_steps 10 \ 16 | --lr 1e-4 \ 17 | --latent_space True \ 18 | --global-seed 42 19 | -------------------------------------------------------------------------------- /scripts/train_wds_huge_256.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnodes=1 --nproc_per_node=8 train.py \ 2 | --model VeSpa-H/2 \ 3 | --dataset-type wds \ 4 | --image-only True \ 5 | --data-path /maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2 \ 6 | --anna-path /maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2 \ 7 | --image-size 512 \ 8 | --text_encoder_type t5 \ 9 | --global-batch-size 32 \ 10 | --epochs 20 \ 11 | --warmup_epochs 0 \ 12 | --accum_iter 8 \ 13 | --eval_steps 10000000 \ 14 | --lr 1e-4 \ 15 | --latent_space True \ 16 | --global-seed 45 \ 17 | --resume /maindata/data/shared/multimodal/zhengcong.fei/code/vespa/results/VeSpa-H-2-wds-image-True/checkpoints/ckpt2.pt -------------------------------------------------------------------------------- /scripts/train_imagenet_huge_256.sh: -------------------------------------------------------------------------------- 1 | torchrun --master_port 2995 --nnodes=1 --nproc_per_node=8 train.py \ 2 | --model VeSpa-H/2 \ 3 | --dataset-type imagenet \ 4 | --image-only True \ 5 | --data-path /maindata/data/shared/multimodal/zhengcong.fei/code/vespa/data/imagenet_tag.json \ 6 | --anna-path /maindata/data/shared/multimodal/zhengcong.fei/code/vespa/data/imagenet_tag.json \ 7 | --image-size 256 \ 8 | --text_encoder_type t5 \ 9 | --global-batch-size 64 \ 10 | --epochs 20 \ 11 | --warmup_epochs 0 \ 12 | --accum_iter 8 \ 13 | --eval_steps 10000000 \ 14 | --lr 1e-4 \ 15 | --latent_space True \ 16 | --global-seed 520 \ 17 | --resume /maindata/data/shared/multimodal/zhengcong.fei/code/vespa/results/VeSpa-H-2-imagenet-image-True/checkpoints/ckpt2.pt -------------------------------------------------------------------------------- /mamba/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | accelerate==0.21.0 4 | aiohttp==3.8.4 5 | aiosignal==1.3.1 6 | async-timeout==4.0.2 7 | attrs==22.2.0 8 | bitsandbytes==0.37.0 9 | cchardet==2.1.7 10 | chardet==5.1.0 11 | contourpy==1.0.7 12 | cycler==0.11.0 13 | filelock==3.9.0 14 | fonttools==4.38.0 15 | frozenlist==1.3.3 16 | huggingface-hub==0.16.4 17 | importlib-resources==5.12.0 18 | kiwisolver==1.4.4 19 | matplotlib==3.7.0 20 | multidict==6.0.4 21 | openai==0.27.0 22 | packaging==23.0 23 | psutil==5.9.4 24 | pycocotools==2.0.6 25 | pyparsing==3.0.9 26 | python-dateutil==2.8.2 27 | pyyaml==6.0 28 | regex==2022.10.31 29 | tokenizers 30 | triton==2.1.0 31 | tqdm==4.64.1 32 | transformers==4.34.0 33 | timm==0.6.13 34 | torchaudio==0.13.1 35 | spacy==3.5.1 36 | webdataset==0.2.48 37 | scikit-learn==1.2.2 38 | scipy==1.10.1 39 | yarl==1.8.2 40 | zipp==3.14.0 41 | omegaconf==2.3.0 42 | opencv-python==4.7.0.72 43 | iopath==0.1.10 44 | decord==0.6.0 45 | tenacity==8.2.2 46 | peft 47 | pycocoevalcap 48 | sentence-transformers 49 | umap-learn 50 | notebook 51 | gradio==3.24.1 52 | gradio-client==0.0.8 53 | wandb 54 | opencv-python-headless 55 | datasets 56 | pytorch-lightning 57 | deepspeed==0.10.0 58 | einops 59 | tensorboard 60 | xformers 61 | apex 62 | einops_exts 63 | jsonlines 64 | pyclipper 65 | onnxruntime 66 | shapely 67 | zhconv 68 | paddleocr 69 | paddlepaddle -------------------------------------------------------------------------------- /causal-conv1d/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = batch_size if batch_size is None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | from transformers import CLIPTokenizer, CLIPTextModel 4 | 5 | 6 | class AbstractEncoder(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def encode(self, *args, **kwargs): 11 | raise NotImplementedError 12 | 13 | 14 | class FrozenCLIPEmbedder(AbstractEncoder): 15 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 16 | def __init__(self, path, device="cuda", max_length=77): 17 | super().__init__() 18 | self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(path, 'tokenizer')) 19 | self.transformer = CLIPTextModel.from_pretrained(os.path.join(path, 'text_encoder')).to(device) 20 | 21 | self.device = device 22 | self.max_length = max_length 23 | self.freeze() 24 | 25 | def freeze(self): 26 | self.transformer = self.transformer.eval() 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, text): 31 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 32 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 33 | tokens = batch_encoding["input_ids"].to(self.device) 34 | outputs = self.transformer(input_ids=tokens) 35 | 36 | z = outputs.last_hidden_state 37 | return z 38 | 39 | def encode(self, text): 40 | return self(text) 41 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | }; 36 | 37 | struct ConvParamsBwd: public ConvParamsBase { 38 | index_t dx_batch_stride; 39 | index_t dx_c_stride; 40 | index_t dx_l_stride; 41 | index_t dweight_c_stride; 42 | index_t dweight_width_stride; 43 | index_t dout_batch_stride; 44 | index_t dout_c_stride; 45 | index_t dout_l_stride; 46 | 47 | // Common data pointers. 48 | void *__restrict__ dx_ptr; 49 | void *__restrict__ dweight_ptr; 50 | void *__restrict__ dbias_ptr; 51 | void *__restrict__ dout_ptr; 52 | }; 53 | 54 | -------------------------------------------------------------------------------- /causal-conv1d/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /mamba/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Vespa🐝: Video Diffusion State Space Models 2 | 3 | This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper video diffusion state space models. 4 | Our model use clip/t5 as text encoder and mamba-based diffusion model. 5 | Its distinctive advantage lies in ites reduced spatial complexity, which renders it exceptionally adept at processing long videos or high-resolution images, eliminating the necessity for window operations. 6 | 7 | The following cases are generated by [model](https://huggingface.co/feizhengcong/VeSpa-M) with prompt "sad". 8 | 9 | ![sad](https://github.com/feizc/Vespa/assets/37614046/5bcd0cba-9cb0-4cba-ab36-801539722709) 10 | 11 | 12 | ### 1. Environments 13 | 14 | - Python 3.10 15 | - `conda create -n your_env_name python=3.10` 16 | 17 | - Requirements file 18 | - `pip install -r requirements.txt` 19 | 20 | - Install ``causal_conv1d`` and ``mamba`` 21 | - `pip install -e causal_conv1d` 22 | - `pip install -e mamba` 23 | 24 | 25 | ### 2. Training 26 | 27 | We provide a training script for VeSpa in [`train.py`](train.py). This script can be used to train video diffusion state space models. 28 | 29 | To launch DiS-M/2 (64x64) in the raw space training with `N` GPUs on one node: 30 | 31 | ```bash 32 | torchrun --nnodes=1 --nproc_per_node=N train.py \ 33 | --model VeSpa-M/2 \ 34 | --model-type video \ 35 | --dataset-type ucf \ 36 | --data-path /path/to/datat \ 37 | --anna-path /path/to/annate \ 38 | --image-size 64 \ 39 | --lr 1e-4 40 | ``` 41 | 42 | 43 | ### 3. Evaluation 44 | 45 | We include a [`sample.py`](sample.py) script which samples images from a DiS model. Besides, we support other metrics evaluation, e.g., FLOPS and model parameters, in [`test.py`](test.py) script. 46 | 47 | ```bash 48 | python sample.py \ 49 | --model VeSpa-M/2 \ 50 | --ckpt /path/to/model \ 51 | --image-size 64 \ 52 | --prompt sad 53 | ``` 54 | 55 | ### 4. BibTeX 56 | 57 | ```bibtex 58 | @article{FeiVespa2024, 59 | title={Video Diffusion State Space Models}, 60 | author={Zhengcong Fei, Mingyuan Fan, Yujun Liu, Changqian Yu, Jusnshi Huang}, 61 | year={2024}, 62 | journal={arXiv preprint}, 63 | } 64 | ``` 65 | 66 | 67 | ### 5. Acknowledgments 68 | 69 | The codebase is based on the awesome [DiS](https://github.com/feizc/DiS), [DiT](https://github.com/facebookresearch/DiT), [mamba](https://github.com/state-spaces/mamba), [U-ViT](https://github.com/baofff/U-ViT), and [Vim](https://github.com/hustvl/Vim) repos. 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /mamba/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--batch", type=int, default=1) 26 | args = parser.parse_args() 27 | 28 | repeats = 3 29 | device = "cuda" 30 | dtype = torch.float16 31 | 32 | print(f"Loading model {args.model_name}") 33 | is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name 34 | 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | ) 66 | else: 67 | fn = lambda: model.generate( 68 | input_ids=input_ids, 69 | attention_mask=attn_mask, 70 | max_length=max_length, 71 | return_dict_in_generate=True, 72 | pad_token_id=tokenizer.eos_token_id, 73 | do_sample=True, 74 | temperature=args.temperature, 75 | top_k=args.topk, 76 | top_p=args.topp, 77 | ) 78 | out = fn() 79 | if args.prompt is not None: 80 | print(tokenizer.batch_decode(out.sequences.tolist())) 81 | 82 | torch.cuda.synchronize() 83 | start = time.time() 84 | for _ in range(repeats): 85 | fn() 86 | torch.cuda.synchronize() 87 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 88 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 89 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import random 4 | torch.backends.cuda.matmul.allow_tf32 = True 5 | torch.backends.cudnn.allow_tf32 = True 6 | 7 | 8 | from torchvision.utils import save_image 9 | from diffusers.models import AutoencoderKL 10 | 11 | from diffusion import create_diffusion 12 | from models_vespa import VeSpa_image_models, VeSpa_video_models 13 | from clip import FrozenCLIPEmbedder 14 | from t5 import T5Embedder 15 | 16 | 17 | def main(args): 18 | print("Sample images from a trained vespa model.") 19 | # Setup PyTorch: 20 | torch.manual_seed(args.seed) 21 | torch.set_grad_enabled(False) 22 | device = "cuda" if torch.cuda.is_available() else "cpu" 23 | 24 | if args.latent_space == True: 25 | img_size = args.image_size // 8 26 | channels = 4 27 | else: 28 | img_size=args.image_size 29 | channels = 3 30 | 31 | if args.text_encoder_type == 'clip': 32 | num_clip_token = 77 33 | clip_dim = 768 34 | else: 35 | num_clip_token = 120 36 | clip_dim = 4096 37 | 38 | if args.model_type == 'image': 39 | model = VeSpa_image_models[args.model]( 40 | img_size=img_size, 41 | channels=channels, 42 | num_clip_token=num_clip_token, 43 | clip_dim=clip_dim, 44 | ) 45 | else: 46 | model = VeSpa_video_models[args.model]( 47 | img_size=img_size, 48 | channels=channels, 49 | enable_temporal_layers= not args.image_only, 50 | num_clip_token=num_clip_token, 51 | clip_dim=clip_dim, 52 | ) 53 | 54 | checkponit = torch.load(args.ckpt, map_location=lambda storage, loc: storage)['ema'] 55 | model.load_state_dict(checkponit) 56 | model = model.to(device) 57 | model.eval() 58 | 59 | diffusion = create_diffusion(str(args.num_sampling_steps)) 60 | if args.latent_space == True: 61 | vae = AutoencoderKL.from_pretrained(args.vae_path).to(device) 62 | 63 | if args.text_encoder_type == 'clip': 64 | text_encoder = FrozenCLIPEmbedder( 65 | path='/maindata/data/shared/multimodal/zhengcong.fei/ckpts/playground', 66 | device=device, 67 | ) 68 | text_encoder.eval() 69 | text_encoder = text_encoder.to(device) 70 | elif args.text_encoder_type == 't5': 71 | t5_path = '/maindata/data/shared/multimodal/zhengcong.fei/ckpts/DeepFloyd/t5-v1_1-xxl' 72 | text_encoder = T5Embedder(device='cuda', local_cache=True, cache_dir=t5_path) 73 | else: 74 | pass 75 | 76 | n = 16 77 | y = ['tiger cat',] * n 78 | # text = ['Skiing',] * n 79 | 80 | if args.latent_space == True: 81 | z = torch.randn(n, 4, args.image_size//8, args.image_size//8, device=device) 82 | else: 83 | z = torch.randn(n, 3, args.image_size, args.image_size, device=device) 84 | 85 | # Setup classifier-free guidance: 86 | # z = torch.cat([z, z], 0) 87 | 88 | with torch.no_grad(): 89 | if args.text_encoder_type == 'clip': 90 | context = text_encoder.encode(y) 91 | else: 92 | context, _ = text_encoder.get_text_embeddings(y) 93 | context = context.float() 94 | 95 | if args.image_only == True: 96 | model_kwargs = dict(context=context,) 97 | else: 98 | model_kwargs = dict(context=context, f=8) 99 | # Sample images: 100 | samples = diffusion.p_sample_loop( 101 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device 102 | ) 103 | eval_samples, _ = samples.chunk(2, dim=0) 104 | # eval_samples = samples 105 | 106 | if args.latent_space == True: 107 | eval_samples = vae.decode(eval_samples / 0.18215).sample 108 | 109 | save_image(eval_samples, "sample.png", nrow=8, normalize=True, value_range=(-1, 1)) 110 | 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model", type=str, default="VeSpa-H/2") 116 | parser.add_argument("--model-type", type=str, default="image") 117 | parser.add_argument("--text_encoder_type", type=str, choices=['clip', 't5'], default='t5') 118 | parser.add_argument("--image-size", type=int, choices=[32, 64, 256, 512], default=256) 119 | parser.add_argument("--image-only", type=bool, default=True) 120 | parser.add_argument("--cfg-scale", type=float, default=1.5) 121 | parser.add_argument("--num-sampling-steps", type=int, default=250) 122 | parser.add_argument("--seed", type=int, default=42) 123 | parser.add_argument("--ckpt", type=str, default="/maindata/data/shared/multimodal/zhengcong.fei/code/vespa/results/VeSpa-H-2-imagenet-image-True/checkpoints/0033000.pt",) 124 | # parser.add_argument("--ckpt", type=str, default="/TrainData/Multimodal/zhengcong.fei/vespa/results/VeSpa-M-2-face-video-False/checkpoints/0024000.pt",) 125 | # parser.add_argument("--ckpt", type=str, default="/TrainData/Multimodal/zhengcong.fei/vespa/results/VeSpa-M-2-ucf-video-False/checkpoints/0024000.pt",) 126 | parser.add_argument('--latent_space', type=bool, default=True,) 127 | parser.add_argument('--vae_path', type=str, default='/maindata/data/shared/multimodal/zhengcong.fei/ckpts/playground/vae') 128 | args = parser.parse_args() 129 | 130 | main(args) 131 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /mamba/README.md: -------------------------------------------------------------------------------- 1 | # Mamba 2 | 3 | ![Mamba](assets/selection.png "Selective State Space") 4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ 5 | > Albert Gu*, Tri Dao*\ 6 | > Paper: https://arxiv.org/abs/2312.00752 7 | 8 | ## About 9 | 10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. 11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), 12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). 13 | 14 | ## Installation 15 | 16 | - `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. 17 | - `pip install mamba-ssm`: the core Mamba package. 18 | 19 | It can also be built from source with `pip install .` from this repository. 20 | 21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. 22 | 23 | Other requirements: 24 | - Linux 25 | - NVIDIA GPU 26 | - PyTorch 1.12+ 27 | - CUDA 11.6+ 28 | 29 | ## Usage 30 | 31 | We expose several levels of interface with the Mamba model. 32 | 33 | ### Selective SSM 34 | 35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). 36 | 37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). 38 | 39 | ### Mamba Block 40 | 41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM. 42 | 43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). 44 | 45 | Usage: 46 | ``` 47 | from mamba_ssm import Mamba 48 | 49 | batch, length, dim = 2, 64, 16 50 | x = torch.randn(batch, length, dim).to("cuda") 51 | model = Mamba( 52 | # This module uses roughly 3 * expand * d_model^2 parameters 53 | d_model=dim, # Model dimension d_model 54 | d_state=16, # SSM state expansion factor 55 | d_conv=4, # Local convolution width 56 | expand=2, # Block expansion factor 57 | ).to("cuda") 58 | y = model(x) 59 | assert y.shape == x.shape 60 | ``` 61 | 62 | ### Mamba Language Model 63 | 64 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. 65 | 66 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). 67 | 68 | This is an example of how to integrate Mamba into an end-to-end neural network. 69 | This example is used in the generation scripts below. 70 | 71 | 72 | 73 | ## Pretrained Models 74 | 75 | Pretrained models are uploaded to 76 | [HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, 77 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`. 78 | 79 | The models will be autodownloaded by the generation script below. 80 | 81 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: 82 | 83 | | Parameters | Layers | Model dim. | 84 | |------------|--------|------------| 85 | | 130M | 12 | 768 | 86 | | 370M | 24 | 1024 | 87 | | 790M | 24 | 1536 | 88 | | 1.4B | 24 | 2048 | 89 | | 2.8B | 32 | 2560 | 90 | 91 | (The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) 92 | 93 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). 94 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. 95 | 96 | 97 | ## Evaluations 98 | 99 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper), 100 | we use the 101 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 102 | library. 103 | 104 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init 105 | --recursive`. We use the `big-refactor` branch. 106 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness` 107 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): 108 | ``` 109 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 110 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 111 | ``` 112 | 113 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. 114 | 115 | ## Inference 116 | 117 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 118 | 1. autoloads a model from the HuggingFace Hub, 119 | 2. generates completions of a user-specified prompt, 120 | 3. benchmarks the inference speed of this generation. 121 | 122 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. 123 | 124 | ### Examples 125 | 126 | To test generation latency (e.g. batch size = 1) with different sampling strategies: 127 | 128 | ``` 129 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 130 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 131 | ``` 132 | 133 | To test generation throughput with random prompts (e.g. large batch size): 134 | ``` 135 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 136 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 137 | ``` 138 | 139 | ## Citation 140 | 141 | If you use this codebase, or otherwise found our work valuable, please cite Mamba: 142 | ``` 143 | @article{mamba, 144 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, 145 | author={Gu, Albert and Dao, Tri}, 146 | journal={arXiv preprint arXiv:2312.00752}, 147 | year={2023} 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /tools/webdataset.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import logging 3 | import torch 4 | import random 5 | import math 6 | import json,os,re 7 | import io 8 | 9 | from PIL import Image 10 | from torchvision import transforms 11 | from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample 12 | from functools import partial 13 | 14 | """ 15 | Set hyper-parameter for wds.shuffle 16 | """ 17 | 18 | _SHARD_SHUFFLE_SIZE = 2000 19 | _SHARD_SHUFFLE_INITIAL = 500 20 | _SAMPLE_SHUFFLE_SIZE = 5000 21 | _SAMPLE_SHUFFLE_INITIAL = 1000 22 | 23 | class WebdatasetFilter: 24 | def __init__(self, min_size=1024, max_pwatermark=0.5): 25 | self.min_size = min_size 26 | self.max_pwatermark = max_pwatermark 27 | 28 | def __call__(self, x): 29 | try: 30 | if "json" in x: 31 | x_json = json.loads(x["json"]) 32 | filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( 33 | "original_height", 0 34 | ) >= self.min_size 35 | filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark 36 | return filter_size and filter_watermark 37 | else: 38 | return False 39 | except Exception: 40 | return False 41 | 42 | 43 | def filter_no_caption_or_no_image(sample): 44 | has_caption = ('txt' in sample) 45 | has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) 46 | return has_caption and has_image 47 | 48 | def log_and_continue(exn): 49 | """Call in an exception handler to ignore any exception, issue a warning, and continue.""" 50 | logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') 51 | return True 52 | 53 | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): 54 | """Return function over iterator that groups key, value pairs into samples. 55 | 56 | :param keys: function that splits the key into key and extension (base_plus_ext) 57 | :param lcase: convert suffixes to lower case (Default value = True) 58 | """ 59 | current_sample = None 60 | for filesample in data: 61 | assert isinstance(filesample, dict) 62 | fname, value = filesample["fname"], filesample["data"] 63 | prefix, suffix = keys(fname) 64 | if prefix is None: 65 | continue 66 | if lcase: 67 | suffix = suffix.lower() 68 | # FIXME webdataset version throws if suffix in current_sample, but we have a potential for 69 | # this happening in the current LAION400m dataset if a tar ends with same prefix as the next 70 | # begins, rare, but can happen since prefix aren't unique across tar files in that dataset 71 | if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: 72 | if valid_sample(current_sample): 73 | yield current_sample 74 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 75 | if suffixes is None or suffix in suffixes: 76 | current_sample[suffix] = value 77 | if valid_sample(current_sample): 78 | yield current_sample 79 | 80 | def tarfile_to_samples_nothrow(src, handler=log_and_continue): 81 | # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw 82 | streams = url_opener(src, handler=handler) 83 | files = tar_file_expander(streams, handler=handler) 84 | samples = group_by_keys_nothrow(files, handler=handler) 85 | 86 | return samples 87 | 88 | class wds_process: 89 | def __init__(self, transform=None): 90 | if transform == None: 91 | self.transform = transforms.Compose([ 92 | transforms.Resize((512, 512)), 93 | transforms.RandomHorizontalFlip(), 94 | transforms.ToTensor(), 95 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 96 | ]) 97 | else: 98 | self.transform = transform 99 | 100 | def __call__(self, sample): 101 | base64_str = sample['jpg'] 102 | img = Image.open(io.BytesIO(base64_str)).convert("RGB") 103 | img = self.transform(img) 104 | # img.save('1.png') 105 | json_line = sample['json'] 106 | text = json.loads(json_line)['caption'] 107 | return img, text 108 | 109 | class Webdataset_Vespa: 110 | def __init__( 111 | self, 112 | args, 113 | transform, 114 | world_size: int, 115 | num_train_examples: int, 116 | per_gpu_batch_size: int, 117 | global_batch_size: int, 118 | num_workers: int, 119 | pin_memory: bool = False, 120 | persistent_workers: bool = False, 121 | ): 122 | self.args = args 123 | tar_list = os.listdir(args.anna_path) 124 | urls = [os.path.join(args.anna_path, f) for f in tar_list] 125 | process = wds_process(transform) 126 | 127 | self.dataset = wds.DataPipeline( 128 | wds.SimpleShardList(urls), 129 | # at this point we have an iterator over all the shards 130 | wds.shuffle(len(urls)), 131 | # add wds.split_by_node here if you are using multiple nodes 132 | wds.split_by_worker, 133 | wds.split_by_node, 134 | # at this point, we have an iterator over the shards assigned to each worker 135 | wds.tarfile_to_samples(), 136 | # this shuffles the samples in memory 137 | wds.shuffle(1000), 138 | # this decodes the images and json 139 | wds.map(process), 140 | wds.shuffle(1000), 141 | wds.batched(int(args.global_batch_size // world_size),) 142 | ) 143 | 144 | num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker 145 | num_batches = num_worker_batches * num_workers 146 | num_samples = num_batches * global_batch_size 147 | 148 | self.dataloader = wds.WebLoader( 149 | self.dataset, 150 | batch_size=None, 151 | shuffle=False, 152 | num_workers=num_workers, 153 | pin_memory=pin_memory, 154 | persistent_workers=persistent_workers, 155 | ) 156 | # add meta-data to dataloader instance for convenience 157 | self.dataloader.num_batches = num_batches 158 | self.dataloader.num_train_images = num_samples 159 | 160 | @property 161 | def train_dataset(self): 162 | return self.dataset 163 | 164 | @property 165 | def train_dataloader(self): 166 | return self.dataloader -------------------------------------------------------------------------------- /mamba/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.log(1.0 + tl.exp(dt)) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /causal-conv1d/tests/test_causal_conv1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import pytest 7 | 8 | from einops import rearrange 9 | 10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref 11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("channel_last", [False, True]) 15 | # @pytest.mark.parametrize('channel_last', [True]) 16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 17 | # @pytest.mark.parametrize('itype', [torch.float16]) 18 | @pytest.mark.parametrize("silu_activation", [False, True]) 19 | # @pytest.mark.parametrize('silu_activation', [True]) 20 | @pytest.mark.parametrize("has_bias", [False, True]) 21 | # @pytest.mark.parametrize('has_bias', [True]) 22 | @pytest.mark.parametrize("width", [2, 3, 4]) 23 | # @pytest.mark.parametrize('width', [2]) 24 | @pytest.mark.parametrize( 25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 26 | ) 27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 28 | # @pytest.mark.parametrize('seqlen', [128]) 29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last): 30 | device = "cuda" 31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 32 | if itype == torch.bfloat16: 33 | rtol, atol = 1e-2, 5e-2 34 | rtolw, atolw = (1e-3, 1e-3) 35 | # set seed 36 | torch.random.manual_seed(0) 37 | batch_size = 2 38 | # batch_size = 1 39 | dim = 4096 + 32 # Try dim not divisible by 64 40 | # dim = 64 41 | if not channel_last: 42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 43 | else: 44 | x = rearrange( 45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 46 | ).requires_grad_() 47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 48 | if has_bias: 49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 50 | else: 51 | bias = None 52 | x_ref = x.detach().clone().requires_grad_() 53 | weight_ref = weight.detach().clone().requires_grad_() 54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 55 | activation = None if not silu_activation else "silu" 56 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) 58 | 59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 62 | 63 | g = torch.randn_like(out) 64 | out_ref.backward(g) 65 | out.backward(g) 66 | 67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 69 | if has_bias: 70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 71 | 72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 74 | if has_bias: 75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 76 | 77 | 78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 79 | # @pytest.mark.parametrize('itype', [torch.float16]) 80 | @pytest.mark.parametrize("silu_activation", [False, True]) 81 | # @pytest.mark.parametrize('silu_activation', [False]) 82 | @pytest.mark.parametrize("has_bias", [False, True]) 83 | # @pytest.mark.parametrize('has_bias', [True]) 84 | @pytest.mark.parametrize("width", [2, 3, 4]) 85 | # @pytest.mark.parametrize('width', [2]) 86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 87 | # @pytest.mark.parametrize("dim", [2048]) 88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): 89 | device = "cuda" 90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 91 | if itype == torch.bfloat16: 92 | rtol, atol = 1e-2, 5e-2 93 | rtolw, atolw = (1e-3, 1e-3) 94 | # set seed 95 | torch.random.manual_seed(0) 96 | batch_size = 2 97 | # batch_size = 1 98 | # dim = 64 99 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype) 101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 102 | if has_bias: 103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 104 | else: 105 | bias = None 106 | conv_state_ref = conv_state.detach().clone() 107 | activation = None if not silu_activation else "silu" 108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) 109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) 110 | 111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 113 | assert torch.equal(conv_state, conv_state_ref) 114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 115 | 116 | 117 | # @pytest.mark.parametrize("channel_last", [False, True]) 118 | @pytest.mark.parametrize('channel_last', [True]) 119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 120 | @pytest.mark.parametrize('itype', [torch.bfloat16]) 121 | # @pytest.mark.parametrize("silu_activation", [False, True]) 122 | @pytest.mark.parametrize('silu_activation', [True]) 123 | # @pytest.mark.parametrize("has_bias", [False, True]) 124 | @pytest.mark.parametrize('has_bias', [True]) 125 | # @pytest.mark.parametrize("width", [2, 3, 4]) 126 | @pytest.mark.parametrize('width', [4]) 127 | @pytest.mark.parametrize( 128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 129 | "seqlen", [2048] 130 | ) 131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 132 | # @pytest.mark.parametrize('seqlen', [128]) 133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): 134 | device = "cuda" 135 | # set seed 136 | torch.random.manual_seed(0) 137 | batch_size = 2 138 | # batch_size = 1 139 | dim = 4096 + 32 # Try dim not divisible by 64 140 | # dim = 64 141 | if not channel_last: 142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 143 | else: 144 | x = rearrange( 145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 146 | ).requires_grad_() 147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 148 | if has_bias: 149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 150 | else: 151 | bias = None 152 | activation = None if not silu_activation else "silu" 153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation) 154 | g = torch.randn_like(out0) 155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) 156 | dw_atol = 1e-4 157 | db_atol = 1e-4 158 | 159 | for i in range(10000): 160 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) 162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol) 163 | # if not dw_equal: 164 | # breakpoint() 165 | if has_bias: 166 | db_equal = torch.allclose(db, db0, atol=db_atol) 167 | # if not db_equal: 168 | # breakpoint() 169 | assert torch.equal(out, out0) 170 | assert torch.equal(dx, dx0) 171 | assert dw_equal 172 | if has_bias: 173 | assert dw_equal 174 | -------------------------------------------------------------------------------- /tools/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import einops 4 | import random 5 | import numpy as np 6 | import json 7 | from PIL import Image 8 | from random import choice 9 | 10 | from datasets import load_dataset 11 | from torch.utils.data import Dataset 12 | from datasets import load_dataset 13 | import torchvision.transforms as transforms 14 | from decord import VideoReader 15 | 16 | 17 | 18 | class CelebADataset(Dataset): 19 | def __init__(self, data_path, transform): 20 | data = load_dataset(data_path) 21 | self.data = data['train'] 22 | self.transform = transform 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, index): 28 | image = self.data[index]['image'].convert("RGB") 29 | return self.transform(image), torch.tensor(index).long() 30 | 31 | 32 | 33 | def center_crop(width, height, img): 34 | resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos'] 35 | crop = np.min(img.shape[:2]) 36 | img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2, 37 | (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] 38 | try: 39 | img = Image.fromarray(img, 'RGB') 40 | except: 41 | img = Image.fromarray(img) 42 | img = img.resize((width, height), resample) 43 | 44 | return np.array(img).astype(np.uint8) 45 | 46 | 47 | class MSCOCODataset(Dataset): 48 | def __init__(self, root, annFile, transform, ): 49 | from pycocotools.coco import COCO 50 | self.root = root 51 | 52 | self.coco = COCO(annFile) 53 | self.keys = list(sorted(self.coco.imgs.keys())) 54 | self.transform = transform 55 | 56 | def _load_image(self, key: int): 57 | path = self.coco.loadImgs(key)[0]["file_name"] 58 | return Image.open(os.path.join(self.root, path)).convert("RGB") 59 | 60 | def _load_target(self, key: int): 61 | return self.coco.loadAnns(self.coco.getAnnIds(key)) 62 | 63 | def __len__(self): 64 | return len(self.keys) 65 | 66 | def __getitem__(self, index): 67 | key = self.keys[index] 68 | image = self._load_image(key) 69 | image = self.transform(image) 70 | 71 | anns = self._load_target(key) 72 | target = [] 73 | for ann in anns: 74 | target.append(ann['caption']) 75 | 76 | return image, choice(target) 77 | 78 | 79 | import io 80 | class wds_process: 81 | def __init__(self, transform=None): 82 | if transform == None: 83 | self.transform = transforms.Compose([ 84 | transforms.Resize((512, 512)), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 88 | ]) 89 | else: 90 | self.transform = transform 91 | 92 | def __call__(self, sample): 93 | base64_str = sample['jpg'] 94 | img = Image.open(io.BytesIO(base64_str)).convert("RGB") 95 | img = self.transform(img) 96 | # img.save('1.png') 97 | json_line = sample['json'] 98 | text = json.loads(json_line)['caption'] 99 | return img, text 100 | 101 | 102 | class MJDataset(Dataset): 103 | def __init__(self, path, transform): 104 | with open(path, 'r') as f: 105 | self.data = json.load(f) 106 | self.key_list = [key for key in self.data.keys()] 107 | self.transform = transform 108 | 109 | def __len__(self): 110 | return len(self.key_list) 111 | 112 | def __getitem__(self, index): 113 | img_path = self.key_list[index] 114 | img = Image.open(img_path).convert("RGB") 115 | img = self.transform(img) 116 | txt = self.data[img_path]['caption'] 117 | return img, txt 118 | 119 | 120 | 121 | 122 | class TagImageNetDataset(Dataset): 123 | def __init__( 124 | self, 125 | path, 126 | transform, 127 | ): 128 | with open(path, 'r') as f: 129 | self.data = json.load(f) 130 | self.transform = transform 131 | 132 | def __len__(self): 133 | return len(self.data) 134 | 135 | def __getitem__(self, index): 136 | img_path = self.data[index]['image'] 137 | img = Image.open(img_path).convert("RGB") 138 | img = self.transform(img) 139 | txt = self.data[index]['text'] 140 | return img, txt 141 | 142 | 143 | 144 | class UCFDataset(Dataset): 145 | def __init__( 146 | self, 147 | data_path, 148 | sample_size=64, 149 | sample_stride=4, 150 | sample_n_frames=8, 151 | is_image=True, 152 | ): 153 | with open(data_path, 'r') as f: 154 | self.dataset = json.load(f) 155 | 156 | self.sample_stride = sample_stride 157 | self.sample_n_frames = sample_n_frames 158 | 159 | self.is_image = is_image 160 | self.length = len(self.dataset) 161 | 162 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 163 | self.pixel_transforms = transforms.Compose([ 164 | transforms.RandomHorizontalFlip(), 165 | transforms.Resize(sample_size[0]), 166 | transforms.CenterCrop(sample_size), 167 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 168 | ]) 169 | 170 | def get_batch(self, idx): 171 | name = self.dataset[idx]['text'] 172 | video_reader = VideoReader(self.dataset[idx]['video']) 173 | video_length = len(video_reader) 174 | if not self.is_image: 175 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) 176 | start_idx = random.randint(0, video_length - clip_length) 177 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 178 | else: 179 | batch_index = [random.randint(0, video_length - 1)] 180 | 181 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 182 | pixel_values = pixel_values / 255. 183 | del video_reader 184 | 185 | if self.is_image: 186 | pixel_values = pixel_values[0] 187 | 188 | return pixel_values, name 189 | 190 | 191 | def __len__(self): 192 | return self.length 193 | 194 | def __getitem__(self, idx): 195 | pixel_values, name = self.get_batch(idx) 196 | pixel_values = self.pixel_transforms(pixel_values) 197 | return pixel_values, name 198 | 199 | 200 | 201 | 202 | class FaceDataset(Dataset): 203 | def __init__( 204 | self, 205 | data_path, 206 | sample_size=64, 207 | sample_stride=4, 208 | sample_n_frames=8, 209 | is_image=True, 210 | ): 211 | with open(data_path, 'r') as f: 212 | self.dataset = json.load(f) 213 | 214 | self.sample_stride = sample_stride 215 | self.sample_n_frames = sample_n_frames 216 | 217 | self.is_image = is_image 218 | self.length = len(self.dataset) 219 | 220 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 221 | self.pixel_transforms = transforms.Compose([ 222 | transforms.RandomHorizontalFlip(), 223 | transforms.Resize(sample_size[0]), 224 | transforms.CenterCrop(sample_size), 225 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 226 | ]) 227 | 228 | def get_batch(self, idx): 229 | name = self.dataset[idx]['text'] 230 | video_reader = VideoReader(self.dataset[idx]['video']) 231 | video_length = len(video_reader) 232 | if not self.is_image: 233 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) 234 | start_idx = random.randint(0, video_length - clip_length) 235 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 236 | else: 237 | batch_index = [random.randint(0, video_length - 1)] 238 | 239 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 240 | pixel_values = pixel_values / 255. 241 | del video_reader 242 | 243 | if self.is_image: 244 | pixel_values = pixel_values[0] 245 | 246 | return pixel_values, name 247 | 248 | 249 | def __len__(self): 250 | return self.length 251 | 252 | def __getitem__(self, idx): 253 | pixel_values, name = self.get_batch(idx) 254 | pixel_values = self.pixel_transforms(pixel_values) 255 | return pixel_values, name 256 | 257 | 258 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | using complex_t = c10::complex; 14 | 15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 16 | return {a.x + b.x, a.y + b.y}; 17 | } 18 | 19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 20 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 21 | } 22 | 23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | template struct BytesToType {}; 30 | 31 | template<> struct BytesToType<16> { 32 | using Type = uint4; 33 | static_assert(sizeof(Type) == 16); 34 | }; 35 | 36 | template<> struct BytesToType<8> { 37 | using Type = uint64_t; 38 | static_assert(sizeof(Type) == 8); 39 | }; 40 | 41 | template<> struct BytesToType<4> { 42 | using Type = uint32_t; 43 | static_assert(sizeof(Type) == 4); 44 | }; 45 | 46 | template<> struct BytesToType<2> { 47 | using Type = uint16_t; 48 | static_assert(sizeof(Type) == 2); 49 | }; 50 | 51 | template<> struct BytesToType<1> { 52 | using Type = uint8_t; 53 | static_assert(sizeof(Type) == 1); 54 | }; 55 | 56 | //////////////////////////////////////////////////////////////////////////////////////////////////// 57 | 58 | template 59 | struct Converter{ 60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 61 | #pragma unroll 62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 63 | } 64 | }; 65 | 66 | template 67 | struct Converter{ 68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 69 | static_assert(N % 2 == 0); 70 | auto &src2 = reinterpret_cast(src); 71 | auto &dst2 = reinterpret_cast(dst); 72 | #pragma unroll 73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 74 | } 75 | }; 76 | 77 | #if __CUDA_ARCH__ >= 800 78 | template 79 | struct Converter{ 80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 81 | static_assert(N % 2 == 0); 82 | auto &src2 = reinterpret_cast(src); 83 | auto &dst2 = reinterpret_cast(dst); 84 | #pragma unroll 85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 86 | } 87 | }; 88 | #endif 89 | 90 | //////////////////////////////////////////////////////////////////////////////////////////////////// 91 | 92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp 93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) { 95 | float t = exp2f(z.real_); 96 | float c, s; 97 | sincosf(z.imag_, &s, &c); 98 | return complex_t(c * t, s * t); 99 | } 100 | 101 | __device__ __forceinline__ complex_t cexpf(complex_t z) { 102 | float t = expf(z.real_); 103 | float c, s; 104 | sincosf(z.imag_, &s, &c); 105 | return complex_t(c * t, s * t); 106 | } 107 | 108 | template struct SSMScanOp; 109 | 110 | template<> 111 | struct SSMScanOp { 112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 114 | } 115 | }; 116 | 117 | template<> 118 | struct SSMScanOp { 119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { 120 | complex_t a0 = complex_t(ab0.x, ab0.y); 121 | complex_t b0 = complex_t(ab0.z, ab0.w); 122 | complex_t a1 = complex_t(ab1.x, ab1.y); 123 | complex_t b1 = complex_t(ab1.z, ab1.w); 124 | complex_t out_a = a1 * a0; 125 | complex_t out_b = a1 * b0 + b1; 126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); 127 | } 128 | }; 129 | 130 | // A stateful callback functor that maintains a running prefix to be applied 131 | // during consecutive scan operations. 132 | template struct SSMScanPrefixCallbackOp { 133 | using scan_t = std::conditional_t, float2, float4>; 134 | scan_t running_prefix; 135 | // Constructor 136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 137 | // Callback operator to be entered by the first warp of threads in the block. 138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 139 | __device__ scan_t operator()(scan_t block_aggregate) { 140 | scan_t old_prefix = running_prefix; 141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 142 | return old_prefix; 143 | } 144 | }; 145 | 146 | //////////////////////////////////////////////////////////////////////////////////////////////////// 147 | 148 | template 149 | inline __device__ void load_input(typename Ktraits::input_t *u, 150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 151 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 152 | int seqlen) { 153 | if constexpr (Ktraits::kIsEvenLen) { 154 | auto& smem_load_vec = reinterpret_cast(smem_load); 155 | using vec_t = typename Ktraits::vec_t; 156 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 157 | reinterpret_cast(u), 158 | reinterpret_cast(u_vals) 159 | ); 160 | } else { 161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 162 | } 163 | } 164 | 165 | template 166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 169 | int seqlen) { 170 | constexpr int kNItems = Ktraits::kNItems; 171 | if constexpr (!Ktraits::kIsComplex) { 172 | typename Ktraits::input_t B_vals_load[kNItems]; 173 | if constexpr (Ktraits::kIsEvenLen) { 174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 175 | using vec_t = typename Ktraits::vec_t; 176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 177 | reinterpret_cast(Bvar), 178 | reinterpret_cast(B_vals_load) 179 | ); 180 | } else { 181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 182 | } 183 | // #pragma unroll 184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 185 | Converter::to_float(B_vals_load, B_vals); 186 | } else { 187 | typename Ktraits::input_t B_vals_load[kNItems * 2]; 188 | if constexpr (Ktraits::kIsEvenLen) { 189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 190 | using vec_t = typename Ktraits::vec_t; 191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 192 | reinterpret_cast(Bvar), 193 | reinterpret_cast(B_vals_load) 194 | ); 195 | } else { 196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 197 | } 198 | #pragma unroll 199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } 200 | } 201 | } 202 | 203 | template 204 | inline __device__ void store_output(typename Ktraits::input_t *out, 205 | const float (&out_vals)[Ktraits::kNItems], 206 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 207 | int seqlen) { 208 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 209 | #pragma unroll 210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 211 | if constexpr (Ktraits::kIsEvenLen) { 212 | auto& smem_store_vec = reinterpret_cast(smem_store); 213 | using vec_t = typename Ktraits::vec_t; 214 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 215 | reinterpret_cast(out), 216 | reinterpret_cast(write_vals) 217 | ); 218 | } else { 219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /mamba/mamba_ssm/models/mixer_seq_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | 3 | import math 4 | from functools import partial 5 | 6 | from collections import namedtuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from mamba_ssm.modules.mamba_simple import Mamba, Block 12 | from mamba_ssm.utils.generation import GenerationMixin 13 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 14 | 15 | try: 16 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 17 | except ImportError: 18 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 19 | 20 | 21 | def create_block( 22 | d_model, 23 | ssm_cfg=None, 24 | norm_epsilon=1e-5, 25 | rms_norm=False, 26 | residual_in_fp32=False, 27 | fused_add_norm=False, 28 | layer_idx=None, 29 | device=None, 30 | dtype=None, 31 | ): 32 | if ssm_cfg is None: 33 | ssm_cfg = {} 34 | factory_kwargs = {"device": device, "dtype": dtype} 35 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 36 | norm_cls = partial( 37 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 38 | ) 39 | block = Block( 40 | d_model, 41 | mixer_cls, 42 | norm_cls=norm_cls, 43 | fused_add_norm=fused_add_norm, 44 | residual_in_fp32=residual_in_fp32, 45 | ) 46 | block.layer_idx = layer_idx 47 | return block 48 | 49 | 50 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 51 | def _init_weights( 52 | module, 53 | n_layer, 54 | initializer_range=0.02, # Now only used for embedding layer. 55 | rescale_prenorm_residual=True, 56 | n_residuals_per_layer=1, # Change to 2 if we have MLP 57 | ): 58 | if isinstance(module, nn.Linear): 59 | if module.bias is not None: 60 | if not getattr(module.bias, "_no_reinit", False): 61 | nn.init.zeros_(module.bias) 62 | elif isinstance(module, nn.Embedding): 63 | nn.init.normal_(module.weight, std=initializer_range) 64 | 65 | if rescale_prenorm_residual: 66 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 67 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 68 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 69 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 70 | # 71 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 72 | for name, p in module.named_parameters(): 73 | if name in ["out_proj.weight", "fc2.weight"]: 74 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 75 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 76 | # We need to reinit p since this code could be called multiple times 77 | # Having just p *= scale would repeatedly scale it down 78 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 79 | with torch.no_grad(): 80 | p /= math.sqrt(n_residuals_per_layer * n_layer) 81 | 82 | 83 | class MixerModel(nn.Module): 84 | def __init__( 85 | self, 86 | d_model: int, 87 | n_layer: int, 88 | vocab_size: int, 89 | ssm_cfg=None, 90 | norm_epsilon: float = 1e-5, 91 | rms_norm: bool = False, 92 | initializer_cfg=None, 93 | fused_add_norm=False, 94 | residual_in_fp32=False, 95 | device=None, 96 | dtype=None, 97 | ) -> None: 98 | factory_kwargs = {"device": device, "dtype": dtype} 99 | super().__init__() 100 | self.residual_in_fp32 = residual_in_fp32 101 | 102 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) 103 | 104 | # We change the order of residual and layer norm: 105 | # Instead of LN -> Attn / MLP -> Add, we do: 106 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 107 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 108 | # This is for performance reason: we can fuse add + layer_norm. 109 | self.fused_add_norm = fused_add_norm 110 | if self.fused_add_norm: 111 | if layer_norm_fn is None or rms_norm_fn is None: 112 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 113 | 114 | self.layers = nn.ModuleList( 115 | [ 116 | create_block( 117 | d_model, 118 | ssm_cfg=ssm_cfg, 119 | norm_epsilon=norm_epsilon, 120 | rms_norm=rms_norm, 121 | residual_in_fp32=residual_in_fp32, 122 | fused_add_norm=fused_add_norm, 123 | layer_idx=i, 124 | **factory_kwargs, 125 | ) 126 | for i in range(n_layer) 127 | ] 128 | ) 129 | 130 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 131 | d_model, eps=norm_epsilon, **factory_kwargs 132 | ) 133 | 134 | self.apply( 135 | partial( 136 | _init_weights, 137 | n_layer=n_layer, 138 | **(initializer_cfg if initializer_cfg is not None else {}), 139 | ) 140 | ) 141 | 142 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 143 | return { 144 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 145 | for i, layer in enumerate(self.layers) 146 | } 147 | 148 | def forward(self, input_ids, inference_params=None): 149 | hidden_states = self.embedding(input_ids) 150 | residual = None 151 | for layer in self.layers: 152 | hidden_states, residual = layer( 153 | hidden_states, residual, inference_params=inference_params 154 | ) 155 | if not self.fused_add_norm: 156 | residual = (hidden_states + residual) if residual is not None else hidden_states 157 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 158 | else: 159 | # Set prenorm=False here since we don't need the residual 160 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 161 | hidden_states = fused_add_norm_fn( 162 | hidden_states, 163 | self.norm_f.weight, 164 | self.norm_f.bias, 165 | eps=self.norm_f.eps, 166 | residual=residual, 167 | prenorm=False, 168 | residual_in_fp32=self.residual_in_fp32, 169 | ) 170 | return hidden_states 171 | 172 | 173 | class MambaLMHeadModel(nn.Module, GenerationMixin): 174 | 175 | def __init__( 176 | self, 177 | d_model: int, 178 | n_layer: int, 179 | vocab_size: int, 180 | initializer_cfg=None, 181 | pad_vocab_size_multiple: int = 1, 182 | device=None, 183 | dtype=None, 184 | **backbone_kwargs, 185 | ) -> None: 186 | factory_kwargs = {"device": device, "dtype": dtype} 187 | super().__init__() 188 | if vocab_size % pad_vocab_size_multiple != 0: 189 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) 190 | self.backbone = MixerModel( 191 | d_model=d_model, 192 | n_layer=n_layer, 193 | vocab_size=vocab_size, 194 | initializer_cfg=initializer_cfg, 195 | **backbone_kwargs, 196 | **factory_kwargs, 197 | ) 198 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) 199 | 200 | # Initialize weights and apply final processing 201 | self.apply( 202 | partial( 203 | _init_weights, 204 | n_layer=n_layer, 205 | **(initializer_cfg if initializer_cfg is not None else {}), 206 | ) 207 | ) 208 | self.tie_weights() 209 | 210 | def tie_weights(self): 211 | self.lm_head.weight = self.backbone.embedding.weight 212 | 213 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 214 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 215 | 216 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): 217 | """ 218 | "position_ids" is just to be compatible with Transformer generation. We don't use it. 219 | num_last_tokens: if > 0, only return the logits for the last n tokens 220 | """ 221 | hidden_states = self.backbone(input_ids, inference_params=inference_params) 222 | if num_last_tokens > 0: 223 | hidden_states = hidden_states[:, -num_last_tokens:] 224 | lm_logits = self.lm_head(hidden_states) 225 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) 226 | return CausalLMOutput(logits=lm_logits) 227 | 228 | @classmethod 229 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): 230 | config = load_config_hf(pretrained_model_name) 231 | model = cls(**config, device=device, dtype=dtype, **kwargs) 232 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) 233 | return model 234 | -------------------------------------------------------------------------------- /tools/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms as TF 40 | from PIL import Image 41 | from scipy import linalg 42 | from torch.nn.functional import adaptive_avg_pool2d 43 | 44 | try: 45 | from tqdm import tqdm 46 | except ImportError: 47 | # If tqdm is not available, provide a mock version of it 48 | def tqdm(x): 49 | return x 50 | 51 | from .inception import InceptionV3 52 | 53 | 54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 55 | 'tif', 'tiff', 'webp'} 56 | 57 | 58 | class ImagePathDataset(torch.utils.data.Dataset): 59 | def __init__(self, files, transforms=None): 60 | self.files = files 61 | self.transforms = transforms 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | def __getitem__(self, i): 67 | path = self.files[i] 68 | img = Image.open(path).convert('RGB') 69 | if self.transforms is not None: 70 | img = self.transforms(img) 71 | return img 72 | 73 | 74 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 75 | """Calculates the activations of the pool_3 layer for all images. 76 | 77 | Params: 78 | -- files : List of image files paths 79 | -- model : Instance of inception model 80 | -- batch_size : Batch size of images for the model to process at once. 81 | Make sure that the number of samples is a multiple of 82 | the batch size, otherwise some samples are ignored. This 83 | behavior is retained to match the original FID score 84 | implementation. 85 | -- dims : Dimensionality of features returned by Inception 86 | -- device : Device to run calculations 87 | -- num_workers : Number of parallel dataloader workers 88 | 89 | Returns: 90 | -- A numpy array of dimension (num images, dims) that contains the 91 | activations of the given tensor when feeding inception with the 92 | query tensor. 93 | """ 94 | model.eval() 95 | 96 | if batch_size > len(files): 97 | print(('Warning: batch size is bigger than the data size. ' 98 | 'Setting batch size to data size')) 99 | batch_size = len(files) 100 | 101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 102 | dataloader = torch.utils.data.DataLoader(dataset, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | drop_last=False, 106 | num_workers=num_workers) 107 | 108 | pred_arr = np.empty((len(files), dims)) 109 | 110 | start_idx = 0 111 | 112 | for batch in tqdm(dataloader): 113 | batch = batch.to(device) 114 | 115 | with torch.no_grad(): 116 | pred = model(batch)[0] 117 | 118 | # If model output is not scalar, apply global spatial average pooling. 119 | # This happens if you choose a dimensionality not equal 2048. 120 | if pred.size(2) != 1 or pred.size(3) != 1: 121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 122 | 123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 124 | 125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 126 | 127 | start_idx = start_idx + pred.shape[0] 128 | 129 | return pred_arr 130 | 131 | 132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 133 | """Numpy implementation of the Frechet Distance. 134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 135 | and X_2 ~ N(mu_2, C_2) is 136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 137 | 138 | Stable version by Dougal J. Sutherland. 139 | 140 | Params: 141 | -- mu1 : Numpy array containing the activations of a layer of the 142 | inception net (like returned by the function 'get_predictions') 143 | for generated samples. 144 | -- mu2 : The sample mean over activations, precalculated on an 145 | representative data set. 146 | -- sigma1: The covariance matrix over activations for generated samples. 147 | -- sigma2: The covariance matrix over activations, precalculated on an 148 | representative data set. 149 | 150 | Returns: 151 | -- : The Frechet Distance. 152 | """ 153 | 154 | mu1 = np.atleast_1d(mu1) 155 | mu2 = np.atleast_1d(mu2) 156 | 157 | sigma1 = np.atleast_2d(sigma1) 158 | sigma2 = np.atleast_2d(sigma2) 159 | 160 | assert mu1.shape == mu2.shape, \ 161 | 'Training and test mean vectors have different lengths' 162 | assert sigma1.shape == sigma2.shape, \ 163 | 'Training and test covariances have different dimensions' 164 | 165 | diff = mu1 - mu2 166 | 167 | # Product might be almost singular 168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 169 | if not np.isfinite(covmean).all(): 170 | msg = ('fid calculation produces singular product; ' 171 | 'adding %s to diagonal of cov estimates') % eps 172 | print(msg) 173 | offset = np.eye(sigma1.shape[0]) * eps 174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 175 | 176 | # Numerical error might give slight imaginary component 177 | if np.iscomplexobj(covmean): 178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 179 | m = np.max(np.abs(covmean.imag)) 180 | raise ValueError('Imaginary component {}'.format(m)) 181 | covmean = covmean.real 182 | 183 | tr_covmean = np.trace(covmean) 184 | 185 | return (diff.dot(diff) + np.trace(sigma1) 186 | + np.trace(sigma2) - 2 * tr_covmean) 187 | 188 | 189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 190 | device='cpu', num_workers=8): 191 | """Calculation of the statistics used by the FID. 192 | Params: 193 | -- files : List of image files paths 194 | -- model : Instance of inception model 195 | -- batch_size : The images numpy array is split into batches with 196 | batch size batch_size. A reasonable batch size 197 | depends on the hardware. 198 | -- dims : Dimensionality of features returned by Inception 199 | -- device : Device to run calculations 200 | -- num_workers : Number of parallel dataloader workers 201 | 202 | Returns: 203 | -- mu : The mean over samples of the activations of the pool_3 layer of 204 | the inception model. 205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 206 | the inception model. 207 | """ 208 | act = get_activations(files, model, batch_size, dims, device, num_workers) 209 | mu = np.mean(act, axis=0) 210 | sigma = np.cov(act, rowvar=False) 211 | return mu, sigma 212 | 213 | 214 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 215 | if path.endswith('.npz'): 216 | with np.load(path) as f: 217 | m, s = f['mu'][:], f['sigma'][:] 218 | else: 219 | path = pathlib.Path(path) 220 | files = sorted([file for ext in IMAGE_EXTENSIONS 221 | for file in path.glob('*.{}'.format(ext))]) 222 | m, s = calculate_activation_statistics(files, model, batch_size, 223 | dims, device, num_workers) 224 | 225 | return m, s 226 | 227 | 228 | def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8): 229 | if device is None: 230 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 231 | else: 232 | device = torch.device(device) 233 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 234 | model = InceptionV3([block_idx]).to(device) 235 | m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers) 236 | np.savez(out_path, mu=m1, sigma=s1) 237 | 238 | 239 | def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8): 240 | """Calculates the FID of two paths""" 241 | if device is None: 242 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 243 | else: 244 | device = torch.device(device) 245 | 246 | for p in paths: 247 | if not os.path.exists(p): 248 | raise RuntimeError('Invalid path: %s' % p) 249 | 250 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 251 | 252 | model = InceptionV3([block_idx]).to(device) 253 | 254 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 255 | dims, device, num_workers) 256 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 257 | dims, device, num_workers) 258 | print(m1, s1, m2, s2) 259 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 260 | 261 | return fid_value 262 | -------------------------------------------------------------------------------- /causal-conv1d/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | 11 | from setuptools import setup, find_packages 12 | import subprocess 13 | 14 | import urllib.request 15 | import urllib.error 16 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 17 | 18 | import torch 19 | from torch.utils.cpp_extension import ( 20 | BuildExtension, 21 | CppExtension, 22 | CUDAExtension, 23 | CUDA_HOME, 24 | ) 25 | 26 | 27 | with open("README.md", "r", encoding="utf-8") as fh: 28 | long_description = fh.read() 29 | 30 | 31 | # ninja build does not work unless include_dirs are abs path 32 | this_dir = os.path.dirname(os.path.abspath(__file__)) 33 | 34 | PACKAGE_NAME = "causal_conv1d" 35 | 36 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" 37 | 38 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 39 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 40 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" 41 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 42 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 43 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" 44 | 45 | 46 | def get_platform(): 47 | """ 48 | Returns the platform name as used in wheel filenames. 49 | """ 50 | if sys.platform.startswith("linux"): 51 | return "linux_x86_64" 52 | elif sys.platform == "darwin": 53 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 54 | return f"macosx_{mac_version}_x86_64" 55 | elif sys.platform == "win32": 56 | return "win_amd64" 57 | else: 58 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 59 | 60 | 61 | def get_cuda_bare_metal_version(cuda_dir): 62 | raw_output = subprocess.check_output( 63 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 64 | ) 65 | output = raw_output.split() 66 | release_idx = output.index("release") + 1 67 | bare_metal_version = parse(output[release_idx].split(",")[0]) 68 | 69 | return raw_output, bare_metal_version 70 | 71 | 72 | def check_if_cuda_home_none(global_option: str) -> None: 73 | if CUDA_HOME is not None: 74 | return 75 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 76 | # in that case. 77 | warnings.warn( 78 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 79 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 80 | "only images whose names contain 'devel' will provide nvcc." 81 | ) 82 | 83 | 84 | def append_nvcc_threads(nvcc_extra_args): 85 | return nvcc_extra_args + ["--threads", "4"] 86 | 87 | 88 | cmdclass = {} 89 | ext_modules = [] 90 | 91 | if not SKIP_CUDA_BUILD: 92 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 93 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 94 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 95 | 96 | check_if_cuda_home_none("causal_conv1d") 97 | # Check, if CUDA11 is installed for compute capability 8.0 98 | cc_flag = [] 99 | if CUDA_HOME is not None: 100 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 101 | if bare_metal_version < Version("11.6"): 102 | raise RuntimeError( 103 | "causal_conv1d is only supported on CUDA 11.6 and above. " 104 | "Note: make sure nvcc has a supported version by running nvcc -V." 105 | ) 106 | 107 | cc_flag.append("-gencode") 108 | cc_flag.append("arch=compute_70,code=sm_70") 109 | cc_flag.append("-gencode") 110 | cc_flag.append("arch=compute_80,code=sm_80") 111 | if bare_metal_version >= Version("11.8"): 112 | cc_flag.append("-gencode") 113 | cc_flag.append("arch=compute_90,code=sm_90") 114 | 115 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 116 | # torch._C._GLIBCXX_USE_CXX11_ABI 117 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 118 | if FORCE_CXX11_ABI: 119 | torch._C._GLIBCXX_USE_CXX11_ABI = True 120 | 121 | ext_modules.append( 122 | CUDAExtension( 123 | name="causal_conv1d_cuda", 124 | sources=[ 125 | "csrc/causal_conv1d.cpp", 126 | "csrc/causal_conv1d_fwd.cu", 127 | "csrc/causal_conv1d_bwd.cu", 128 | "csrc/causal_conv1d_update.cu", 129 | ], 130 | extra_compile_args={ 131 | "cxx": ["-O3"], 132 | "nvcc": append_nvcc_threads( 133 | [ 134 | "-O3", 135 | "-U__CUDA_NO_HALF_OPERATORS__", 136 | "-U__CUDA_NO_HALF_CONVERSIONS__", 137 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 138 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 139 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 140 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 141 | "--expt-relaxed-constexpr", 142 | "--expt-extended-lambda", 143 | "--use_fast_math", 144 | "--ptxas-options=-v", 145 | "-lineinfo", 146 | ] 147 | + cc_flag 148 | ), 149 | }, 150 | include_dirs=[this_dir], 151 | ) 152 | ) 153 | 154 | 155 | def get_package_version(): 156 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: 157 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 158 | public_version = ast.literal_eval(version_match.group(1)) 159 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") 160 | if local_version: 161 | return f"{public_version}+{local_version}" 162 | else: 163 | return str(public_version) 164 | 165 | 166 | def get_wheel_url(): 167 | # Determine the version numbers that will be used to determine the correct wheel 168 | # We're using the CUDA version used to build torch, not the one currently installed 169 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 170 | torch_cuda_version = parse(torch.version.cuda) 171 | torch_version_raw = parse(torch.__version__) 172 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 173 | # to save CI time. Minor versions should be compatible. 174 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 175 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 176 | platform_name = get_platform() 177 | causal_conv1d_version = get_package_version() 178 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 179 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 180 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 181 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 182 | 183 | # Determine wheel URL based on CUDA version, torch version, python version and OS 184 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 185 | wheel_url = BASE_WHEEL_URL.format( 186 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename 187 | ) 188 | return wheel_url, wheel_filename 189 | 190 | 191 | class CachedWheelsCommand(_bdist_wheel): 192 | """ 193 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 194 | find an existing wheel (which is currently the case for all installs). We use 195 | the environment parameters to detect whether there is already a pre-built version of a compatible 196 | wheel available and short-circuits the standard full build pipeline. 197 | """ 198 | 199 | def run(self): 200 | if FORCE_BUILD: 201 | return super().run() 202 | 203 | wheel_url, wheel_filename = get_wheel_url() 204 | print("Guessing wheel URL: ", wheel_url) 205 | try: 206 | urllib.request.urlretrieve(wheel_url, wheel_filename) 207 | 208 | # Make the archive 209 | # Lifted from the root wheel processing command 210 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 211 | if not os.path.exists(self.dist_dir): 212 | os.makedirs(self.dist_dir) 213 | 214 | impl_tag, abi_tag, plat_tag = self.get_tag() 215 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 216 | 217 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 218 | print("Raw wheel path", wheel_path) 219 | os.rename(wheel_filename, wheel_path) 220 | except urllib.error.HTTPError: 221 | print("Precompiled wheel not found. Building from source...") 222 | # If the wheel could not be downloaded, build from source 223 | super().run() 224 | 225 | 226 | setup( 227 | name=PACKAGE_NAME, 228 | version=get_package_version(), 229 | packages=find_packages( 230 | exclude=( 231 | "build", 232 | "csrc", 233 | "include", 234 | "tests", 235 | "dist", 236 | "docs", 237 | "benchmarks", 238 | "causal_conv1d.egg-info", 239 | ) 240 | ), 241 | author="Tri Dao", 242 | author_email="tri@tridao.me", 243 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface", 244 | long_description=long_description, 245 | long_description_content_type="text/markdown", 246 | url="https://github.com/Dao-AILab/causal-conv1d", 247 | classifiers=[ 248 | "Programming Language :: Python :: 3", 249 | "License :: OSI Approved :: BSD License", 250 | "Operating System :: Unix", 251 | ], 252 | ext_modules=ext_modules, 253 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 254 | if ext_modules 255 | else { 256 | "bdist_wheel": CachedWheelsCommand, 257 | }, 258 | python_requires=">=3.7", 259 | install_requires=[ 260 | "torch", 261 | "packaging", 262 | "ninja", 263 | ], 264 | ) 265 | -------------------------------------------------------------------------------- /t5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | import html 5 | import urllib.parse as ul 6 | 7 | import ftfy 8 | import torch 9 | from bs4 import BeautifulSoup 10 | from transformers import T5EncoderModel, AutoTokenizer 11 | from huggingface_hub import hf_hub_download 12 | 13 | class T5Embedder: 14 | 15 | available_models = ['t5-v1_1-xxl'] 16 | bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa 17 | 18 | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, 19 | t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): 20 | self.device = torch.device(device) 21 | self.torch_dtype = torch_dtype or torch.bfloat16 22 | if t5_model_kwargs is None: 23 | t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} 24 | if use_offload_folder is not None: 25 | t5_model_kwargs['offload_folder'] = use_offload_folder 26 | t5_model_kwargs['device_map'] = { 27 | 'shared': self.device, 28 | 'encoder.embed_tokens': self.device, 29 | 'encoder.block.0': self.device, 30 | 'encoder.block.1': self.device, 31 | 'encoder.block.2': self.device, 32 | 'encoder.block.3': self.device, 33 | 'encoder.block.4': self.device, 34 | 'encoder.block.5': self.device, 35 | 'encoder.block.6': self.device, 36 | 'encoder.block.7': self.device, 37 | 'encoder.block.8': self.device, 38 | 'encoder.block.9': self.device, 39 | 'encoder.block.10': self.device, 40 | 'encoder.block.11': self.device, 41 | 'encoder.block.12': 'disk', 42 | 'encoder.block.13': 'disk', 43 | 'encoder.block.14': 'disk', 44 | 'encoder.block.15': 'disk', 45 | 'encoder.block.16': 'disk', 46 | 'encoder.block.17': 'disk', 47 | 'encoder.block.18': 'disk', 48 | 'encoder.block.19': 'disk', 49 | 'encoder.block.20': 'disk', 50 | 'encoder.block.21': 'disk', 51 | 'encoder.block.22': 'disk', 52 | 'encoder.block.23': 'disk', 53 | 'encoder.final_layer_norm': 'disk', 54 | 'encoder.dropout': 'disk', 55 | } 56 | else: 57 | t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} 58 | 59 | self.use_text_preprocessing = use_text_preprocessing 60 | self.hf_token = hf_token 61 | self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') 62 | self.dir_or_name = dir_or_name 63 | tokenizer_path, path = dir_or_name, dir_or_name 64 | 65 | """ 66 | if local_cache: 67 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 68 | tokenizer_path, path = cache_dir, cache_dir 69 | elif dir_or_name in self.available_models: 70 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 71 | for filename in [ 72 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 73 | 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' 74 | ]: 75 | hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, 76 | force_filename=filename, token=self.hf_token) 77 | tokenizer_path, path = cache_dir, cache_dir 78 | else: 79 | cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') 80 | for filename in [ 81 | 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', 82 | ]: 83 | hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, 84 | force_filename=filename, token=self.hf_token) 85 | tokenizer_path = cache_dir 86 | 87 | print(tokenizer_path) 88 | """ 89 | 90 | self.tokenizer = AutoTokenizer.from_pretrained(cache_dir) 91 | self.model = T5EncoderModel.from_pretrained(cache_dir, **t5_model_kwargs).eval() 92 | self.model_max_length = model_max_length 93 | 94 | def get_text_embeddings(self, texts): 95 | texts = [self.text_preprocessing(text) for text in texts] 96 | 97 | text_tokens_and_mask = self.tokenizer( 98 | texts, 99 | max_length=self.model_max_length, 100 | padding='max_length', 101 | truncation=True, 102 | return_attention_mask=True, 103 | add_special_tokens=True, 104 | return_tensors='pt' 105 | ) 106 | 107 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] 108 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] 109 | 110 | with torch.no_grad(): 111 | text_encoder_embs = self.model( 112 | input_ids=text_tokens_and_mask['input_ids'].to(self.device), 113 | attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), 114 | )['last_hidden_state'].detach() 115 | return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) 116 | 117 | def text_preprocessing(self, text): 118 | if self.use_text_preprocessing: 119 | # The exact text cleaning as was in the training stage: 120 | text = self.clean_caption(text) 121 | text = self.clean_caption(text) 122 | return text 123 | else: 124 | return text.lower().strip() 125 | 126 | @staticmethod 127 | def basic_clean(text): 128 | text = ftfy.fix_text(text) 129 | text = html.unescape(html.unescape(text)) 130 | return text.strip() 131 | 132 | def clean_caption(self, caption): 133 | caption = str(caption) 134 | caption = ul.unquote_plus(caption) 135 | caption = caption.strip().lower() 136 | caption = re.sub('', 'person', caption) 137 | # urls: 138 | caption = re.sub( 139 | r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa 140 | '', caption) # regex for urls 141 | caption = re.sub( 142 | r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa 143 | '', caption) # regex for urls 144 | # html: 145 | caption = BeautifulSoup(caption, features='html.parser').text 146 | 147 | # @ 148 | caption = re.sub(r'@[\w\d]+\b', '', caption) 149 | 150 | # 31C0—31EF CJK Strokes 151 | # 31F0—31FF Katakana Phonetic Extensions 152 | # 3200—32FF Enclosed CJK Letters and Months 153 | # 3300—33FF CJK Compatibility 154 | # 3400—4DBF CJK Unified Ideographs Extension A 155 | # 4DC0—4DFF Yijing Hexagram Symbols 156 | # 4E00—9FFF CJK Unified Ideographs 157 | caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) 158 | caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) 159 | caption = re.sub(r'[\u3200-\u32ff]+', '', caption) 160 | caption = re.sub(r'[\u3300-\u33ff]+', '', caption) 161 | caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) 162 | caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) 163 | caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) 164 | ####################################################### 165 | 166 | # все виды тире / all types of dash --> "-" 167 | caption = re.sub( 168 | r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa 169 | '-', caption) 170 | 171 | # кавычки к одному стандарту 172 | caption = re.sub(r'[`´«»“”¨]', '"', caption) 173 | caption = re.sub(r'[‘’]', "'", caption) 174 | 175 | # " 176 | caption = re.sub(r'"?', '', caption) 177 | # & 178 | caption = re.sub(r'&', '', caption) 179 | 180 | # ip adresses: 181 | caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) 182 | 183 | # article ids: 184 | caption = re.sub(r'\d:\d\d\s+$', '', caption) 185 | 186 | # \n 187 | caption = re.sub(r'\\n', ' ', caption) 188 | 189 | # "#123" 190 | caption = re.sub(r'#\d{1,3}\b', '', caption) 191 | # "#12345.." 192 | caption = re.sub(r'#\d{5,}\b', '', caption) 193 | # "123456.." 194 | caption = re.sub(r'\b\d{6,}\b', '', caption) 195 | # filenames: 196 | caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) 197 | 198 | # 199 | caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" 200 | caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" 201 | 202 | caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT 203 | caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " 204 | 205 | # this-is-my-cute-cat / this_is_my_cute_cat 206 | regex2 = re.compile(r'(?:\-|\_)') 207 | if len(re.findall(regex2, caption)) > 3: 208 | caption = re.sub(regex2, ' ', caption) 209 | 210 | caption = self.basic_clean(caption) 211 | 212 | caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 213 | caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc 214 | caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 215 | 216 | caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) 217 | caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) 218 | caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) 219 | caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) 220 | caption = re.sub(r'\bpage\s+\d+\b', '', caption) 221 | 222 | caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... 223 | 224 | caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) 225 | 226 | caption = re.sub(r'\b\s+\:\s+', r': ', caption) 227 | caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) 228 | caption = re.sub(r'\s+', ' ', caption) 229 | 230 | caption.strip() 231 | 232 | caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) 233 | caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) 234 | caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) 235 | caption = re.sub(r'^\.\S+$', '', caption) 236 | 237 | return caption.strip() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | 5 | 6 | 7 | def test_timeembedding(): 8 | from models_dis import timestep_embedding 9 | times_steps = torch.randint(1, 100, (1,)) 10 | print(timestep_embedding(times_steps, 1000)) 11 | 12 | 13 | 14 | def test_cifar10(): 15 | data_path = "/TrainData/Multimodal/zhengcong.fei/dis/data" 16 | cifar10 = torchvision.datasets.CIFAR10( 17 | root=data_path, 18 | train=True, 19 | download=False 20 | ) 21 | cifar10_test = torchvision.datasets.CIFAR10( 22 | root=data_path, 23 | train=False, 24 | download=False 25 | ) 26 | print(cifar10) 27 | print(cifar10_test[0]) 28 | 29 | 30 | 31 | def test_imagenet1k(): 32 | data_path = '/TrainData/Multimodal/public/datasets/ImageNet/train' 33 | import torchvision.datasets as datasets 34 | dataset_train = datasets.ImageFolder(data_path) 35 | print(dataset_train[0]) 36 | 37 | 38 | 39 | def test_celeba(): 40 | from datasets import load_dataset 41 | data_path = "/TrainData/Multimodal/zhengcong.fei/dis/data/CelebA" 42 | dataset = load_dataset(data_path) 43 | # dataset = dataset['train'] 44 | # dataset = dataset.map(lambda e: e['image'].convert('RGB'), batched=True) 45 | #print(dataset[0]) 46 | print(dataset['train'][0].keys()) 47 | #print(dataset['train'][0]['image'].convert("RGB")) 48 | # print(len(dataset['train'])) 49 | 50 | 51 | def test_fid_score(): 52 | from tools.fid_score import calculate_fid_given_paths 53 | path1 = '/TrainData/Multimodal/zhengcong.fei/dis/results/cond_cifar10_small/his' 54 | path2 = '/TrainData/Multimodal/zhengcong.fei/dis/results/uncond_cifar10_small/his' 55 | fid = calculate_fid_given_paths((path1, path2)) 56 | 57 | 58 | 59 | def test_vae(): 60 | from diffusers.models import AutoencoderKL 61 | vae_path = '/TrainData/Multimodal/zhengcong.fei/dis/vae' 62 | vae = AutoencoderKL.from_pretrained(vae_path) 63 | 64 | 65 | def test_clip(): 66 | from transformers import CLIPTokenizer, CLIPTextModel 67 | clip_path = '/TrainData/Multimodal/michael.fan/ckpts/sdxl-turbo/text_encoder' 68 | tokenizer_path = '/TrainData/Multimodal/michael.fan/ckpts/sdxl-turbo/tokenizer' 69 | tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 70 | transformer = CLIPTextModel.from_pretrained(clip_path) 71 | 72 | text = ['HighJump'] 73 | batch_encoding = tokenizer(text, truncation=True, max_length=77, return_length=True, 74 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 75 | tokens = batch_encoding["input_ids"] 76 | print(tokens.size()) 77 | tokens = tokenizer.convert_ids_to_tokens(tokens.tolist()[0]) 78 | print(tokens) 79 | 80 | 81 | def test_t5(): 82 | from t5 import T5Embedder 83 | t5_path = '/TrainData/Multimodal/michael.fan/ckpts/DeepFloyd/t5-v1_1-xxl' 84 | 85 | t5 = T5Embedder(device='cuda', local_cache=True, cache_dir=t5_path) 86 | 87 | prompts = ['state space models for video generation', 'mamba model is good'] 88 | with torch.no_grad(): 89 | caption_embs, emb_masks = t5.get_text_embeddings(prompts) 90 | caption_embs = caption_embs.float() # [:, None] 91 | print(caption_embs.size(), emb_masks.size()) 92 | 93 | 94 | def test_vespa(): 95 | from models_vespa import timestep_embedding, VeSpa_models 96 | from thop import profile 97 | 98 | for k, v in VeSpa_models.items(): 99 | print(k) 100 | model = v(img_size=32).cuda() 101 | input_image = torch.randn(1, 3, 32, 32).cuda() 102 | times_steps = torch.randint(1, 100, (1,)).cuda() 103 | context = torch.randn(1, 77, 768).cuda() 104 | flops, _ = profile(model, inputs=(input_image, times_steps, context)) 105 | # out = model(x=input_image, timesteps=times_steps) 106 | #print(out.size()) 107 | print('FLOPs = ' + str(flops * 2/1000**3) + 'G') 108 | 109 | parameters_sum = sum(x.numel() for x in model.parameters()) 110 | print(parameters_sum / 1000000.0, "M") 111 | 112 | 113 | 114 | def test_coco(): 115 | from tools.dataset import MSCOCODataset 116 | annafile = '/TrainData/Multimodal/will.zhang/data/coco2014/annotations/captions_train2014.json' 117 | root = '/TrainData/Multimodal/will.zhang/data/coco2014/train2014/train2014' 118 | dataset = MSCOCODataset(root=root, annFile=annafile,) 119 | print(dataset[0]) 120 | 121 | 122 | 123 | def test_mjdataset(): 124 | from tools.dataset import MJDataset 125 | data_path = '/TrainData/Multimodal/public/datasets_gen/mj580w/cleaned_mj_580w.json' 126 | dataset = MJDataset(path=data_path) 127 | print(dataset[0]) 128 | 129 | 130 | def test_video(): 131 | from einops import rearrange 132 | f = 6 133 | frames = torch.randn(4 * f, 64*64, 768) 134 | print(frames.size()) 135 | frames = rearrange(frames, "(b f) n d -> (b n) f d", f=f) 136 | print(frames.size()) 137 | frames = rearrange(frames, "(b n) f d -> (b f) n d", b=4) 138 | print(frames.size()) 139 | 140 | 141 | def ucf_dataset_create(): 142 | data_path = '/TrainData/Multimodal/public/datasets_gen/video_dataset/UCF-101' 143 | import json 144 | file_list_path = os.listdir(data_path) 145 | print(file_list_path) 146 | 147 | video_test_list = [] 148 | for file in file_list_path: 149 | avi_path_list = os.listdir(os.path.join(data_path, file)) 150 | for avi_path in avi_path_list: 151 | video_test_list.append( 152 | { 153 | "video": os.path.join(data_path, file, avi_path), 154 | "text": file, 155 | } 156 | ) 157 | print(len(video_test_list)) 158 | target_path = '/TrainData/Multimodal/zhengcong.fei/vespa/data/ucf.json' 159 | with open(target_path, 'w') as f: 160 | json.dump(video_test_list, f, indent=4) 161 | 162 | 163 | def test_ucf_dataset(): 164 | from tools.dataset import UCFDataset 165 | data_path = '/TrainData/Multimodal/zhengcong.fei/vespa/data/ucf.json' 166 | dataset = UCFDataset(data_path, is_image=False) 167 | print(dataset[1][0].size()) 168 | # ([8, 3, 64, 64]) 169 | 170 | 171 | 172 | def test_video_vespa(): 173 | from models_vespa import VeSpa_video_models 174 | model = VeSpa_video_models['VeSpa-M/2']( 175 | img_size=64, 176 | channels=32, 177 | enable_temporal_layers=True, 178 | ) 179 | print(model) 180 | parameters_sum = sum(x.numel() for x in model.parameters()) 181 | print(parameters_sum / 1000000.0, "M") 182 | 183 | 184 | 185 | def face_create(): 186 | import json 187 | data_path = '/TrainData/Multimodal/public/datasets_gen/video_dataset/face/training_AV/RAVDESS/train' 188 | file_list_path = os.listdir(data_path) 189 | print(file_list_path) 190 | video_test_list = [] 191 | 192 | for file in file_list_path: 193 | video_test_list.append( 194 | { 195 | "video": os.path.join(data_path, file), 196 | "text": file[:-4].split('_')[2].lower(), 197 | } 198 | ) 199 | # print(video_test_list) 200 | 201 | print(len(video_test_list)) 202 | target_path = '/TrainData/Multimodal/zhengcong.fei/vespa/data/face.json' 203 | with open(target_path, 'w') as f: 204 | json.dump(video_test_list, f, indent=4) 205 | from tools.dataset import FaceDataset 206 | dataset = FaceDataset(target_path, is_image=False) 207 | print(dataset[1][0].size()) 208 | 209 | 210 | import torchvision.transforms as transforms 211 | from PIL import Image 212 | import io 213 | import json 214 | 215 | 216 | 217 | def wds_dataset(): 218 | print('test wds dataset') 219 | 220 | train_shards_path_or_url = ['/maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2/00074.tar', 221 | '/maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2/00782.tar',] 222 | #train_shards_path_or_url = "/maindata/data/shared/multimodal/public/dataset_gen/mj580w_wds2/0{0001..0003}.tar" 223 | import webdataset as wds 224 | import io, base64 225 | from PIL import Image 226 | import json 227 | """ 228 | from tools.dataset import WdsImageTextDataset 229 | dataset = WdsImageTextDataset(urls=train_shards_path_or_url, rank=0, world_size=8) 230 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=4) 231 | for batch in dataloader: 232 | print(batch[0].size()) 233 | break 234 | """ 235 | 236 | from tools.dataset import wds_process 237 | process = wds_process() 238 | 239 | # dataset = wds.WebDataset(train_shards_path_or_url).map(process).batched(batch_size, partial=False) 240 | 241 | dataset = wds.DataPipeline( 242 | wds.SimpleShardList(train_shards_path_or_url), 243 | # at this point we have an iterator over all the shards 244 | wds.shuffle(100), 245 | # add wds.split_by_node here if you are using multiple nodes 246 | wds.split_by_worker, 247 | wds.split_by_node, 248 | # at this point, we have an iterator over the shards assigned to each worker 249 | wds.tarfile_to_samples(), 250 | # this shuffles the samples in memory 251 | wds.shuffle(1000), 252 | # this decodes the images and json 253 | wds.map(process), 254 | wds.shuffle(1000), 255 | wds.batched(2) 256 | ) 257 | #loader = wds.WebLoader(dataset, num_workers=4) 258 | #loader = loader.ddp_equalize(dataset_size // batch_size) 259 | for data in dataset: 260 | print(data[0].size()) 261 | break 262 | 263 | 264 | def test_tag_imagenet(): 265 | from tools.constants import IMAGENET2012_CLASSES 266 | # print(IMAGENET2012_CLASSES) 267 | data_path = '/maindata/data/shared/multimodal/public/dataset_img_only/imagenet/data/train' 268 | from tqdm import tqdm 269 | import os 270 | import json 271 | 272 | tgt_path = '/maindata/data/shared/multimodal/zhengcong.fei/code/vespa/data/imagenet_tag.json' 273 | file_list = [] 274 | i = 0 275 | for file_path, _, file_names in os.walk(data_path): 276 | # print(file_names) 277 | print(i) 278 | i += 1 279 | for file_name in file_names: 280 | tag_id = file_name.split('_')[0] 281 | tag = IMAGENET2012_CLASSES[tag_id] 282 | #print(tag) 283 | image_path = os.path.join(file_path, file_name) 284 | file_list.append( 285 | { 286 | "text": tag, 287 | "image": image_path, 288 | } 289 | ) 290 | # break 291 | # break 292 | #if i > 10: break 293 | 294 | with open(tgt_path, 'w') as f: 295 | json.dump(file_list, f, indent=4) 296 | 297 | test_tag_imagenet() 298 | 299 | # wds_dataset() 300 | # face_create() 301 | # test_video_vespa() 302 | # test_ucf_dataset() 303 | # test_video() 304 | # test_mjdataset() 305 | # test_clip() 306 | # test_vespa() 307 | # test_coco() 308 | # test_cifar10() 309 | # test_imagenet1k() 310 | # test_celeba() 311 | # test_fid_score() 312 | # test_vae() 313 | # test_t5() -------------------------------------------------------------------------------- /mamba/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | import shutil 11 | 12 | from setuptools import setup, find_packages 13 | import subprocess 14 | 15 | import urllib.request 16 | import urllib.error 17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 18 | 19 | import torch 20 | from torch.utils.cpp_extension import ( 21 | BuildExtension, 22 | CppExtension, 23 | CUDAExtension, 24 | CUDA_HOME, 25 | ) 26 | 27 | 28 | with open("README.md", "r", encoding="utf-8") as fh: 29 | long_description = fh.read() 30 | 31 | 32 | # ninja build does not work unless include_dirs are abs path 33 | this_dir = os.path.dirname(os.path.abspath(__file__)) 34 | 35 | PACKAGE_NAME = "mamba_ssm" 36 | 37 | BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" 38 | 39 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 40 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 41 | FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" 42 | SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 43 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 44 | FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" 45 | 46 | 47 | def get_platform(): 48 | """ 49 | Returns the platform name as used in wheel filenames. 50 | """ 51 | if sys.platform.startswith("linux"): 52 | return "linux_x86_64" 53 | elif sys.platform == "darwin": 54 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 55 | return f"macosx_{mac_version}_x86_64" 56 | elif sys.platform == "win32": 57 | return "win_amd64" 58 | else: 59 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 60 | 61 | 62 | def get_cuda_bare_metal_version(cuda_dir): 63 | raw_output = subprocess.check_output( 64 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 65 | ) 66 | output = raw_output.split() 67 | release_idx = output.index("release") + 1 68 | bare_metal_version = parse(output[release_idx].split(",")[0]) 69 | 70 | return raw_output, bare_metal_version 71 | 72 | 73 | def check_if_cuda_home_none(global_option: str) -> None: 74 | if CUDA_HOME is not None: 75 | return 76 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 77 | # in that case. 78 | warnings.warn( 79 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 80 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 81 | "only images whose names contain 'devel' will provide nvcc." 82 | ) 83 | 84 | 85 | def append_nvcc_threads(nvcc_extra_args): 86 | return nvcc_extra_args + ["--threads", "4"] 87 | 88 | 89 | cmdclass = {} 90 | ext_modules = [] 91 | 92 | if not SKIP_CUDA_BUILD: 93 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 94 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 95 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 96 | 97 | check_if_cuda_home_none(PACKAGE_NAME) 98 | # Check, if CUDA11 is installed for compute capability 8.0 99 | cc_flag = [] 100 | if CUDA_HOME is not None: 101 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 102 | if bare_metal_version < Version("11.6"): 103 | raise RuntimeError( 104 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " 105 | "Note: make sure nvcc has a supported version by running nvcc -V." 106 | ) 107 | 108 | cc_flag.append("-gencode") 109 | cc_flag.append("arch=compute_70,code=sm_70") 110 | cc_flag.append("-gencode") 111 | cc_flag.append("arch=compute_80,code=sm_80") 112 | if bare_metal_version >= Version("11.8"): 113 | cc_flag.append("-gencode") 114 | cc_flag.append("arch=compute_90,code=sm_90") 115 | 116 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 117 | # torch._C._GLIBCXX_USE_CXX11_ABI 118 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 119 | if FORCE_CXX11_ABI: 120 | torch._C._GLIBCXX_USE_CXX11_ABI = True 121 | 122 | ext_modules.append( 123 | CUDAExtension( 124 | name="selective_scan_cuda", 125 | sources=[ 126 | "csrc/selective_scan/selective_scan.cpp", 127 | "csrc/selective_scan/selective_scan_fwd_fp32.cu", 128 | "csrc/selective_scan/selective_scan_fwd_fp16.cu", 129 | "csrc/selective_scan/selective_scan_fwd_bf16.cu", 130 | "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", 131 | "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", 132 | "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", 133 | "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", 134 | "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", 135 | "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", 136 | ], 137 | extra_compile_args={ 138 | "cxx": ["-O3", "-std=c++17"], 139 | "nvcc": append_nvcc_threads( 140 | [ 141 | "-O3", 142 | "-std=c++17", 143 | "-U__CUDA_NO_HALF_OPERATORS__", 144 | "-U__CUDA_NO_HALF_CONVERSIONS__", 145 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 146 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 147 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 148 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 149 | "--expt-relaxed-constexpr", 150 | "--expt-extended-lambda", 151 | "--use_fast_math", 152 | "--ptxas-options=-v", 153 | "-lineinfo", 154 | ] 155 | + cc_flag 156 | ), 157 | }, 158 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], 159 | ) 160 | ) 161 | 162 | 163 | def get_package_version(): 164 | with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: 165 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 166 | public_version = ast.literal_eval(version_match.group(1)) 167 | local_version = os.environ.get("MAMBA_LOCAL_VERSION") 168 | if local_version: 169 | return f"{public_version}+{local_version}" 170 | else: 171 | return str(public_version) 172 | 173 | 174 | def get_wheel_url(): 175 | # Determine the version numbers that will be used to determine the correct wheel 176 | # We're using the CUDA version used to build torch, not the one currently installed 177 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 178 | torch_cuda_version = parse(torch.version.cuda) 179 | torch_version_raw = parse(torch.__version__) 180 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 181 | # to save CI time. Minor versions should be compatible. 182 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 183 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 184 | platform_name = get_platform() 185 | mamba_ssm_version = get_package_version() 186 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 187 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 188 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 189 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 190 | 191 | # Determine wheel URL based on CUDA version, torch version, python version and OS 192 | wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 193 | wheel_url = BASE_WHEEL_URL.format( 194 | tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename 195 | ) 196 | return wheel_url, wheel_filename 197 | 198 | 199 | class CachedWheelsCommand(_bdist_wheel): 200 | """ 201 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 202 | find an existing wheel (which is currently the case for all installs). We use 203 | the environment parameters to detect whether there is already a pre-built version of a compatible 204 | wheel available and short-circuits the standard full build pipeline. 205 | """ 206 | 207 | def run(self): 208 | if FORCE_BUILD: 209 | return super().run() 210 | 211 | wheel_url, wheel_filename = get_wheel_url() 212 | print("Guessing wheel URL: ", wheel_url) 213 | try: 214 | urllib.request.urlretrieve(wheel_url, wheel_filename) 215 | 216 | # Make the archive 217 | # Lifted from the root wheel processing command 218 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 219 | if not os.path.exists(self.dist_dir): 220 | os.makedirs(self.dist_dir) 221 | 222 | impl_tag, abi_tag, plat_tag = self.get_tag() 223 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 224 | 225 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 226 | print("Raw wheel path", wheel_path) 227 | shutil.move(wheel_filename, wheel_path) 228 | except urllib.error.HTTPError: 229 | print("Precompiled wheel not found. Building from source...") 230 | # If the wheel could not be downloaded, build from source 231 | super().run() 232 | 233 | 234 | setup( 235 | name=PACKAGE_NAME, 236 | version=get_package_version(), 237 | packages=find_packages( 238 | exclude=( 239 | "build", 240 | "csrc", 241 | "include", 242 | "tests", 243 | "dist", 244 | "docs", 245 | "benchmarks", 246 | "mamba_ssm.egg-info", 247 | ) 248 | ), 249 | author="Tri Dao, Albert Gu", 250 | author_email="tri@tridao.me, agu@cs.cmu.edu", 251 | description="Mamba state-space model", 252 | long_description=long_description, 253 | long_description_content_type="text/markdown", 254 | url="https://github.com/state-spaces/mamba", 255 | classifiers=[ 256 | "Programming Language :: Python :: 3", 257 | "License :: OSI Approved :: BSD License", 258 | "Operating System :: Unix", 259 | ], 260 | ext_modules=ext_modules, 261 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 262 | if ext_modules 263 | else { 264 | "bdist_wheel": CachedWheelsCommand, 265 | }, 266 | python_requires=">=3.7", 267 | install_requires=[ 268 | "torch", 269 | "packaging", 270 | "ninja", 271 | "einops", 272 | "triton", 273 | "transformers", 274 | "causal_conv1d", 275 | ], 276 | ) 277 | -------------------------------------------------------------------------------- /mamba/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Tri Dao, Albert Gu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tools/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "causal_conv1d.h" 11 | 12 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 13 | 14 | #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ 15 | if (ITYPE == at::ScalarType::Half) { \ 16 | using input_t = at::Half; \ 17 | __VA_ARGS__(); \ 18 | } else if (ITYPE == at::ScalarType::BFloat16) { \ 19 | using input_t = at::BFloat16; \ 20 | __VA_ARGS__(); \ 21 | } else if (ITYPE == at::ScalarType::Float) { \ 22 | using input_t = float; \ 23 | __VA_ARGS__(); \ 24 | } else { \ 25 | AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ 26 | } 27 | 28 | #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ 29 | if (WTYPE == at::ScalarType::Half) { \ 30 | using weight_t = at::Half; \ 31 | __VA_ARGS__(); \ 32 | } else if (WTYPE == at::ScalarType::BFloat16) { \ 33 | using weight_t = at::BFloat16; \ 34 | __VA_ARGS__(); \ 35 | } else if (WTYPE == at::ScalarType::Float) { \ 36 | using weight_t = float; \ 37 | __VA_ARGS__(); \ 38 | } else { \ 39 | AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ 40 | } 41 | 42 | template 43 | void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 44 | template 45 | void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 46 | 47 | template 48 | void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 49 | template 50 | void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 51 | 52 | template 53 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 54 | 55 | void set_conv_params_fwd(ConvParamsBase ¶ms, 56 | // sizes 57 | const size_t batch, 58 | const size_t dim, 59 | const size_t seqlen, 60 | const size_t width, 61 | // device pointers 62 | const at::Tensor x, 63 | const at::Tensor weight, 64 | const at::Tensor out, 65 | void* bias_ptr, 66 | bool silu_activation) { 67 | 68 | // Reset the parameters 69 | memset(¶ms, 0, sizeof(params)); 70 | 71 | params.batch = batch; 72 | params.dim = dim; 73 | params.seqlen = seqlen; 74 | params.width = width; 75 | 76 | params.silu_activation = silu_activation; 77 | 78 | // Set the pointers and strides. 79 | params.x_ptr = x.data_ptr(); 80 | params.weight_ptr = weight.data_ptr(); 81 | params.bias_ptr = bias_ptr; 82 | params.out_ptr = out.data_ptr(); 83 | // All stride are in elements, not bytes. 84 | params.x_batch_stride = x.stride(0); 85 | params.x_c_stride = x.stride(1); 86 | params.x_l_stride = x.stride(-1); 87 | params.weight_c_stride = weight.stride(0); 88 | params.weight_width_stride = weight.stride(1); 89 | params.out_batch_stride = out.stride(0); 90 | params.out_c_stride = out.stride(1); 91 | params.out_l_stride = out.stride(-1); 92 | } 93 | 94 | 95 | void set_conv_params_bwd(ConvParamsBwd ¶ms, 96 | // sizes 97 | const size_t batch, 98 | const size_t dim, 99 | const size_t seqlen, 100 | const size_t width, 101 | // device pointers 102 | const at::Tensor x, 103 | const at::Tensor weight, 104 | void* bias_ptr, 105 | const at::Tensor dout, 106 | const at::Tensor dx, 107 | const at::Tensor dweight, 108 | void* dbias_ptr, 109 | bool silu_activation) { 110 | // Pass in "dout" instead of "out", we're not gonna use "out" at all. 111 | set_conv_params_fwd(params, batch, dim, seqlen, width, 112 | x, weight, dout, bias_ptr, silu_activation); 113 | 114 | // Set the pointers and strides. 115 | params.dout_ptr = dout.data_ptr(); 116 | params.dx_ptr = dx.data_ptr(); 117 | params.dweight_ptr = dweight.data_ptr(); 118 | params.dbias_ptr = dbias_ptr; 119 | // All stride are in elements, not bytes. 120 | params.dout_batch_stride = dout.stride(0); 121 | params.dout_c_stride = dout.stride(1); 122 | params.dout_l_stride = dout.stride(2); 123 | params.dweight_c_stride = dweight.stride(0); 124 | params.dweight_width_stride = dweight.stride(1); 125 | params.dx_batch_stride = dx.stride(0); 126 | params.dx_c_stride = dx.stride(1); 127 | params.dx_l_stride = dx.stride(2); 128 | } 129 | 130 | at::Tensor 131 | causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, 132 | const c10::optional &bias_, 133 | bool silu_activation) { 134 | auto input_type = x.scalar_type(); 135 | auto weight_type = weight.scalar_type(); 136 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 137 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 138 | 139 | TORCH_CHECK(x.is_cuda()); 140 | TORCH_CHECK(weight.is_cuda()); 141 | 142 | const auto sizes = x.sizes(); 143 | const int batch_size = sizes[0]; 144 | const int dim = sizes[1]; 145 | const int seqlen = sizes[2]; 146 | const int width = weight.size(-1); 147 | 148 | CHECK_SHAPE(x, batch_size, dim, seqlen); 149 | CHECK_SHAPE(weight, dim, width); 150 | 151 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 152 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 153 | 154 | if (is_channel_last) { 155 | TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); 156 | } 157 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 158 | 159 | 160 | if (bias_.has_value()) { 161 | auto bias = bias_.value(); 162 | TORCH_CHECK(bias.scalar_type() == weight_type); 163 | TORCH_CHECK(bias.is_cuda()); 164 | TORCH_CHECK(bias.stride(-1) == 1); 165 | CHECK_SHAPE(bias, dim); 166 | } 167 | 168 | at::Tensor out = torch::empty_like(x); 169 | 170 | ConvParamsBase params; 171 | set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, 172 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 173 | silu_activation); 174 | 175 | // Otherwise the kernel will be launched from cuda:0 device 176 | // Cast to char to avoid compiler warning about narrowing 177 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 178 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 179 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { 180 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { 181 | if (!is_channel_last) { 182 | causal_conv1d_fwd_cuda(params, stream); 183 | } else { 184 | causal_conv1d_channellast_fwd_cuda(params, stream); 185 | } 186 | }); 187 | }); 188 | return out; 189 | } 190 | 191 | std::vector 192 | causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight, 193 | const c10::optional &bias_, 194 | at::Tensor &dout, 195 | c10::optional &dx_, 196 | bool silu_activation) { 197 | auto input_type = x.scalar_type(); 198 | auto weight_type = weight.scalar_type(); 199 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 200 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 201 | 202 | TORCH_CHECK(x.is_cuda()); 203 | TORCH_CHECK(weight.is_cuda()); 204 | TORCH_CHECK(dout.is_cuda()); 205 | 206 | const auto sizes = x.sizes(); 207 | const int batch_size = sizes[0]; 208 | const int dim = sizes[1]; 209 | const int seqlen = sizes[2]; 210 | const int width = weight.size(-1); 211 | 212 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 213 | 214 | CHECK_SHAPE(x, batch_size, dim, seqlen); 215 | CHECK_SHAPE(weight, dim, width); 216 | CHECK_SHAPE(dout, batch_size, dim, seqlen); 217 | 218 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 219 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 220 | if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } 221 | if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } 222 | 223 | if (bias_.has_value()) { 224 | auto bias = bias_.value(); 225 | TORCH_CHECK(bias.scalar_type() == weight_type); 226 | TORCH_CHECK(bias.is_cuda()); 227 | TORCH_CHECK(bias.stride(-1) == 1); 228 | CHECK_SHAPE(bias, dim); 229 | } 230 | 231 | at::Tensor dx; 232 | if (dx_.has_value()) { 233 | dx = dx_.value(); 234 | TORCH_CHECK(dx.scalar_type() == input_type); 235 | TORCH_CHECK(dx.is_cuda()); 236 | CHECK_SHAPE(dx, batch_size, dim, seqlen); 237 | if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } 238 | if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } 239 | } else { 240 | dx = torch::empty_like(x); 241 | } 242 | 243 | // Otherwise the kernel will be launched from cuda:0 device 244 | // Cast to char to avoid compiler warning about narrowing 245 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 246 | 247 | at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat)); 248 | at::Tensor dbias; 249 | if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); } 250 | 251 | ConvParamsBwd params; 252 | set_conv_params_bwd(params, batch_size, dim, seqlen, width, 253 | x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, 254 | dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr, 255 | silu_activation); 256 | 257 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 258 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { 259 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { 260 | if (!is_channel_last) { 261 | causal_conv1d_bwd_cuda(params, stream); 262 | } else { 263 | causal_conv1d_channellast_bwd_cuda(params, stream); 264 | } 265 | }); 266 | }); 267 | return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias}; 268 | } 269 | 270 | at::Tensor 271 | causal_conv1d_update(const at::Tensor &x, 272 | const at::Tensor &conv_state, 273 | const at::Tensor &weight, 274 | const c10::optional &bias_, 275 | bool silu_activation) { 276 | auto input_type = x.scalar_type(); 277 | auto weight_type = weight.scalar_type(); 278 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 279 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 280 | TORCH_CHECK(conv_state.scalar_type() == input_type); 281 | 282 | TORCH_CHECK(x.is_cuda()); 283 | TORCH_CHECK(conv_state.is_cuda()); 284 | TORCH_CHECK(weight.is_cuda()); 285 | 286 | const auto sizes = x.sizes(); 287 | const int batch_size = sizes[0]; 288 | const int dim = sizes[1]; 289 | const int width = weight.size(-1); 290 | 291 | CHECK_SHAPE(x, batch_size, dim); 292 | CHECK_SHAPE(conv_state, batch_size, dim, width); 293 | CHECK_SHAPE(weight, dim, width); 294 | 295 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 296 | 297 | if (bias_.has_value()) { 298 | auto bias = bias_.value(); 299 | TORCH_CHECK(bias.scalar_type() == weight_type); 300 | TORCH_CHECK(bias.is_cuda()); 301 | TORCH_CHECK(bias.stride(-1) == 1); 302 | CHECK_SHAPE(bias, dim); 303 | } 304 | 305 | at::Tensor out = torch::empty_like(x); 306 | 307 | ConvParamsBase params; 308 | set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, 309 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 310 | silu_activation); 311 | params.conv_state_ptr = conv_state.data_ptr(); 312 | // All stride are in elements, not bytes. 313 | params.conv_state_batch_stride = conv_state.stride(0); 314 | params.conv_state_c_stride = conv_state.stride(1); 315 | params.conv_state_l_stride = conv_state.stride(2); 316 | 317 | // Otherwise the kernel will be launched from cuda:0 device 318 | // Cast to char to avoid compiler warning about narrowing 319 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 320 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 321 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { 322 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { 323 | causal_conv1d_update_cuda(params, stream); 324 | }); 325 | }); 326 | return out; 327 | } 328 | 329 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 330 | m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); 331 | m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); 332 | m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); 333 | } 334 | --------------------------------------------------------------------------------