├── flash_attn ├── ops │ └── __init__.py ├── utils │ └── __init__.py ├── layers │ ├── __init__.py │ └── patch_embed.py ├── losses │ └── __init__.py ├── models │ └── __init__.py ├── modules │ └── __init__.py └── __init__.py ├── AUTHORS ├── training ├── configs │ ├── callbacks │ │ ├── none.yaml │ │ ├── norm-monitor.yaml │ │ ├── model-summary.yaml │ │ ├── causality-monitor.yaml │ │ ├── ema.yaml │ │ ├── flop-count.yaml │ │ ├── params-log.yaml │ │ ├── gpu-monitor.yaml │ │ ├── wandb.yaml │ │ └── default.yaml │ ├── task │ │ └── sequence-model.yaml │ ├── optimizer │ │ ├── adam.yaml │ │ ├── sgd.yaml │ │ ├── adamw.yaml │ │ ├── fusedlamb.yaml │ │ ├── fusedlamb-ds.yaml │ │ ├── adamw-apex.yaml │ │ ├── adamw-apex-distributed.yaml │ │ ├── adamw-zero.yaml │ │ └── adamw-apex-zero.yaml │ ├── metrics │ │ ├── mse.yaml │ │ ├── acc.yaml │ │ ├── perplexity.yaml │ │ ├── num-tokens.yaml │ │ ├── acctop5.yaml │ │ └── acc_ignore_index.yaml │ ├── scheduler │ │ ├── multi-step.yaml │ │ ├── step.yaml │ │ ├── cosine-warmup.yaml │ │ ├── linear-warmup.yaml │ │ ├── invsqrt.yaml │ │ ├── poly-warmup.yaml │ │ ├── cosine-warmup-timm.yaml │ │ └── plateau.yaml │ ├── trainer │ │ ├── ddp.yaml │ │ ├── default.yaml │ │ ├── debug.yaml │ │ └── all_params.yaml │ ├── model │ │ ├── gpt2model │ │ │ ├── gpt2-large.yaml │ │ │ ├── gpt2-medium.yaml │ │ │ ├── gpt2-small.yaml │ │ │ └── gpt2-xlarge.yaml │ │ ├── gpt2.yaml │ │ └── gpt2-hf.yaml │ ├── experiment │ │ ├── owt │ │ │ ├── gpt2xl-hf.yaml │ │ │ ├── gpt2m-hf.yaml │ │ │ ├── gpt2m.yaml │ │ │ ├── gpt2s.yaml │ │ │ ├── gpt2l-hf.yaml │ │ │ ├── gpt2l.yaml │ │ │ ├── gpt2xl.yaml │ │ │ ├── gpt2m-flash.yaml │ │ │ ├── gpt2s-flash.yaml │ │ │ ├── gpt2s-hf.yaml │ │ │ ├── gpt2xl-flash.yaml │ │ │ └── gpt2l-flash.yaml │ │ └── pile │ │ │ ├── gpt3l-flash-rotary-30B.yaml │ │ │ ├── gpt3l-flash-rotary.yaml │ │ │ ├── gpt3m-flash-rotary-30B.yaml │ │ │ ├── gpt3m-flash-rotary.yaml │ │ │ ├── gpt3s-flash-rotary-30B.yaml │ │ │ ├── gpt3s-flash-rotary.yaml │ │ │ ├── gpt3m-hf.yaml │ │ │ ├── gpt3xl-flash-rotary-60B.yaml │ │ │ ├── gpt3xl-flash-rotary.yaml │ │ │ ├── gpt3l-flash-rotary-8k.yaml │ │ │ ├── gpt3m-flash-rotary-8k.yaml │ │ │ ├── gpt3s-flash-rotary-8k.yaml │ │ │ ├── gpt3xl-flash-rotary-8k.yaml │ │ │ ├── gpt3l-hf.yaml │ │ │ ├── gpt3-2.7B-hf.yaml │ │ │ ├── gpt3l-flash-8k.yaml │ │ │ ├── gpt3m-flash-8k.yaml │ │ │ ├── gpt3s-flash-8k.yaml │ │ │ ├── gpt3xl-flash-8k.yaml │ │ │ ├── gpt3s-hf.yaml │ │ │ ├── gpt3-2.7B-hf-hdim128.yaml │ │ │ ├── gpt3-2.7B-flash.yaml │ │ │ ├── gpt3-2.7B-flash-8k.yaml │ │ │ ├── gpt3m-flash.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128.yaml │ │ │ ├── gpt3-2.7B-flash-rotary-8k.yaml │ │ │ ├── gpt3-2.7B-flash-rotary.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary-8k.yaml │ │ │ ├── gpt3s-flash.yaml │ │ │ ├── gpt3l-flash.yaml │ │ │ ├── gpt3xl-hf.yaml │ │ │ └── gpt3xl-flash.yaml │ ├── logger │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── comet.yaml │ │ ├── mlflow.yaml │ │ ├── tensorboard.yaml │ │ ├── neptune.yaml │ │ └── wandb.yaml │ ├── mode │ │ ├── default.yaml │ │ ├── exp.yaml │ │ ├── smoke.yaml │ │ ├── profile.yaml │ │ └── debug.yaml │ ├── datamodule │ │ ├── thepile.yaml │ │ └── openwebtext.yaml │ └── config.yaml └── src │ ├── callbacks │ ├── __init__.py │ ├── loss_scale_monitor.py │ ├── params_log.py │ ├── gpu_affinity.py │ ├── model_checkpoint.py │ ├── flop_count.py │ └── causality_monitor.py │ ├── metrics │ ├── accuracy.py │ └── num_tokens.py │ ├── datamodules │ ├── timm_mixup.py │ └── datasets │ │ ├── lm_dataset.py │ │ └── detokenizer.py │ ├── optim │ └── timm_lr_scheduler.py │ ├── distributed │ └── ddp_comm_hooks.py │ └── utils │ └── flops.py ├── assets ├── flashattn_banner.jpg ├── flashattn_banner.pdf ├── flashattn_memory.jpg ├── flashattn_speedup.jpg ├── flashattention_logo.png ├── flashattn_speedup_t4.jpg ├── gpt2_training_curve.jpg ├── gpt3_training_curve.jpg ├── flashattn_speedup_3090.jpg ├── flashattn_speedup_t4_fwd.jpg ├── gpt2_training_efficiency.jpg ├── gpt3_training_efficiency.jpg ├── flashattn_speedup_a100_d128.jpg ├── flash2_a100_fwd_bwd_benchmark.png └── flash2_h100_fwd_bwd_benchmark.png ├── Makefile ├── csrc ├── flash_attn_with_bias_and_mask │ └── src │ │ ├── cuda_utils.h │ │ ├── fmha_fwd_hdim128.cu │ │ ├── fmha_bwd_with_mask_bias_hdim128.cu │ │ ├── fmha_bwd_hdim128.cu │ │ ├── utils.h │ │ ├── fmha_fwd_hdim32.cu │ │ ├── fmha_fwd_hdim64.cu │ │ ├── fmha_bwd_hdim32.cu │ │ ├── fmha_bwd_with_mask_bias_hdim32.cu │ │ ├── fmha_fwd_with_mask_bias_hdim32.cu │ │ ├── cuda_utils.cu │ │ ├── fmha_fwd_with_mask_bias_hdim128.cu │ │ ├── fmha_fwd_with_mask_bias_hdim64.cu │ │ ├── fmha_bwd_with_mask_bias_hdim64.cu │ │ ├── fmha_bwd_hdim64.cu │ │ ├── utils.cu │ │ └── random_utils.h ├── flash_attn │ └── src │ │ ├── calc_reduced_attn_scores_dispatch │ │ ├── hdim128_fp16_sm80.cu │ │ ├── hdim160_fp16_sm80.cu │ │ ├── hdim192_fp16_sm80.cu │ │ ├── hdim224_fp16_sm80.cu │ │ ├── hdim256_fp16_sm80.cu │ │ ├── hdim32_fp16_sm80.cu │ │ ├── hdim64_fp16_sm80.cu │ │ ├── hdim96_fp16_sm80.cu │ │ ├── hdim32_bf16_sm80.cu │ │ ├── hdim64_bf16_sm80.cu │ │ ├── hdim96_bf16_sm80.cu │ │ ├── hdim128_bf16_sm80.cu │ │ ├── hdim160_bf16_sm80.cu │ │ ├── hdim192_bf16_sm80.cu │ │ ├── hdim224_bf16_sm80.cu │ │ └── hdim256_bf16_sm80.cu │ │ ├── cuda_utils.cu │ │ ├── cuda_utils.h │ │ ├── block_info.h │ │ └── random_utils.h ├── xentropy │ ├── README.md │ └── interface.cpp ├── ft_attention │ ├── README.md │ └── cuda_bf16_wrapper.h ├── fused_dense_lib │ ├── README.md │ └── setup.py ├── layer_norm │ ├── README.md │ ├── ln_fwd_1024.cu │ ├── ln_fwd_1280.cu │ ├── ln_fwd_1536.cu │ ├── ln_fwd_2048.cu │ ├── ln_fwd_256.cu │ ├── ln_fwd_2560.cu │ ├── ln_fwd_3072.cu │ ├── ln_fwd_4096.cu │ ├── ln_fwd_512.cu │ ├── ln_fwd_5120.cu │ ├── ln_fwd_6144.cu │ ├── ln_fwd_7168.cu │ ├── ln_fwd_768.cu │ ├── ln_fwd_8192.cu │ ├── ln_bwd_2048.cu │ ├── ln_bwd_3072.cu │ ├── ln_bwd_4096.cu │ ├── ln_bwd_5120.cu │ ├── ln_bwd_6144.cu │ ├── ln_bwd_7168.cu │ ├── ln_bwd_8192.cu │ ├── ln_bwd_1536.cu │ ├── ln_bwd_256.cu │ ├── ln_bwd_2560.cu │ ├── ln_bwd_512.cu │ ├── ln_bwd_768.cu │ ├── ln_bwd_1024.cu │ ├── ln_bwd_1280.cu │ ├── ln_parallel_fwd_256.cu │ ├── ln_parallel_fwd_1024.cu │ ├── ln_parallel_fwd_1280.cu │ ├── ln_parallel_fwd_1536.cu │ ├── ln_parallel_fwd_2048.cu │ ├── ln_parallel_fwd_2560.cu │ ├── ln_parallel_fwd_3072.cu │ ├── ln_parallel_fwd_4096.cu │ ├── ln_parallel_fwd_512.cu │ ├── ln_parallel_fwd_5120.cu │ ├── ln_parallel_fwd_6144.cu │ ├── ln_parallel_fwd_7168.cu │ ├── ln_parallel_fwd_768.cu │ ├── ln_parallel_fwd_8192.cu │ ├── ln_parallel_bwd_2048.cu │ ├── ln_parallel_bwd_3072.cu │ ├── ln_parallel_bwd_6144.cu │ ├── ln_parallel_bwd_7168.cu │ ├── ln_parallel_bwd_8192.cu │ ├── ln_parallel_bwd_1536.cu │ ├── ln_parallel_bwd_256.cu │ ├── ln_parallel_bwd_2560.cu │ ├── ln_parallel_bwd_512.cu │ ├── ln_parallel_bwd_768.cu │ ├── ln_parallel_bwd_1024.cu │ ├── ln_parallel_bwd_1280.cu │ ├── ln_parallel_bwd_4096.cu │ ├── ln_parallel_bwd_5120.cu │ └── static_switch.h ├── flashmask_v2 │ ├── flash_fwd_combine.cu │ ├── print_val.cu │ ├── cuda_check.h │ └── copy_sm90_bulk_reduce.hpp ├── flash_attn_v3 │ ├── flash_fwd_combine.cu │ ├── cuda_check.h │ └── copy_sm90_bulk_reduce.hpp ├── fused_softmax │ ├── type_shim.h │ └── setup.py └── rotary │ ├── rotary.cpp │ └── rotary_cuda.cu ├── MANIFEST.in ├── .gitmodules ├── .gitignore ├── LICENSE └── tests ├── losses └── test_cross_entropy.py ├── test_rotary.py └── models └── test_vit.py /flash_attn/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_attn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_attn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_attn/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_attn/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_attn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, trid@cs.stanford.edu -------------------------------------------------------------------------------- /training/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/configs/task/sequence-model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.tasks.seq.SequenceModel 2 | -------------------------------------------------------------------------------- /training/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.optim.Adam 3 | -------------------------------------------------------------------------------- /training/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.optim.SGD 3 | -------------------------------------------------------------------------------- /training/configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.optim.AdamW 3 | -------------------------------------------------------------------------------- /training/configs/optimizer/fusedlamb.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: apex.optimizers.FusedLAMB 3 | -------------------------------------------------------------------------------- /assets/flashattn_banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_banner.jpg -------------------------------------------------------------------------------- /assets/flashattn_banner.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_banner.pdf -------------------------------------------------------------------------------- /assets/flashattn_memory.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_memory.jpg -------------------------------------------------------------------------------- /training/configs/callbacks/norm-monitor.yaml: -------------------------------------------------------------------------------- 1 | norm_monitor: 2 | _target_: src.callbacks.norm_monitor.NormMonitor 3 | -------------------------------------------------------------------------------- /training/configs/metrics/mse.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | mse: 3 | _target_: torchmetrics.MeanSquaredError 4 | -------------------------------------------------------------------------------- /training/configs/optimizer/fusedlamb-ds.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: deepspeed.ops.lamb.FusedLamb 3 | -------------------------------------------------------------------------------- /assets/flashattn_speedup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_speedup.jpg -------------------------------------------------------------------------------- /training/configs/metrics/acc.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | acc: 3 | _target_: src.metrics.accuracy.AccuracyMine 4 | -------------------------------------------------------------------------------- /assets/flashattention_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattention_logo.png -------------------------------------------------------------------------------- /assets/flashattn_speedup_t4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_speedup_t4.jpg -------------------------------------------------------------------------------- /assets/gpt2_training_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/gpt2_training_curve.jpg -------------------------------------------------------------------------------- /assets/gpt3_training_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/gpt3_training_curve.jpg -------------------------------------------------------------------------------- /training/configs/callbacks/model-summary.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: pytorch_lightning.callbacks.RichModelSummary 3 | -------------------------------------------------------------------------------- /training/configs/metrics/perplexity.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | ppl: 3 | _target_: src.metrics.perplexity.Perplexity 4 | -------------------------------------------------------------------------------- /training/configs/scheduler/multi-step.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: torch.optim.lr_scheduler.MultiStepLR 3 | -------------------------------------------------------------------------------- /training/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 4 6 | strategy: ddp 7 | -------------------------------------------------------------------------------- /assets/flashattn_speedup_3090.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_speedup_3090.jpg -------------------------------------------------------------------------------- /training/configs/callbacks/causality-monitor.yaml: -------------------------------------------------------------------------------- 1 | causality-monitor: 2 | _target_: src.callbacks.causality_monitor.CausalityMonitor -------------------------------------------------------------------------------- /training/configs/metrics/num-tokens.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | num-tokens: 3 | _target_: src.metrics.num_tokens.NumTokens 4 | -------------------------------------------------------------------------------- /training/configs/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: torch.optim.lr_scheduler.StepLR 3 | step_size: ??? 4 | -------------------------------------------------------------------------------- /assets/flashattn_speedup_t4_fwd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_speedup_t4_fwd.jpg -------------------------------------------------------------------------------- /assets/gpt2_training_efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/gpt2_training_efficiency.jpg -------------------------------------------------------------------------------- /assets/gpt3_training_efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/gpt3_training_efficiency.jpg -------------------------------------------------------------------------------- /training/configs/callbacks/ema.yaml: -------------------------------------------------------------------------------- 1 | ema: 2 | _target_: src.callbacks.ema.EMACallback 3 | decay: ??? 4 | use_num_updates: False 5 | -------------------------------------------------------------------------------- /training/configs/optimizer/adamw-apex.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: apex.optimizers.FusedAdam 3 | adam_w_mode: True 4 | -------------------------------------------------------------------------------- /training/configs/scheduler/cosine-warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: transformers.get_cosine_schedule_with_warmup 3 | -------------------------------------------------------------------------------- /training/configs/scheduler/linear-warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: transformers.get_linear_schedule_with_warmup 3 | -------------------------------------------------------------------------------- /assets/flashattn_speedup_a100_d128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flashattn_speedup_a100_d128.jpg -------------------------------------------------------------------------------- /training/configs/metrics/acctop5.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | acctop5: 3 | _target_: src.metrics.accuracy.AccuracyMine 4 | top_k: 5 5 | -------------------------------------------------------------------------------- /training/configs/scheduler/invsqrt.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: src.optim.lr_scheduler.InvSqrt 3 | num_warmup_steps: ??? 4 | -------------------------------------------------------------------------------- /training/configs/scheduler/poly-warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: transformers.get_polynomial_decay_schedule_with_warmup 3 | -------------------------------------------------------------------------------- /assets/flash2_a100_fwd_bwd_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flash2_a100_fwd_bwd_benchmark.png -------------------------------------------------------------------------------- /assets/flash2_h100_fwd_bwd_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaddlePaddle/flash-attention/HEAD/assets/flash2_h100_fwd_bwd_benchmark.png -------------------------------------------------------------------------------- /training/configs/metrics/acc_ignore_index.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | acc: 3 | _target_: torchmetrics.Accuracy 4 | ignore_index: -100 5 | -------------------------------------------------------------------------------- /training/configs/scheduler/cosine-warmup-timm.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler 3 | -------------------------------------------------------------------------------- /training/configs/model/gpt2model/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 1280 5 | n_head: 20 6 | n_layer: 36 7 | -------------------------------------------------------------------------------- /training/configs/model/gpt2model/gpt2-medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 1024 5 | n_head: 16 6 | n_layer: 24 7 | -------------------------------------------------------------------------------- /training/configs/model/gpt2model/gpt2-small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 768 5 | n_head: 12 6 | n_layer: 12 7 | -------------------------------------------------------------------------------- /training/configs/model/gpt2model/gpt2-xlarge.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 1600 5 | n_head: 25 6 | n_layer: 48 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | clean_dist: 3 | rm -rf dist/* 4 | 5 | create_dist: clean_dist 6 | python setup.py sdist 7 | 8 | upload_package: create_dist 9 | twine upload dist/* 10 | -------------------------------------------------------------------------------- /training/configs/callbacks/flop-count.yaml: -------------------------------------------------------------------------------- 1 | flop_count: 2 | _target_: src.callbacks.flop_count.FlopCount 3 | profilers: ['fvcore'] 4 | input_size: [3, 224, 224] 5 | device: null 6 | -------------------------------------------------------------------------------- /training/configs/optimizer/adamw-apex-distributed.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam 3 | adam_w_mode: True 4 | -------------------------------------------------------------------------------- /training/configs/callbacks/params-log.yaml: -------------------------------------------------------------------------------- 1 | params_log: 2 | _target_: src.callbacks.params_log.ParamsLog 3 | total_params_log: True 4 | trainable_params_log: True 5 | non_trainable_params_log: True 6 | -------------------------------------------------------------------------------- /training/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # set `gpu` to train on GPU, null to train on CPU only 4 | accelerator: null 5 | 6 | min_epochs: 1 7 | max_epochs: 1000 8 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2xl-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2l-hf.yaml 4 | - override /model/gpt2model: gpt2-xlarge 5 | 6 | datamodule: 7 | batch_size: 1 8 | -------------------------------------------------------------------------------- /training/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | version: ${name} 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /training/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | # - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /training/configs/optimizer/adamw-zero.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.distributed.optim.ZeroRedundancyOptimizer 3 | _recursive_: True 4 | optimizer_class: 5 | _target_: torch.optim.__getattribute__ 6 | _args_: 7 | - "AdamW" 8 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2m-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2s-hf.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | datamodule: 7 | batch_size: 4 8 | 9 | train: 10 | optimizer: 11 | lr: 1.5e-4 12 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3l-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3m-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3s-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/optimizer/adamw-apex-zero.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.distributed.optim.ZeroRedundancyOptimizer 3 | _recursive_: True 4 | optimizer_class: 5 | _target_: apex.optimizers.FusedAdam 6 | _partial_: True 7 | adam_w_mode: True 8 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2s.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | datamodule: 7 | batch_size: 8 # Per GPU 8 | 9 | train: 10 | optimizer: 11 | lr: 1.5e-4 12 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3m-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-hf.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | datamodule: 7 | batch_size: 4 8 | 9 | train: 10 | optimizer: 11 | lr: 3.0e-4 12 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3xl-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /training/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | project_name: "template-tests" 7 | experiment_name: ${name} 8 | -------------------------------------------------------------------------------- /training/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | experiment_name: ${name} 6 | tracking_uri: null 7 | tags: null 8 | save_dir: ./mlruns 9 | prefix: "" 10 | artifact_location: null 11 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "cuda_runtime.h" 5 | #include "fmha_utils.h" 6 | 7 | static int GetCurrentDeviceId(); 8 | 9 | static int GetCudaDeviceCount(); 10 | 11 | cudaDeviceProp* GetDeviceProperties(int id); 12 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/base.yaml 4 | - override /model: gpt2 5 | - override /model/gpt2model: gpt2-small 6 | 7 | datamodule: 8 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} 9 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim128_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim128(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim160_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim160(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim192_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim192(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim224_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim224(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim256_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim256(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim32_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim32(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim64_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim64(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim96_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim96(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3l-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-hf.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 1536 8 | n_head: 16 9 | n_layer: 24 10 | 11 | datamodule: 12 | batch_size: 2 13 | 14 | train: 15 | optimizer: 16 | lr: 2.5e-4 17 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim32_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim32(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim64_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim64(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim96_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim96(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-hf.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | 11 | datamodule: 12 | batch_size: 1 13 | 14 | train: 15 | optimizer: 16 | lr: 1.6e-4 17 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3l-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 8 | 9 | train: 10 | global_batch_size: 64 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3m-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} 8 | 9 | train: 10 | global_batch_size: 64 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3s-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} 8 | 9 | train: 10 | global_batch_size: 64 11 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim128_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim128(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim160_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim160(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim192_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim192(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim224_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim224(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /csrc/flash_attn/src/calc_reduced_attn_scores_dispatch/hdim256_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | #include "launch_template.h" 2 | 3 | namespace reduced_scores { 4 | template<> 5 | void run_(Params ¶ms, cudaStream_t stream) { 6 | run_hdim256(params, stream); 7 | } 8 | } // namespace reduced_scores -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3xl-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 8 | 9 | train: 10 | global_batch_size: 128 11 | -------------------------------------------------------------------------------- /training/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | version: ${name} 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3s-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/base.yaml 4 | - override /model: gpt2-hf 5 | - override /model/gpt2model: gpt2-small 6 | 7 | datamodule: 8 | batch_size: 8 9 | 10 | train: 11 | # Use the standard torch.nn.CrossEntropyLoss 12 | loss_fn: null 13 | -------------------------------------------------------------------------------- /training/configs/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | scheduler_interval: epoch 4 | scheduler_monitor: ??? 5 | scheduler: 6 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 7 | factor: 0.2 # Decay factor when ReduceLROnPlateau is used 8 | patience: 20 9 | min_lr: 0.0 # Minimum learning rate during annealing 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include csrc *.cu 2 | recursive-include csrc *.h 3 | recursive-include csrc *.cuh 4 | recursive-include csrc *.cpp 5 | recursive-include csrc *.hpp 6 | 7 | recursive-include flash_attn *.cu 8 | recursive-include flash_attn *.h 9 | recursive-include flash_attn *.cuh 10 | recursive-include flash_attn *.cpp 11 | recursive-include flash_attn *.hpp 12 | -------------------------------------------------------------------------------- /csrc/xentropy/README.md: -------------------------------------------------------------------------------- 1 | This CUDA extension implements optimized cross-entropy loss, adapted from Apex's 2 | [Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). 3 | We make it work for bfloat16 and support in-place backward to save memory. 4 | 5 | It has only been tested on A100s. 6 | 7 | ```sh 8 | cd csrc/xentropy && pip install . 9 | ``` 10 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-hf-hdim128.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-hf.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 128 9 | n_layer: 32 10 | 11 | # OOM on A100 80GB even with batch_size = 1 12 | datamodule: 13 | batch_size: 1 14 | 15 | train: 16 | optimizer: 17 | lr: 1.6e-4 18 | -------------------------------------------------------------------------------- /training/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: ${name} 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2l-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m-hf.yaml 4 | - override /model/gpt2model: gpt2-large 5 | - override /optimizer: adamw-zero 6 | 7 | datamodule: 8 | batch_size: 2 9 | 10 | trainer: 11 | strategy: 12 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 13 | find_unused_parameters: False 14 | gradient_as_bucket_view: True 15 | -------------------------------------------------------------------------------- /training/configs/mode/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default running mode 4 | 5 | default_mode: True 6 | 7 | hydra: 8 | # default output paths for all file logs 9 | run: 10 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 11 | sweep: 12 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/multiruns/${now:%Y-%m-%d_%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m.yaml 4 | - override /model/gpt2model: gpt2-large 5 | - override /optimizer: adamw-zero 6 | 7 | datamodule: 8 | batch_size: 4 # Per GPU 9 | 10 | trainer: 11 | strategy: 12 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 13 | find_unused_parameters: False 14 | gradient_as_bucket_view: True 15 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m.yaml 4 | - override /model/gpt2model: gpt2-xlarge 5 | - override /optimizer: adamw-zero 6 | 7 | datamodule: 8 | batch_size: 2 # Per GPU 9 | 10 | trainer: 11 | strategy: 12 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 13 | find_unused_parameters: False 14 | gradient_as_bucket_view: True 15 | -------------------------------------------------------------------------------- /training/configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: 0 5 | 6 | min_epochs: 1 7 | max_epochs: 2 8 | 9 | # prints 10 | weights_summary: "full" 11 | profiler: null 12 | 13 | # debugs 14 | fast_dev_run: true 15 | num_sanity_val_steps: 2 16 | overfit_batches: 0 17 | limit_train_batches: 1.0 18 | limit_val_batches: 1.0 19 | limit_test_batches: 1.0 20 | track_grad_norm: -1 21 | terminate_on_nan: true 22 | -------------------------------------------------------------------------------- /csrc/ft_attention/README.md: -------------------------------------------------------------------------------- 1 | # Attention kernel from FasterTransformer 2 | 3 | This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from 4 | FasterTransformer v5.2.1 for benchmarking purpose. 5 | 6 | ```sh 7 | cd csrc/ft_attention && pip install . 8 | ``` 9 | -------------------------------------------------------------------------------- /training/src/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from torchmetrics import Metric, Accuracy 5 | 6 | 7 | class AccuracyMine(Accuracy): 8 | """Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. 9 | """ 10 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 11 | super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target) 12 | -------------------------------------------------------------------------------- /training/configs/callbacks/gpu-monitor.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpu_stats_monitor: 5 | _target_: pytorch_lightning.callbacks.GPUStatsMonitor 6 | # [2021-08-13] TD: I just want the intra_step_size but it'll error if I 7 | # don't have memory_utilization and gpu_utilization. 8 | # Maybe I should write a callback with just the intra_step_size. 9 | memory_utilization: True 10 | gpu_utilization: True 11 | intra_step_time: True 12 | -------------------------------------------------------------------------------- /flash_attn/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.0.8" 2 | 3 | from flash_attn.flash_attn_interface import flash_attn_func 4 | from flash_attn.flash_attn_interface import flash_attn_kvpacked_func 5 | from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func 6 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 7 | from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func 8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 9 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /training/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: attention 6 | name: ${name} 7 | save_dir: "." 8 | mode: online # set offline to store all logs only locally 9 | id: ${oc.select:name} # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team or just remove it 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /training/configs/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - gpt2model: gpt2-small 4 | 5 | _target_: flash_attn.models.gpt.GPTLMHeadModel 6 | _recursive_: True 7 | config: 8 | _target_: transformers.GPT2Config 9 | # Mistral's config: # https://github.com/stanford-crfm/mistral/blob/main/conf/models/mistral-small.yaml 10 | # However, reorder_and_upcast_attn slows things down 11 | reorder_and_upcast_attn: false 12 | scale_attn_by_inverse_layer_idx: true 13 | n_positions: ${datamodule.max_length} 14 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3m-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | # Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB 7 | # model: 8 | # config: 9 | # mlp_checkpoint_lvl: 1 10 | 11 | datamodule: 12 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 13 | 14 | train: 15 | optimizer: 16 | lr: 3.0e-4 17 | -------------------------------------------------------------------------------- /training/configs/datamodule/thepile.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.language_modeling_hf.LMDataModule 2 | dataset_name: the_pile 3 | dataset_config_name: null 4 | tokenizer_name: gpt2 5 | cache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache 6 | max_length: 2048 7 | add_eos: True 8 | batch_size: 4 # per GPU 9 | batch_size_eval: ${eval:${.batch_size} * 2} 10 | num_workers: 64 # For preprocessing only 11 | use_shmem: False 12 | shuffle: True 13 | pin_memory: True 14 | __train_len: ${div_up:374337375694, ${.max_length}} 15 | -------------------------------------------------------------------------------- /training/configs/mode/exp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run in experiment mode with: 4 | # `python run.py mode=exp name=experiment_name` 5 | 6 | experiment_mode: True 7 | 8 | # allows for custom naming of the experiment 9 | name: ??? 10 | 11 | hydra: 12 | # sets output paths for all file logs to `logs/experiment/name' 13 | run: 14 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} 15 | sweep: 16 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} 17 | subdir: ${hydra.job.num} 18 | -------------------------------------------------------------------------------- /training/configs/model/gpt2-hf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - gpt2model: gpt2-small 4 | 5 | _target_: transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel 6 | _recursive_: True 7 | config: 8 | _target_: transformers.GPT2Config 9 | # Mistral's config: https://github.com/stanford-crfm/mistral/blob/main/conf/models/gpt2-small.yaml 10 | # However, reorder_and_upcast_attn slows things down 11 | reorder_and_upcast_attn: false 12 | scale_attn_by_inverse_layer_idx: true 13 | n_positions: ${datamodule.max_length} 14 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "csrc/cutlass"] 2 | path = csrc/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | [submodule "csrc/flash_attn_with_bias_and_mask/cutlass"] 5 | path = csrc/flash_attn_with_bias_and_mask/cutlass 6 | url = https://github.com/NVIDIA/cutlass.git 7 | [submodule "csrc/flash_attn_v3/cutlass"] 8 | path = csrc/flash_attn_v3/cutlass 9 | url = https://github.com/NVIDIA/cutlass.git 10 | [submodule "csrc/flashmask_v2/cutlass"] 11 | path = csrc/flashmask_v2/cutlass 12 | url = https://github.com/NVIDIA/cutlass.git 13 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2m-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2s-flash.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | # Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB 7 | # model: 8 | # config: 9 | # mlp_checkpoint_lvl: 1 10 | 11 | datamodule: 12 | # batch_size: 32 13 | batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else (32 if ${train.gpu_mem} < 80 else 64))"} 14 | 15 | train: 16 | optimizer: 17 | lr: 1.5e-4 18 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash-hdim128.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 20 # Headdim 128 is faster than headdim 80 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_fwd_hdim128.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_fwd_launch_template.h" 6 | 7 | void run_fmha_fwd_hdim128(Launch_params &launch_params) { 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; 10 | run_fmha_fwd_loop(launch_params); 11 | })); 12 | } -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary-8k.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # IDE-related 24 | .idea/ 25 | .vscode/ 26 | 27 | csrc/flash_attn/src/*_sm80.cu 28 | csrc/flash_attn_v3/*_sm90.cu 29 | csrc/flash_attn_v3/instantiations/ 30 | csrc/flashmask_v2/*_sm90.cu 31 | csrc/flashmask_v2/instantiations/ 32 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 20 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary-8k.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 20 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | #include "fmha_bwd_with_mask_bias_launch_template.h" 4 | 5 | bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream) { 6 | bool status = true; 7 | FP16_SWITCH(params.is_bf16, ([&] { 8 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; 9 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 10 | })); 11 | return status; 12 | } 13 | -------------------------------------------------------------------------------- /training/configs/datamodule/openwebtext.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.language_modeling_hf.LMDataModule 2 | dataset_name: openwebtext 3 | dataset_config_name: null 4 | tokenizer_name: gpt2 5 | cache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache 6 | max_length: 1024 7 | val_ratio: 0.0005 8 | val_split_seed: 2357 9 | add_eos: True 10 | batch_size: 8 # per GPU 11 | batch_size_eval: ${eval:${.batch_size} * 2} 12 | num_workers: 32 # For preprocessing only 13 | shuffle: True 14 | pin_memory: True 15 | __train_len: ${div_up:9035582198, ${.max_length}} 16 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_bwd_hdim128.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_bwd_launch_template.h" 6 | 7 | void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { 8 | FP16_SWITCH(params.is_bf16, ([&] { 9 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; 10 | run_fmha_bwd_loop(params, stream, configure); 11 | })); 12 | } -------------------------------------------------------------------------------- /csrc/fused_dense_lib/README.md: -------------------------------------------------------------------------------- 1 | This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu 2 | (forward and backward), adapted from Apex's 3 | [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). 4 | We make it work for bfloat16. 5 | 6 | For best performance, you should use CUDA >= 11.8. CuBLAS versions before 7 | this doesn't have the best matmul + bias + gelu performance for bfloat16. 8 | 9 | It has only been tested on A100s. 10 | 11 | ```sh 12 | cd csrc/fused_dense_lib && pip install . 13 | ``` 14 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3s-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/base.yaml 4 | - override /model: gpt2 5 | - override /model/gpt2model: gpt2-small 6 | 7 | model: 8 | config: 9 | # n_positions is already set to ${datamodule.max_length} 10 | residual_in_fp32: True 11 | use_flash_attn: True 12 | fused_dropout_add_ln: True 13 | fused_mlp: True 14 | fused_bias_fc: True 15 | pad_vocab_size_multiple: 8 16 | 17 | datamodule: 18 | batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"} 19 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "fmha_utils.h" 8 | 9 | void SetZero(void *ptr, size_t sizeof_type, std::initializer_list shapes, cudaStream_t stream); 10 | 11 | template 12 | void SetConstValue(void *ptr, T value, size_t n, cudaStream_t stream); 13 | 14 | void Float2Half(void *float_ptr, void *half_ptr, size_t n, cudaStream_t stream); 15 | void Float2BF16(void *float_ptr, void *bf16_ptr, size_t n, cudaStream_t stream); 16 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2s-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/base.yaml 4 | - override /model: gpt2 5 | - override /model/gpt2model: gpt2-small 6 | 7 | model: 8 | config: 9 | # n_positions is already set to ${datamodule.max_length} 10 | residual_in_fp32: True 11 | use_flash_attn: True 12 | fused_bias_fc: True 13 | fused_mlp: True 14 | fused_dropout_add_ln: True 15 | pad_vocab_size_multiple: 8 16 | 17 | datamodule: 18 | # batch_size: 64 19 | batch_size: ${eval:"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)"} 20 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2s-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/base.yaml 4 | - override /model: gpt2-hf 5 | - override /model/gpt2model: gpt2-small 6 | - override /callbacks: [default, norm-monitor, flop-count] 7 | 8 | datamodule: 9 | batch_size: 8 10 | 11 | train: 12 | # Use the standard torch.nn.CrossEntropyLoss 13 | loss_fn: null 14 | 15 | callbacks: 16 | flop_count: 17 | input_size: 18 | - ${datamodule.max_length} 19 | input_dtype: 20 | # It's surprisingly hard to get hydra to return torch.long since it's not a callable 21 | _target_: torch.__getattribute__ 22 | _args_: 23 | - long 24 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3l-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | - override /optimizer: adamw-zero 5 | 6 | model: 7 | config: 8 | n_embd: 1536 9 | n_head: 16 10 | n_layer: 24 11 | # mlp_checkpoint_lvl: 1 # To fit batch_size 8 12 | 13 | datamodule: 14 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 2.5e-4 19 | 20 | trainer: 21 | strategy: 22 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 23 | find_unused_parameters: False 24 | gradient_as_bucket_view: True 25 | -------------------------------------------------------------------------------- /csrc/layer_norm/README.md: -------------------------------------------------------------------------------- 1 | This CUDA extension implements fused dropout + residual + LayerNorm, building on 2 | Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). 3 | Major changes: 4 | - Add dropout and residual. 5 | - Make it work for both pre-norm and post-norm architecture. 6 | - Support more hidden dimensions (all dimensions divisible by 8, up to 8192). 7 | - Implement RMSNorm as an option. 8 | - Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM). 9 | 10 | If you want to use it for dimensions larger than 8k, please file an issue. 11 | 12 | This extension has only been tested on A100s. 13 | 14 | ```sh 15 | cd csrc/layer_norm && pip install . 16 | ``` 17 | -------------------------------------------------------------------------------- /training/configs/mode/smoke.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # Smoke test: disable logging and model checkpointing 3 | 4 | logger: 5 | wandb: 6 | mode: disabled 7 | 8 | callbacks: 9 | model_checkpoint: null 10 | model_checkpoint_progress: null 11 | 12 | hydra: 13 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 14 | # sets level of only chosen command line loggers to 'DEBUG' 15 | # verbose: [src.train, src.utils.utils] 16 | 17 | # sets output paths for all file logs to 'logs/debug/' 18 | run: 19 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} 20 | sweep: 21 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} 22 | subdir: ${hydra.job.num} 23 | -------------------------------------------------------------------------------- /training/src/datamodules/timm_mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.data import Mixup 4 | from timm.data.mixup import mixup_target 5 | 6 | 7 | class TimmMixup(Mixup): 8 | """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. 9 | """ 10 | def __call__(self, x, target): 11 | if self.mode == 'elem': 12 | lam = self._mix_elem(x) 13 | elif self.mode == 'pair': 14 | # We move the assert from the beginning of the function to here 15 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 16 | lam = self._mix_pair(x) 17 | else: 18 | lam = self._mix_batch(x) 19 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 20 | return x, target 21 | -------------------------------------------------------------------------------- /training/configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | log_f1_precision_recall_heatmap: 19 | _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | log_confusion_matrix: 22 | _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | log_image_predictions: 25 | _target_: src.callbacks.wandb_callbacks.LogImagePredictions 26 | num_samples: 8 27 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_fwd_hdim32.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_fwd_launch_template.h" 6 | 7 | void run_fmha_fwd_hdim32(Launch_params &launch_params) { 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | if (launch_params.params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; 11 | run_fmha_fwd_loop(launch_params); 12 | } else if (launch_params.params.seqlen_k >= 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; 14 | run_fmha_fwd_loop(launch_params); 15 | } 16 | })); 17 | } 18 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_fwd_hdim64.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_fwd_launch_template.h" 6 | 7 | void run_fmha_fwd_hdim64(Launch_params &launch_params) { 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | if (launch_params.params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; 11 | run_fmha_fwd_loop(launch_params); 12 | } else if (launch_params.params.seqlen_k >= 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; 14 | run_fmha_fwd_loop(launch_params); 15 | } 16 | })); 17 | } 18 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_bwd_hdim32.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_bwd_launch_template.h" 6 | 7 | void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { 8 | FP16_SWITCH(params.is_bf16, ([&] { 9 | if (params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; 11 | run_fmha_bwd_loop(params, stream, configure); 12 | } else if (params.seqlen_k >= 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; 14 | run_fmha_bwd_loop(params, stream, configure); 15 | } 16 | })); 17 | } 18 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | #include "fmha_bwd_with_mask_bias_launch_template.h" 4 | 5 | bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream) { 6 | bool status = true; 7 | FP16_SWITCH(params.is_bf16, ([&] { 8 | if( params.seqlen_k == 128 ) { 9 | using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; 10 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 11 | } else if( params.seqlen_k >= 256 ) { 12 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; 13 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 14 | } 15 | })); 16 | return status; 17 | } 18 | -------------------------------------------------------------------------------- /training/configs/mode/profile.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # Run the Pytorch profiler 3 | 4 | trainer: 5 | profiler: 6 | _target_: pytorch_lightning.profilers.PyTorchProfiler 7 | dirpath: ${hydra.run.dir} 8 | schedule: 9 | _target_: torch.profiler.schedule 10 | wait: 5 11 | warmup: 5 12 | active: 5 13 | use_cuda: True 14 | max_steps: 20 15 | 16 | logger: 17 | wandb: 18 | mode: disabled 19 | 20 | callbacks: 21 | model_checkpoint: null 22 | model_checkpoint_progress: null 23 | early_stopping: null 24 | 25 | hydra: 26 | # sets output paths for all file logs to 'logs/profile/' 27 | run: 28 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/${now:%Y-%m-%d}/${now:%H-%M-%S} 29 | sweep: 30 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/multirun_${now:%Y-%m-%d_%H-%M-%S} 31 | subdir: ${hydra.job.num} 32 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3xl-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-hf.yaml 4 | - override /optimizer: adamw-zero 5 | 6 | model: 7 | config: 8 | n_embd: 2048 9 | n_head: 16 10 | n_layer: 24 11 | 12 | datamodule: 13 | batch_size: 2 14 | 15 | train: 16 | global_batch_size: 512 17 | optimizer: 18 | lr: 2.0e-4 19 | scheduler: 20 | t_initial: 300000 21 | 22 | trainer: 23 | strategy: 24 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 25 | find_unused_parameters: False 26 | gradient_as_bucket_view: True 27 | max_steps: 400000 28 | val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} 29 | 30 | callbacks: 31 | model_checkpoint: 32 | every_n_train_steps: 1000 33 | model_checkpoint_progress: 34 | every_n_train_steps: 12500 35 | fault_tolerant: False # Saving takes too long 36 | -------------------------------------------------------------------------------- /training/configs/mode/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run in debug mode with: 4 | # `python run.py mode=debug` 5 | 6 | defaults: 7 | - override /trainer: debug.yaml 8 | 9 | debug_mode: True 10 | 11 | hydra: 12 | # sets level of all command line loggers to 'DEBUG' 13 | verbose: True 14 | 15 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 16 | # sets level of only chosen command line loggers to 'DEBUG' 17 | # verbose: [src.train, src.utils.utils] 18 | 19 | # sets output paths for all file logs to 'logs/debug/' 20 | run: 21 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} 22 | sweep: 23 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} 24 | subdir: ${hydra.job.num} 25 | 26 | # disable rich config printing, since it will be already printed by hydra when `verbose: True` 27 | print_config: False 28 | -------------------------------------------------------------------------------- /csrc/ft_attention/cuda_bf16_wrapper.h: -------------------------------------------------------------------------------- 1 | // Downloaded from from FasterTransformer v5.2.1 2 | // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h 3 | /* 4 | * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #pragma once 20 | 21 | #ifdef ENABLE_BF16 22 | #include 23 | #endif 24 | -------------------------------------------------------------------------------- /csrc/flashmask_v2/flash_fwd_combine.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | 4 | #include "flash_fwd_combine_launch_template.h" 5 | 6 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 7 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 8 | 9 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 10 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 11 | 12 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 13 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 14 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_1024.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_1280.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_1536.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_2048.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_256.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_2560.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_3072.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_4096.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_512.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_5120.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_6144.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 9 | REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 10 | REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 11 | REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 12 | REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 13 | REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 14 | REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 15 | REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_7168.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_768.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_fwd_8192.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 9 | REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 10 | REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 11 | REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 12 | REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 13 | REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 14 | REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 15 | REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 16 | -------------------------------------------------------------------------------- /training/configs/experiment/pile/gpt3xl-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | - override /optimizer: adamw-zero 5 | 6 | model: 7 | config: 8 | n_embd: 2048 9 | n_head: 16 10 | n_layer: 24 11 | 12 | datamodule: 13 | batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} 14 | 15 | train: 16 | global_batch_size: 512 17 | optimizer: 18 | lr: 2.0e-4 19 | scheduler: 20 | t_initial: 300000 21 | 22 | trainer: 23 | strategy: 24 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 25 | find_unused_parameters: False 26 | gradient_as_bucket_view: True 27 | max_steps: 400000 28 | val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} 29 | 30 | callbacks: 31 | model_checkpoint: 32 | every_n_train_steps: 1000 33 | model_checkpoint_progress: 34 | every_n_train_steps: 12500 35 | fault_tolerant: False # Saving takes too long 36 | -------------------------------------------------------------------------------- /csrc/flash_attn_v3/flash_fwd_combine.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | 4 | #include "flash_fwd_combine_launch_template.h" 5 | 6 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 7 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 8 | 9 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 10 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 11 | 12 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 13 | template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); 14 | -------------------------------------------------------------------------------- /csrc/flashmask_v2/print_val.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "utils.h" 3 | 4 | namespace flash{ 5 | __global__ void print_addr_value(int* base, size_t offset_bytes) { 6 | int* ptr = (int*)((char*)base + offset_bytes); 7 | printf("Value at address %p: %d\n", ptr, *ptr); 8 | } 9 | 10 | __global__ void print_addr_value_ordered(int* base, size_t start_offset_bytes, int count) { 11 | int tid = blockIdx.x * blockDim.x + threadIdx.x; 12 | int total_threads = gridDim.x * blockDim.x; 13 | 14 | // 按线程ID顺序打印,避免输出混乱 15 | for (int current_thread = 0; current_thread < total_threads; current_thread++) { 16 | if (tid == current_thread && tid < count) { 17 | size_t offset_bytes = start_offset_bytes + tid * sizeof(int); 18 | int* ptr = (int*)((char*)base + offset_bytes); 19 | printf("Thread %d - Value at address %p (offset %zu): %d\n", 20 | tid, ptr, offset_bytes, *ptr); 21 | } 22 | __syncthreads(); // 同步保证顺序 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_2048.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_3072.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_4096.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_5120.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_6144.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_7168.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); 9 | REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); 10 | REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); 11 | REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); 12 | REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); 13 | REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); 14 | REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); 15 | REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_8192.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_1536.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 9 | REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 10 | REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 11 | REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 12 | REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 13 | REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 14 | REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 15 | REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_256.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_2560.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 9 | REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 10 | REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 11 | REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 12 | REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 13 | REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 14 | REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 15 | REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_512.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_768.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_1024.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_bwd_1280.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_256.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_1024.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_1280.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_1536.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_2048.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_2560.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_3072.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_4096.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_512.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_5120.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_6144.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_7168.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_768.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_fwd_8192.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 9 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 10 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 11 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 12 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 13 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 14 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 15 | REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 16 | -------------------------------------------------------------------------------- /csrc/flash_attn_v3/cuda_check.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(call) \ 11 | do { \ 12 | cudaError_t status_ = call; \ 13 | if (status_ != cudaSuccess) { \ 14 | fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ 15 | exit(1); \ 16 | } \ 17 | } while(0) 18 | 19 | #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) 20 | -------------------------------------------------------------------------------- /csrc/flashmask_v2/cuda_check.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(call) \ 11 | do { \ 12 | cudaError_t status_ = call; \ 13 | if (status_ != cudaSuccess) { \ 14 | fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ 15 | exit(1); \ 16 | } \ 17 | } while(0) 18 | 19 | #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) 20 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2xl-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2l-flash.yaml 4 | - override /model/gpt2model: gpt2-xlarge 5 | 6 | # Can enable mlp_checkpoint_lvl to fit to A100 40GB 7 | # model: 8 | # config: 9 | # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} 10 | # mlp_checkpoint_lvl: 1 11 | 12 | datamodule: 13 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} 14 | # With adamw-zero optimizer, on A100 40GB: 15 | # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1) 16 | # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1) 17 | # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1) 18 | # With adamw-apex-distributed optimizer: 19 | # checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1) 20 | # checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers, 21 | # batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1) 22 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_2048.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_3072.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_6144.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_7168.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_8192.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | #include "fmha_fwd_with_mask_bias_launch_template.h" 4 | 5 | bool run_fmha_fwd_with_mask_bias_hdim32(Launch_params &launch_params, 6 | const bool configure) { 7 | bool status = false; 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | if (launch_params.params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; 11 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 12 | } else if (launch_params.params.seqlen_k == 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; 14 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 15 | } else { 16 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; 17 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 18 | } 19 | })); 20 | return status; 21 | } 22 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_1536.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_256.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_2560.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_512.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_768.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_1024.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_1280.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_4096.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | // Use 8 warps otherwise there's a lot of register spilling 7 | 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); 16 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 17 | REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -------------------------------------------------------------------------------- /csrc/layer_norm/ln_parallel_bwd_5120.cu: -------------------------------------------------------------------------------- 1 | #include "ln_parallel_residual_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | // Use 8 warps otherwise there's a lot of register spilling 7 | 8 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); 11 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); 12 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); 13 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); 14 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); 15 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); 16 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); 17 | REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); -------------------------------------------------------------------------------- /csrc/fused_softmax/type_shim.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 4 | switch(TYPE) \ 5 | { \ 6 | case at::ScalarType::Half: \ 7 | { \ 8 | using scalar_t = at::Half; \ 9 | __VA_ARGS__; \ 10 | break; \ 11 | } \ 12 | case at::ScalarType::BFloat16: \ 13 | { \ 14 | using scalar_t = at::BFloat16; \ 15 | __VA_ARGS__; \ 16 | break; \ 17 | } \ 18 | default: \ 19 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 20 | } 21 | -------------------------------------------------------------------------------- /training/configs/trainer/all_params.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_summary: "top" 33 | weights_save_path: null 34 | num_sanity_val_steps: 2 35 | truncated_bptt_steps: null 36 | resume_from_checkpoint: null 37 | profiler: null 38 | benchmark: False 39 | deterministic: False 40 | reload_dataloaders_every_epoch: False 41 | auto_lr_find: False 42 | replace_sampler_ddp: True 43 | terminate_on_nan: False 44 | auto_scale_batch_size: False 45 | prepare_data_per_node: True 46 | plugins: null 47 | amp_backend: "native" 48 | amp_level: "O2" 49 | move_metrics_to_cpu: False 50 | -------------------------------------------------------------------------------- /training/src/optim/timm_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | from timm.scheduler import CosineLRScheduler 5 | 6 | 7 | # We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain 8 | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): 9 | """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. 10 | It supports resuming as well. 11 | """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self._last_epoch = -1 16 | self.step(epoch=0) 17 | 18 | def step(self, epoch=None): 19 | if epoch is None: 20 | self._last_epoch += 1 21 | else: 22 | self._last_epoch = epoch 23 | # We call either step or step_update, depending on whether we're using the scheduler every 24 | # epoch or every step. 25 | # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set 26 | # scheduler interval to "step", then the learning rate update will be wrong. 27 | if self.t_in_epochs: 28 | super().step(epoch=self._last_epoch) 29 | else: 30 | super().step_update(num_updates=self._last_epoch) 31 | -------------------------------------------------------------------------------- /training/src/datamodules/datasets/lm_dataset.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py 2 | # Except we don't pad the last block and don't use overlapping eval 3 | # And we return both the input and the target 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | 9 | 10 | class LMDataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, tokens, seq_len, drop_last=True): 13 | """tokens should be a numpy array 14 | """ 15 | self.seq_len = seq_len 16 | ntokens = len(tokens) 17 | if drop_last: 18 | ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 19 | self.ntokens = ntokens 20 | # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, 21 | # and slicing would load it to memory. 22 | self.tokens = tokens 23 | self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) 24 | 25 | def __len__(self): 26 | return self.total_sequences 27 | 28 | def __getitem__(self, idx): 29 | start_idx = idx * self.seq_len 30 | seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) 31 | data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) 32 | return data[:-1], data[1:].clone() 33 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/cuda_utils.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_utils.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | static std::once_flag g_device_props_size_init_flag; 10 | static std::vector> g_device_props_init_flags; 11 | static std::vector g_device_props; 12 | 13 | static int GetCurrentDeviceId() { 14 | int device_id; 15 | FMHA_CHECK_CUDA(cudaGetDevice(&device_id)); 16 | return device_id; 17 | } 18 | 19 | static int GetCudaDeviceCount() { 20 | int count; 21 | FMHA_CHECK_CUDA(cudaGetDeviceCount(&count)); 22 | return count; 23 | } 24 | 25 | cudaDeviceProp* GetDeviceProperties(int id) { 26 | std::call_once(g_device_props_size_init_flag, [&] { 27 | int gpu_num = 0; 28 | gpu_num = GetCudaDeviceCount(); 29 | g_device_props_init_flags.resize(gpu_num); 30 | g_device_props.resize(gpu_num); 31 | for (int i = 0; i < gpu_num; ++i) { 32 | g_device_props_init_flags[i] = std::make_unique(); 33 | } 34 | }); 35 | 36 | if (id == -1) { 37 | id = GetCurrentDeviceId(); 38 | } 39 | 40 | std::call_once(*(g_device_props_init_flags[id]), [&] { 41 | FMHA_CHECK_CUDA(cudaGetDeviceProperties(&g_device_props[id], id)); 42 | }); 43 | 44 | return &g_device_props[id]; 45 | } 46 | 47 | -------------------------------------------------------------------------------- /csrc/fused_dense_lib/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from packaging.version import parse, Version 4 | 5 | import torch 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 8 | 9 | 10 | def get_cuda_bare_metal_version(cuda_dir): 11 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 12 | output = raw_output.split() 13 | release_idx = output.index("release") + 1 14 | bare_metal_version = parse(output[release_idx].split(",")[0]) 15 | 16 | return raw_output, bare_metal_version 17 | 18 | 19 | def append_nvcc_threads(nvcc_extra_args): 20 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 21 | if bare_metal_version >= Version("11.2"): 22 | return nvcc_extra_args + ["--threads", "4"] 23 | return nvcc_extra_args 24 | 25 | 26 | setup( 27 | name='fused_dense_lib', 28 | ext_modules=[ 29 | CUDAExtension( 30 | name='fused_dense_lib', 31 | sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], 32 | extra_compile_args={ 33 | 'cxx': ['-O3',], 34 | 'nvcc': append_nvcc_threads(['-O3']) 35 | } 36 | ) 37 | ], 38 | cmdclass={ 39 | 'build_ext': BuildExtension 40 | }) 41 | 42 | -------------------------------------------------------------------------------- /csrc/layer_norm/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/cuda_utils.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_utils.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #if !FLASH_ATTN_WITH_TORCH 10 | namespace at { 11 | namespace cuda { 12 | static std::once_flag g_device_props_size_init_flag; 13 | static std::vector> g_device_props_init_flags; 14 | static std::vector g_device_props; 15 | 16 | static int GetCurrentDeviceId() { 17 | int device_id; 18 | C10_CUDA_CHECK(cudaGetDevice(&device_id)); 19 | return device_id; 20 | } 21 | 22 | static int GetCudaDeviceCount() { 23 | int count; 24 | C10_CUDA_CHECK(cudaGetDeviceCount(&count)); 25 | return count; 26 | } 27 | 28 | cudaDeviceProp* getCurrentDeviceProperties(int id) { 29 | std::call_once(g_device_props_size_init_flag, [&] { 30 | int gpu_num = 0; 31 | gpu_num = GetCudaDeviceCount(); 32 | g_device_props_init_flags.resize(gpu_num); 33 | g_device_props.resize(gpu_num); 34 | for (int i = 0; i < gpu_num; ++i) { 35 | g_device_props_init_flags[i] = std::make_unique(); 36 | } 37 | }); 38 | 39 | if (id == -1) { 40 | id = GetCurrentDeviceId(); 41 | } 42 | 43 | std::call_once(*(g_device_props_init_flags[id]), [&] { 44 | C10_CUDA_CHECK(cudaGetDeviceProperties(&g_device_props[id], id)); 45 | }); 46 | 47 | return &g_device_props[id]; 48 | } 49 | } // namespace cuda 50 | } // namespace at 51 | #endif 52 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | #include "fmha_fwd_with_mask_bias_launch_template.h" 4 | 5 | bool run_fmha_fwd_with_mask_bias_hdim128(Launch_params &launch_params, 6 | const bool configure) { 7 | bool status = true; 8 | auto dprops = GetDeviceProperties(-1); 9 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 10 | if (launch_params.params.seqlen_k == 128) { 11 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; 12 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 13 | } else { 14 | if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { 15 | // TD [2022-06-05] Keep K in registers to reduce register spilling 16 | // Gives about 6% speedup compared to using block size 128. 17 | using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; 18 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 19 | } else { // Need to use the same block size as backward 20 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; 21 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 22 | } 23 | } 24 | })); 25 | return status; 26 | } 27 | -------------------------------------------------------------------------------- /training/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default 7 | - optimizer: adamw 8 | - scheduler: null 9 | - task: sequence-model 10 | - model: null 11 | - datamodule: null 12 | - callbacks: default # set this to null if you don't want to use callbacks 13 | - metrics: null 14 | - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) 15 | 16 | - mode: default 17 | 18 | - experiment: null 19 | - hparams_search: null 20 | 21 | # enable color logging 22 | - override hydra/hydra_logging: colorlog 23 | - override hydra/job_logging: colorlog 24 | 25 | # path to original working directory 26 | # hydra hijacks working directory by changing it to the current log directory, 27 | # so it's useful to have this path as a special variable 28 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 29 | work_dir: ${hydra:runtime.cwd} 30 | 31 | # path to folder with data 32 | data_dir: ${work_dir}/data/ 33 | 34 | # pretty print config at the start of the run using Rich library 35 | print_config: True 36 | 37 | # disable python warnings if they annoy you 38 | ignore_warnings: True 39 | 40 | # check performance on test set, using the best model achieved during training 41 | # lightning chooses best model based on metric specified in checkpoint callback 42 | test_after_training: True 43 | 44 | resume: False 45 | 46 | # seed for random number generators in pytorch, numpy and python.random 47 | seed: null 48 | 49 | # name of the run, accessed by loggers 50 | name: null 51 | -------------------------------------------------------------------------------- /training/src/callbacks/loss_scale_monitor.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py. 2 | from typing import Any 3 | 4 | from pytorch_lightning import Callback, Trainer 5 | from pytorch_lightning.utilities import rank_zero_only 6 | from pytorch_lightning.strategies import DeepSpeedStrategy 7 | 8 | 9 | class LossScaleMonitor(Callback): 10 | """Monitor the loss scale for AMP (fp16). 11 | """ 12 | 13 | # Use on_before_optimizer_step instead of on_train_batch_start since there might be 14 | # gradient accumulation and we only care about the loss scale when it could change (i.e., 15 | # optimizer.step). 16 | @rank_zero_only 17 | def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None: 18 | if not trainer._logger_connector.should_update_logs: 19 | return 20 | stats = {} 21 | if isinstance(trainer.strategy, DeepSpeedStrategy): 22 | stats = {'scalar/scale': trainer.model.optimizer.loss_scale} 23 | if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'): 24 | scaler = trainer.precision_plugin.scaler 25 | if scaler is not None: 26 | stats = { 27 | 'scaler/scale': scaler.get_scale(), 28 | 'scaler/growth_tracker': scaler._get_growth_tracker(), 29 | } 30 | if stats and trainer.loggers is not None: 31 | for logger in trainer.loggers: 32 | logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /training/src/callbacks/params_log.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pytorch_lightning import Callback, Trainer, LightningModule 4 | from pytorch_lightning.utilities import rank_zero_only 5 | from pytorch_lightning.utilities.parsing import AttributeDict 6 | 7 | 8 | class ParamsLog(Callback): 9 | """Log the number of parameters of the model 10 | """ 11 | def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True, 12 | non_trainable_params_log: bool = True): 13 | super().__init__() 14 | self._log_stats = AttributeDict( 15 | { 16 | 'total_params_log': total_params_log, 17 | 'trainable_params_log': trainable_params_log, 18 | 'non_trainable_params_log': non_trainable_params_log, 19 | } 20 | ) 21 | 22 | @rank_zero_only 23 | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 24 | logs = {} 25 | if self._log_stats.total_params_log: 26 | logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters()) 27 | if self._log_stats.trainable_params_log: 28 | logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters() 29 | if p.requires_grad) 30 | if self._log_stats.non_trainable_params_log: 31 | logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters() 32 | if not p.requires_grad) 33 | if trainer.logger is not None: 34 | trainer.logger.log_hyperparams(logs) 35 | -------------------------------------------------------------------------------- /training/src/callbacks/gpu_affinity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import Callback, Trainer, LightningModule 4 | 5 | import logging 6 | 7 | log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0 8 | 9 | 10 | def l2_promote(): 11 | import ctypes 12 | _libcudart = ctypes.CDLL('libcudart.so') 13 | # Set device limit on the current device 14 | # cudaLimitMaxL2FetchGranularity = 0x05 15 | pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int)) 16 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) 17 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) 18 | assert pValue.contents.value == 128 19 | 20 | 21 | def set_affinity(trainer): 22 | try: 23 | from src.utils.gpu_affinity import set_affinity 24 | nproc_per_node = torch.cuda.device_count() 25 | affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous') 26 | log.info(f'{trainer.local_rank}: thread affinity: {affinity}') 27 | # TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per 28 | # number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan. 29 | # l2_promote() 30 | except: 31 | pass 32 | 33 | 34 | class GpuAffinity(Callback): 35 | """Set GPU affinity and increase the L2 fetch granularity. 36 | Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL 37 | """ 38 | 39 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None: 40 | set_affinity(trainer) 41 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | #include "fmha_fwd_with_mask_bias_launch_template.h" 4 | 5 | bool run_fmha_fwd_with_mask_bias_hdim64(Launch_params &launch_params, 6 | const bool configure) { 7 | bool status = true; 8 | auto dprops = GetDeviceProperties(-1); 9 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 10 | if (launch_params.params.seqlen_k == 128) { 11 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; 12 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 13 | } else if (launch_params.params.seqlen_k >= 256) { 14 | if (dprops->major == 8 && dprops->minor >= 0) { 15 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; 16 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 17 | } else if (dprops->major == 7 && dprops->minor == 5) { 18 | if (launch_params.is_dropout) { // Need to use the same block size as backward 19 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; 20 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 21 | } else { 22 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; 23 | status = run_fmha_fp16_sm80_loop_(launch_params, configure); 24 | } 25 | } 26 | } 27 | })); 28 | return status; 29 | } 30 | -------------------------------------------------------------------------------- /training/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | # rich_progress_bar: 2 | # _target_: pytorch_lightning.callbacks.RichProgressBar 3 | 4 | rich_model_summary: 5 | _target_: pytorch_lightning.callbacks.RichModelSummary 6 | 7 | model_checkpoint: 8 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 9 | monitor: "val/acc" # name of the logged metric which determines when model is improving 10 | mode: "max" # can be "max" or "min" 11 | save_top_k: 1 # save k best models (determined by above metric) 12 | save_last: True # additionally always save model from last epoch 13 | verbose: False 14 | dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''} 15 | filename: "epoch_{epoch:03d}" 16 | auto_insert_metric_name: False 17 | 18 | early_stopping: 19 | _target_: pytorch_lightning.callbacks.EarlyStopping 20 | monitor: "val/acc" # name of the logged metric which determines when model is improving 21 | mode: "max" # can be "max" or "min" 22 | patience: 100 # how many epochs of not improving until training stops 23 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 24 | 25 | learning_rate_monitor: 26 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 27 | logging_interval: step 28 | 29 | speed_monitor: 30 | _target_: src.callbacks.speed_monitor.SpeedMonitor 31 | intra_step_time: True 32 | inter_step_time: True 33 | epoch_time: True 34 | 35 | loss_scale_monitor: 36 | _target_: src.callbacks.loss_scale_monitor.LossScaleMonitor 37 | 38 | params_log: 39 | _target_: src.callbacks.params_log.ParamsLog 40 | total_params_log: True 41 | trainable_params_log: True 42 | non_trainable_params_log: True 43 | 44 | gpu_affinity: 45 | _target_: src.callbacks.gpu_affinity.GpuAffinity 46 | -------------------------------------------------------------------------------- /tests/losses/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import pytest 6 | 7 | from einops import rearrange 8 | 9 | from flash_attn.losses.cross_entropy import CrossEntropyLossApex 10 | 11 | is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 12 | 13 | 14 | @pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) 15 | # @pytest.mark.parametrize('dtype', [torch.float16]) 16 | @pytest.mark.parametrize('inplace_backward', [False, True]) 17 | # @pytest.mark.parametrize('inplace_backward', [False]) 18 | @pytest.mark.parametrize('smoothing', [0.0, 0.9]) 19 | @pytest.mark.parametrize('vocab_size', [50257]) 20 | def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype): 21 | device = 'cuda' 22 | rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) 23 | # set seed 24 | torch.random.manual_seed(0) 25 | batch_size = 8 26 | seqlen = 128 27 | x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True) 28 | x = x_pt.detach().clone().requires_grad_() 29 | y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) 30 | y[torch.randperm(batch_size * seqlen)[:10]] = -100 31 | model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) 32 | model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward) 33 | out = model(x, y) 34 | out_pt = model_pt(x_pt.float(), y) 35 | assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) 36 | 37 | g = torch.randn_like(out) 38 | out_pt.backward(g) 39 | out.backward(g) 40 | assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) 41 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | #include "fmha_bwd_with_mask_bias_launch_template.h" 4 | 5 | bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream) { 6 | bool status = true; 7 | auto dprops = GetDeviceProperties(-1); 8 | FP16_SWITCH(params.is_bf16, ([&] { 9 | if( params.seqlen_k == 128 ) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; 11 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 12 | } else if( params.seqlen_k >= 256 ) { 13 | if (dprops->major == 8 && dprops->minor == 0) { 14 | // Don't share smem for K & V, and don't keep V in registers 15 | // This speeds things up by 2-3% by avoiding register spills, but it 16 | // uses more shared memory, which is fine on A100 but not other GPUs. 17 | // For other GPUs, we keep V in registers. 18 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; 19 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 20 | } else if (dprops->major == 8 && dprops->minor > 0) { 21 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; 22 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 23 | } else if (dprops->major == 7 && dprops->minor == 5) { 24 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; 25 | status = run_fmha_dgrad_fp16_sm80_loop_(params, stream); 26 | } 27 | } 28 | })); 29 | return status; 30 | } 31 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/fmha_bwd_hdim64.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_bwd_launch_template.h" 6 | #include "cuda_utils.h" 7 | 8 | void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { 9 | FP16_SWITCH(params.is_bf16, ([&] { 10 | auto dprops = GetDeviceProperties(-1); 11 | if (params.seqlen_k == 128) { 12 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; 13 | run_fmha_bwd_loop(params, stream, configure); 14 | } else if (params.seqlen_k >= 256) { 15 | if (dprops->major == 8 && dprops->minor == 0) { 16 | // Don't share smem for K & V, and don't keep V in registers 17 | // This speeds things up by 2-3% by avoiding register spills, but it 18 | // uses more shared memory, which is fine on A100 but not other GPUs. 19 | // For other GPUs, we keep V in registers. 20 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; 21 | run_fmha_bwd_loop(params, stream, configure); 22 | } else if (dprops->major == 8 && dprops->minor > 0) { 23 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; 24 | run_fmha_bwd_loop(params, stream, configure); 25 | } else if (dprops->major == 7 && dprops->minor == 5) { 26 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; 27 | run_fmha_bwd_loop(params, stream, configure); 28 | } 29 | } 30 | })); 31 | } 32 | -------------------------------------------------------------------------------- /csrc/fused_softmax/setup.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron 2 | # We add the case where seqlen = 4k and seqlen = 8k 3 | import os 4 | import subprocess 5 | 6 | import torch 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 9 | 10 | 11 | def get_cuda_bare_metal_version(cuda_dir): 12 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 13 | output = raw_output.split() 14 | release_idx = output.index("release") + 1 15 | release = output[release_idx].split(".") 16 | bare_metal_major = release[0] 17 | bare_metal_minor = release[1][0] 18 | 19 | return raw_output, bare_metal_major, bare_metal_minor 20 | 21 | 22 | def append_nvcc_threads(nvcc_extra_args): 23 | _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) 24 | if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: 25 | return nvcc_extra_args + ["--threads", "4"] 26 | return nvcc_extra_args 27 | 28 | 29 | cc_flag = [] 30 | cc_flag.append("-gencode") 31 | cc_flag.append("arch=compute_70,code=sm_70") 32 | cc_flag.append("-gencode") 33 | cc_flag.append("arch=compute_80,code=sm_80") 34 | 35 | setup( 36 | name='fused_softmax_lib', 37 | ext_modules=[ 38 | CUDAExtension( 39 | name='fused_softmax_lib', 40 | sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], 41 | extra_compile_args={ 42 | 'cxx': ['-O3',], 43 | 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) 44 | } 45 | ) 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension 49 | }) 50 | -------------------------------------------------------------------------------- /training/src/metrics/num_tokens.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from torchmetrics import Metric 7 | 8 | 9 | class NumTokens(Metric): 10 | """Keep track of how many tokens we've seen. 11 | """ 12 | # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch 13 | # of the next epoch. 14 | # Right now the hack is that we override reset(), which would mess up the forward method. 15 | # We then override forward to do the right thing. 16 | 17 | is_differentiable = False 18 | higher_is_better = False 19 | full_state_update = False 20 | count: Tensor 21 | 22 | def __init__(self, **kwargs: Dict[str, Any]): 23 | super().__init__(**kwargs) 24 | self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", 25 | persistent=True) # We want the count to be saved to state-dict 26 | 27 | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore 28 | self.count += target.numel() 29 | 30 | def compute(self) -> Tensor: 31 | return self.count 32 | 33 | def reset(self): 34 | count = self.count 35 | super().reset() 36 | self.count = count 37 | 38 | # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py 39 | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: 40 | """forward computation using single call to `update` to calculate the metric value on the current batch and 41 | accumulate global state. 42 | This can be done when the global metric state is a sinple reduction of batch states. 43 | """ 44 | self.update(*args, **kwargs) 45 | return self.compute() 46 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "cuda.h" 5 | #include "cuda_runtime.h" 6 | 7 | #if !FLASH_ATTN_WITH_TORCH 8 | //////////////////////////////////////////////////////////////////////////////////////////////////// 9 | 10 | #define C10_CUDA_CHECK( call ) \ 11 | do { \ 12 | cudaError_t status_ = call; \ 13 | if( status_ != cudaSuccess ) { \ 14 | fprintf( stderr, \ 15 | "CUDA error (%s:%d): %s\n", \ 16 | __FILE__, \ 17 | __LINE__, \ 18 | cudaGetErrorString( status_ ) ); \ 19 | exit( 1 ); \ 20 | } \ 21 | } while( 0 ) 22 | 23 | #define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) 24 | 25 | //////////////////////////////////////////////////////////////////////////////////////////////////// 26 | 27 | namespace at { 28 | namespace cuda { 29 | static int GetCurrentDeviceId(); 30 | 31 | static int GetCudaDeviceCount(); 32 | 33 | cudaDeviceProp* getCurrentDeviceProperties(int id = -1); 34 | } // namespace cuda 35 | } // namespace at 36 | #endif 37 | -------------------------------------------------------------------------------- /tests/test_rotary.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import pytest 6 | 7 | from einops import rearrange 8 | 9 | from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch 10 | 11 | 12 | is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0) 13 | 14 | @pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])) 15 | # @pytest.mark.parametrize('dtype', ([torch.float16])) 16 | @pytest.mark.parametrize('rotary_fraction', [1.0, 0.5]) 17 | # @pytest.mark.parametrize('rotary_fraction', [0.5]) 18 | @pytest.mark.parametrize('inplace', [False, True]) 19 | # @pytest.mark.parametrize('inplace', [False]) 20 | def test_rotary_single_tensor(inplace, rotary_fraction, dtype): 21 | rtol = 1e-3 22 | batch_size = 32 23 | nheads = 4 24 | seqlen = 217 25 | headdim = 128 26 | x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda', 27 | requires_grad=True) 28 | x_pt = x.detach().clone().requires_grad_() 29 | rotary_dim = int(rotary_fraction * headdim) 30 | assert rotary_dim % 2 == 0 31 | angle = torch.randn(seqlen, rotary_dim // 2, device='cuda') 32 | cos = torch.cos(angle).to(dtype=dtype) 33 | sin = torch.sin(angle).to(dtype=dtype) 34 | out = apply_rotary_emb_func(x, cos, sin, inplace) 35 | out_pt = apply_rotary_emb_torch(x_pt, cos, sin) 36 | # Numerical error if we just do any arithmetic 37 | atol = ((out + 0.3 - 0.3) - out).abs().max().item() 38 | assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) 39 | g = torch.randn_like(out) 40 | g_pt = g.clone() # If inplace=True, we might modify the gradient inplace 41 | out.backward(g) 42 | out_pt.backward(g_pt) 43 | atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() 44 | assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol) 45 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/block_info.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | namespace flash { 8 | 9 | //////////////////////////////////////////////////////////////////////////////////////////////////// 10 | 11 | template 12 | struct BlockInfo { 13 | 14 | template 15 | __device__ BlockInfo(const Params ¶ms, const int bidb) 16 | : sum_s_q(!Varlen || params.varlen_padded_input || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) 17 | , sum_s_k(!Varlen || params.varlen_padded_input || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) 18 | , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]) 19 | , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]) 20 | { 21 | } 22 | 23 | template 24 | inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 25 | return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; 26 | } 27 | 28 | template 29 | inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 30 | return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; 31 | } 32 | 33 | const int sum_s_q; 34 | const int sum_s_k; 35 | const uint32_t actual_seqlen_q; 36 | const uint32_t actual_seqlen_k; 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | } // namespace flash 42 | -------------------------------------------------------------------------------- /csrc/rotary/rotary.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | 8 | #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") 9 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 10 | 11 | void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, 12 | const torch::Tensor cos, const torch::Tensor sin, 13 | torch::Tensor out1, torch::Tensor out2, 14 | const bool conj); 15 | 16 | void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, 17 | const torch::Tensor cos, const torch::Tensor sin, 18 | torch::Tensor out1, torch::Tensor out2, 19 | const bool conj) { 20 | CHECK_DEVICE(x1); CHECK_DEVICE(x2); 21 | CHECK_DEVICE(cos); CHECK_DEVICE(sin); 22 | CHECK_DEVICE(out1); CHECK_DEVICE(out1); 23 | TORCH_CHECK(x1.dtype() == x2.dtype()); 24 | TORCH_CHECK(cos.dtype() == sin.dtype()); 25 | TORCH_CHECK(out1.dtype() == out2.dtype()); 26 | TORCH_CHECK(x1.dtype() == cos.dtype()); 27 | TORCH_CHECK(x1.dtype() == out1.dtype()); 28 | TORCH_CHECK(x1.sizes() == x2.sizes()); 29 | TORCH_CHECK(cos.sizes() == sin.sizes()); 30 | TORCH_CHECK(out1.sizes() == out2.sizes()); 31 | 32 | // Otherwise the kernel will be launched from cuda:0 device 33 | // Cast to char to avoid compiler warning about narrowing 34 | at::cuda::CUDAGuard device_guard{(char)x1.get_device()}; 35 | 36 | apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); 37 | } 38 | 39 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 40 | m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); 41 | } 42 | -------------------------------------------------------------------------------- /csrc/flash_attn_v3/copy_sm90_bulk_reduce.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace cute 10 | { 11 | 12 | //////////////////////////////////////////////////////////////////////////////////////////////////// 13 | 14 | struct SM90_BULK_REDUCE_ADD 15 | { 16 | CUTE_HOST_DEVICE static void 17 | copy(float const* smem_ptr, 18 | float * gmem_ptr, int32_t store_bytes) 19 | { 20 | #if defined(CUTE_ARCH_TMA_SM90_ENABLED) 21 | uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); 22 | asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" 23 | : 24 | : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) 25 | : "memory"); 26 | #else 27 | CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); 28 | #endif 29 | } 30 | 31 | CUTE_HOST_DEVICE static void 32 | copy(float const* smem_ptr, 33 | float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint) 34 | { 35 | #if defined(CUTE_ARCH_TMA_SM90_ENABLED) 36 | uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); 37 | asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n" 38 | : 39 | : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint) 40 | : "memory"); 41 | #else 42 | CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); 43 | #endif 44 | } 45 | }; 46 | 47 | //////////////////////////////////////////////////////////////////////////////////////////////////// 48 | 49 | } // end namespace cute 50 | -------------------------------------------------------------------------------- /csrc/flashmask_v2/copy_sm90_bulk_reduce.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace cute 10 | { 11 | 12 | //////////////////////////////////////////////////////////////////////////////////////////////////// 13 | 14 | struct SM90_BULK_REDUCE_ADD 15 | { 16 | CUTE_HOST_DEVICE static void 17 | copy(float const* smem_ptr, 18 | float * gmem_ptr, int32_t store_bytes) 19 | { 20 | #if defined(CUTE_ARCH_TMA_SM90_ENABLED) 21 | uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); 22 | asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" 23 | : 24 | : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) 25 | : "memory"); 26 | #else 27 | CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); 28 | #endif 29 | } 30 | 31 | CUTE_HOST_DEVICE static void 32 | copy(float const* smem_ptr, 33 | float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint) 34 | { 35 | #if defined(CUTE_ARCH_TMA_SM90_ENABLED) 36 | uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); 37 | asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n" 38 | : 39 | : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint) 40 | : "memory"); 41 | #else 42 | CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); 43 | #endif 44 | } 45 | }; 46 | 47 | //////////////////////////////////////////////////////////////////////////////////////////////////// 48 | 49 | } // end namespace cute 50 | -------------------------------------------------------------------------------- /training/configs/experiment/owt/gpt2l-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m-flash.yaml 4 | - override /model/gpt2model: gpt2-large 5 | # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW. 6 | # Still, fairscale is even faster and uses less memory. 7 | # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2? 8 | # However, fairscale has issues with saving checkpoint (either OOM or very 9 | # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the 10 | # upstream version of OSS 11 | # https://github.com/facebookresearch/fairscale/issues/937 12 | # Pytorch ZeRO as also very slow for saving checkpoints due to 13 | # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU. 14 | - override /optimizer: adamw-zero 15 | 16 | # FusedAdam doesn't seem to speed things up here, time per global step 17 | # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam. 18 | # This could be because each GPU is only doing the optimizer step for 1 / 19 | # world_size of the parameters. 20 | # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO). 21 | # - override /optimizer: adamw-apex-zero 22 | 23 | # Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB 24 | # model: 25 | # config: 26 | # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} 27 | # mlp_checkpoint_lvl: 1 28 | 29 | datamodule: 30 | # batch_size: 16 31 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 32 | 33 | trainer: 34 | # strategy: null 35 | # strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"} 36 | strategy: 37 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 38 | find_unused_parameters: False 39 | gradient_as_bucket_view: True 40 | # TD [2022-08-03] Deepspeed makes the ppl curve go wild 41 | # strategy: deepspeed_stage_1 42 | -------------------------------------------------------------------------------- /training/src/distributed/ddp_comm_hooks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html 2 | # We divide by world_size first before converting to fp16, so it's safer. 3 | from typing import Any, Callable 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def fp16_compress_hook( 10 | process_group: dist.ProcessGroup, bucket: dist.GradBucket 11 | ) -> torch.futures.Future[torch.Tensor]: 12 | """ 13 | This DDP communication hook implements a simple gradient compression 14 | approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) 15 | and then divides it by the process group size. 16 | It allreduces those ``float16`` gradient tensors. Once compressed gradient 17 | tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). 18 | 19 | Example:: 20 | >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) 21 | """ 22 | group_to_use = process_group if process_group is not None else dist.group.WORLD 23 | world_size = group_to_use.size() 24 | 25 | # Divide first before converting to fp16 26 | # Use out argument to fuse the division and the conversion. 27 | compressed_tensor = torch.div(bucket.buffer(), world_size, 28 | out=torch.empty_like(bucket.buffer(), dtype=torch.float16)) 29 | 30 | fut = dist.all_reduce( 31 | compressed_tensor, group=group_to_use, async_op=True 32 | ).get_future() 33 | 34 | def decompress(fut): 35 | decompressed_tensor = bucket.buffer() 36 | # Decompress in place to reduce the peak memory. 37 | # See: https://github.com/pytorch/pytorch/issues/45968 38 | decompressed_tensor.copy_(fut.value()[0]) 39 | return decompressed_tensor 40 | 41 | # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case 42 | # resend with fp32? 43 | return fut.then(decompress) 44 | -------------------------------------------------------------------------------- /tests/models/test_vit.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | import pytest 5 | 6 | from timm.models.vision_transformer import vit_base_patch16_224 7 | 8 | from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 9 | 10 | 11 | @pytest.mark.parametrize('fused_mlp', [False, True]) 12 | # @pytest.mark.parametrize('fused_mlp', [False]) 13 | @pytest.mark.parametrize('optimized', [False, True]) 14 | # @pytest.mark.parametrize('optimized', [True]) 15 | def test_vit(optimized, fused_mlp): 16 | """Check that our implementation of ViT matches the timm's implementation: 17 | the output of our forward pass in fp16 should be around the same as 18 | timm' forward pass in fp16, when compared to timm's forward pass in fp32. 19 | """ 20 | dtype = torch.float16 21 | device = 'cuda' 22 | 23 | kwargs = {} 24 | if optimized: 25 | kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) 26 | kwargs['fused_mlp'] = fused_mlp 27 | model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) 28 | 29 | model_ref = vit_base_patch16_224(pretrained=True).to(device=device) 30 | model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype) 31 | 32 | model.load_state_dict(model_ref.state_dict()) 33 | 34 | model.eval() 35 | model_ref.eval() 36 | model_timm.eval() 37 | 38 | torch.manual_seed(0) 39 | batch_size = 2 40 | x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype) 41 | out = model(x) 42 | out_timm = model_timm(x) 43 | out_ref = model_ref(x.float()) 44 | 45 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') 46 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') 47 | print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') 48 | print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') 49 | rtol = 2 if not fused_mlp else 4 50 | assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() 51 | -------------------------------------------------------------------------------- /training/src/callbacks/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py 2 | from typing import Any 3 | from pathlib import Path 4 | 5 | import pytorch_lightning as pl 6 | 7 | 8 | class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint): 9 | 10 | def __init__(self, *args, fault_tolerant=False, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.fault_tolerant = fault_tolerant 13 | 14 | def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: 15 | if self.fault_tolerant: 16 | # overwrite if necessary 17 | trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) 18 | 19 | # def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: 20 | # if self.fault_tolerant: 21 | # trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) 22 | 23 | 24 | # TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant. 25 | # However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work. 26 | # So I decided to just copy _FaultToleranceCheckpoint and just save on_exception. 27 | 28 | # def on_save_checkpoint(self, checkpoint): 29 | # # TD [2022-07-12] The "completed" counter is off by 1 so when it resumes 30 | # # it's off by 1 iteration. However, the data is still off by 1 iteration, probably 31 | # # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how 32 | # # to fix it cleanly. 33 | # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1 34 | # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1 35 | # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1 36 | # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1 37 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/utils.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | void SetZero(void *ptr, size_t sizeof_type, std::initializer_list shapes, cudaStream_t stream) { 4 | size_t n = sizeof_type; 5 | for (int s : shapes) n *= s; 6 | FMHA_CHECK_CUDA(cudaMemsetAsync(ptr, 0, n, stream)); 7 | } 8 | 9 | template 10 | static __global__ void FillConstantKernel(T *ptr, T value, size_t n) { 11 | auto idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; 12 | if (idx < n) { 13 | ptr[idx] = value; 14 | } 15 | } 16 | 17 | template 18 | void SetConstValue(void *ptr, T value, size_t n, cudaStream_t stream) { 19 | constexpr auto kNumThreads = 1024; 20 | auto block = (n + kNumThreads - 1) / kNumThreads; 21 | FillConstantKernel<<>>(static_cast(ptr), value, n); 22 | } 23 | 24 | template 25 | void SetConstValue(void *ptr, float value, size_t n, cudaStream_t stream); 26 | 27 | static __global__ void _float2half(float *float_ptr, __half *half_ptr, size_t n) { 28 | const int idx = threadIdx.x + blockDim.x * blockIdx.x; 29 | if (idx < n) { 30 | half_ptr[idx] = __float2half(float_ptr[idx]); 31 | } 32 | } 33 | 34 | void Float2Half(void *float_ptr, void *half_ptr, size_t n, cudaStream_t stream) { 35 | constexpr auto kNumThreads = 1024; 36 | auto block = (n + kNumThreads - 1) / kNumThreads; 37 | _float2half<<>>(static_cast(float_ptr), static_cast<__half *>(half_ptr), n); 38 | } 39 | 40 | static __global__ void _float2bfloat16(float *float_ptr, __nv_bfloat16 *bf16_ptr, size_t n) { 41 | const int idx = threadIdx.x + blockDim.x * blockIdx.x; 42 | if (idx < n) { 43 | bf16_ptr[idx] = __float2bfloat16(float_ptr[idx]); 44 | } 45 | } 46 | 47 | void Float2BF16(void *float_ptr, void *bf16_ptr, size_t n, cudaStream_t stream) { 48 | constexpr auto kNumThreads = 1024; 49 | auto block = (n + kNumThreads - 1) / kNumThreads; 50 | _float2bfloat16<<>>(static_cast(float_ptr), static_cast<__nv_bfloat16 *>(bf16_ptr), n); 51 | } 52 | -------------------------------------------------------------------------------- /csrc/rotary/rotary_cuda.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, 10 | const torch::Tensor cos, const torch::Tensor sin, 11 | torch::Tensor out1, torch::Tensor out2, 12 | const bool conj) { 13 | auto iter = at::TensorIteratorConfig() 14 | .add_output(out1) 15 | .add_output(out2) 16 | .add_input(x1) 17 | .add_input(x2) 18 | .add_input(cos) 19 | .add_input(sin) 20 | .check_all_same_dtype(false) 21 | .promote_inputs_to_common_dtype(false) 22 | .build(); 23 | 24 | if (!conj) { 25 | AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { 26 | at::native::gpu_kernel_multiple_outputs( 27 | iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, 28 | scalar_t sin) -> thrust::tuple { 29 | scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); 30 | scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); 31 | return {out1, out2}; 32 | }); 33 | }); 34 | } else { 35 | AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { 36 | at::native::gpu_kernel_multiple_outputs( 37 | iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, 38 | scalar_t sin) -> thrust::tuple { 39 | scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); 40 | scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); 41 | return {out1, out2}; 42 | }); 43 | }); 44 | } 45 | } -------------------------------------------------------------------------------- /training/src/datamodules/datasets/detokenizer.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py 2 | # Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py 3 | 4 | """ 5 | Handle detokenization for different dataset for zero-shot LM evaluation. 6 | """ 7 | import re 8 | 9 | 10 | def wikitext_detokenize(string: str) -> str: 11 | """ 12 | Wikitext is whitespace tokenized and we remove these whitespaces. 13 | Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py 14 | """ 15 | # Contractions 16 | string = string.replace("s '", "s'") 17 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 18 | 19 | # Number Separators 20 | string = string.replace(" @-@ ", "-") 21 | string = string.replace(" @,@ ", ",") 22 | string = string.replace(" @.@ ", ".") 23 | 24 | # Punctuation 25 | string = string.replace(" : ", ": ") 26 | string = string.replace(" ; ", "; ") 27 | string = string.replace(" . ", ". ") 28 | string = string.replace(" ! ", "! ") 29 | string = string.replace(" ? ", "? ") 30 | string = string.replace(" , ", ", ") 31 | 32 | # Double Brackets 33 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 34 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 35 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 36 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 37 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 38 | 39 | # Miscellaneous 40 | string = string.replace("= = = =", "====") 41 | string = string.replace("= = =", "===") 42 | string = string.replace("= =", "==") 43 | string = string.replace(" " + chr(176) + " ", chr(176)) 44 | string = string.replace(" \n", "\n") 45 | string = string.replace("\n ", "\n") 46 | string = string.replace(" N ", " 1 ") 47 | string = string.replace(" 's", "'s") 48 | 49 | return string 50 | 51 | 52 | # Set Registry for Various Datasets 53 | DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} 54 | -------------------------------------------------------------------------------- /flash_attn/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py 2 | # But we use nn.Linear instead of Conv2d and it's about 8x faster. 3 | 4 | from functools import partial 5 | 6 | import torch.nn as nn 7 | from torch import _assert 8 | from torch.nn.modules.utils import _pair 9 | 10 | from einops import rearrange 11 | 12 | try: 13 | from flash_attn.ops.fused_dense import FusedDense 14 | except ImportError: 15 | FusedDense = None 16 | 17 | 18 | class PatchEmbed(nn.Module): 19 | """ 2D Image to Patch Embedding 20 | """ 21 | def __init__( 22 | self, 23 | img_size=224, 24 | patch_size=16, 25 | in_chans=3, 26 | embed_dim=768, 27 | norm_layer=None, 28 | flatten=True, 29 | bias=True, 30 | fused_bias_fc=False, 31 | ): 32 | super().__init__() 33 | img_size = _pair(img_size) 34 | patch_size = _pair(patch_size) 35 | self.img_size = img_size 36 | self.patch_size = patch_size 37 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 38 | self.num_patches = self.grid_size[0] * self.grid_size[1] 39 | self.flatten = flatten 40 | if fused_bias_fc and FusedDense is None: 41 | raise ImportError('fused_dense is not installed') 42 | 43 | linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense 44 | self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) 45 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 46 | 47 | def forward(self, x): 48 | _, _, H, W = x.shape 49 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 50 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 51 | x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)', 52 | p1=self.patch_size[0], p2=self.patch_size[1])) 53 | if self.flatten: 54 | x = rearrange(x, 'b h w c -> b (h w) c') 55 | x = self.norm(x) 56 | return x 57 | -------------------------------------------------------------------------------- /training/src/utils/flops.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py 2 | import torch 3 | 4 | try: 5 | from deepspeed.profiling.flops_profiler import get_model_profile 6 | has_deepspeed_profiling = True 7 | except ImportError as e: 8 | has_deepspeed_profiling = False 9 | 10 | try: 11 | from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table 12 | from fvcore.nn import ActivationCountAnalysis 13 | has_fvcore_profiling = True 14 | except ImportError as e: 15 | FlopCountAnalysis = None 16 | ActivationCountAnalysis = None 17 | has_fvcore_profiling = False 18 | 19 | 20 | def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch.float32, 21 | batch_size=1, detailed=False): 22 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 23 | flops, macs, params = get_model_profile( 24 | model=model, 25 | args=torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype), 26 | print_profile=detailed, # prints the model graph with the measured profile attached to each module 27 | detailed=detailed, # print the detailed profile 28 | warm_up=10, # the number of warm-ups before measuring the time of each module 29 | as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) 30 | output_file=None, # path to the output file. If None, the profiler prints to stdout. 31 | ignore_modules=None) # the list of modules to ignore in the profiling 32 | return macs, 0 # no activation count in DS 33 | 34 | 35 | def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.float32, max_depth=4, 36 | batch_size=1, detailed=False, force_cpu=False): 37 | if force_cpu: 38 | model = model.to('cpu') 39 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 40 | example_input = torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype) 41 | fca = FlopCountAnalysis(model, example_input) 42 | aca = ActivationCountAnalysis(model, example_input) 43 | if detailed: 44 | print(flop_count_table(fca, max_depth=max_depth)) 45 | return fca, fca.total(), aca, aca.total() 46 | -------------------------------------------------------------------------------- /training/src/callbacks/flop_count.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | 6 | from pytorch_lightning import Callback, Trainer, LightningModule 7 | from pytorch_lightning.utilities import rank_zero_only 8 | from pytorch_lightning.utilities.parsing import AttributeDict 9 | 10 | from src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling 11 | from src.utils.flops import profile_deepspeed, profile_fvcore 12 | 13 | 14 | class FlopCount(Callback): 15 | """Counter the number of FLOPs used by the model 16 | """ 17 | def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'], 18 | input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None): 19 | if not isinstance(profilers, Sequence): 20 | profilers = [profilers] 21 | if any(p not in ['fvcore', 'deepspeed'] for p in profilers): 22 | raise NotImplementedError('Only support fvcore and deepspeed profilers') 23 | if 'fvcore' in profilers and not has_fvcore_profiling: 24 | raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`') 25 | elif 'deepspeed' in profilers and not has_deepspeed_profiling: 26 | raise ImportError('deepspeed is not installed') 27 | super().__init__() 28 | self.profilers = profilers 29 | self.input_size = tuple(input_size) 30 | self.input_dtype = input_dtype 31 | self.device = device 32 | 33 | @rank_zero_only 34 | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 35 | if 'fvcore' in self.profilers: 36 | _, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size, 37 | input_dtype=self.input_dtype, detailed=True) 38 | trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6}) 39 | if 'deepspeed' in self.profilers: 40 | macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size, 41 | input_dtype=self.input_dtype, detailed=True) 42 | if 'fvcore' not in self.profilers: # fvcore's MACs seem more accurate 43 | trainer.logger.log_hyperparams({'GMACs': macs * 1e-9}) 44 | -------------------------------------------------------------------------------- /csrc/flash_attn_with_bias_and_mask/src/random_utils.h: -------------------------------------------------------------------------------- 1 | // Stores RNG state values. Passed as a kernel argument. 2 | // See Note [CUDA Graph-safe RNG states]. 3 | // 4 | // The raw definition lives in its own file so jit codegen can easily copy it. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | struct PhiloxCudaState { 12 | PhiloxCudaState() = default; 13 | // Called if graph capture is not underway 14 | PhiloxCudaState(uint64_t seed, 15 | uint64_t offset) { 16 | seed_ = seed; 17 | offset_.val = offset; 18 | } 19 | // Called if graph capture is underway 20 | PhiloxCudaState(uint64_t seed, 21 | int64_t* offset_extragraph, 22 | uint32_t offset_intragraph) { 23 | seed_ = seed; 24 | offset_.ptr = offset_extragraph; 25 | offset_intragraph_ = offset_intragraph; 26 | captured_ = true; 27 | } 28 | 29 | // Public members, directly accessible by philox::unpack. 30 | // If we made them private with getters/setters, the getters/setters 31 | // would have to be __device__, and we can't declare __device__ in ATen. 32 | union Payload { 33 | uint64_t val; 34 | int64_t* ptr; 35 | }; 36 | 37 | uint64_t seed_ = 0; 38 | Payload offset_; 39 | uint32_t offset_intragraph_ = 0; 40 | bool captured_ = false; 41 | }; 42 | 43 | 44 | // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether 45 | // that instance was created with graph capture underway or not. 46 | // See Note [CUDA Graph-safe RNG states]. 47 | // 48 | // The raw definition lives in its own file so jit codegen can easily copy it. 49 | #if defined(__CUDA_ACC__) or defined(__CUDA_ARCH__) 50 | #define DEVICE __device__ 51 | #else 52 | #define DEVICE 53 | #endif 54 | 55 | namespace philox { 56 | 57 | inline DEVICE std::tuple 58 | unpack(PhiloxCudaState arg) { 59 | if (arg.captured_) { 60 | // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". 61 | // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. 62 | // For most threads' reads it will hit in cache, so it shouldn't hurt performance. 63 | return std::make_tuple(arg.seed_, static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); 64 | } else { 65 | return std::make_tuple(arg.seed_, arg.offset_.val); 66 | } 67 | } 68 | 69 | } // namespace philox 70 | 71 | -------------------------------------------------------------------------------- /csrc/xentropy/interface.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | std::vector softmax_xentropy_cuda( 5 | const at::Tensor &input, 6 | const at::Tensor &labels, 7 | const float smoothing, 8 | const int total_classes); 9 | 10 | at::Tensor softmax_xentropy_backward_cuda( 11 | const at::Tensor &grad_loss, 12 | at::Tensor &logits, 13 | const at::Tensor &max_log_sum_exp, 14 | const at::Tensor &labels, 15 | const float smoothing, 16 | const bool inplace, 17 | const int total_classes); 18 | 19 | // C++ interface 20 | 21 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 24 | 25 | std::vector softmax_xentropy_forward( 26 | const at::Tensor &input, 27 | const at::Tensor &labels, 28 | const float smoothing, 29 | const int total_classes=-1) { 30 | // For tensor parallel cross entropy with smoothing, we want to pass in the total number 31 | // of classes so that smoothing can be applied correctly. If total_classes=-1, use the 32 | // last dimension of the input tensor. 33 | CHECK_INPUT(input); 34 | CHECK_INPUT(labels); 35 | 36 | return softmax_xentropy_cuda(input, labels, smoothing, total_classes); 37 | } 38 | 39 | at::Tensor softmax_xentropy_backward( 40 | const at::Tensor &grad_loss, 41 | at::Tensor &logits, 42 | const at::Tensor &max_log_sum_exp, 43 | const at::Tensor &labels, 44 | const float smoothing, 45 | const bool inplace, 46 | const int total_classes=-1) { 47 | CHECK_INPUT(grad_loss); 48 | CHECK_INPUT(logits); 49 | CHECK_INPUT(max_log_sum_exp); 50 | CHECK_INPUT(labels); 51 | 52 | return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, 53 | smoothing, inplace, total_classes); 54 | } 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); 58 | m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); 59 | } 60 | -------------------------------------------------------------------------------- /csrc/flash_attn/src/random_utils.h: -------------------------------------------------------------------------------- 1 | // Stores RNG state values. Passed as a kernel argument. 2 | // See Note [CUDA Graph-safe RNG states]. 3 | // 4 | // The raw definition lives in its own file so jit codegen can easily copy it. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | #if !FLASH_ATTN_WITH_TORCH 12 | namespace at { 13 | struct PhiloxCudaState { 14 | PhiloxCudaState() = default; 15 | // Called if graph capture is not underway 16 | PhiloxCudaState(uint64_t seed, 17 | uint64_t offset) { 18 | seed_ = seed; 19 | offset_.val = offset; 20 | } 21 | // Called if graph capture is underway 22 | PhiloxCudaState(uint64_t seed, 23 | int64_t* offset_extragraph, 24 | uint32_t offset_intragraph) { 25 | seed_ = seed; 26 | offset_.ptr = offset_extragraph; 27 | offset_intragraph_ = offset_intragraph; 28 | captured_ = true; 29 | } 30 | 31 | // Public members, directly accessible by philox::unpack. 32 | // If we made them private with getters/setters, the getters/setters 33 | // would have to be __device__, and we can't declare __device__ in ATen. 34 | union Payload { 35 | uint64_t val; 36 | int64_t* ptr; 37 | }; 38 | 39 | uint64_t seed_ = 0; 40 | Payload offset_; 41 | uint32_t offset_intragraph_ = 0; 42 | bool captured_ = false; 43 | }; 44 | } // namespace at 45 | #endif 46 | 47 | 48 | // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether 49 | // that instance was created with graph capture underway or not. 50 | // See Note [CUDA Graph-safe RNG states]. 51 | // 52 | // The raw definition lives in its own file so jit codegen can easily copy it. 53 | #if defined(__CUDA_ACC__) or defined(__CUDA_ARCH__) 54 | #define DEVICE __device__ 55 | #else 56 | #define DEVICE 57 | #endif 58 | 59 | #if !FLASH_ATTN_WITH_TORCH 60 | namespace at { 61 | namespace cuda { 62 | namespace philox { 63 | 64 | inline DEVICE std::tuple 65 | unpack(PhiloxCudaState arg) { 66 | if (arg.captured_) { 67 | // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". 68 | // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. 69 | // For most threads' reads it will hit in cache, so it shouldn't hurt performance. 70 | return std::make_tuple(arg.seed_, static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); 71 | } else { 72 | return std::make_tuple(arg.seed_, arg.offset_.val); 73 | } 74 | } 75 | 76 | } // namespace philox 77 | } // namespace cuda 78 | } // namespace at 79 | #endif 80 | -------------------------------------------------------------------------------- /training/src/callbacks/causality_monitor.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from pytorch_lightning import Callback 4 | from pytorch_lightning.utilities import rank_zero_only 5 | 6 | import torch 7 | from torch.autograd import grad 8 | 9 | class CausalityMonitor(Callback): 10 | r"""Monitor causality of a model by tracking gradient leakage forward in time. 11 | In a fully causal model, dy[k]du[s] ~= 0 for all k < s. 12 | 13 | Args: 14 | seq_len (int): Length of the sequence to monitor. 15 | input_dim (int): Dimension of the input to monitor. If 0, the callback assumes 16 | the task to be language modeling, and skips the embedding layer. If > 0, 17 | input_dim is interpreted as the input channel dimension, i.e. D with 18 | dummy input of dimension [B, L, D]. 19 | 20 | Notes: 21 | This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute, 22 | indicating the primary model to monitor. For LMs, `net` or `s4seq` should 23 | be after the embedding layer. 24 | """ 25 | 26 | def __init__(self, seq_len: int = 10, input_dim: int = 0): 27 | super().__init__() 28 | self.seq_len = seq_len 29 | self.input_dim = input_dim 30 | 31 | @rank_zero_only 32 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 33 | model = pl_module.model 34 | 35 | with torch.enable_grad(): 36 | if self.input_dim == 0: 37 | # [MP] LongTensors cannot have gradients - we start from post 38 | # embedding in the LM case 39 | input_dim = model.d_model 40 | x = torch.randn((2, self.seq_len, input_dim), \ 41 | requires_grad=True).to(pl_module.device) 42 | # [DF] HACK: we need to get the layer that comes after the embedding 43 | if hasattr(model, 'net'): 44 | y = model.net(x) 45 | else: 46 | y = model.s4seq(x) 47 | else: 48 | x = torch.randn(1, self.seq_len, self.input_dim, \ 49 | requires_grad=True).to(pl_module.device) 50 | y = model(x) 51 | 52 | stats = {} 53 | for i in range(self.seq_len): 54 | # total gradients flowing from y_i to x 55 | g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0] 56 | g = g[0,i+1:,:].abs().mean() 57 | stats[f'stats/causality_{i}'] = g.item() 58 | 59 | if trainer.loggers is not None: 60 | for logger in trainer.loggers: 61 | logger.log_metrics(stats, step=trainer.global_step) 62 | --------------------------------------------------------------------------------