├── .gitignore ├── AUTHORS ├── CHANGELOG ├── LICENSE ├── README.md ├── assets ├── banner.png ├── benchmark.png └── logo.png ├── benchmarks ├── benchmark.py ├── benchmark_conv1d.py └── benchmark_flashfftconv.py ├── csrc └── flashfftconv │ ├── .gitignore │ ├── butterfly │ ├── butterfly.h │ ├── butterfly_cuda.cu │ ├── butterfly_cuda_bf16.cu │ ├── butterfly_ifft_cuda.cu │ ├── butterfly_ifft_cuda_bf16.cu │ ├── butterfly_padded_cuda.cu │ ├── butterfly_padded_cuda_bf16.cu │ ├── butterfly_padded_ifft_cuda.cu │ ├── butterfly_padded_ifft_cuda_bf16.cu │ └── shared.h │ ├── conv1d │ ├── conv1d.h │ ├── conv1d_bhl.cu │ ├── conv1d_blh.cu │ ├── conv1d_bwd_cuda_bhl.cu │ ├── conv1d_bwd_cuda_blh.cu │ └── shared.h │ ├── monarch.cpp │ ├── monarch_cuda │ ├── kernels_bf16 │ │ ├── monarch_cuda_16_16_16_bwd_complex_kernel_bf16.h │ │ ├── monarch_cuda_16_16_16_bwd_kernel_bf16.h │ │ ├── monarch_cuda_16_16_16_complex_kernel_bf16.h │ │ ├── monarch_cuda_16_16_16_kernel_bf16.h │ │ ├── monarch_cuda_16_32_32_bwd_complex_kernel_bf16.h │ │ ├── monarch_cuda_16_32_32_bwd_kernel_bf16.h │ │ ├── monarch_cuda_16_32_32_complex_kernel_bf16.h │ │ ├── monarch_cuda_16_32_32_kernel_bf16.h │ │ ├── monarch_cuda_32_16_16_bwd_complex_kernel_bf16.h │ │ ├── monarch_cuda_32_16_16_bwd_kernel_bf16.h │ │ ├── monarch_cuda_32_16_16_complex_kernel_bf16.h │ │ ├── monarch_cuda_32_16_16_kernel_bf16.h │ │ ├── monarch_cuda_32_32_32_bwd_complex_kernel_bf16.h │ │ ├── monarch_cuda_32_32_32_bwd_kernel_bf16.h │ │ ├── monarch_cuda_32_32_32_complex_kernel_bf16.h │ │ ├── monarch_cuda_32_32_32_kernel_bf16.h │ │ ├── monarch_cuda_bwd_kernel_bf16.h │ │ ├── monarch_cuda_bwd_kernel_r2r_bf16.h │ │ ├── monarch_cuda_kernel_bf16.h │ │ ├── monarch_cuda_kernel_r2r_bf16.h │ │ ├── monarch_cuda_shared_bf16.h │ │ ├── monarch_cuda_shared_bf16_no_float_shm.h │ │ ├── monarch_cuda_shared_r2r_bf16.h │ │ └── shared │ │ │ ├── monarch_cuda_shared_bf16_complex_mul.h │ │ │ ├── monarch_cuda_shared_bf16_load_frags.h │ │ │ └── monarch_cuda_shared_bf16_matmuls.h │ ├── kernels_fp16 │ │ ├── monarch_cuda_16_16_16_bwd_complex_kernel.h │ │ ├── monarch_cuda_16_16_16_bwd_kernel.h │ │ ├── monarch_cuda_16_16_16_bwd_kernel_fp16_bf16_inp.h │ │ ├── monarch_cuda_16_16_16_complex_kernel.h │ │ ├── monarch_cuda_16_16_16_kernel.h │ │ ├── monarch_cuda_16_16_16_kernel_fp16_bf16_inp.h │ │ ├── monarch_cuda_16_32_32_bwd_complex_kernel.h │ │ ├── monarch_cuda_16_32_32_bwd_kernel.h │ │ ├── monarch_cuda_16_32_32_complex_kernel.h │ │ ├── monarch_cuda_16_32_32_kernel.h │ │ ├── monarch_cuda_32_16_16_bwd_complex_kernel.h │ │ ├── monarch_cuda_32_16_16_bwd_kernel.h │ │ ├── monarch_cuda_32_16_16_complex_kernel.h │ │ ├── monarch_cuda_32_16_16_kernel.h │ │ ├── monarch_cuda_32_16_16_kernel_fp16_bf16_inp.h │ │ ├── monarch_cuda_32_32_32_bwd_complex_kernel.h │ │ ├── monarch_cuda_32_32_32_bwd_kernel.h │ │ ├── monarch_cuda_32_32_32_complex_kernel.h │ │ ├── monarch_cuda_32_32_32_complex_truncated_kernel.h │ │ ├── monarch_cuda_32_32_32_kernel.h │ │ ├── monarch_cuda_bwd_kernel.h │ │ ├── monarch_cuda_bwd_kernel_r2r.h │ │ ├── monarch_cuda_kernel.h │ │ ├── monarch_cuda_kernel_r2r.h │ │ ├── monarch_cuda_shared.h │ │ ├── monarch_cuda_shared_r2r.h │ │ ├── monarch_cuda_shared_truncated.h │ │ └── shared │ │ │ ├── monarch_cuda_shared_fp16_complex_mul.h │ │ │ ├── monarch_cuda_shared_fp16_load_frags.h │ │ │ └── monarch_cuda_shared_fp16_matmuls.h │ ├── monarch_bwd.h │ ├── monarch_bwd_complex.h │ ├── monarch_bwd_r2r.h │ ├── monarch_cuda_interface_bwd.cu │ ├── monarch_cuda_interface_bwd_bf16.cu │ ├── monarch_cuda_interface_bwd_bf16_complex.cu │ ├── monarch_cuda_interface_bwd_complex.cu │ ├── monarch_cuda_interface_bwd_r2r.cu │ ├── monarch_cuda_interface_bwd_r2r_bf16.cu │ ├── monarch_cuda_interface_fwd.cu │ ├── monarch_cuda_interface_fwd_bf16.cu │ ├── monarch_cuda_interface_fwd_bf16_complex.cu │ ├── monarch_cuda_interface_fwd_complex.cu │ ├── monarch_cuda_interface_fwd_r2r.cu │ ├── monarch_cuda_interface_fwd_r2r_bf16.cu │ ├── monarch_fwd.h │ ├── monarch_fwd_complex.h │ └── monarch_fwd_r2r.h │ └── setup.py ├── examples ├── README.md ├── bert │ ├── README.md │ ├── benchmark.py │ ├── benchmark_fwd.py │ ├── bert_layers.py │ ├── bert_padding.py │ ├── blockdiag_linear.py │ ├── blockdiag_multiply.py │ ├── configs │ │ ├── m2-110M-flashfftconv.yaml │ │ └── m2-110M.yaml │ ├── configuration_bert.py │ ├── create_bert.py │ ├── hyena_utils.py │ ├── monarch_mixer_sequence_mixer.py │ ├── monarch_mixer_sequence_mixer_flashfftconv.py │ ├── requirements.txt │ └── structured_linear.py ├── hyena-dna │ ├── README.md │ ├── benchmark.py │ ├── benchmark_flash_dna_fwd.py │ ├── benchmark_hyena_dna_fwd.py │ ├── huggingface.py │ ├── hyenadna_flashfftconv.py │ └── hyenadna_standalone.py ├── hyena │ ├── README.md │ ├── benchmark.py │ ├── benchmark_fwd.py │ ├── configs │ │ ├── callbacks │ │ │ └── base.yaml │ │ ├── config.yaml │ │ ├── dataset │ │ │ └── thepile.yaml │ │ ├── experiment │ │ │ ├── base.yaml │ │ │ └── pile │ │ │ │ ├── base.yaml │ │ │ │ ├── hyena-flashfft.yaml │ │ │ │ └── hyena.yaml │ │ ├── loader │ │ │ └── default.yaml │ │ ├── model │ │ │ ├── base.yaml │ │ │ ├── layer │ │ │ │ ├── ff.yaml │ │ │ │ ├── h3-conv.yaml │ │ │ │ ├── h3.yaml │ │ │ │ ├── hyena-filter.yaml │ │ │ │ ├── hyena-flashfft.yaml │ │ │ │ ├── hyena.yaml │ │ │ │ ├── id.yaml │ │ │ │ ├── long-conv.yaml │ │ │ │ ├── mha.yaml │ │ │ │ ├── s4_simple.yaml │ │ │ │ ├── s4d.yaml │ │ │ │ ├── transformer.yaml │ │ │ │ └── vit.yaml │ │ │ └── long-conv.yaml │ │ ├── optimizer │ │ │ ├── adam.yaml │ │ │ ├── adamw.yaml │ │ │ ├── lamb.yaml │ │ │ └── sgd.yaml │ │ ├── pipeline │ │ │ └── thepile.yaml │ │ ├── scheduler │ │ │ ├── cosine_warmup.yaml │ │ │ └── cosine_warmup_timm.yaml │ │ ├── task │ │ │ └── lm.yaml │ │ └── trainer │ │ │ ├── debug.yaml │ │ │ ├── default.yaml │ │ │ ├── full.yaml │ │ │ └── lm.yaml │ ├── flash-attention │ │ ├── .gitignore │ │ ├── .gitmodules │ │ ├── AUTHORS │ │ ├── LICENSE │ │ ├── MANIFEST.in │ │ ├── Makefile │ │ ├── README.md │ │ ├── assets │ │ │ ├── flashattn_banner.jpg │ │ │ ├── flashattn_banner.pdf │ │ │ ├── flashattn_memory.jpg │ │ │ ├── flashattn_speedup.jpg │ │ │ ├── flashattn_speedup_3090.jpg │ │ │ ├── flashattn_speedup_a100_d128.jpg │ │ │ ├── flashattn_speedup_t4.jpg │ │ │ ├── flashattn_speedup_t4_fwd.jpg │ │ │ ├── gpt2_training_curve.jpg │ │ │ ├── gpt2_training_efficiency.jpg │ │ │ ├── gpt3_training_curve.jpg │ │ │ └── gpt3_training_efficiency.jpg │ │ ├── benchmarks │ │ │ ├── benchmark_causal.py │ │ │ └── benchmark_flash_attention.py │ │ ├── csrc │ │ │ ├── flash_attn │ │ │ │ ├── fmha_api.cpp │ │ │ │ └── src │ │ │ │ │ ├── fmha.h │ │ │ │ │ ├── fmha │ │ │ │ │ ├── gemm.h │ │ │ │ │ ├── gmem_tile.h │ │ │ │ │ ├── kernel_traits.h │ │ │ │ │ ├── mask.h │ │ │ │ │ ├── smem_tile.h │ │ │ │ │ ├── softmax.h │ │ │ │ │ └── utils.h │ │ │ │ │ ├── fmha_block_dgrad_fp16_kernel_loop.sm80.cu │ │ │ │ │ ├── fmha_block_dgrad_kernel_1xN_loop.h │ │ │ │ │ ├── fmha_block_fprop_fp16_kernel.sm80.cu │ │ │ │ │ ├── fmha_block_fprop_kernel_1xN.h │ │ │ │ │ ├── fmha_blockmask.h │ │ │ │ │ ├── fmha_bwd_hdim128.cu │ │ │ │ │ ├── fmha_bwd_hdim32.cu │ │ │ │ │ ├── fmha_bwd_hdim64.cu │ │ │ │ │ ├── fmha_bwd_launch_template.h │ │ │ │ │ ├── fmha_dgrad_kernel_1xN_loop.h │ │ │ │ │ ├── fmha_fprop_kernel_1xN.h │ │ │ │ │ ├── fmha_fwd_hdim128.cu │ │ │ │ │ ├── fmha_fwd_hdim32.cu │ │ │ │ │ ├── fmha_fwd_hdim64.cu │ │ │ │ │ ├── fmha_fwd_launch_template.h │ │ │ │ │ ├── fmha_kernel.h │ │ │ │ │ ├── fmha_utils.h │ │ │ │ │ ├── philox.cuh │ │ │ │ │ └── static_switch.h │ │ │ ├── ft_attention │ │ │ │ ├── README.md │ │ │ │ ├── cuda_bf16_fallbacks.cuh │ │ │ │ ├── cuda_bf16_wrapper.h │ │ │ │ ├── decoder_masked_multihead_attention.cu │ │ │ │ ├── decoder_masked_multihead_attention.h │ │ │ │ ├── decoder_masked_multihead_attention_template.hpp │ │ │ │ ├── decoder_masked_multihead_attention_utils.h │ │ │ │ ├── ft_attention.cpp │ │ │ │ └── setup.py │ │ │ ├── fused_dense_lib │ │ │ │ ├── README.md │ │ │ │ ├── fused_dense.cpp │ │ │ │ ├── fused_dense_cuda.cu │ │ │ │ └── setup.py │ │ │ ├── fused_softmax │ │ │ │ ├── fused_softmax.cpp │ │ │ │ ├── scaled_masked_softmax.h │ │ │ │ ├── scaled_masked_softmax_cuda.cu │ │ │ │ ├── scaled_upper_triang_masked_softmax.h │ │ │ │ ├── scaled_upper_triang_masked_softmax_cuda.cu │ │ │ │ ├── setup.py │ │ │ │ └── type_shim.h │ │ │ ├── layer_norm │ │ │ │ ├── README.md │ │ │ │ ├── ln.h │ │ │ │ ├── ln_api.cpp │ │ │ │ ├── ln_bwd_1024.cu │ │ │ │ ├── ln_bwd_1280.cu │ │ │ │ ├── ln_bwd_1536.cu │ │ │ │ ├── ln_bwd_2048.cu │ │ │ │ ├── ln_bwd_256.cu │ │ │ │ ├── ln_bwd_2560.cu │ │ │ │ ├── ln_bwd_3072.cu │ │ │ │ ├── ln_bwd_4096.cu │ │ │ │ ├── ln_bwd_512.cu │ │ │ │ ├── ln_bwd_5120.cu │ │ │ │ ├── ln_bwd_6144.cu │ │ │ │ ├── ln_bwd_768.cu │ │ │ │ ├── ln_bwd_kernels.cuh │ │ │ │ ├── ln_fwd_1024.cu │ │ │ │ ├── ln_fwd_1280.cu │ │ │ │ ├── ln_fwd_1536.cu │ │ │ │ ├── ln_fwd_2048.cu │ │ │ │ ├── ln_fwd_256.cu │ │ │ │ ├── ln_fwd_2560.cu │ │ │ │ ├── ln_fwd_3072.cu │ │ │ │ ├── ln_fwd_4096.cu │ │ │ │ ├── ln_fwd_512.cu │ │ │ │ ├── ln_fwd_5120.cu │ │ │ │ ├── ln_fwd_6144.cu │ │ │ │ ├── ln_fwd_768.cu │ │ │ │ ├── ln_fwd_kernels.cuh │ │ │ │ ├── ln_kernel_traits.h │ │ │ │ ├── ln_utils.cuh │ │ │ │ ├── setup.py │ │ │ │ └── static_switch.h │ │ │ ├── rotary │ │ │ │ ├── rotary.cpp │ │ │ │ ├── rotary_cuda.cu │ │ │ │ └── setup.py │ │ │ └── xentropy │ │ │ │ ├── README.md │ │ │ │ ├── interface.cpp │ │ │ │ ├── setup.py │ │ │ │ └── xentropy_kernel.cu │ │ ├── flash_attn │ │ │ ├── __init__.py │ │ │ ├── bert_padding.py │ │ │ ├── flash_attention.py │ │ │ ├── flash_attn_interface.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── flash_attn_triton_og.py │ │ │ ├── flash_blocksparse_attention.py │ │ │ ├── flash_blocksparse_attn_interface.py │ │ │ ├── fused_softmax.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── patch_embed.py │ │ │ │ └── rotary.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ └── cross_entropy.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── bert.py │ │ │ │ ├── gpt.py │ │ │ │ ├── opt.py │ │ │ │ └── vit.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── block.py │ │ │ │ ├── embedding.py │ │ │ │ ├── mha.py │ │ │ │ └── mlp.py │ │ │ ├── ops │ │ │ │ ├── __init__.py │ │ │ │ ├── fused_dense.py │ │ │ │ ├── gelu_activation.py │ │ │ │ ├── layer_norm.py │ │ │ │ ├── rms_norm.py │ │ │ │ └── triton │ │ │ │ │ ├── k_activations.py │ │ │ │ │ ├── linear.py │ │ │ │ │ └── mlp.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── benchmark.py │ │ │ │ ├── distributed.py │ │ │ │ ├── generation.py │ │ │ │ └── pretrained.py │ │ ├── setup.py │ │ ├── tests │ │ │ ├── losses │ │ │ │ ├── test_cross_entropy.py │ │ │ │ └── test_cross_entropy_parallel.py │ │ │ ├── models │ │ │ │ ├── test_bert.py │ │ │ │ ├── test_gpt.py │ │ │ │ ├── test_gpt_generation.py │ │ │ │ ├── test_gpt_generation_parallel.py │ │ │ │ ├── test_gpt_parallel.py │ │ │ │ ├── test_opt.py │ │ │ │ └── test_vit.py │ │ │ ├── modules │ │ │ │ ├── test_block_parallel.py │ │ │ │ ├── test_embedding_parallel.py │ │ │ │ └── test_mha_parallel.py │ │ │ ├── ops │ │ │ │ ├── test_dropout_layer_norm.py │ │ │ │ ├── test_fused_dense.py │ │ │ │ └── test_fused_dense_parallel.py │ │ │ ├── test_flash_attn.py │ │ │ └── test_rotary.py │ │ ├── training │ │ │ ├── Dockerfile │ │ │ ├── README.md │ │ │ ├── configs │ │ │ │ ├── callbacks │ │ │ │ │ ├── causality-monitor.yaml │ │ │ │ │ ├── default.yaml │ │ │ │ │ ├── ema.yaml │ │ │ │ │ ├── flop-count.yaml │ │ │ │ │ ├── gpu-monitor.yaml │ │ │ │ │ ├── model-summary.yaml │ │ │ │ │ ├── none.yaml │ │ │ │ │ ├── norm-monitor.yaml │ │ │ │ │ ├── params-log.yaml │ │ │ │ │ └── wandb.yaml │ │ │ │ ├── config.yaml │ │ │ │ ├── datamodule │ │ │ │ │ ├── openwebtext.yaml │ │ │ │ │ └── thepile.yaml │ │ │ │ ├── experiment │ │ │ │ │ ├── owt │ │ │ │ │ │ ├── base.yaml │ │ │ │ │ │ ├── gpt2l-flash.yaml │ │ │ │ │ │ ├── gpt2l-hf.yaml │ │ │ │ │ │ ├── gpt2l.yaml │ │ │ │ │ │ ├── gpt2m-flash.yaml │ │ │ │ │ │ ├── gpt2m-hf.yaml │ │ │ │ │ │ ├── gpt2m.yaml │ │ │ │ │ │ ├── gpt2s-flash.yaml │ │ │ │ │ │ ├── gpt2s-hf.yaml │ │ │ │ │ │ ├── gpt2s.yaml │ │ │ │ │ │ ├── gpt2xl-flash.yaml │ │ │ │ │ │ ├── gpt2xl-hf.yaml │ │ │ │ │ │ └── gpt2xl.yaml │ │ │ │ │ └── pile │ │ │ │ │ │ ├── base.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash-8k.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary-8k.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash-hdim128.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash-rotary-8k.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash-rotary.yaml │ │ │ │ │ │ ├── gpt3-2.7B-flash.yaml │ │ │ │ │ │ ├── gpt3-2.7B-hf-hdim128.yaml │ │ │ │ │ │ ├── gpt3-2.7B-hf.yaml │ │ │ │ │ │ ├── gpt3l-flash-8k.yaml │ │ │ │ │ │ ├── gpt3l-flash-rotary-30B.yaml │ │ │ │ │ │ ├── gpt3l-flash-rotary-8k.yaml │ │ │ │ │ │ ├── gpt3l-flash-rotary.yaml │ │ │ │ │ │ ├── gpt3l-flash.yaml │ │ │ │ │ │ ├── gpt3l-hf.yaml │ │ │ │ │ │ ├── gpt3m-flash-8k.yaml │ │ │ │ │ │ ├── gpt3m-flash-rotary-30B.yaml │ │ │ │ │ │ ├── gpt3m-flash-rotary-8k.yaml │ │ │ │ │ │ ├── gpt3m-flash-rotary.yaml │ │ │ │ │ │ ├── gpt3m-flash.yaml │ │ │ │ │ │ ├── gpt3m-hf.yaml │ │ │ │ │ │ ├── gpt3s-flash-8k.yaml │ │ │ │ │ │ ├── gpt3s-flash-rotary-30B.yaml │ │ │ │ │ │ ├── gpt3s-flash-rotary-8k.yaml │ │ │ │ │ │ ├── gpt3s-flash-rotary.yaml │ │ │ │ │ │ ├── gpt3s-flash.yaml │ │ │ │ │ │ ├── gpt3s-hf.yaml │ │ │ │ │ │ ├── gpt3xl-flash-8k.yaml │ │ │ │ │ │ ├── gpt3xl-flash-rotary-60B.yaml │ │ │ │ │ │ ├── gpt3xl-flash-rotary-8k.yaml │ │ │ │ │ │ ├── gpt3xl-flash-rotary.yaml │ │ │ │ │ │ ├── gpt3xl-flash.yaml │ │ │ │ │ │ └── gpt3xl-hf.yaml │ │ │ │ ├── logger │ │ │ │ │ ├── comet.yaml │ │ │ │ │ ├── csv.yaml │ │ │ │ │ ├── many_loggers.yaml │ │ │ │ │ ├── mlflow.yaml │ │ │ │ │ ├── neptune.yaml │ │ │ │ │ ├── tensorboard.yaml │ │ │ │ │ └── wandb.yaml │ │ │ │ ├── metrics │ │ │ │ │ ├── acc.yaml │ │ │ │ │ ├── acc_ignore_index.yaml │ │ │ │ │ ├── acctop5.yaml │ │ │ │ │ ├── mse.yaml │ │ │ │ │ ├── num-tokens.yaml │ │ │ │ │ └── perplexity.yaml │ │ │ │ ├── mode │ │ │ │ │ ├── debug.yaml │ │ │ │ │ ├── default.yaml │ │ │ │ │ ├── exp.yaml │ │ │ │ │ ├── profile.yaml │ │ │ │ │ └── smoke.yaml │ │ │ │ ├── model │ │ │ │ │ ├── gpt2-hf.yaml │ │ │ │ │ ├── gpt2.yaml │ │ │ │ │ └── gpt2model │ │ │ │ │ │ ├── gpt2-large.yaml │ │ │ │ │ │ ├── gpt2-medium.yaml │ │ │ │ │ │ ├── gpt2-small.yaml │ │ │ │ │ │ └── gpt2-xlarge.yaml │ │ │ │ ├── optimizer │ │ │ │ │ ├── adam.yaml │ │ │ │ │ ├── adamw-apex-distributed.yaml │ │ │ │ │ ├── adamw-apex-zero.yaml │ │ │ │ │ ├── adamw-apex.yaml │ │ │ │ │ ├── adamw-zero.yaml │ │ │ │ │ ├── adamw.yaml │ │ │ │ │ ├── fusedlamb-ds.yaml │ │ │ │ │ ├── fusedlamb.yaml │ │ │ │ │ └── sgd.yaml │ │ │ │ ├── scheduler │ │ │ │ │ ├── cosine-warmup-timm.yaml │ │ │ │ │ ├── cosine-warmup.yaml │ │ │ │ │ ├── invsqrt.yaml │ │ │ │ │ ├── linear-warmup.yaml │ │ │ │ │ ├── multi-step.yaml │ │ │ │ │ ├── plateau.yaml │ │ │ │ │ ├── poly-warmup.yaml │ │ │ │ │ └── step.yaml │ │ │ │ ├── task │ │ │ │ │ └── sequence-model.yaml │ │ │ │ └── trainer │ │ │ │ │ ├── all_params.yaml │ │ │ │ │ ├── ddp.yaml │ │ │ │ │ ├── debug.yaml │ │ │ │ │ └── default.yaml │ │ │ ├── run.py │ │ │ ├── src │ │ │ │ ├── callbacks │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── causality_monitor.py │ │ │ │ │ ├── ema.py │ │ │ │ │ ├── flop_count.py │ │ │ │ │ ├── gpu_affinity.py │ │ │ │ │ ├── loss_scale_monitor.py │ │ │ │ │ ├── model_checkpoint.py │ │ │ │ │ ├── norm_monitor.py │ │ │ │ │ ├── params_log.py │ │ │ │ │ ├── speed_monitor.py │ │ │ │ │ └── wandb_callbacks.py │ │ │ │ ├── datamodules │ │ │ │ │ ├── datasets │ │ │ │ │ │ ├── detokenizer.py │ │ │ │ │ │ └── lm_dataset.py │ │ │ │ │ ├── fault_tolerant_sampler.py │ │ │ │ │ ├── imagenet.py │ │ │ │ │ ├── language_modeling_hf.py │ │ │ │ │ └── timm_mixup.py │ │ │ │ ├── distributed │ │ │ │ │ └── ddp_comm_hooks.py │ │ │ │ ├── eval.py │ │ │ │ ├── metrics │ │ │ │ │ ├── accuracy.py │ │ │ │ │ ├── num_tokens.py │ │ │ │ │ └── perplexity.py │ │ │ │ ├── models │ │ │ │ │ └── modules │ │ │ │ │ │ └── seq_common.py │ │ │ │ ├── optim │ │ │ │ │ ├── param_grouping.py │ │ │ │ │ └── timm_lr_scheduler.py │ │ │ │ ├── tasks │ │ │ │ │ └── seq.py │ │ │ │ ├── train.py │ │ │ │ └── utils │ │ │ │ │ ├── checkpoint.py │ │ │ │ │ ├── ddp_zero1.py │ │ │ │ │ ├── ddp_zero2.py │ │ │ │ │ ├── distributed.py │ │ │ │ │ ├── ema.py │ │ │ │ │ ├── flops.py │ │ │ │ │ ├── gpu_affinity.py │ │ │ │ │ └── utils.py │ │ │ └── tests │ │ │ │ └── datamodules │ │ │ │ └── test_language_modeling_hf.py │ │ └── usage.md │ └── src │ │ ├── callbacks │ │ ├── norms.py │ │ ├── params.py │ │ ├── progressive_resizing.py │ │ ├── timer.py │ │ └── wandb.py │ │ ├── dataloaders │ │ ├── README.md │ │ ├── __init__.py │ │ ├── base.py │ │ ├── basic.py │ │ ├── datasets │ │ │ ├── detokenizer.py │ │ │ └── lm_dataset.py │ │ ├── et.py │ │ ├── fault_tolerant_sampler.py │ │ ├── language_modeling_hf.py │ │ ├── lm.py │ │ ├── lra.py │ │ ├── synthetics.py │ │ ├── utils │ │ │ ├── cifar_augmentations.py │ │ │ ├── timm_mixup.py │ │ │ └── vocabulary.py │ │ └── vision.py │ │ ├── models │ │ ├── baselines │ │ │ └── vit_all.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── adaptive_softmax.py │ │ │ ├── components.py │ │ │ ├── dxt.py │ │ │ ├── gate.py │ │ │ ├── residual.py │ │ │ └── utils.py │ │ └── sequence │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── block.py │ │ │ ├── block_fft.py │ │ │ ├── ff.py │ │ │ ├── h3.py │ │ │ ├── h3_conv.py │ │ │ ├── hyena-flashfft.py │ │ │ ├── hyena.py │ │ │ ├── long_conv.py │ │ │ ├── long_conv_kernel.py │ │ │ ├── long_conv_lm.py │ │ │ ├── mha.py │ │ │ ├── model.py │ │ │ ├── pool.py │ │ │ ├── simple_lm.py │ │ │ └── ssm │ │ │ ├── dplr.py │ │ │ ├── hippo.py │ │ │ ├── s4_simple.py │ │ │ ├── s4d.py │ │ │ ├── ss_kernel.py │ │ │ ├── ss_kernel_diag.py │ │ │ └── ss_kernel_shift.py │ │ ├── ops │ │ ├── krylov.py │ │ ├── toeplitz.py │ │ ├── unroll.py │ │ └── vandermonde.py │ │ ├── tasks │ │ ├── decoders.py │ │ ├── encoders.py │ │ ├── metrics.py │ │ ├── tasks.py │ │ └── torchmetrics.py │ │ └── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── distributed.py │ │ ├── optim │ │ ├── lamb.py │ │ └── schedulers.py │ │ ├── optim_groups.py │ │ ├── permutations.py │ │ ├── registry.py │ │ └── train.py └── long-convs │ ├── README.md │ └── flashfftconv_long_convs.py ├── flashfftconv ├── __init__.py ├── conv.py ├── depthwise_1d.py └── sparse_conv.py ├── flashfftconv_long_convs.py ├── rand.py ├── setup.py ├── standalone_cifar.py └── tests ├── test_conv1d.py └── test_flashfftconv.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | # input data, saved log, checkpoints 106 | data/ 107 | input/ 108 | saved/ 109 | datasets/ 110 | 111 | # editor, os cache directory 112 | .vscode/ 113 | .idea/ 114 | __MACOSX/ -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Dan Fu, danfu@cs.stanford.edu 2 | Hermann Kumbong, kumboh@stanford.edu -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | 11-21-23: Support fp32 weights with fp16 inputs for short depthwise convs 2 | 3 | 11-13-23: Initial release! -------------------------------------------------------------------------------- /assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/assets/banner.png -------------------------------------------------------------------------------- /assets/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/assets/benchmark.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/assets/logo.png -------------------------------------------------------------------------------- /csrc/flashfftconv/.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | *.json 3 | *.png 4 | 5 | */*.npy 6 | */*.json 7 | */*.png 8 | 9 | *.DS_Store 10 | */*.DS_Store -------------------------------------------------------------------------------- /csrc/flashfftconv/butterfly/shared.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 Dan Fu, Hermann Kumbong 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | using namespace nvcuda; 12 | 13 | using complex_half_t = typename c10::complex; 14 | using complex_bhalf_t = typename c10::complex; 15 | 16 | #define WMMA_M 16 17 | #define WMMA_N 16 18 | #define WMMA_K 16 19 | #define WARP_SIZE 32 20 | 21 | #ifndef MONARCH_CUDA_H_ 22 | #define MONARCH_CUDA_H_ 23 | 24 | __device__ __forceinline__ float2 25 | 26 | operator+( float2 lhs, float2 rhs) 27 | 28 | { 29 | 30 | float2 res = { lhs.x + rhs.x , lhs.y + rhs.y }; 31 | 32 | return res; 33 | 34 | } 35 | 36 | 37 | __device__ __forceinline__ float2 38 | 39 | operator-( float2 lhs, float2 rhs) 40 | 41 | { 42 | 43 | float2 res = { lhs.x - rhs.x , lhs.y - rhs.y }; 44 | 45 | return res; 46 | 47 | } 48 | 49 | __device__ __forceinline__ float2 50 | 51 | operator*( float2 lhs, float2 rhs) 52 | 53 | { 54 | 55 | float2 res = { lhs.x * rhs.x , lhs.y * rhs.y }; 56 | 57 | return res; 58 | 59 | } 60 | #endif -------------------------------------------------------------------------------- /csrc/flashfftconv/conv1d/conv1d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 Dan Fu, Hermann Kumbong 2 | 3 | #include 4 | 5 | #include 6 | 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_IS_HALF_OR_BFLOAT_OR_FLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16 || x.dtype() == torch::kFloat32, #x " must be float16 or bfloat16 or float32") 11 | #define CHECK_SAME_TYPE(x, y) TORCH_CHECK(x.dtype() == y.dtype(), #x " and " #y " must have the same dtype") 12 | 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x); \ 16 | CHECK_IS_HALF_OR_BFLOAT_OR_FLOAT(x) 17 | 18 | torch::Tensor conv1d_cuda_bhl( 19 | torch::Tensor u, 20 | torch::Tensor weight, 21 | torch::Tensor bias, 22 | uint padding); 23 | 24 | torch::Tensor conv1d_cuda_blh( 25 | torch::Tensor u, 26 | torch::Tensor weight, 27 | torch::Tensor bias, 28 | uint padding); 29 | 30 | std::vector conv1d_backward_bhl_cuda( 31 | torch::Tensor dout, 32 | torch::Tensor input, 33 | torch::Tensor weight, 34 | torch::Tensor bias, 35 | uint padding 36 | ); 37 | 38 | std::vector conv1d_backward_blh_cuda( 39 | torch::Tensor dout, 40 | torch::Tensor input, 41 | torch::Tensor weight, 42 | torch::Tensor bias, 43 | uint padding 44 | ); 45 | 46 | 47 | torch::Tensor conv1d_fwd( 48 | torch::Tensor u, 49 | torch::Tensor weight, 50 | torch::Tensor bias, 51 | uint padding, 52 | bool is_bhl) 53 | { 54 | CHECK_INPUT(u); 55 | CHECK_INPUT(weight); 56 | CHECK_INPUT(bias); 57 | CHECK_SAME_TYPE(weight, bias); 58 | 59 | int k; 60 | 61 | if(is_bhl){ 62 | k = weight.size(1); 63 | }else{ 64 | k = weight.size(0); 65 | } 66 | 67 | TORCH_CHECK(k % 2 == 1, "Filter size must be odd number"); 68 | 69 | if(is_bhl){ 70 | return conv1d_cuda_bhl(u, weight, bias, padding); 71 | }else{ 72 | return conv1d_cuda_blh(u, weight, bias, padding); 73 | } 74 | } 75 | 76 | std::vector conv1d_bwd( 77 | torch::Tensor dout, 78 | torch::Tensor input, 79 | torch::Tensor weight, 80 | torch::Tensor bias, 81 | uint padding, 82 | bool is_bhl) 83 | { 84 | CHECK_INPUT(dout); 85 | CHECK_INPUT(input); 86 | CHECK_INPUT(weight); 87 | CHECK_INPUT(bias); 88 | CHECK_SAME_TYPE(weight, bias); 89 | CHECK_SAME_TYPE(dout, input); 90 | 91 | if(is_bhl){ 92 | return conv1d_backward_bhl_cuda(dout, input, weight, bias, padding); 93 | } else{ 94 | return conv1d_backward_blh_cuda(dout, input, weight, bias, padding); 95 | } 96 | } -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # FlashFFTConv Example Models 2 | 3 | This folder contains example architectures using FlashFFTConv. 4 | 5 | We currently have examples for the following models: 6 | * [M2-BERT](bert) 7 | * [Hyena](hyena) 8 | * [HyenaDNA](hyena-dna) 9 | * [Long Convs](long-convs) 10 | 11 | Please check out the READMEs in each sub-folder to learn more about each model and get examples of how to use them. -------------------------------------------------------------------------------- /examples/bert/README.md: -------------------------------------------------------------------------------- 1 | # Monarch Mixer BERT 2 | 3 | This folder shows an example of adapting M2-BERT to use FlashFFTConv. 4 | The original files are sourced from the [M2-BERT](https://github.com/HazyResearch/m2/tree/main/bert/src) implementation. 5 | 6 | ## Requirements 7 | 8 | Install model-specific requirements: 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Usage 14 | 15 | We have sample configs for M2-BERT models of different sizes that you can benchmark: 16 | ``` 17 | python benchmark_fwd.py configs/m2-110M.yaml 18 | python benchmark_fwd.py configs/m2-110M-flashfftconv.yaml 19 | ``` 20 | 21 | ## Changes to Use FlashFFTConv in M2-BERT 22 | 23 | We describe the changes necessary to use FlashFFTConv in M2-BERT: 24 | 25 | Create an instance of `FlashFFTConv` in `BERTEncoder`. In [bert_layers.py](bert_layers.py), lines 294-301: 26 | ```Python 27 | seqlen = config.max_position_embeddings 28 | if config.use_flashfftconv: 29 | self.flashfftconv = FlashFFTConv(seqlen * 2, dtype=torch.float16) # 2x for padding, may need bfloat16 30 | self.layer = nn.ModuleList( 31 | [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 32 | if config.use_flashfftconv: 33 | for layer in self.layer: 34 | layer.attention.flashfftconv = self.flashfftconv # add it to the layers 35 | ``` 36 | 37 | Then, we adapt the actual sequence mixer to use `flashfftconv` in [monarch_mixer_sequence_mixer_flashfftconv.py](monarch_mixer_sequence_mixer_flashfftconv.py). 38 | 39 | We make a couple more optimizations: 40 | * We use our fast depthwise kernel. 41 | * We introduce an "inference mode" that simply loads the convolution kernel from weights, instead of recomputing it every time (which is especially expensive for short kernels). An alternative is to use a fast kernel to generate the convolution kernel, as in the [M2 repo](https://github.com/HazyResearch/m2/tree/main/csrc/flashmm). -------------------------------------------------------------------------------- /examples/bert/bert_padding.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py 2 | # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py 3 | 4 | """ 5 | Functions for padding and unpadding 6 | """ 7 | 8 | from typing import Tuple 9 | 10 | import torch 11 | 12 | class IndexPutFirstAxis(torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, values: torch.Tensor, indices: torch.Tensor, 16 | first_axis_dim) -> torch.Tensor: 17 | ctx.save_for_backward(indices) 18 | assert indices.ndim == 1 19 | assert values.ndim >= 2 20 | output = torch.zeros(first_axis_dim, 21 | *values.shape[1:], 22 | device=values.device, 23 | dtype=values.dtype) 24 | output[indices] = values 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, 29 | grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 30 | indices, = ctx.saved_tensors 31 | grad_values = grad_output[indices] 32 | return grad_values, None, None 33 | 34 | 35 | index_put_first_axis = IndexPutFirstAxis.apply -------------------------------------------------------------------------------- /examples/bert/configs/m2-110M-flashfftconv.yaml: -------------------------------------------------------------------------------- 1 | max_seq_len: 128 2 | batch_size: 128 3 | dtype: half 4 | 5 | model: 6 | name: bert 7 | pretrained_model_name: bert-base-uncased 8 | tokenizer_name: bert-base-uncased 9 | model_config: 10 | num_attention_heads: 12 11 | num_hidden_layers: 12 12 | attention_probs_dropout_prob: 0.0 13 | max_position_embeddings: 128 14 | 15 | monarch_mixer_sequence_mixing: True 16 | use_flashfftconv: True 17 | inference_mode: True 18 | long_conv_l_max: 128 19 | hyena_w: 10 20 | hyena_wd: 0.1 21 | hyena_emb_dim: 5 22 | hyena_filter_order: 128 23 | 24 | bidirectional: True 25 | residual_long_conv: False 26 | 27 | use_glu_mlp: False 28 | use_monarch_mlp: False 29 | monarch_mlp_nblocks: 4 30 | use_positional_encodings: True -------------------------------------------------------------------------------- /examples/bert/configs/m2-110M.yaml: -------------------------------------------------------------------------------- 1 | max_seq_len: 128 2 | batch_size: 128 3 | dtype: half 4 | 5 | model: 6 | name: bert 7 | pretrained_model_name: bert-base-uncased 8 | tokenizer_name: bert-base-uncased 9 | model_config: 10 | num_attention_heads: 12 11 | num_hidden_layers: 12 12 | attention_probs_dropout_prob: 0.0 13 | max_position_embeddings: 128 14 | 15 | monarch_mixer_sequence_mixing: True 16 | long_conv_l_max: 128 17 | hyena_w: 10 18 | hyena_wd: 0.1 19 | hyena_emb_dim: 5 20 | hyena_filter_order: 128 21 | 22 | bidirectional: False 23 | residual_long_conv: False 24 | 25 | use_glu_mlp: False 26 | use_monarch_mlp: False 27 | monarch_mlp_nblocks: 4 28 | use_positional_encodings: True -------------------------------------------------------------------------------- /examples/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | mosaicml[nlp,wandb]==0.14.1 2 | mosaicml-streaming==0.4.1 3 | hydra-core==1.2.0 -------------------------------------------------------------------------------- /examples/hyena-dna/README.md: -------------------------------------------------------------------------------- 1 | # Hyena-DNA 2 | 3 | This folder shows an example of adapting Hyena-DNA to use FlashFFTConv. 4 | 5 | ## Requirements 6 | 7 | This code downloads the config from HuggingFace, so you need to have git lfs installed. 8 | 9 | To check, run: 10 | ``` 11 | git lfs install 12 | ``` 13 | If it fails, you can install git lfs using your favorite package manager. 14 | 15 | ## Usage 16 | 17 | We have sample scripts to benchmark Hyena DNA using PyTorch, vs. using FlashFFTConv: 18 | ``` 19 | python benchmark_hyena_dna_fwd.py 20 | python benchmark_flash_dna_fwd.py 21 | ``` 22 | 23 | ## Changes to Use FlashFFTConv in Hyena 24 | 25 | We describe the changes necessary to use FlashFFTConv in HyenaDNA: 26 | 27 | Create an instance of `FlashFFTConv` in `LMBackbone`. In [hyenadna_flashfftconv.py](hyenadna_flashfftconv.py), lines 716-721: 28 | ```Python 29 | seqlen = layer['l_max'] 30 | seqlen = next_power_of_2(seqlen) * 2 31 | self.flashfftconv = FlashFFTConv(seqlen, dtype=torch.float16) # may need bfloat16 32 | 33 | for layer in self.layers: 34 | layer.mixer.flashfftconv = self.flashfftconv 35 | ``` 36 | 37 | Note that HyenaDNA does not use sequence lengths that are powers of two, so we need to find the next closest power of two (lines 688-689). 38 | 39 | Then, we adapt the Hyena layers to use the `flashfftconv` variable (lines 269-289). 40 | 41 | We make a couple more optimizations: 42 | * We use our fast depthwise kernel. 43 | * We introduce an "inference mode" that simply loads the convolution kernel from weights, instead of recomputing it every time. An alternative is to use a fast kernel to generate the convolution kernel, as in the [M2 repo](https://github.com/HazyResearch/m2/tree/main/csrc/flashmm). 44 | * In this benchmarking code, the weights have different names than in the PyTorch code, so the model will not load pretrained weights out of the box. We are working on a minimal example that can load the pretrained weights. -------------------------------------------------------------------------------- /examples/hyena-dna/benchmark_flash_dna_fwd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark import benchmark_forward, pytorch_profiler 3 | 4 | from huggingface import load_model 5 | import sys 6 | 7 | ''' 8 | model options: 9 | 'hyenadna-tiny-1k-seqlen' # fine-tune on colab ok 10 | 'hyenadna-tiny-1k-seqlen-d256' 11 | 'hyenadna-tiny-16k-seqlen-d128' 12 | 'hyenadna-small-32k-seqlen' 13 | 'hyenadna-medium-160k-seqlen' # inference only on colab 14 | 'hyenadna-medium-450k-seqlen' # inference only on colab 15 | 'hyenadna-large-1m-seqlen' # inference only on colab 16 | ''' 17 | 18 | model_name = 'hyenadna-large-1m-seqlen' 19 | B = 4 20 | repeats = 10 21 | use_flash = True 22 | 23 | model, tokenizer, max_length = load_model(model_name, use_flash=use_flash) 24 | 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | #### Single embedding example #### 28 | 29 | # create a sample 450k long, prepare 30 | sequence = 'ACTG' * int(max_length/4) 31 | tok_seq = tokenizer(sequence) 32 | tok_seq = tok_seq["input_ids"] # grab ids 33 | 34 | # place on device, convert to tensor 35 | tok_seq = torch.LongTensor(tok_seq).repeat(B, 1) # unsqueeze for batch dim 36 | tok_seq = tok_seq.to(device) 37 | 38 | # prep model and forward 39 | model.to(device) 40 | model = model.half() 41 | model.eval() 42 | 43 | def run_model(model, tok_seq): 44 | return model(tok_seq) 45 | 46 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True): 47 | with torch.no_grad(): 48 | run_model(model, tok_seq) 49 | 50 | torch.cuda.empty_cache() 51 | 52 | with torch.no_grad(): 53 | _, ret = benchmark_forward(run_model, model, tok_seq, repeats=repeats, verbose=True, amp_dtype=torch.float16, amp=True) 54 | 55 | time = ret._mean 56 | print('Time: ', time) 57 | print('Tokens/ms: ', (tok_seq.shape[0] * tok_seq.shape[1])/time/1000) 58 | print('Seqs/s: ', B/time) 59 | 60 | # pytorch_profiler(run_model, model, tok_seq, backward=False, cpu=True, trace_filename=f'dna_fwd_{model_name}_flash_{use_flash}.json') -------------------------------------------------------------------------------- /examples/hyena-dna/benchmark_hyena_dna_fwd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark import benchmark_forward, pytorch_profiler 3 | 4 | from huggingface import load_model 5 | import sys 6 | 7 | ''' 8 | model options: 9 | 'hyenadna-tiny-1k-seqlen' # fine-tune on colab ok 10 | 'hyenadna-tiny-1k-seqlen-d256' 11 | 'hyenadna-tiny-16k-seqlen-d128' 12 | 'hyenadna-small-32k-seqlen' 13 | 'hyenadna-medium-160k-seqlen' # inference only on colab 14 | 'hyenadna-medium-450k-seqlen' # inference only on colab 15 | 'hyenadna-large-1m-seqlen' # inference only on colab 16 | ''' 17 | 18 | model_name = 'hyenadna-large-1m-seqlen' 19 | B = 1 20 | repeats = 10 21 | use_flash = False 22 | 23 | model, tokenizer, max_length = load_model(model_name, use_flash=use_flash) 24 | 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | #### Single embedding example #### 28 | 29 | # create a sample 450k long, prepare 30 | sequence = 'ACTG' * int(max_length/4) 31 | tok_seq = tokenizer(sequence) 32 | if use_flash: 33 | tok_seq = tok_seq["input_ids"][:-2] # grab ids 34 | else: 35 | tok_seq = tok_seq["input_ids"] 36 | 37 | # place on device, convert to tensor 38 | tok_seq = torch.LongTensor(tok_seq).repeat(B, 1) # unsqueeze for batch dim 39 | tok_seq = tok_seq.to(device) 40 | 41 | # prep model and forward 42 | model.to(device) 43 | model.eval() 44 | 45 | def run_model(model, tok_seq): 46 | return model(tok_seq) 47 | 48 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True): 49 | with torch.no_grad(): 50 | run_model(model, tok_seq) 51 | 52 | torch.cuda.empty_cache() 53 | 54 | with torch.no_grad(): 55 | _, ret = benchmark_forward(run_model, model, tok_seq, repeats=repeats, verbose=True, amp_dtype=torch.float16, amp=True) 56 | 57 | time = ret._mean 58 | print('Time: ', time) 59 | print('Tokens/ms: ', (tok_seq.shape[0] * tok_seq.shape[1])/time/1000) 60 | print('Seqs/s: ', B/time) 61 | 62 | # pytorch_profiler(run_model, model, tok_seq, backward=False, cpu=True, trace_filename=f'dna_fwd_{model_name}_flash_{use_flash}.json') -------------------------------------------------------------------------------- /examples/hyena/README.md: -------------------------------------------------------------------------------- 1 | # Hyena 2 | 3 | This folder shows an example of adapting Hyena to use FlashFFTConv. 4 | The original files are sourced from [safari](https://github.com/HazyResearch/safari). 5 | 6 | ## Requirements 7 | 8 | Install model-specific requirements. See the [safari repo](https://github.com/HazyResearch/safari/tree/main) for instructions. 9 | 10 | This code depends on an old version of FlashAttention (0.2.8) for the MLP interface. 11 | 12 | ## Usage 13 | 14 | We have sample configs for Hyena models of different sizes that you can benchmark: 15 | ``` 16 | python benchmark_fwd.py experiment=pile/hyena.yaml 17 | python benchmark_fwd.py experiment=pile/hyena-flashfft.yaml 18 | ``` 19 | 20 | ## Changes to Use FlashFFTConv in Hyena 21 | 22 | We describe the changes necessary to use FlashFFTConv in Hyena: 23 | 24 | Create an instance of `FlashFFTConv` in `LMBackbone`. In [src/models/sequence/long_conv_lm.py](src/models/sequence/long_conv_lm.py), lines 193-197: 25 | ```Python 26 | if use_flashfftconv: 27 | self.flashfftconv = FlashFFTConv(layer['l_max'] * 2, dtype=torch.float16) 28 | 29 | for layer in self.layers: 30 | layer.mixer.flashfftconv = self.flashfftconv 31 | ``` 32 | 33 | Then, we adapt Hyena to use the `flashfftconv` variable in [src/models/sequence/hyena-flashfft.py](src/models/sequence/hyena-flashfft.py). 34 | 35 | We make a couple more optimizations: 36 | * We use our fast depthwise kernel. 37 | * We introduce an "inference mode" that simply loads the convolution kernel from weights, instead of recomputing it every time. An alternative is to use a fast kernel to generate the convolution kernel, as in the [M2 repo](https://github.com/HazyResearch/m2/tree/main/csrc/flashmm). -------------------------------------------------------------------------------- /examples/hyena/configs/callbacks/base.yaml: -------------------------------------------------------------------------------- 1 | learning_rate_monitor: 2 | # _target_: pytorch_lightning.callbacks.LearningRateMonitor 3 | logging_interval: ${train.interval} 4 | 5 | timer: 6 | # _target_: callbacks.timer.Timer 7 | step: True 8 | inter_step: False 9 | epoch: True 10 | val: True 11 | 12 | params: 13 | # _target_: callbacks.params.ParamsLog 14 | total: True 15 | trainable: True 16 | fixed: True 17 | -------------------------------------------------------------------------------- /examples/hyena/configs/dataset/thepile.yaml: -------------------------------------------------------------------------------- 1 | _name_: the_pile 2 | dataset_name: the_pile 3 | dataset_config_name: null 4 | tokenizer_name: gpt2 5 | cache_dir: /data/the_pile/cache 6 | max_length: 2048 7 | add_eos: True 8 | batch_size: 4 # per GPU 9 | batch_size_eval: ${eval:${.batch_size} * 2} 10 | num_workers: 64 # For preprocessing only 11 | use_shmem: False 12 | shuffle: True 13 | pin_memory: True 14 | __train_len: ${div_up:374337375694, ${.max_length}} 15 | __l_max: ${.max_length} -------------------------------------------------------------------------------- /examples/hyena/configs/experiment/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: mnist 4 | - /model: long-conv 5 | 6 | # This file is a bare bones config for an experiment for illustration, consisting of a pipeline and model backbone 7 | -------------------------------------------------------------------------------- /examples/hyena/configs/experiment/pile/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: thepile 4 | - override /scheduler: cosine_warmup_timm 5 | 6 | trainer: 7 | accelerator: gpu 8 | devices: 8 9 | num_nodes: 1 10 | accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} 11 | max_steps: 800000 12 | val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}} 13 | check_val_every_n_epoch: null # We don't care about epoch boundary 14 | precision: bf16 15 | gradient_clip_val: 1.0 16 | strategy: null 17 | 18 | dataset: 19 | batch_size: 16 20 | max_length: 2048 21 | 22 | scheduler: 23 | t_in_epochs: False 24 | t_initial: 600000 25 | warmup_lr_init: 1e-6 26 | warmup_t: ${eval:${trainer.max_steps} * 0.01} 27 | lr_min: ${eval:0.1 * ${optimizer.lr}} 28 | 29 | optimizer: 30 | lr: 6e-4 31 | weight_decay: 0.1 32 | 33 | train: 34 | gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} 35 | seed: 2222 36 | global_batch_size: 256 37 | 38 | eval: 39 | log_on_step: True # don't wait to the end of the epoch to log -------------------------------------------------------------------------------- /examples/hyena/configs/experiment/pile/hyena-flashfft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/base.yaml 4 | 5 | dataset: 6 | max_length: 4096 7 | 8 | model: 9 | _name_: lm 10 | d_model: 864 11 | n_layer: 18 12 | d_inner: ${eval:2*${.d_model}} 13 | vocab_size: 50257 14 | resid_dropout: 0.0 15 | embed_dropout: 0.1 16 | layer: 17 | _name_: hyena-flashfft 18 | emb_dim: 33 19 | filter_order: 64 20 | local_order: 3 21 | l_max: ${dataset.max_length} 22 | modulate: True 23 | w: 14 24 | lr: ${optimizer.lr} 25 | lr_pos_emb: ${optimizer.lr} 26 | fused_mlp: True 27 | residual_in_fp32: True 28 | pad_vocab_size_multiple: 8 29 | use_flashfftconv: True 30 | 31 | batch_size: 8 32 | dtype: half 33 | -------------------------------------------------------------------------------- /examples/hyena/configs/experiment/pile/hyena.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/base.yaml 4 | 5 | dataset: 6 | max_length: 4096 7 | 8 | model: 9 | _name_: lm 10 | d_model: 864 11 | n_layer: 18 12 | d_inner: ${eval:2*${.d_model}} 13 | vocab_size: 50257 14 | resid_dropout: 0.0 15 | embed_dropout: 0.1 16 | layer: 17 | _name_: hyena 18 | emb_dim: 33 19 | filter_order: 64 20 | local_order: 3 21 | l_max: ${dataset.max_length} 22 | modulate: True 23 | w: 14 24 | lr: ${optimizer.lr} 25 | lr_pos_emb: ${optimizer.lr} 26 | residual_in_fp32: True 27 | pad_vocab_size_multiple: 8 28 | 29 | batch_size: 8 30 | dtype: half 31 | -------------------------------------------------------------------------------- /examples/hyena/configs/loader/default.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 50 2 | num_workers: 4 3 | pin_memory: True 4 | drop_last: True # We set this to true because of the recurrent state mechanism -------------------------------------------------------------------------------- /examples/hyena/configs/model/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - layer: long-conv 3 | 4 | _name_: model 5 | prenorm: true 6 | transposed: false 7 | n_layers: 4 8 | d_model: 256 9 | residual: R 10 | pool: 11 | _name_: pool 12 | stride: 1 13 | expand: null 14 | norm: layer 15 | dropout: 0.0 16 | tie_dropout: false 17 | track_norms: true # Logs to wandb 18 | 19 | # Optional encoder/decoder, e.g. add positional embeddings or padding masks 20 | encoder: null 21 | decoder: null 22 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/ff.yaml: -------------------------------------------------------------------------------- 1 | _name_: ff 2 | expand: 4 3 | dropout: null 4 | transposed: False 5 | dropout: 0.0 6 | tie_dropout: ${model.tie_dropout,null} 7 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/h3-conv.yaml: -------------------------------------------------------------------------------- 1 | _name_: h3-conv 2 | head_dim: 1 3 | learning_rate: ${eval:"min(0.001, ${optimizer.lr})"} 4 | kernel_dropout: 0.2 5 | lam: 0.003 -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/h3.yaml: -------------------------------------------------------------------------------- 1 | _name_: h3 2 | d_state: 64 3 | head_dim: 1 4 | mode: diag 5 | measure: diag-lin 6 | lr: ${eval:"min(0.001, ${optimizer.lr})"} -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/hyena-filter.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena-filter 2 | emb_dim: 3 # dim of input to MLP, augments with positional encoding 3 | order: 16 # width of the implicit MLP 4 | fused_fft_conv: false 5 | # seq_len: ${dataset.__l_max} 6 | lr: 1e-3 7 | lr_pos_emb: 1e-5 8 | dropout: 0.0 9 | w: 1 # frequency of periodic activations 10 | wd: 0 # weight decay of kernel parameters 11 | bias: true 12 | normalized: False 13 | num_inner_mlps: 2 -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/hyena-flashfft.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena-flashfft 2 | l_max: 1024 3 | order: 2 4 | filter_order: 64 5 | num_heads: 1 6 | inner_factor: 1 7 | num_blocks: 1 8 | fused_bias_fc: false 9 | outer_mixing: false 10 | dropout: 0.0 11 | filter_dropout: 0.0 12 | filter_cls: 'hyena-filter' 13 | post_order_ffn: false 14 | jit_filter: false 15 | short_filter_order: 3 16 | activation: "id" 17 | inference_mode: False -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/hyena.yaml: -------------------------------------------------------------------------------- 1 | _name_: hyena 2 | l_max: 1024 3 | order: 2 4 | filter_order: 64 5 | num_heads: 1 6 | inner_factor: 1 7 | num_blocks: 1 8 | fused_bias_fc: false 9 | outer_mixing: false 10 | dropout: 0.0 11 | filter_dropout: 0.0 12 | filter_cls: 'hyena-filter' 13 | post_order_ffn: false 14 | jit_filter: false 15 | short_filter_order: 3 16 | activation: "id" -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/id.yaml: -------------------------------------------------------------------------------- 1 | _name_: id 2 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/long-conv.yaml: -------------------------------------------------------------------------------- 1 | _name_: long-conv 2 | channels: 2 3 | causal: True 4 | lam: 0.003 5 | kernel_dropout: 0.2 6 | bidirectional: false 7 | activation: gelu 8 | postact: glu 9 | initializer: null 10 | weight_norm: false 11 | # dropout: ${model.dropout} # Same as null 12 | tie_dropout: ${oc.select:model.tie_dropout,null} 13 | l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to 1024 14 | verbose: true 15 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/mha.yaml: -------------------------------------------------------------------------------- 1 | _name_: mha 2 | causal: true 3 | n_heads: 8 4 | dropout: null 5 | bias: True 6 | add_bias_kv: False 7 | add_zero_attn: False 8 | kdim: null 9 | vdim: null 10 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/s4_simple.yaml: -------------------------------------------------------------------------------- 1 | _name_: s4_simple 2 | d_state: 64 3 | channels: 1 4 | bidirectional: false 5 | activation: gelu 6 | postact: null 7 | initializer: null 8 | weight_norm: false 9 | dropout: ${..dropout} # Same as null 10 | dt_min: 0.001 11 | dt_max: 0.1 12 | lr: 0.001 13 | learn_a: true 14 | learn_theta: true 15 | theta_scale: false 16 | learn_dt: false 17 | zero_order_hold: false 18 | use_initial: false 19 | trap_rule: false 20 | linear: false 21 | skip_connection: true 22 | repr: cont 23 | param_norm: 'none' -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/s4d.yaml: -------------------------------------------------------------------------------- 1 | _name_: s4d 2 | d_state: 64 3 | lr: ${eval:"min(0.001, ${optimizer.lr})"} -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/transformer.yaml: -------------------------------------------------------------------------------- 1 | - _name_: mha 2 | causal: true 3 | n_heads: 8 4 | dropout: null 5 | bias: True 6 | add_bias_kv: False 7 | add_zero_attn: False 8 | kdim: null 9 | vdim: null 10 | - _name_: ff 11 | expand: 4 12 | activation: gelu 13 | dropout: ${...dropout} # Same as null 14 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/layer/vit.yaml: -------------------------------------------------------------------------------- 1 | _name_: vit 2 | num_heads: 8 3 | qkv_bias: False 4 | qk_scale: null 5 | attn_drop: 0.0 6 | packed_linear: true 7 | linear_cfg: null 8 | -------------------------------------------------------------------------------- /examples/hyena/configs/model/long-conv.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - override layer: long-conv 4 | 5 | transposed: false # Actually faster than "true" 6 | tie_dropout: true 7 | -------------------------------------------------------------------------------- /examples/hyena/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.Adam 2 | _name_: adam 3 | lr: 0.001 # Initial learning rate 4 | # weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /examples/hyena/configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.AdamW 2 | _name_: adamw 3 | lr: 0.001 # Initial learning rate 4 | weight_decay: 0.00 # Weight decay 5 | betas: [0.9, 0.999] 6 | -------------------------------------------------------------------------------- /examples/hyena/configs/optimizer/lamb.yaml: -------------------------------------------------------------------------------- 1 | # _target_: utils.lamb.JITLamb 2 | _name_: lamb 3 | lr: 0.01 # Initial learning rate 4 | weight_decay: 0.0 # Weight decay for adam|lamb 5 | -------------------------------------------------------------------------------- /examples/hyena/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # _target_: torch.optim.SGD 2 | _name_: sgd 3 | lr: 0.001 # Initial learning rate 4 | momentum: 0.9 5 | weight_decay: 0.0 # Weight decay for adam|lamb 6 | -------------------------------------------------------------------------------- /examples/hyena/configs/pipeline/thepile.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /trainer: default 4 | - /loader: default 5 | - /dataset: thepile 6 | - /optimizer: adamw 7 | - /scheduler: cosine_warmup 8 | - /callbacks: [base, checkpoint] 9 | 10 | train: 11 | monitor: val/loss 12 | mode: min 13 | 14 | task: 15 | _name_: lm 16 | loss: cross_entropy 17 | torchmetrics: ['perplexity', 'num_tokens'] 18 | 19 | encoder: null 20 | decoder: null 21 | -------------------------------------------------------------------------------- /examples/hyena/configs/scheduler/cosine_warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup 7 | num_warmup_steps: 1000 8 | num_training_steps: 40000 9 | -------------------------------------------------------------------------------- /examples/hyena/configs/scheduler/cosine_warmup_timm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | interval: step 4 | scheduler: 5 | # _target_: transformers.get_cosine_schedule_with_warmup 6 | _name_: cosine_warmup_timm 7 | t_in_epochs: False 8 | t_initial: 300 9 | lr_min: 1e-5 10 | warmup_lr_init: 1e-6 11 | warmup_t: 10 12 | -------------------------------------------------------------------------------- /examples/hyena/configs/task/lm.yaml: -------------------------------------------------------------------------------- 1 | _name_: lm 2 | # loss: cross_entropy # Handled by task: cross entropy loss 3 | metrics: ppl 4 | -------------------------------------------------------------------------------- /examples/hyena/configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | gpus: 1 5 | min_epochs: 1 6 | max_epochs: 10 7 | 8 | # prints 9 | progress_bar_refresh_rate: null 10 | weights_summary: full 11 | profiler: null 12 | 13 | # debugs 14 | fast_dev_run: False 15 | num_sanity_val_steps: 2 16 | overfit_batches: 0 17 | limit_train_batches: 0.1 18 | limit_val_batches: 0.1 19 | limit_test_batches: 0.1 20 | track_grad_norm: -1 21 | terminate_on_nan: False 22 | -------------------------------------------------------------------------------- /examples/hyena/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | devices: 1 4 | accelerator: gpu 5 | accumulate_grad_batches: 1 # Gradient accumulation every n batches 6 | max_epochs: 200 7 | # accelerator: ddp # Automatically set if gpus > 1 8 | gradient_clip_val: 0.0 9 | log_every_n_steps: 10 10 | limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run 11 | limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run 12 | -------------------------------------------------------------------------------- /examples/hyena/configs/trainer/full.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_summary: "top" 33 | weights_save_path: null 34 | num_sanity_val_steps: 2 35 | truncated_bptt_steps: null 36 | resume_from_checkpoint: null 37 | profiler: null 38 | benchmark: False 39 | deterministic: False 40 | reload_dataloaders_every_epoch: False 41 | auto_lr_find: False 42 | replace_sampler_ddp: True 43 | terminate_on_nan: False 44 | auto_scale_batch_size: False 45 | prepare_data_per_node: True 46 | plugins: null 47 | amp_backend: "native" 48 | amp_level: "O2" 49 | move_metrics_to_cpu: False 50 | -------------------------------------------------------------------------------- /examples/hyena/configs/trainer/lm.yaml: -------------------------------------------------------------------------------- 1 | accumulate_grad_batches: 1 2 | # accelerator: null # set to 'ddp' for distributed 3 | # amp_backend: native # 'native' | 'apex' 4 | gpus: 8 5 | max_epochs: 50 6 | gradient_clip_val: 0.0 # Gradient clipping 7 | log_every_n_steps: 10 8 | precision: 16 9 | progress_bar_refresh_rate: 1 10 | weights_summary: top # Set to 'full' to see every layer 11 | track_grad_norm: -1 # Set to 2 to track norms of gradients 12 | limit_train_batches: 1.0 13 | limit_val_batches: 1.0 14 | # We use the dataloader from Transformer-XL to ensure adjacent minibatches 15 | # are from text that are next to each other. 16 | # So that dataloader has to deal with DDP, and we don't want PL to handle 17 | # that. 18 | replace_sampler_ddp: False 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "csrc/flash_attn/cutlass"] 2 | path = csrc/flash_attn/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, trid@stanford.edu 2 | Dan Fu, danfu@cs.stanford.edu -------------------------------------------------------------------------------- /examples/hyena/flash-attention/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include csrc *.cu 2 | recursive-include csrc *.h 3 | recursive-include csrc *.cuh 4 | recursive-include csrc *.cpp 5 | 6 | recursive-include flash_attn *.cu 7 | recursive-include flash_attn *.h 8 | recursive-include flash_attn *.cuh 9 | recursive-include flash_attn *.cpp 10 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/Makefile: -------------------------------------------------------------------------------- 1 | 2 | clean_dist: 3 | rm -rf dist/* 4 | 5 | create_dist: clean_dist 6 | python setup.py sdist 7 | 8 | upload_package: create_dist 9 | twine upload dist/* 10 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_banner.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_banner.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_banner.pdf -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_memory.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_memory.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_speedup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_speedup.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_speedup_3090.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_speedup_3090.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_speedup_a100_d128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_speedup_a100_d128.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_speedup_t4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_speedup_t4.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/flashattn_speedup_t4_fwd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/flashattn_speedup_t4_fwd.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/gpt2_training_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/gpt2_training_curve.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/gpt2_training_efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/gpt2_training_efficiency.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/gpt3_training_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/gpt3_training_curve.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/assets/gpt3_training_efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/assets/gpt3_training_efficiency.jpg -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_blockmask.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | namespace fmha { 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | struct Blockmask { 42 | 43 | template 44 | __device__ Blockmask(const Params ¶ms, int loop_step_idx) : 45 | blockmask_ptr(params.blockmask + loop_step_idx * params.seqlen_q / 16) { 46 | } 47 | 48 | __device__ int mask_val(int block_row_idx) const { 49 | return blockmask_ptr[block_row_idx]; 50 | } 51 | 52 | const int *blockmask_ptr; 53 | }; 54 | 55 | //////////////////////////////////////////////////////////////////////////////////////////////////// 56 | 57 | } // namespace fmha 58 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_bwd_hdim128.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_bwd_launch_template.h" 6 | 7 | void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { 8 | FP16_SWITCH(params.is_bf16, ([&] { 9 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; 10 | run_fmha_bwd_loop(params, stream, configure); 11 | })); 12 | } -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_bwd_hdim32.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_bwd_launch_template.h" 6 | 7 | void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { 8 | FP16_SWITCH(params.is_bf16, ([&] { 9 | if (params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; 11 | run_fmha_bwd_loop(params, stream, configure); 12 | } else if (params.seqlen_k >= 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; 14 | run_fmha_bwd_loop(params, stream, configure); 15 | } 16 | })); 17 | } -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_bwd_hdim64.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_bwd_launch_template.h" 6 | 7 | void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { 8 | FP16_SWITCH(params.is_bf16, ([&] { 9 | auto dprops = at::cuda::getCurrentDeviceProperties(); 10 | if (params.seqlen_k == 128) { 11 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; 12 | run_fmha_bwd_loop(params, stream, configure); 13 | } else if (params.seqlen_k >= 256) { 14 | if (dprops->major == 8 && dprops->minor == 0) { 15 | // Don't share smem for K & V, and don't keep V in registers 16 | // This speeds things up by 2-3% by avoiding register spills, but it 17 | // uses more shared memory, which is fine on A100 but not other GPUs. 18 | // For other GPUs, we keep V in registers. 19 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; 20 | run_fmha_bwd_loop(params, stream, configure); 21 | } else if (dprops->major == 8 && dprops->minor > 0) { 22 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; 23 | run_fmha_bwd_loop(params, stream, configure); 24 | } else if (dprops->major == 7 && dprops->minor == 5) { 25 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; 26 | run_fmha_bwd_loop(params, stream, configure); 27 | } 28 | } 29 | })); 30 | } -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_fwd_hdim128.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_fwd_launch_template.h" 6 | 7 | void run_fmha_fwd_hdim128(Launch_params &launch_params) { 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; 10 | run_fmha_fwd_loop(launch_params); 11 | })); 12 | } -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_fwd_hdim32.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_fwd_launch_template.h" 6 | 7 | void run_fmha_fwd_hdim32(Launch_params &launch_params) { 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | if (launch_params.params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; 11 | run_fmha_fwd_loop(launch_params); 12 | } else if (launch_params.params.seqlen_k >= 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; 14 | run_fmha_fwd_loop(launch_params); 15 | } 16 | })); 17 | } -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/fmha_fwd_hdim64.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "fmha_fwd_launch_template.h" 6 | 7 | void run_fmha_fwd_hdim64(Launch_params &launch_params) { 8 | FP16_SWITCH(launch_params.params.is_bf16, ([&] { 9 | if (launch_params.params.seqlen_k == 128) { 10 | using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; 11 | run_fmha_fwd_loop(launch_params); 12 | } else if (launch_params.params.seqlen_k >= 256) { 13 | using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; 14 | run_fmha_fwd_loop(launch_params); 15 | } 16 | })); 17 | } 18 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/flash_attn/src/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | // and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8 4 | 5 | #pragma once 6 | 7 | /// @param COND - a boolean expression to switch by 8 | /// @param CONST_NAME - a name given for the constexpr bool variable. 9 | /// @param ... - code to execute for true and false 10 | /// 11 | /// Usage: 12 | /// ``` 13 | /// BOOL_SWITCH(flag, BoolConst, ([&] { 14 | /// some_function(...); 15 | /// })); 16 | /// ``` 17 | /// We need "({" and "})" to make sure that the code is a single argument being passed to the macro. 18 | #define BOOL_SWITCH(COND, CONST_NAME, F) \ 19 | { \ 20 | if (COND) { \ 21 | constexpr bool CONST_NAME = true; \ 22 | F(); \ 23 | } else { \ 24 | constexpr bool CONST_NAME = false; \ 25 | F(); \ 26 | } \ 27 | } 28 | 29 | // modified from BOOL_SWITCH 30 | // because MSVC cannot handle std::conditional with constexpr variable 31 | #define FP16_SWITCH(COND, F) \ 32 | { \ 33 | if (COND) { \ 34 | using elem_type = __nv_bfloat16; \ 35 | F(); \ 36 | } else { \ 37 | using elem_type = __half; \ 38 | F(); \ 39 | } \ 40 | } 41 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/ft_attention/README.md: -------------------------------------------------------------------------------- 1 | # Attention kernel from FasterTransformer 2 | 3 | This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from 4 | FasterTransformer v5.2.1 for benchmarking purpose. 5 | 6 | ```sh 7 | cd csrc/ft_attention && pip install . 8 | ``` 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/ft_attention/cuda_bf16_wrapper.h: -------------------------------------------------------------------------------- 1 | // Downloaded from from FasterTransformer v5.2.1 2 | // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h 3 | /* 4 | * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #pragma once 20 | 21 | #ifdef ENABLE_BF16 22 | #include 23 | #endif 24 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/fused_dense_lib/README.md: -------------------------------------------------------------------------------- 1 | This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu 2 | (forward and backward), adapted from Apex's 3 | [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). 4 | We make it work for bfloat16. 5 | 6 | For best performance, you should use CUDA >= 11.8. CuBLAS versions before 7 | this doesn't have the best matmul + bias + gelu performance for bfloat16. 8 | 9 | It has only been tested on A100s. 10 | 11 | ```sh 12 | cd csrc/fused_dense_lib && pip install . 13 | ``` 14 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/fused_dense_lib/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import torch 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 7 | 8 | 9 | def get_cuda_bare_metal_version(cuda_dir): 10 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 11 | output = raw_output.split() 12 | release_idx = output.index("release") + 1 13 | release = output[release_idx].split(".") 14 | bare_metal_major = release[0] 15 | bare_metal_minor = release[1][0] 16 | 17 | return raw_output, bare_metal_major, bare_metal_minor 18 | 19 | 20 | def append_nvcc_threads(nvcc_extra_args): 21 | _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) 22 | if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: 23 | return nvcc_extra_args + ["--threads", "4"] 24 | return nvcc_extra_args 25 | 26 | 27 | setup( 28 | name='fused_dense_lib', 29 | ext_modules=[ 30 | CUDAExtension( 31 | name='fused_dense_lib', 32 | sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], 33 | extra_compile_args={ 34 | 'cxx': ['-O3',], 35 | 'nvcc': append_nvcc_threads(['-O3']) 36 | } 37 | ) 38 | ], 39 | cmdclass={ 40 | 'build_ext': BuildExtension 41 | }) 42 | 43 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/fused_softmax/setup.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron 2 | # We add the case where seqlen = 4k and seqlen = 8k 3 | import os 4 | import subprocess 5 | 6 | import torch 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 9 | 10 | 11 | def get_cuda_bare_metal_version(cuda_dir): 12 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 13 | output = raw_output.split() 14 | release_idx = output.index("release") + 1 15 | release = output[release_idx].split(".") 16 | bare_metal_major = release[0] 17 | bare_metal_minor = release[1][0] 18 | 19 | return raw_output, bare_metal_major, bare_metal_minor 20 | 21 | 22 | def append_nvcc_threads(nvcc_extra_args): 23 | _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) 24 | if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: 25 | return nvcc_extra_args + ["--threads", "4"] 26 | return nvcc_extra_args 27 | 28 | 29 | cc_flag = [] 30 | cc_flag.append("-gencode") 31 | cc_flag.append("arch=compute_70,code=sm_70") 32 | cc_flag.append("-gencode") 33 | cc_flag.append("arch=compute_80,code=sm_80") 34 | 35 | setup( 36 | name='fused_softmax_lib', 37 | ext_modules=[ 38 | CUDAExtension( 39 | name='fused_softmax_lib', 40 | sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], 41 | extra_compile_args={ 42 | 'cxx': ['-O3',], 43 | 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) 44 | } 45 | ) 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension 49 | }) 50 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/fused_softmax/type_shim.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 4 | switch(TYPE) \ 5 | { \ 6 | case at::ScalarType::Half: \ 7 | { \ 8 | using scalar_t = at::Half; \ 9 | __VA_ARGS__; \ 10 | break; \ 11 | } \ 12 | case at::ScalarType::BFloat16: \ 13 | { \ 14 | using scalar_t = at::BFloat16; \ 15 | __VA_ARGS__; \ 16 | break; \ 17 | } \ 18 | default: \ 19 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 20 | } 21 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/README.md: -------------------------------------------------------------------------------- 1 | This CUDA extension implements fused dropout + residual + LayerNorm, building on 2 | Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). 3 | We add dropout and residual, and make it work for both pre-norm and post-norm architecture. 4 | We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144). 5 | We also implement RMSNorm as an option. 6 | 7 | If you want to use it for dimensions larger than 6k, please file an issue. 8 | 9 | This extension has only been tested on A100s. 10 | 11 | ```sh 12 | cd csrc/layer_norm && pip install . 13 | ``` 14 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_1024.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_1280.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_1536.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 9 | REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 10 | REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 11 | REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 12 | REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 13 | REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 14 | REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 15 | REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_2048.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_256.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_2560.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 9 | REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); 10 | REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 11 | REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 12 | REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); 13 | REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 14 | REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); 15 | REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_3072.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_4096.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_512.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_5120.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_6144.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_bwd_768.cu: -------------------------------------------------------------------------------- 1 | #include "ln_bwd_kernels.cuh" 2 | 3 | // Create backward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL 5 | 6 | REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 7 | REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); 8 | REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 9 | REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); 10 | REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 11 | REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 12 | REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); 13 | REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 14 | REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); 15 | REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_1024.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_1280.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_1536.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_2048.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_256.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_2560.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_3072.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_4096.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_512.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_5120.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 7 | REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); 8 | REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 9 | REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); 10 | REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 11 | REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 12 | REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); 13 | REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 14 | REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); 15 | REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_6144.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 7 | REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); 8 | REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 9 | REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); 10 | REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 11 | REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 12 | REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); 13 | REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 14 | REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); 15 | REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/ln_fwd_768.cu: -------------------------------------------------------------------------------- 1 | #include "ln_fwd_kernels.cuh" 2 | 3 | // Create forward launch function and register. Macro signature: 4 | // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 5 | 6 | REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 7 | REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); 8 | REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 9 | REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); 10 | REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 11 | REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 12 | REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); 13 | REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 14 | REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); 15 | REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/layer_norm/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/rotary/rotary.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") 5 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 6 | 7 | void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, 8 | const torch::Tensor cos, const torch::Tensor sin, 9 | torch::Tensor out1, torch::Tensor out2, 10 | const bool conj); 11 | 12 | void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, 13 | const torch::Tensor cos, const torch::Tensor sin, 14 | torch::Tensor out1, torch::Tensor out2, 15 | const bool conj) { 16 | CHECK_DEVICE(x1); CHECK_DEVICE(x2); 17 | CHECK_DEVICE(cos); CHECK_DEVICE(sin); 18 | CHECK_DEVICE(out1); CHECK_DEVICE(out1); 19 | TORCH_CHECK(x1.dtype() == x2.dtype()); 20 | TORCH_CHECK(cos.dtype() == sin.dtype()); 21 | TORCH_CHECK(out1.dtype() == out2.dtype()); 22 | TORCH_CHECK(x1.dtype() == cos.dtype()); 23 | TORCH_CHECK(x1.dtype() == out1.dtype()); 24 | TORCH_CHECK(x1.sizes() == x2.sizes()); 25 | TORCH_CHECK(cos.sizes() == sin.sizes()); 26 | TORCH_CHECK(out1.sizes() == out2.sizes()); 27 | 28 | // Otherwise the kernel will be launched from cuda:0 device 29 | // Cast to char to avoid compiler warning about narrowing 30 | at::cuda::CUDAGuard device_guard{(char)x1.get_device()}; 31 | 32 | apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); 33 | } 34 | 35 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 36 | m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); 37 | } 38 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/rotary/rotary_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, 6 | const torch::Tensor cos, const torch::Tensor sin, 7 | torch::Tensor out1, torch::Tensor out2, 8 | const bool conj) { 9 | auto iter = at::TensorIteratorConfig() 10 | .add_output(out1) 11 | .add_output(out2) 12 | .add_input(x1) 13 | .add_input(x2) 14 | .add_input(cos) 15 | .add_input(sin) 16 | .check_all_same_dtype(false) 17 | .promote_inputs_to_common_dtype(false) 18 | .build(); 19 | 20 | if (!conj) { 21 | AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { 22 | at::native::gpu_kernel_multiple_outputs( 23 | iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, 24 | scalar_t sin) -> thrust::tuple { 25 | scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); 26 | scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); 27 | return {out1, out2}; 28 | }); 29 | }); 30 | } else { 31 | AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { 32 | at::native::gpu_kernel_multiple_outputs( 33 | iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, 34 | scalar_t sin) -> thrust::tuple { 35 | scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); 36 | scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); 37 | return {out1, out2}; 38 | }); 39 | }); 40 | } 41 | } -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/xentropy/README.md: -------------------------------------------------------------------------------- 1 | This CUDA extension implements optimized cross-entropy loss, adapted from Apex's 2 | [Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). 3 | We make it work for bfloat16 and support in-place backward to save memory. 4 | 5 | It has only been tested on A100s. 6 | 7 | ```sh 8 | cd csrc/xentropy && pip install . 9 | ``` 10 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/csrc/xentropy/interface.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | std::vector softmax_xentropy_cuda( 5 | const at::Tensor &input, 6 | const at::Tensor &labels, 7 | const float smoothing, 8 | const int total_classes); 9 | 10 | at::Tensor softmax_xentropy_backward_cuda( 11 | const at::Tensor &grad_loss, 12 | at::Tensor &logits, 13 | const at::Tensor &max_log_sum_exp, 14 | const at::Tensor &labels, 15 | const float smoothing, 16 | const bool inplace, 17 | const int total_classes); 18 | 19 | // C++ interface 20 | 21 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 24 | 25 | std::vector softmax_xentropy_forward( 26 | const at::Tensor &input, 27 | const at::Tensor &labels, 28 | const float smoothing, 29 | const int total_classes=-1) { 30 | // For tensor parallel cross entropy with smoothing, we want to pass in the total number 31 | // of classes so that smoothing can be applied correctly. If total_classes=-1, use the 32 | // last dimension of the input tensor. 33 | CHECK_INPUT(input); 34 | CHECK_INPUT(labels); 35 | 36 | return softmax_xentropy_cuda(input, labels, smoothing, total_classes); 37 | } 38 | 39 | at::Tensor softmax_xentropy_backward( 40 | const at::Tensor &grad_loss, 41 | at::Tensor &logits, 42 | const at::Tensor &max_log_sum_exp, 43 | const at::Tensor &labels, 44 | const float smoothing, 45 | const bool inplace, 46 | const int total_classes=-1) { 47 | CHECK_INPUT(grad_loss); 48 | CHECK_INPUT(logits); 49 | CHECK_INPUT(max_log_sum_exp); 50 | CHECK_INPUT(labels); 51 | 52 | return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, 53 | smoothing, inplace, total_classes); 54 | } 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); 58 | m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); 59 | } 60 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/layers/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py 2 | # But we use nn.Linear instead of Conv2d and it's about 8x faster. 3 | 4 | from functools import partial 5 | 6 | import torch.nn as nn 7 | from torch import _assert 8 | from torch.nn.modules.utils import _pair 9 | 10 | from einops import rearrange 11 | 12 | try: 13 | from flash_attn.ops.fused_dense import FusedDense 14 | except ImportError: 15 | FusedDense = None 16 | 17 | 18 | class PatchEmbed(nn.Module): 19 | """ 2D Image to Patch Embedding 20 | """ 21 | def __init__( 22 | self, 23 | img_size=224, 24 | patch_size=16, 25 | in_chans=3, 26 | embed_dim=768, 27 | norm_layer=None, 28 | flatten=True, 29 | bias=True, 30 | fused_bias_fc=False, 31 | ): 32 | super().__init__() 33 | img_size = _pair(img_size) 34 | patch_size = _pair(patch_size) 35 | self.img_size = img_size 36 | self.patch_size = patch_size 37 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 38 | self.num_patches = self.grid_size[0] * self.grid_size[1] 39 | self.flatten = flatten 40 | if fused_bias_fc and FusedDense is None: 41 | raise ImportError('fused_dense is not installed') 42 | 43 | linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense 44 | self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) 45 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 46 | 47 | def forward(self, x): 48 | _, _, H, W = x.shape 49 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 50 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 51 | x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)', 52 | p1=self.patch_size[0], p2=self.patch_size[1])) 53 | if self.flatten: 54 | x = rearrange(x, 'b h w c -> b (h w) c') 55 | x = self.norm(x) 56 | return x 57 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/losses/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/models/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/modules/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Tri Dao. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | try: 8 | from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP 9 | except ImportError: 10 | FusedMLP, ParallelFusedMLP = None, None 11 | 12 | 13 | class Mlp(nn.Module): 14 | 15 | def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, 16 | return_residual=False, device=None, dtype=None): 17 | factory_kwargs = {'device': device, 'dtype': dtype} 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.return_residual = return_residual 22 | self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) 23 | self.activation = activation 24 | self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) 25 | 26 | def forward(self, x): 27 | y = self.fc1(x) 28 | y = self.activation(y) 29 | y = self.fc2(y) 30 | return y if not self.return_residual else (y, x) 31 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/ops/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/flash_attn/utils/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/flash_attn/utils/pretrained.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME 4 | from transformers.utils import is_remote_url 5 | from transformers.modeling_utils import load_state_dict 6 | from transformers.utils.hub import cached_file, get_checkpoint_shard_files 7 | 8 | 9 | def state_dict_from_pretrained(model_name, device=None, dtype=None): 10 | # If not fp32, then we don't want to load directly to the GPU 11 | mapped_device = 'cpu' if dtype not in [torch.float32, None] else device 12 | is_sharded = False 13 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, 14 | _raise_exceptions_for_missing_entries=False) 15 | if resolved_archive_file is None: 16 | resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, 17 | _raise_exceptions_for_missing_entries=False) 18 | if resolved_archive_file is not None: 19 | is_sharded = True 20 | if resolved_archive_file is None: 21 | raise EnvironmentError(f"Model name {model_name} was not found.") 22 | if is_sharded: 23 | # resolved_archive_file becomes a list of files that point to the different 24 | # checkpoint shards in this case. 25 | resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( 26 | model_name, resolved_archive_file 27 | ) 28 | state_dict = {} 29 | for sharded_file in resolved_archive_file: 30 | state_dict.update(torch.load(sharded_file, map_location=mapped_device)) 31 | else: 32 | state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) 33 | # Convert dtype before moving to GPU to save memory 34 | if dtype is not None: 35 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 36 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 37 | return state_dict 38 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/tests/losses/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import pytest 6 | 7 | from einops import rearrange 8 | 9 | from flash_attn.losses.cross_entropy import CrossEntropyLossApex 10 | 11 | is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 12 | 13 | 14 | @pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) 15 | # @pytest.mark.parametrize('dtype', [torch.float16]) 16 | @pytest.mark.parametrize('inplace_backward', [False, True]) 17 | # @pytest.mark.parametrize('inplace_backward', [False]) 18 | @pytest.mark.parametrize('smoothing', [0.0, 0.9]) 19 | @pytest.mark.parametrize('vocab_size', [50257]) 20 | def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype): 21 | device = 'cuda' 22 | rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) 23 | # set seed 24 | torch.random.manual_seed(0) 25 | batch_size = 8 26 | seqlen = 128 27 | x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True) 28 | x = x_pt.detach().clone().requires_grad_() 29 | y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) 30 | y[torch.randperm(batch_size * seqlen)[:10]] = -100 31 | model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) 32 | model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward) 33 | out = model(x, y) 34 | out_pt = model_pt(x_pt.float(), y) 35 | assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) 36 | 37 | g = torch.randn_like(out) 38 | out_pt.backward(g) 39 | out.backward(g) 40 | assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) 41 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/tests/models/test_vit.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | import pytest 5 | 6 | from timm.models.vision_transformer import vit_base_patch16_224 7 | 8 | from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 9 | 10 | 11 | @pytest.mark.parametrize('fused_mlp', [False, True]) 12 | # @pytest.mark.parametrize('fused_mlp', [False]) 13 | @pytest.mark.parametrize('optimized', [False, True]) 14 | # @pytest.mark.parametrize('optimized', [True]) 15 | def test_vit(optimized, fused_mlp): 16 | """Check that our implementation of ViT matches the timm's implementation: 17 | the output of our forward pass in fp16 should be around the same as 18 | timm' forward pass in fp16, when compared to timm's forward pass in fp32. 19 | """ 20 | dtype = torch.float16 21 | device = 'cuda' 22 | 23 | kwargs = {} 24 | if optimized: 25 | kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) 26 | kwargs['fused_mlp'] = fused_mlp 27 | model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) 28 | 29 | model_ref = vit_base_patch16_224(pretrained=True).to(device=device) 30 | model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype) 31 | 32 | model.load_state_dict(model_ref.state_dict()) 33 | 34 | model.eval() 35 | model_ref.eval() 36 | model_timm.eval() 37 | 38 | torch.manual_seed(0) 39 | batch_size = 2 40 | x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype) 41 | out = model(x) 42 | out_timm = model_timm(x) 43 | out_ref = model_ref(x.float()) 44 | 45 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') 46 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') 47 | print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') 48 | print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') 49 | rtol = 2 if not fused_mlp else 4 50 | assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() 51 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/tests/test_rotary.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import pytest 6 | 7 | from einops import rearrange 8 | 9 | from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch 10 | 11 | 12 | is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0) 13 | 14 | @pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])) 15 | # @pytest.mark.parametrize('dtype', ([torch.float16])) 16 | @pytest.mark.parametrize('rotary_fraction', [1.0, 0.5]) 17 | # @pytest.mark.parametrize('rotary_fraction', [0.5]) 18 | @pytest.mark.parametrize('inplace', [False, True]) 19 | # @pytest.mark.parametrize('inplace', [False]) 20 | def test_rotary_single_tensor(inplace, rotary_fraction, dtype): 21 | rtol = 1e-3 22 | batch_size = 32 23 | nheads = 4 24 | seqlen = 217 25 | headdim = 128 26 | x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda', 27 | requires_grad=True) 28 | x_pt = x.detach().clone().requires_grad_() 29 | rotary_dim = int(rotary_fraction * headdim) 30 | assert rotary_dim % 2 == 0 31 | angle = torch.randn(seqlen, rotary_dim // 2, device='cuda') 32 | cos = torch.cos(angle).to(dtype=dtype) 33 | sin = torch.sin(angle).to(dtype=dtype) 34 | out = apply_rotary_emb_func(x, cos, sin, inplace) 35 | out_pt = apply_rotary_emb_torch(x_pt, cos, sin) 36 | # Numerical error if we just do any arithmetic 37 | atol = ((out + 0.3 - 0.3) - out).abs().max().item() 38 | assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) 39 | g = torch.randn_like(out) 40 | g_pt = g.clone() # If inplace=True, we might modify the gradient inplace 41 | out.backward(g) 42 | out_pt.backward(g_pt) 43 | atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() 44 | assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol) 45 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/causality-monitor.yaml: -------------------------------------------------------------------------------- 1 | causality-monitor: 2 | _target_: src.callbacks.causality_monitor.CausalityMonitor -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | # rich_progress_bar: 2 | # _target_: pytorch_lightning.callbacks.RichProgressBar 3 | 4 | rich_model_summary: 5 | _target_: pytorch_lightning.callbacks.RichModelSummary 6 | 7 | model_checkpoint: 8 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 9 | monitor: "val/acc" # name of the logged metric which determines when model is improving 10 | mode: "max" # can be "max" or "min" 11 | save_top_k: 1 # save k best models (determined by above metric) 12 | save_last: True # additionally always save model from last epoch 13 | verbose: False 14 | dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''} 15 | filename: "epoch_{epoch:03d}" 16 | auto_insert_metric_name: False 17 | 18 | early_stopping: 19 | _target_: pytorch_lightning.callbacks.EarlyStopping 20 | monitor: "val/acc" # name of the logged metric which determines when model is improving 21 | mode: "max" # can be "max" or "min" 22 | patience: 100 # how many epochs of not improving until training stops 23 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 24 | 25 | learning_rate_monitor: 26 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 27 | logging_interval: step 28 | 29 | speed_monitor: 30 | _target_: src.callbacks.speed_monitor.SpeedMonitor 31 | intra_step_time: True 32 | inter_step_time: True 33 | epoch_time: True 34 | 35 | loss_scale_monitor: 36 | _target_: src.callbacks.loss_scale_monitor.LossScaleMonitor 37 | 38 | params_log: 39 | _target_: src.callbacks.params_log.ParamsLog 40 | total_params_log: True 41 | trainable_params_log: True 42 | non_trainable_params_log: True 43 | 44 | gpu_affinity: 45 | _target_: src.callbacks.gpu_affinity.GpuAffinity 46 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/ema.yaml: -------------------------------------------------------------------------------- 1 | ema: 2 | _target_: src.callbacks.ema.EMACallback 3 | decay: ??? 4 | use_num_updates: False 5 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/flop-count.yaml: -------------------------------------------------------------------------------- 1 | flop_count: 2 | _target_: src.callbacks.flop_count.FlopCount 3 | profilers: ['fvcore'] 4 | input_size: [3, 224, 224] 5 | device: null 6 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/gpu-monitor.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpu_stats_monitor: 5 | _target_: pytorch_lightning.callbacks.GPUStatsMonitor 6 | # [2021-08-13] TD: I just want the intra_step_size but it'll error if I 7 | # don't have memory_utilization and gpu_utilization. 8 | # Maybe I should write a callback with just the intra_step_size. 9 | memory_utilization: True 10 | gpu_utilization: True 11 | intra_step_time: True 12 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/model-summary.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: pytorch_lightning.callbacks.RichModelSummary 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/training/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/norm-monitor.yaml: -------------------------------------------------------------------------------- 1 | norm_monitor: 2 | _target_: src.callbacks.norm_monitor.NormMonitor 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/params-log.yaml: -------------------------------------------------------------------------------- 1 | params_log: 2 | _target_: src.callbacks.params_log.ParamsLog 3 | total_params_log: True 4 | trainable_params_log: True 5 | non_trainable_params_log: True 6 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | log_f1_precision_recall_heatmap: 19 | _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | log_confusion_matrix: 22 | _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | log_image_predictions: 25 | _target_: src.callbacks.wandb_callbacks.LogImagePredictions 26 | num_samples: 8 27 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default 7 | - optimizer: adamw 8 | - scheduler: null 9 | - task: sequence-model 10 | - model: null 11 | - datamodule: null 12 | - callbacks: default # set this to null if you don't want to use callbacks 13 | - metrics: null 14 | - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) 15 | 16 | - mode: default 17 | 18 | - experiment: null 19 | - hparams_search: null 20 | 21 | # enable color logging 22 | - override hydra/hydra_logging: colorlog 23 | - override hydra/job_logging: colorlog 24 | 25 | # path to original working directory 26 | # hydra hijacks working directory by changing it to the current log directory, 27 | # so it's useful to have this path as a special variable 28 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 29 | work_dir: ${hydra:runtime.cwd} 30 | 31 | # path to folder with data 32 | data_dir: ${work_dir}/data/ 33 | 34 | # pretty print config at the start of the run using Rich library 35 | print_config: True 36 | 37 | # disable python warnings if they annoy you 38 | ignore_warnings: True 39 | 40 | # check performance on test set, using the best model achieved during training 41 | # lightning chooses best model based on metric specified in checkpoint callback 42 | test_after_training: True 43 | 44 | resume: False 45 | 46 | # seed for random number generators in pytorch, numpy and python.random 47 | seed: null 48 | 49 | # name of the run, accessed by loggers 50 | name: null 51 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/datamodule/openwebtext.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.language_modeling_hf.LMDataModule 2 | dataset_name: openwebtext 3 | dataset_config_name: null 4 | tokenizer_name: gpt2 5 | cache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache 6 | max_length: 1024 7 | val_ratio: 0.0005 8 | val_split_seed: 2357 9 | add_eos: True 10 | batch_size: 8 # per GPU 11 | batch_size_eval: ${eval:${.batch_size} * 2} 12 | num_workers: 32 # For preprocessing only 13 | shuffle: True 14 | pin_memory: True 15 | __train_len: ${div_up:9035582198, ${.max_length}} 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/datamodule/thepile.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.language_modeling_hf.LMDataModule 2 | dataset_name: the_pile 3 | dataset_config_name: null 4 | tokenizer_name: gpt2 5 | cache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache 6 | max_length: 2048 7 | add_eos: True 8 | batch_size: 4 # per GPU 9 | batch_size_eval: ${eval:${.batch_size} * 2} 10 | num_workers: 64 # For preprocessing only 11 | use_shmem: False 12 | shuffle: True 13 | pin_memory: True 14 | __train_len: ${div_up:374337375694, ${.max_length}} 15 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2l-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m-flash.yaml 4 | - override /model/gpt2model: gpt2-large 5 | # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW. 6 | # Still, fairscale is even faster and uses less memory. 7 | # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2? 8 | # However, fairscale has issues with saving checkpoint (either OOM or very 9 | # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the 10 | # upstream version of OSS 11 | # https://github.com/facebookresearch/fairscale/issues/937 12 | # Pytorch ZeRO as also very slow for saving checkpoints due to 13 | # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU. 14 | - override /optimizer: adamw-zero 15 | 16 | # FusedAdam doesn't seem to speed things up here, time per global step 17 | # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam. 18 | # This could be because each GPU is only doing the optimizer step for 1 / 19 | # world_size of the parameters. 20 | # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO). 21 | # - override /optimizer: adamw-apex-zero 22 | 23 | # Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB 24 | # model: 25 | # config: 26 | # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} 27 | # mlp_checkpoint_lvl: 1 28 | 29 | datamodule: 30 | # batch_size: 16 31 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 32 | 33 | trainer: 34 | # strategy: null 35 | # strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"} 36 | strategy: 37 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 38 | find_unused_parameters: False 39 | gradient_as_bucket_view: True 40 | # TD [2022-08-03] Deepspeed makes the ppl curve go wild 41 | # strategy: deepspeed_stage_1 42 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2l-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m-hf.yaml 4 | - override /model/gpt2model: gpt2-large 5 | - override /optimizer: adamw-zero 6 | 7 | datamodule: 8 | batch_size: 2 9 | 10 | trainer: 11 | strategy: 12 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 13 | find_unused_parameters: False 14 | gradient_as_bucket_view: True 15 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m.yaml 4 | - override /model/gpt2model: gpt2-large 5 | - override /optimizer: adamw-zero 6 | 7 | datamodule: 8 | batch_size: 4 # Per GPU 9 | 10 | trainer: 11 | strategy: 12 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 13 | find_unused_parameters: False 14 | gradient_as_bucket_view: True 15 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2m-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2s-flash.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | # Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB 7 | # model: 8 | # config: 9 | # mlp_checkpoint_lvl: 1 10 | 11 | datamodule: 12 | # batch_size: 32 13 | batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else (32 if ${train.gpu_mem} < 80 else 64))"} 14 | 15 | train: 16 | optimizer: 17 | lr: 1.5e-4 18 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2m-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2s-hf.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | datamodule: 7 | batch_size: 4 8 | 9 | train: 10 | optimizer: 11 | lr: 1.5e-4 12 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2m.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2s.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | datamodule: 7 | batch_size: 8 # Per GPU 8 | 9 | train: 10 | optimizer: 11 | lr: 1.5e-4 12 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2s-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/base.yaml 4 | - override /model: gpt2 5 | - override /model/gpt2model: gpt2-small 6 | 7 | model: 8 | config: 9 | # n_positions is already set to ${datamodule.max_length} 10 | residual_in_fp32: True 11 | use_flash_attn: True 12 | fused_bias_fc: True 13 | fused_mlp: True 14 | fused_dropout_add_ln: True 15 | pad_vocab_size_multiple: 8 16 | 17 | datamodule: 18 | # batch_size: 64 19 | batch_size: ${eval:"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)"} 20 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2s-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/base.yaml 4 | - override /model: gpt2-hf 5 | - override /model/gpt2model: gpt2-small 6 | - override /callbacks: [default, norm-monitor, flop-count] 7 | 8 | datamodule: 9 | batch_size: 8 10 | 11 | train: 12 | # Use the standard torch.nn.CrossEntropyLoss 13 | loss_fn: null 14 | 15 | callbacks: 16 | flop_count: 17 | input_size: 18 | - ${datamodule.max_length} 19 | input_dtype: 20 | # It's surprisingly hard to get hydra to return torch.long since it's not a callable 21 | _target_: torch.__getattribute__ 22 | _args_: 23 | - long 24 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/base.yaml 4 | - override /model: gpt2 5 | - override /model/gpt2model: gpt2-small 6 | 7 | datamodule: 8 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2xl-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2l-flash.yaml 4 | - override /model/gpt2model: gpt2-xlarge 5 | 6 | # Can enable mlp_checkpoint_lvl to fit to A100 40GB 7 | # model: 8 | # config: 9 | # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} 10 | # mlp_checkpoint_lvl: 1 11 | 12 | datamodule: 13 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} 14 | # With adamw-zero optimizer, on A100 40GB: 15 | # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1) 16 | # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1) 17 | # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1) 18 | # With adamw-apex-distributed optimizer: 19 | # checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1) 20 | # checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers, 21 | # batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1) 22 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2xl-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2l-hf.yaml 4 | - override /model/gpt2model: gpt2-xlarge 5 | 6 | datamodule: 7 | batch_size: 1 8 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/owt/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/owt/gpt2m.yaml 4 | - override /model/gpt2model: gpt2-xlarge 5 | - override /optimizer: adamw-zero 6 | 7 | datamodule: 8 | batch_size: 2 # Per GPU 9 | 10 | trainer: 11 | strategy: 12 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 13 | find_unused_parameters: False 14 | gradient_as_bucket_view: True 15 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary-8k.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 20 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 20 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 20 # Headdim 128 is faster than headdim 80 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary-8k.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} 11 | mlp_checkpoint_lvl: 0 12 | 13 | datamodule: 14 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 15 | 16 | train: 17 | optimizer: 18 | lr: 1.6e-4 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-hf-hdim128.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-hf.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 128 9 | n_layer: 32 10 | 11 | # OOM on A100 80GB even with batch_size = 1 12 | datamodule: 13 | batch_size: 1 14 | 15 | train: 16 | optimizer: 17 | lr: 1.6e-4 18 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3-2.7B-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-hf.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 2560 8 | n_head: 32 9 | n_layer: 32 10 | 11 | datamodule: 12 | batch_size: 1 13 | 14 | train: 15 | optimizer: 16 | lr: 1.6e-4 17 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3l-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 8 | 9 | train: 10 | global_batch_size: 64 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3l-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3l-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3l-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | - override /optimizer: adamw-zero 5 | 6 | model: 7 | config: 8 | n_embd: 1536 9 | n_head: 16 10 | n_layer: 24 11 | # mlp_checkpoint_lvl: 1 # To fit batch_size 8 12 | 13 | datamodule: 14 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} 15 | 16 | train: 17 | optimizer: 18 | lr: 2.5e-4 19 | 20 | trainer: 21 | strategy: 22 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 23 | find_unused_parameters: False 24 | gradient_as_bucket_view: True 25 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3l-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-hf.yaml 4 | 5 | model: 6 | config: 7 | n_embd: 1536 8 | n_head: 16 9 | n_layer: 24 10 | 11 | datamodule: 12 | batch_size: 2 13 | 14 | train: 15 | optimizer: 16 | lr: 2.5e-4 17 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3m-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} 8 | 9 | train: 10 | global_batch_size: 64 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3m-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3m-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3m-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | # Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB 7 | # model: 8 | # config: 9 | # mlp_checkpoint_lvl: 1 10 | 11 | datamodule: 12 | batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} 13 | 14 | train: 15 | optimizer: 16 | lr: 3.0e-4 17 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3m-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-hf.yaml 4 | - override /model/gpt2model: gpt2-medium 5 | 6 | datamodule: 7 | batch_size: 4 8 | 9 | train: 10 | optimizer: 11 | lr: 3.0e-4 12 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3s-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} 8 | 9 | train: 10 | global_batch_size: 64 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3s-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3s-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/base.yaml 4 | - override /model: gpt2 5 | - override /model/gpt2model: gpt2-small 6 | 7 | model: 8 | config: 9 | # n_positions is already set to ${datamodule.max_length} 10 | residual_in_fp32: True 11 | use_flash_attn: True 12 | fused_dropout_add_ln: True 13 | fused_mlp: True 14 | fused_bias_fc: True 15 | pad_vocab_size_multiple: 8 16 | 17 | datamodule: 18 | batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"} 19 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3s-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/base.yaml 4 | - override /model: gpt2-hf 5 | - override /model/gpt2model: gpt2-small 6 | 7 | datamodule: 8 | batch_size: 8 9 | 10 | train: 11 | # Use the standard torch.nn.CrossEntropyLoss 12 | loss_fn: null 13 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3xl-flash-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | datamodule: 6 | max_length: 8192 7 | batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} 8 | 9 | train: 10 | global_batch_size: 128 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-rotary.yaml 4 | 5 | trainer: 6 | max_steps: 60000 7 | 8 | train: 9 | scheduler: 10 | t_initial: ${trainer.max_steps} 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash-8k.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3xl-flash-rotary.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3xl-flash.yaml 4 | 5 | model: 6 | config: 7 | max_position_embeddings: 0 # Disable absolute position embedding 8 | rotary_emb_fraction: 0.5 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3xl-flash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-flash.yaml 4 | - override /optimizer: adamw-zero 5 | 6 | model: 7 | config: 8 | n_embd: 2048 9 | n_head: 16 10 | n_layer: 24 11 | 12 | datamodule: 13 | batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} 14 | 15 | train: 16 | global_batch_size: 512 17 | optimizer: 18 | lr: 2.0e-4 19 | scheduler: 20 | t_initial: 300000 21 | 22 | trainer: 23 | strategy: 24 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 25 | find_unused_parameters: False 26 | gradient_as_bucket_view: True 27 | max_steps: 400000 28 | val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} 29 | 30 | callbacks: 31 | model_checkpoint: 32 | every_n_train_steps: 1000 33 | model_checkpoint_progress: 34 | every_n_train_steps: 12500 35 | fault_tolerant: False # Saving takes too long 36 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/experiment/pile/gpt3xl-hf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/pile/gpt3s-hf.yaml 4 | - override /optimizer: adamw-zero 5 | 6 | model: 7 | config: 8 | n_embd: 2048 9 | n_head: 16 10 | n_layer: 24 11 | 12 | datamodule: 13 | batch_size: 2 14 | 15 | train: 16 | global_batch_size: 512 17 | optimizer: 18 | lr: 2.0e-4 19 | scheduler: 20 | t_initial: 300000 21 | 22 | trainer: 23 | strategy: 24 | _target_: src.utils.ddp_zero1.DDPStrategyZero1 25 | find_unused_parameters: False 26 | gradient_as_bucket_view: True 27 | max_steps: 400000 28 | val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} 29 | 30 | callbacks: 31 | model_checkpoint: 32 | every_n_train_steps: 1000 33 | model_checkpoint_progress: 34 | every_n_train_steps: 12500 35 | fault_tolerant: False # Saving takes too long 36 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | project_name: "template-tests" 7 | experiment_name: ${name} 8 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | version: ${name} 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | # - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | experiment_name: ${name} 6 | tracking_uri: null 7 | tags: null 8 | save_dir: ./mlruns 9 | prefix: "" 10 | artifact_location: null 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: ${name} 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | version: ${name} 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: attention 6 | name: ${name} 7 | save_dir: "." 8 | mode: online # set offline to store all logs only locally 9 | id: ${oc.select:name} # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team or just remove it 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/metrics/acc.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | acc: 3 | _target_: src.metrics.accuracy.AccuracyMine 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/metrics/acc_ignore_index.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | acc: 3 | _target_: torchmetrics.Accuracy 4 | ignore_index: -100 5 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/metrics/acctop5.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | acctop5: 3 | _target_: src.metrics.accuracy.AccuracyMine 4 | top_k: 5 5 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/metrics/mse.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | mse: 3 | _target_: torchmetrics.MeanSquaredError 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/metrics/num-tokens.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | num-tokens: 3 | _target_: src.metrics.num_tokens.NumTokens 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/metrics/perplexity.yaml: -------------------------------------------------------------------------------- 1 | # @package eval.metrics 2 | ppl: 3 | _target_: src.metrics.perplexity.Perplexity 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/mode/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run in debug mode with: 4 | # `python run.py mode=debug` 5 | 6 | defaults: 7 | - override /trainer: debug.yaml 8 | 9 | debug_mode: True 10 | 11 | hydra: 12 | # sets level of all command line loggers to 'DEBUG' 13 | verbose: True 14 | 15 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 16 | # sets level of only chosen command line loggers to 'DEBUG' 17 | # verbose: [src.train, src.utils.utils] 18 | 19 | # sets output paths for all file logs to 'logs/debug/' 20 | run: 21 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} 22 | sweep: 23 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} 24 | subdir: ${hydra.job.num} 25 | 26 | # disable rich config printing, since it will be already printed by hydra when `verbose: True` 27 | print_config: False 28 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/mode/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default running mode 4 | 5 | default_mode: True 6 | 7 | hydra: 8 | # default output paths for all file logs 9 | run: 10 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 11 | sweep: 12 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/multiruns/${now:%Y-%m-%d_%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/mode/exp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # run in experiment mode with: 4 | # `python run.py mode=exp name=experiment_name` 5 | 6 | experiment_mode: True 7 | 8 | # allows for custom naming of the experiment 9 | name: ??? 10 | 11 | hydra: 12 | # sets output paths for all file logs to `logs/experiment/name' 13 | run: 14 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} 15 | sweep: 16 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} 17 | subdir: ${hydra.job.num} 18 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/mode/profile.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # Run the Pytorch profiler 3 | 4 | trainer: 5 | profiler: 6 | _target_: pytorch_lightning.profilers.PyTorchProfiler 7 | dirpath: ${hydra.run.dir} 8 | schedule: 9 | _target_: torch.profiler.schedule 10 | wait: 5 11 | warmup: 5 12 | active: 5 13 | use_cuda: True 14 | max_steps: 20 15 | 16 | logger: 17 | wandb: 18 | mode: disabled 19 | 20 | callbacks: 21 | model_checkpoint: null 22 | model_checkpoint_progress: null 23 | early_stopping: null 24 | 25 | hydra: 26 | # sets output paths for all file logs to 'logs/profile/' 27 | run: 28 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/${now:%Y-%m-%d}/${now:%H-%M-%S} 29 | sweep: 30 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/multirun_${now:%Y-%m-%d_%H-%M-%S} 31 | subdir: ${hydra.job.num} 32 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/mode/smoke.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # Smoke test: disable logging and model checkpointing 3 | 4 | logger: 5 | wandb: 6 | mode: disabled 7 | 8 | callbacks: 9 | model_checkpoint: null 10 | model_checkpoint_progress: null 11 | 12 | hydra: 13 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 14 | # sets level of only chosen command line loggers to 'DEBUG' 15 | # verbose: [src.train, src.utils.utils] 16 | 17 | # sets output paths for all file logs to 'logs/debug/' 18 | run: 19 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} 20 | sweep: 21 | dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} 22 | subdir: ${hydra.job.num} 23 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/model/gpt2-hf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - gpt2model: gpt2-small 4 | 5 | _target_: transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel 6 | _recursive_: True 7 | config: 8 | _target_: transformers.GPT2Config 9 | # Mistral's config: https://github.com/stanford-crfm/mistral/blob/main/conf/models/gpt2-small.yaml 10 | # However, reorder_and_upcast_attn slows things down 11 | reorder_and_upcast_attn: false 12 | scale_attn_by_inverse_layer_idx: true 13 | n_positions: ${datamodule.max_length} 14 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - gpt2model: gpt2-small 4 | 5 | _target_: flash_attn.models.gpt.GPTLMHeadModel 6 | _recursive_: True 7 | config: 8 | _target_: transformers.GPT2Config 9 | # Mistral's config: # https://github.com/stanford-crfm/mistral/blob/main/conf/models/mistral-small.yaml 10 | # However, reorder_and_upcast_attn slows things down 11 | reorder_and_upcast_attn: false 12 | scale_attn_by_inverse_layer_idx: true 13 | n_positions: ${datamodule.max_length} 14 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/model/gpt2model/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 1280 5 | n_head: 20 6 | n_layer: 36 7 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/model/gpt2model/gpt2-medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 1024 5 | n_head: 16 6 | n_layer: 24 7 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/model/gpt2model/gpt2-small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 768 5 | n_head: 12 6 | n_layer: 12 7 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/model/gpt2model/gpt2-xlarge.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | config: 4 | n_embd: 1600 5 | n_head: 25 6 | n_layer: 48 7 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.optim.Adam 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/adamw-apex-distributed.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam 3 | adam_w_mode: True 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/adamw-apex-zero.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.distributed.optim.ZeroRedundancyOptimizer 3 | _recursive_: True 4 | optimizer_class: 5 | _target_: apex.optimizers.FusedAdam 6 | _partial_: True 7 | adam_w_mode: True 8 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/adamw-apex.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: apex.optimizers.FusedAdam 3 | adam_w_mode: True 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/adamw-zero.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.distributed.optim.ZeroRedundancyOptimizer 3 | _recursive_: True 4 | optimizer_class: 5 | _target_: torch.optim.__getattribute__ 6 | _args_: 7 | - "AdamW" 8 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.optim.AdamW 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/fusedlamb-ds.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: deepspeed.ops.lamb.FusedLamb 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/fusedlamb.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: apex.optimizers.FusedLAMB 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package train.optimizer 2 | _target_: torch.optim.SGD 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/cosine-warmup-timm.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/cosine-warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: transformers.get_cosine_schedule_with_warmup 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/invsqrt.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: src.optim.lr_scheduler.InvSqrt 3 | num_warmup_steps: ??? 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/linear-warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: transformers.get_linear_schedule_with_warmup 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/multi-step.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: torch.optim.lr_scheduler.MultiStepLR 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | train: 3 | scheduler_interval: epoch 4 | scheduler_monitor: ??? 5 | scheduler: 6 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 7 | factor: 0.2 # Decay factor when ReduceLROnPlateau is used 8 | patience: 20 9 | min_lr: 0.0 # Minimum learning rate during annealing 10 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/poly-warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: transformers.get_polynomial_decay_schedule_with_warmup 3 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package train.scheduler 2 | _target_: torch.optim.lr_scheduler.StepLR 3 | step_size: ??? 4 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/task/sequence-model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.tasks.seq.SequenceModel 2 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/trainer/all_params.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_summary: "top" 33 | weights_save_path: null 34 | num_sanity_val_steps: 2 35 | truncated_bptt_steps: null 36 | resume_from_checkpoint: null 37 | profiler: null 38 | benchmark: False 39 | deterministic: False 40 | reload_dataloaders_every_epoch: False 41 | auto_lr_find: False 42 | replace_sampler_ddp: True 43 | terminate_on_nan: False 44 | auto_scale_batch_size: False 45 | prepare_data_per_node: True 46 | plugins: null 47 | amp_backend: "native" 48 | amp_level: "O2" 49 | move_metrics_to_cpu: False 50 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 4 6 | strategy: ddp 7 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: 0 5 | 6 | min_epochs: 1 7 | max_epochs: 2 8 | 9 | # prints 10 | weights_summary: "full" 11 | profiler: null 12 | 13 | # debugs 14 | fast_dev_run: true 15 | num_sanity_val_steps: 2 16 | overfit_batches: 0 17 | limit_train_batches: 1.0 18 | limit_val_batches: 1.0 19 | limit_test_batches: 1.0 20 | track_grad_norm: -1 21 | terminate_on_nan: true 22 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # set `gpu` to train on GPU, null to train on CPU only 4 | accelerator: null 5 | 6 | min_epochs: 1 7 | max_epochs: 1000 8 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/run.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import dotenv 4 | import hydra 5 | from omegaconf import OmegaConf, DictConfig 6 | 7 | # load environment variables from `.env` file if it exists 8 | # recursively searches for `.env` in all folders starting from work dir 9 | dotenv.load_dotenv(override=True) 10 | 11 | OmegaConf.register_new_resolver('eval', eval) 12 | OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) 13 | # Delay the evaluation until we have the datamodule 14 | # So we want the resolver to yield the same string. 15 | OmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}') 16 | 17 | # Turn on TensorFloat32 18 | import torch.backends 19 | torch.backends.cuda.matmul.allow_tf32 = True 20 | torch.backends.cudnn.allow_tf32 = True 21 | 22 | 23 | def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig: 24 | """Only keep keys where fn(key) is True. Support nested DictConfig. 25 | """ 26 | # Using d.items_ex(resolve=False) instead of d.items() since we want to keep the 27 | # ${datamodule:foo} unresolved for now. 28 | return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v 29 | # for k, v in d.items_ex(resolve=False) if fn(k)}) 30 | for k, v in d.items() if fn(k)}) 31 | 32 | 33 | @hydra.main(config_path="configs/", config_name="config.yaml") 34 | def main(config: DictConfig): 35 | 36 | # Remove config keys that start with '__'. These are meant to be used only in computing 37 | # other entries in the config. 38 | config = dictconfig_filter_key(config, lambda k: not k.startswith('__')) 39 | 40 | # Imports should be nested inside @hydra.main to optimize tab completion 41 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 42 | from src.train import train 43 | from src.eval import evaluate 44 | from src.utils import utils 45 | 46 | # A couple of optional utilities: 47 | # - disabling python warnings 48 | # - forcing debug-friendly configuration 49 | # - verifying experiment name is set when running in experiment mode 50 | # You can safely get rid of this line if you don't want those 51 | utils.extras(config) 52 | 53 | # Pretty print config using Rich library 54 | if config.get("print_config"): 55 | utils.print_config(config, resolve=True) 56 | 57 | # Train model 58 | mode = config.get('mode', 'train') 59 | if mode not in ['train', 'eval']: 60 | raise NotImplementedError(f'mode {mode} not supported') 61 | if mode == 'train': 62 | return train(config) 63 | elif mode == 'eval': 64 | return evaluate(config) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flash-fft-conv/b8771028717f46d5b22cbb8e12833f35033d621b/examples/hyena/flash-attention/training/src/callbacks/__init__.py -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/causality_monitor.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from pytorch_lightning import Callback 4 | from pytorch_lightning.utilities import rank_zero_only 5 | 6 | import torch 7 | from torch.autograd import grad 8 | 9 | class CausalityMonitor(Callback): 10 | r"""Monitor causality of a model by tracking gradient leakage forward in time. 11 | In a fully causal model, dy[k]du[s] ~= 0 for all k < s. 12 | 13 | Args: 14 | seq_len (int): Length of the sequence to monitor. 15 | input_dim (int): Dimension of the input to monitor. If 0, the callback assumes 16 | the task to be language modeling, and skips the embedding layer. If > 0, 17 | input_dim is interpreted as the input channel dimension, i.e. D with 18 | dummy input of dimension [B, L, D]. 19 | 20 | Notes: 21 | This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute, 22 | indicating the primary model to monitor. For LMs, `net` or `s4seq` should 23 | be after the embedding layer. 24 | """ 25 | 26 | def __init__(self, seq_len: int = 10, input_dim: int = 0): 27 | super().__init__() 28 | self.seq_len = seq_len 29 | self.input_dim = input_dim 30 | 31 | @rank_zero_only 32 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 33 | model = pl_module.model 34 | 35 | with torch.enable_grad(): 36 | if self.input_dim == 0: 37 | # [MP] LongTensors cannot have gradients - we start from post 38 | # embedding in the LM case 39 | input_dim = model.d_model 40 | x = torch.randn((2, self.seq_len, input_dim), \ 41 | requires_grad=True).to(pl_module.device) 42 | # [DF] HACK: we need to get the layer that comes after the embedding 43 | if hasattr(model, 'net'): 44 | y = model.net(x) 45 | else: 46 | y = model.s4seq(x) 47 | else: 48 | x = torch.randn(1, self.seq_len, self.input_dim, \ 49 | requires_grad=True).to(pl_module.device) 50 | y = model(x) 51 | 52 | stats = {} 53 | for i in range(self.seq_len): 54 | # total gradients flowing from y_i to x 55 | g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0] 56 | g = g[0,i+1:,:].abs().mean() 57 | stats[f'stats/causality_{i}'] = g.item() 58 | 59 | if trainer.loggers is not None: 60 | for logger in trainer.loggers: 61 | logger.log_metrics(stats, step=trainer.global_step) 62 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/flop_count.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | 6 | from pytorch_lightning import Callback, Trainer, LightningModule 7 | from pytorch_lightning.utilities import rank_zero_only 8 | from pytorch_lightning.utilities.parsing import AttributeDict 9 | 10 | from src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling 11 | from src.utils.flops import profile_deepspeed, profile_fvcore 12 | 13 | 14 | class FlopCount(Callback): 15 | """Counter the number of FLOPs used by the model 16 | """ 17 | def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'], 18 | input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None): 19 | if not isinstance(profilers, Sequence): 20 | profilers = [profilers] 21 | if any(p not in ['fvcore', 'deepspeed'] for p in profilers): 22 | raise NotImplementedError('Only support fvcore and deepspeed profilers') 23 | if 'fvcore' in profilers and not has_fvcore_profiling: 24 | raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`') 25 | elif 'deepspeed' in profilers and not has_deepspeed_profiling: 26 | raise ImportError('deepspeed is not installed') 27 | super().__init__() 28 | self.profilers = profilers 29 | self.input_size = tuple(input_size) 30 | self.input_dtype = input_dtype 31 | self.device = device 32 | 33 | @rank_zero_only 34 | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 35 | if 'fvcore' in self.profilers: 36 | _, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size, 37 | input_dtype=self.input_dtype, detailed=True) 38 | trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6}) 39 | if 'deepspeed' in self.profilers: 40 | macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size, 41 | input_dtype=self.input_dtype, detailed=True) 42 | if 'fvcore' not in self.profilers: # fvcore's MACs seem more accurate 43 | trainer.logger.log_hyperparams({'GMACs': macs * 1e-9}) 44 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/gpu_affinity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import Callback, Trainer, LightningModule 4 | 5 | import logging 6 | 7 | log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0 8 | 9 | 10 | def l2_promote(): 11 | import ctypes 12 | _libcudart = ctypes.CDLL('libcudart.so') 13 | # Set device limit on the current device 14 | # cudaLimitMaxL2FetchGranularity = 0x05 15 | pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int)) 16 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) 17 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) 18 | assert pValue.contents.value == 128 19 | 20 | 21 | def set_affinity(trainer): 22 | try: 23 | from src.utils.gpu_affinity import set_affinity 24 | nproc_per_node = torch.cuda.device_count() 25 | affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous') 26 | log.info(f'{trainer.local_rank}: thread affinity: {affinity}') 27 | # TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per 28 | # number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan. 29 | # l2_promote() 30 | except: 31 | pass 32 | 33 | 34 | class GpuAffinity(Callback): 35 | """Set GPU affinity and increase the L2 fetch granularity. 36 | Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL 37 | """ 38 | 39 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None: 40 | set_affinity(trainer) 41 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/loss_scale_monitor.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py. 2 | from typing import Any 3 | 4 | from pytorch_lightning import Callback, Trainer 5 | from pytorch_lightning.utilities import rank_zero_only 6 | from pytorch_lightning.strategies import DeepSpeedStrategy 7 | 8 | 9 | class LossScaleMonitor(Callback): 10 | """Monitor the loss scale for AMP (fp16). 11 | """ 12 | 13 | # Use on_before_optimizer_step instead of on_train_batch_start since there might be 14 | # gradient accumulation and we only care about the loss scale when it could change (i.e., 15 | # optimizer.step). 16 | @rank_zero_only 17 | def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None: 18 | if not trainer._logger_connector.should_update_logs: 19 | return 20 | stats = {} 21 | if isinstance(trainer.strategy, DeepSpeedStrategy): 22 | stats = {'scalar/scale': trainer.model.optimizer.loss_scale} 23 | if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'): 24 | scaler = trainer.precision_plugin.scaler 25 | if scaler is not None: 26 | stats = { 27 | 'scaler/scale': scaler.get_scale(), 28 | 'scaler/growth_tracker': scaler._get_growth_tracker(), 29 | } 30 | if stats and trainer.loggers is not None: 31 | for logger in trainer.loggers: 32 | logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) 33 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py 2 | from typing import Any 3 | from pathlib import Path 4 | 5 | import pytorch_lightning as pl 6 | 7 | 8 | class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint): 9 | 10 | def __init__(self, *args, fault_tolerant=False, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.fault_tolerant = fault_tolerant 13 | 14 | def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: 15 | if self.fault_tolerant: 16 | # overwrite if necessary 17 | trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) 18 | 19 | # def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: 20 | # if self.fault_tolerant: 21 | # trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) 22 | 23 | 24 | # TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant. 25 | # However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work. 26 | # So I decided to just copy _FaultToleranceCheckpoint and just save on_exception. 27 | 28 | # def on_save_checkpoint(self, checkpoint): 29 | # # TD [2022-07-12] The "completed" counter is off by 1 so when it resumes 30 | # # it's off by 1 iteration. However, the data is still off by 1 iteration, probably 31 | # # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how 32 | # # to fix it cleanly. 33 | # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1 34 | # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1 35 | # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1 36 | # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1 37 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/callbacks/params_log.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pytorch_lightning import Callback, Trainer, LightningModule 4 | from pytorch_lightning.utilities import rank_zero_only 5 | from pytorch_lightning.utilities.parsing import AttributeDict 6 | 7 | 8 | class ParamsLog(Callback): 9 | """Log the number of parameters of the model 10 | """ 11 | def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True, 12 | non_trainable_params_log: bool = True): 13 | super().__init__() 14 | self._log_stats = AttributeDict( 15 | { 16 | 'total_params_log': total_params_log, 17 | 'trainable_params_log': trainable_params_log, 18 | 'non_trainable_params_log': non_trainable_params_log, 19 | } 20 | ) 21 | 22 | @rank_zero_only 23 | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 24 | logs = {} 25 | if self._log_stats.total_params_log: 26 | logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters()) 27 | if self._log_stats.trainable_params_log: 28 | logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters() 29 | if p.requires_grad) 30 | if self._log_stats.non_trainable_params_log: 31 | logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters() 32 | if not p.requires_grad) 33 | if trainer.logger is not None: 34 | trainer.logger.log_hyperparams(logs) 35 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/datamodules/datasets/detokenizer.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py 2 | # Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py 3 | 4 | """ 5 | Handle detokenization for different dataset for zero-shot LM evaluation. 6 | """ 7 | import re 8 | 9 | 10 | def wikitext_detokenize(string: str) -> str: 11 | """ 12 | Wikitext is whitespace tokenized and we remove these whitespaces. 13 | Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py 14 | """ 15 | # Contractions 16 | string = string.replace("s '", "s'") 17 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 18 | 19 | # Number Separators 20 | string = string.replace(" @-@ ", "-") 21 | string = string.replace(" @,@ ", ",") 22 | string = string.replace(" @.@ ", ".") 23 | 24 | # Punctuation 25 | string = string.replace(" : ", ": ") 26 | string = string.replace(" ; ", "; ") 27 | string = string.replace(" . ", ". ") 28 | string = string.replace(" ! ", "! ") 29 | string = string.replace(" ? ", "? ") 30 | string = string.replace(" , ", ", ") 31 | 32 | # Double Brackets 33 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 34 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 35 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 36 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 37 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 38 | 39 | # Miscellaneous 40 | string = string.replace("= = = =", "====") 41 | string = string.replace("= = =", "===") 42 | string = string.replace("= =", "==") 43 | string = string.replace(" " + chr(176) + " ", chr(176)) 44 | string = string.replace(" \n", "\n") 45 | string = string.replace("\n ", "\n") 46 | string = string.replace(" N ", " 1 ") 47 | string = string.replace(" 's", "'s") 48 | 49 | return string 50 | 51 | 52 | # Set Registry for Various Datasets 53 | DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} 54 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/datamodules/datasets/lm_dataset.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py 2 | # Except we don't pad the last block and don't use overlapping eval 3 | # And we return both the input and the target 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | 9 | 10 | class LMDataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, tokens, seq_len, drop_last=True): 13 | """tokens should be a numpy array 14 | """ 15 | self.seq_len = seq_len 16 | ntokens = len(tokens) 17 | if drop_last: 18 | ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 19 | self.ntokens = ntokens 20 | # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, 21 | # and slicing would load it to memory. 22 | self.tokens = tokens 23 | self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) 24 | 25 | def __len__(self): 26 | return self.total_sequences 27 | 28 | def __getitem__(self, idx): 29 | start_idx = idx * self.seq_len 30 | seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) 31 | data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) 32 | return data[:-1], data[1:].clone() 33 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/datamodules/timm_mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.data import Mixup 4 | from timm.data.mixup import mixup_target 5 | 6 | 7 | class TimmMixup(Mixup): 8 | """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. 9 | """ 10 | def __call__(self, x, target): 11 | if self.mode == 'elem': 12 | lam = self._mix_elem(x) 13 | elif self.mode == 'pair': 14 | # We move the assert from the beginning of the function to here 15 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 16 | lam = self._mix_pair(x) 17 | else: 18 | lam = self._mix_batch(x) 19 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 20 | return x, target 21 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/distributed/ddp_comm_hooks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html 2 | # We divide by world_size first before converting to fp16, so it's safer. 3 | from typing import Any, Callable 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def fp16_compress_hook( 10 | process_group: dist.ProcessGroup, bucket: dist.GradBucket 11 | ) -> torch.futures.Future[torch.Tensor]: 12 | """ 13 | This DDP communication hook implements a simple gradient compression 14 | approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) 15 | and then divides it by the process group size. 16 | It allreduces those ``float16`` gradient tensors. Once compressed gradient 17 | tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). 18 | 19 | Example:: 20 | >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) 21 | """ 22 | group_to_use = process_group if process_group is not None else dist.group.WORLD 23 | world_size = group_to_use.size() 24 | 25 | # Divide first before converting to fp16 26 | # Use out argument to fuse the division and the conversion. 27 | compressed_tensor = torch.div(bucket.buffer(), world_size, 28 | out=torch.empty_like(bucket.buffer(), dtype=torch.float16)) 29 | 30 | fut = dist.all_reduce( 31 | compressed_tensor, group=group_to_use, async_op=True 32 | ).get_future() 33 | 34 | def decompress(fut): 35 | decompressed_tensor = bucket.buffer() 36 | # Decompress in place to reduce the peak memory. 37 | # See: https://github.com/pytorch/pytorch/issues/45968 38 | decompressed_tensor.copy_(fut.value()[0]) 39 | return decompressed_tensor 40 | 41 | # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case 42 | # resend with fp32? 43 | return fut.then(decompress) 44 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from torchmetrics import Metric, Accuracy 5 | 6 | 7 | class AccuracyMine(Accuracy): 8 | """Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. 9 | """ 10 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 11 | super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target) 12 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/metrics/num_tokens.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from torchmetrics import Metric 7 | 8 | 9 | class NumTokens(Metric): 10 | """Keep track of how many tokens we've seen. 11 | """ 12 | # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch 13 | # of the next epoch. 14 | # Right now the hack is that we override reset(), which would mess up the forward method. 15 | # We then override forward to do the right thing. 16 | 17 | is_differentiable = False 18 | higher_is_better = False 19 | full_state_update = False 20 | count: Tensor 21 | 22 | def __init__(self, **kwargs: Dict[str, Any]): 23 | super().__init__(**kwargs) 24 | self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", 25 | persistent=True) # We want the count to be saved to state-dict 26 | 27 | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore 28 | self.count += target.numel() 29 | 30 | def compute(self) -> Tensor: 31 | return self.count 32 | 33 | def reset(self): 34 | count = self.count 35 | super().reset() 36 | self.count = count 37 | 38 | # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py 39 | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: 40 | """forward computation using single call to `update` to calculate the metric value on the current batch and 41 | accumulate global state. 42 | This can be done when the global metric state is a sinple reduction of batch states. 43 | """ 44 | self.update(*args, **kwargs) 45 | return self.compute() 46 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/optim/timm_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | from timm.scheduler import CosineLRScheduler 5 | 6 | 7 | # We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain 8 | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): 9 | """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. 10 | It supports resuming as well. 11 | """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self._last_epoch = -1 16 | self.step(epoch=0) 17 | 18 | def step(self, epoch=None): 19 | if epoch is None: 20 | self._last_epoch += 1 21 | else: 22 | self._last_epoch = epoch 23 | # We call either step or step_update, depending on whether we're using the scheduler every 24 | # epoch or every step. 25 | # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set 26 | # scheduler interval to "step", then the learning rate update will be wrong. 27 | if self.t_in_epochs: 28 | super().step(epoch=self._last_epoch) 29 | else: 30 | super().step_update(num_updates=self._last_epoch) 31 | -------------------------------------------------------------------------------- /examples/hyena/flash-attention/training/src/utils/flops.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py 2 | import torch 3 | 4 | try: 5 | from deepspeed.profiling.flops_profiler import get_model_profile 6 | has_deepspeed_profiling = True 7 | except ImportError as e: 8 | has_deepspeed_profiling = False 9 | 10 | try: 11 | from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table 12 | from fvcore.nn import ActivationCountAnalysis 13 | has_fvcore_profiling = True 14 | except ImportError as e: 15 | FlopCountAnalysis = None 16 | ActivationCountAnalysis = None 17 | has_fvcore_profiling = False 18 | 19 | 20 | def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch.float32, 21 | batch_size=1, detailed=False): 22 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 23 | flops, macs, params = get_model_profile( 24 | model=model, 25 | args=torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype), 26 | print_profile=detailed, # prints the model graph with the measured profile attached to each module 27 | detailed=detailed, # print the detailed profile 28 | warm_up=10, # the number of warm-ups before measuring the time of each module 29 | as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) 30 | output_file=None, # path to the output file. If None, the profiler prints to stdout. 31 | ignore_modules=None) # the list of modules to ignore in the profiling 32 | return macs, 0 # no activation count in DS 33 | 34 | 35 | def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.float32, max_depth=4, 36 | batch_size=1, detailed=False, force_cpu=False): 37 | if force_cpu: 38 | model = model.to('cpu') 39 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 40 | example_input = torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype) 41 | fca = FlopCountAnalysis(model, example_input) 42 | aca = ActivationCountAnalysis(model, example_input) 43 | if detailed: 44 | print(flop_count_table(fca, max_depth=max_depth)) 45 | return fca, fca.total(), aca, aca.total() 46 | -------------------------------------------------------------------------------- /examples/hyena/src/callbacks/norms.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.utilities import rank_zero_only 3 | from pytorch_lightning.utilities.parsing import AttributeDict 4 | from omegaconf import OmegaConf 5 | 6 | class TrackNorms(pl.Callback): 7 | 8 | # TODO do callbacks happen before or after the method in the main LightningModule? 9 | # @rank_zero_only # needed? 10 | def on_after_training_step(self, batch, batch_idx, trainer: pl.Trainer, pl_module: pl.LightningModule): 11 | # Log extra metrics 12 | metrics = {} 13 | 14 | if hasattr(pl_module, "_grad_norms"): 15 | metrics.update(pl_module._grad_norms) 16 | 17 | self.log_dict( 18 | metrics, 19 | on_step=True, 20 | on_epoch=False, 21 | prog_bar=False, 22 | add_dataloader_idx=False, 23 | sync_dist=True, 24 | ) 25 | 26 | 27 | def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 28 | # example to inspect gradient information in tensorboard 29 | if OmegaConf.select(trainer.hparams, 'trainer.track_grad_norms'): # TODO dot notation should work with omegaconf? 30 | norms = {} 31 | for name, p in pl_module.named_parameters(): 32 | if p.grad is None: 33 | continue 34 | 35 | # param_norm = float(p.grad.data.norm(norm_type)) 36 | param_norm = torch.mean(p.grad.data ** 2) 37 | norms[f"grad_norm.{name}"] = param_norm 38 | pl_module._grad_norms = norms 39 | 40 | -------------------------------------------------------------------------------- /examples/hyena/src/callbacks/params.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.utilities import rank_zero_only 5 | from pytorch_lightning.utilities.parsing import AttributeDict 6 | 7 | 8 | class ParamsLog(pl.Callback): 9 | """ Log the number of parameters of the model """ 10 | def __init__( 11 | self, 12 | total: bool = True, 13 | trainable: bool = True, 14 | fixed: bool = True, 15 | ): 16 | super().__init__() 17 | self._log_stats = AttributeDict( 18 | { 19 | 'total_params_log': total, 20 | 'trainable_params_log': trainable, 21 | 'non_trainable_params_log': fixed, 22 | } 23 | ) 24 | 25 | @rank_zero_only 26 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 27 | logs = {} 28 | if self._log_stats.total_params_log: 29 | logs["params/total"] = sum(p.numel() for p in pl_module.parameters()) 30 | if self._log_stats.trainable_params_log: 31 | logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters() 32 | if p.requires_grad) 33 | if self._log_stats.non_trainable_params_log: 34 | logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters() 35 | if not p.requires_grad) 36 | if trainer.logger: 37 | trainer.logger.log_hyperparams(logs) 38 | -------------------------------------------------------------------------------- /examples/hyena/src/dataloaders/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | Basic datasets including MNIST and CIFAR will auto-download. Source code for these datamodules are in [basic.py](basic.py). 4 | 5 | By default, data is downloaded to `./data/` by default, where `.` is the top level directory of this repository (e.g. 'safari'). 6 | 7 | ## Advanced Usage 8 | 9 | After downloading and preparing data, the paths can be configured in several ways. 10 | 11 | 1. Suppose that it is desired to download all data to a different folder, for example a different disk. 12 | The data path can be configured by setting the environment variable `DATA_PATH`, which defaults to `./data`. 13 | 14 | 2. For fine-grained control over the path of a particular dataset, set `dataset.data_dir` in the config. For example, if the LRA ListOps files are located in `/home/lra/listops-1000/` instead of the default `./data/listops/`, 15 | pass in `+dataset.data_dir=/home/lra/listops-1000` on the command line or modify the config file directly. 16 | 17 | 3. As a simple workaround, softlinks can be set, e.g. `ln -s /home/lra/listops-1000 ./data/listops` 18 | 19 | 20 | # Data Preparation 21 | 22 | [LRA](#long-range-arena-lra) must be manually downloaded. 23 | 24 | By default, these should go under `$DATA_PATH/`, which defaults to `./data`. For the remainder of this README, these are used interchangeably. 25 | 26 | ## Long Range Arena (LRA) 27 | 28 | LRA can be downloaded from the [GitHub page](https://github.com/google-research/long-range-arena). 29 | These datasets should be organized as follows: 30 | ``` 31 | $DATA_PATH/ 32 | pathfinder/ 33 | pathfinder32/ 34 | pathfinder64/ 35 | pathfinder128/ 36 | pathfinder256/ 37 | aan/ 38 | listops/ 39 | ``` 40 | The other two datasets in the suite ("Image" or grayscale sequential CIFAR-10; "Text" or char-level IMDB sentiment classification) are both auto-downloaded. -------------------------------------------------------------------------------- /examples/hyena/src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import basic, et, lra, language_modeling_hf, synthetics, vision 2 | from .base import SequenceDataset 3 | -------------------------------------------------------------------------------- /examples/hyena/src/dataloaders/datasets/detokenizer.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py 2 | # Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py 3 | 4 | """ 5 | Handle detokenization for different dataset for zero-shot LM evaluation. 6 | """ 7 | import re 8 | 9 | 10 | def wikitext_detokenize(string: str) -> str: 11 | """ 12 | Wikitext is whitespace tokenized and we remove these whitespaces. 13 | Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py 14 | """ 15 | # Contractions 16 | string = string.replace("s '", "s'") 17 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 18 | 19 | # Number Separators 20 | string = string.replace(" @-@ ", "-") 21 | string = string.replace(" @,@ ", ",") 22 | string = string.replace(" @.@ ", ".") 23 | 24 | # Punctuation 25 | string = string.replace(" : ", ": ") 26 | string = string.replace(" ; ", "; ") 27 | string = string.replace(" . ", ". ") 28 | string = string.replace(" ! ", "! ") 29 | string = string.replace(" ? ", "? ") 30 | string = string.replace(" , ", ", ") 31 | 32 | # Double Brackets 33 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 34 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 35 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 36 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 37 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 38 | 39 | # Miscellaneous 40 | string = string.replace("= = = =", "====") 41 | string = string.replace("= = =", "===") 42 | string = string.replace("= =", "==") 43 | string = string.replace(" " + chr(176) + " ", chr(176)) 44 | string = string.replace(" \n", "\n") 45 | string = string.replace("\n ", "\n") 46 | string = string.replace(" N ", " 1 ") 47 | string = string.replace(" 's", "'s") 48 | 49 | return string 50 | 51 | 52 | # Set Registry for Various Datasets 53 | DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} -------------------------------------------------------------------------------- /examples/hyena/src/dataloaders/datasets/lm_dataset.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py 2 | # Except we don't pad the last block and don't use overlapping eval 3 | # And we return both the input and the target 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | 9 | 10 | class LMDataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, tokens, seq_len, drop_last=True): 13 | """tokens should be a numpy array 14 | """ 15 | self.seq_len = seq_len 16 | ntokens = len(tokens) 17 | if drop_last: 18 | ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 19 | self.ntokens = ntokens 20 | # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, 21 | # and slicing would load it to memory. 22 | self.tokens = tokens 23 | self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) 24 | 25 | def __len__(self): 26 | return self.total_sequences 27 | 28 | def __getitem__(self, idx): 29 | start_idx = idx * self.seq_len 30 | seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) 31 | data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) 32 | return data[:-1], data[1:].clone() -------------------------------------------------------------------------------- /examples/hyena/src/dataloaders/utils/timm_mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.data import Mixup 4 | from timm.data.mixup import mixup_target 5 | 6 | 7 | class TimmMixup(Mixup): 8 | """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. 9 | """ 10 | def __call__(self, x, target, *args): 11 | if self.mode == 'elem': 12 | lam = self._mix_elem(x) 13 | elif self.mode == 'pair': 14 | # We move the assert from the beginning of the function to here 15 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 16 | lam = self._mix_pair(x) 17 | else: 18 | lam = self._mix_batch(x) 19 | # Another change is to set the right device here 20 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, 21 | device=target.device) 22 | return x, target, *args -------------------------------------------------------------------------------- /examples/hyena/src/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .components import LinearActivation, Activation, Normalization, DropoutNd 2 | -------------------------------------------------------------------------------- /examples/hyena/src/models/sequence/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SequenceModule, TransposedModule 2 | from .model import SequenceModel 3 | from .ff import FF 4 | -------------------------------------------------------------------------------- /examples/hyena/src/models/sequence/ff.py: -------------------------------------------------------------------------------- 1 | """ Implementation of FFN block in the style of Transformers """ 2 | 3 | from functools import partial 4 | from torch import nn 5 | from src.models.sequence.base import SequenceModule 6 | from src.models.nn import LinearActivation, DropoutNd 7 | 8 | class FF(SequenceModule): 9 | def __init__(self, d_input, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0, tie_dropout=False): 10 | super().__init__() 11 | self.d_output = d_input if d_output is None else d_output 12 | self.transposed = transposed 13 | d_inner = expand * d_input 14 | 15 | linear1 = LinearActivation( 16 | d_input, d_inner, 17 | transposed=transposed, 18 | activation=activation, 19 | initializer=initializer, 20 | activate=True, 21 | ) 22 | dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout 23 | # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout 24 | drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() 25 | 26 | linear2 = LinearActivation( 27 | d_inner, self.d_output, 28 | transposed=transposed, 29 | activation=None, 30 | initializer=initializer, 31 | activate=False, 32 | ) 33 | 34 | self.ff = nn.Sequential( 35 | linear1, 36 | drop, 37 | linear2, 38 | ) 39 | 40 | def forward(self, x, *args, **kwargs): 41 | return self.ff(x), None 42 | 43 | def step(self, x, state, **kwargs): 44 | # x: [batch, d_input] 45 | if self.transposed: 46 | # expects: [batch, d_input, seq_len] 47 | return self.ff(x.unsqueeze(-1)).squeeze(-1), state 48 | else: 49 | return self.ff(x), state 50 | 51 | -------------------------------------------------------------------------------- /examples/hyena/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate 2 | -------------------------------------------------------------------------------- /examples/hyena/src/utils/registry.py: -------------------------------------------------------------------------------- 1 | optimizer = { 2 | "adam": "torch.optim.Adam", 3 | "adamw": "torch.optim.AdamW", 4 | "rmsprop": "torch.optim.RMSprop", 5 | "sgd": "torch.optim.SGD", 6 | "lamb": "src.utils.optim.lamb.JITLamb", 7 | } 8 | 9 | scheduler = { 10 | "constant": "transformers.get_constant_schedule", 11 | "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", 12 | "step": "torch.optim.lr_scheduler.StepLR", 13 | "multistep": "torch.optim.lr_scheduler.MultiStepLR", 14 | "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", 15 | "constant_warmup": "transformers.get_constant_schedule_with_warmup", 16 | "linear_warmup": "transformers.get_linear_schedule_with_warmup", 17 | "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", 18 | "cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler", 19 | } 20 | 21 | model = { 22 | # Backbones from this repo 23 | "model": "src.models.sequence.SequenceModel", 24 | "lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel", 25 | "lm_simple": "src.models.sequence.simple_lm.SimpleLMHeadModel", 26 | "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", 27 | } 28 | 29 | layer = { 30 | "id": "src.models.sequence.base.SequenceIdentity", 31 | "ff": "src.models.sequence.ff.FF", 32 | "mha": "src.models.sequence.mha.MultiheadAttention", 33 | "s4d": "src.models.sequence.ssm.s4d.S4D", 34 | "s4_simple": "src.models.sequence.ssm.s4_simple.SimpleS4Wrapper", 35 | "long-conv": "src.models.sequence.long_conv.LongConv", 36 | "h3": "src.models.sequence.h3.H3", 37 | "h3-conv": "src.models.sequence.h3_conv.H3Conv", 38 | "hyena": "src.models.sequence.hyena.HyenaOperator", 39 | "hyena-flashfft": "src.models.sequence.hyena-flashfft.FlashHyenaOperator", 40 | "hyena-filter": "src.models.sequence.hyena.HyenaFilter", 41 | "vit": "src.models.sequence.mha.VitAttention", 42 | } 43 | 44 | callbacks = { 45 | "timer": "src.callbacks.timer.Timer", 46 | "params": "src.callbacks.params.ParamsLog", 47 | "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", 48 | "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", 49 | "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", 50 | "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", 51 | "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", 52 | "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", 53 | "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", 54 | } 55 | -------------------------------------------------------------------------------- /examples/long-convs/README.md: -------------------------------------------------------------------------------- 1 | # Long Convs 2 | 3 | This folder shows an example of adapting a Long Conv backbone, as presented in [Simple Hardware-Efficient Long Convolutions for Sequence Modeling](https://arxiv.org/abs/2302.06646), to use FlashFFTConv. 4 | These files are sourced from the standalone example in the main repo. 5 | 6 | ## Changes to Use FlashFFTConv in Long Convs 7 | 8 | We describe the changes necessary to use FlashFFTConv in M2-BERT: 9 | 10 | Create an instance of `FlashFFTConv` in `LongConvModel`. In [flashfftconv_long_convs.py](flashfftconv_long_convs.py), lines 113-122: 11 | ```Python 12 | self.flashfftconv = FlashFFTConv(2048) 13 | 14 | ... 15 | 16 | for _ in self.layers: 17 | layer = LongConv(d_model, L=1024, dropout=dropout, **conv_kwargs) 18 | layer.flashfftconv = self.flashfftconv 19 | self.conv_layers.append(layer) 20 | ``` 21 | 22 | Note that we set FFT size to twice the sequence length, to create a bidirectional convolution. -------------------------------------------------------------------------------- /flashfftconv/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import FlashFFTConv 2 | from .depthwise_1d import FlashDepthWiseConv1d -------------------------------------------------------------------------------- /flashfftconv/depthwise_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Dan Fu and Hermann Kumbong. 2 | import torch 3 | import math 4 | from monarch_cuda import conv1d_forward, conv1d_backward 5 | from einops import rearrange 6 | 7 | class conv1dFunc(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input, weights, bias, padding, is_bhl=True): 10 | outputs = conv1d_forward(input, weights, bias, padding, is_bhl) 11 | ctx.padding = padding 12 | ctx.is_bhl = is_bhl 13 | ctx.save_for_backward(input, weights, bias) 14 | return outputs 15 | 16 | @staticmethod 17 | def backward(ctx, dout): 18 | input, weight, bias = ctx.saved_tensors 19 | dout = dout.contiguous() 20 | du, dk, dbias = conv1d_backward(dout, input, weight, bias, ctx.padding, ctx.is_bhl) 21 | return du, dk, dbias, None, None 22 | 23 | #TODO: initialization 24 | class FlashDepthWiseConv1d(torch.nn.Module): 25 | def __init__(self, channels, kernel_size, padding, weights, bias, is_bhl=True, device=None, dtype=None): 26 | factory_kwargs = {'device': device, 'dtype': dtype} 27 | super(FlashDepthWiseConv1d, self).__init__() 28 | self.d = channels 29 | self.k = kernel_size 30 | self.padding = padding 31 | self.is_bhl = is_bhl 32 | if is_bhl: 33 | self.weights = torch.nn.Parameter(weights.squeeze()) 34 | else: 35 | self.weights = torch.nn.Parameter(rearrange(weights.squeeze(), 'd k -> k d').detach().clone().contiguous()) 36 | self.bias = torch.nn.Parameter(bias.detach().clone().contiguous()) 37 | self.reset_parameters(weights, bias) 38 | 39 | #TODO: initialization 40 | def reset_parameters(self, weights, bias): 41 | pass 42 | # stdv = 1.0 / math.sqrt(self.state_size) 43 | # for weight in self.parameters(): 44 | # weight.data.uniform_(-stdv, +stdv) 45 | 46 | #current format for the weights is transpose of what is used in nn.Conv1d 47 | #[HK]: load the weights for the conv1d layer and then transpose them 48 | def load_state_dict(self, state_dict, strict: bool = True): 49 | pass 50 | 51 | #[HK]: transpose the weights before saving so that they can be loaded in nn.Conv1d 52 | def save_state_dict(self): 53 | pass 54 | 55 | def forward(self, input): 56 | return conv1dFunc.apply(input, self.weights, self.bias, self.padding, self.is_bhl) -------------------------------------------------------------------------------- /flashfftconv/sparse_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Dan Fu and Hermann Kumbong. 2 | import torch 3 | ''' 4 | Example implementations of partial and frequency-sparse convolutions. 5 | These are just PyTorch examples, not optimized versions. 6 | ''' 7 | 8 | class PartialFFTConv(torch.nn.Module): 9 | def __init__(self, N_partial): 10 | super().__init__() 11 | self.N_partial = N_partial 12 | 13 | def forward(self, x, k): 14 | L = x.shape[-1] 15 | N = 2 * L 16 | x_dtype = x.dtype 17 | x_f = torch.fft.rfft(x.float(), n = N) 18 | k_f = torch.fft.rfft(k[..., :self.N_partial], n = N) 19 | y_f = x_f * k_f 20 | y = torch.fft.irfft(y_f, n = N)[..., :L].to(x_dtype) 21 | 22 | return y 23 | 24 | class FrequencySparseFFTConv(torch.nn.Module): 25 | def __init__(self, N_partial): 26 | super().__init__() 27 | self.N_partial = N_partial 28 | 29 | def forward(self, x, k): 30 | L = x.shape[-1] 31 | N = 2 * L 32 | x_dtype = x.dtype 33 | x_f = torch.fft.rfft(x.float(), n = N) 34 | k_f = torch.fft.rfft(k, n = N) 35 | k_f[..., self.N_partial // 2:] = 0 36 | y_f = x_f * k_f 37 | y = torch.fft.irfft(y_f, n = N)[..., :L].to(x_dtype) 38 | 39 | return y -------------------------------------------------------------------------------- /rand.py: -------------------------------------------------------------------------------- 1 | from flashfftconv import FlashDepthWiseConv1d 2 | import torch.nn as nn 3 | import torch 4 | 5 | conv1d_torch = nn.Conv1d( 6 | in_channels=512*3, 7 | out_channels=512*3, 8 | kernel_size=3, 9 | groups=512*3, 10 | padding=2, 11 | dtype=torch.float32 12 | ).cuda() 13 | 14 | flash_conv1d = FlashDepthWiseConv1d( 15 | channels=512*3, 16 | kernel_size=3, 17 | padding=1, 18 | weights=conv1d_torch.weight, 19 | bias=conv1d_torch.bias, 20 | dtype=torch.float32 21 | ).cuda() 22 | 23 | x = torch.rand(1, 1536, 2048, requires_grad=True).cuda() 24 | y = torch.rand(1, 1536, 2048, requires_grad=True).cuda() 25 | 26 | out_torch = conv1d_torch(x) 27 | out_flash = flash_conv1d(x) 28 | 29 | criterion = nn.MSELoss().cuda() 30 | optimizer = torch.optim.AdamW(flash_conv1d.parameters()) 31 | scaler = torch.cuda.amp.GradScaler() 32 | 33 | with torch.autocast(device_type='cuda', dtype=torch.float16): 34 | optimizer.zero_grad() 35 | logits = flash_conv1d(x) 36 | loss = criterion(logits, y) 37 | 38 | scaler.scale(loss).backward() 39 | scaler.step(optimizer) 40 | scaler.update() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup(name='flashfftconv', 5 | version='0.0.0', 6 | description='Fast FFT algorithms for convolutions', 7 | url='https://github.com/HazyResearch/flash-fft-conv', 8 | author='Dan Fu, Hermann Kumbong', 9 | author_email='danfu@cs.stanford.edu', 10 | license='Apache 2.0', 11 | packages=['flashfftconv']) --------------------------------------------------------------------------------