├── .github └── workflows │ ├── ci.yml │ ├── ci_gpu.yml │ └── ci_tests.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── Makefile ├── README.md ├── dev ├── cpu │ └── matmul_forward.c ├── cuda │ ├── Makefile │ ├── README.md │ ├── adamw.cu │ ├── attention_backward.cu │ ├── attention_forward.cu │ ├── benchmark_on_modal.py │ ├── classifier_fused.cu │ ├── common.h │ ├── crossentropy_forward.cu │ ├── crossentropy_softmax_backward.cu │ ├── encoder_backward.cu │ ├── encoder_forward.cu │ ├── fused_residual_forward.cu │ ├── gelu_backward.cu │ ├── gelu_forward.cu │ ├── global_norm.cu │ ├── layernorm_backward.cu │ ├── layernorm_forward.cu │ ├── matmul_backward.cu │ ├── matmul_backward_bias.cu │ ├── matmul_forward.cu │ ├── nccl_all_reduce.cu │ ├── permute.cu │ ├── residual_forward.cu │ ├── softmax_forward.cu │ └── trimat_forward.cu ├── data │ ├── README.md │ ├── data_common.py │ ├── edu_fineweb.sh │ ├── fineweb.py │ ├── fineweb.sh │ ├── hellaswag.py │ ├── mmlu.py │ ├── tinyshakespeare.py │ └── tinystories.py ├── download_starter_pack.sh ├── eval │ ├── README.md │ ├── export_hf.py │ ├── run_eval.sh │ └── summarize_eval.py ├── loss_checker_ci.py ├── test │ ├── Makefile │ ├── device_file_io.cu │ ├── test_dataloader.c │ └── test_outlier_detector.c ├── unistd.h └── vislog.ipynb ├── doc └── layernorm │ ├── layernorm.c │ ├── layernorm.md │ └── layernorm.py ├── llmc ├── CMakeLists.txt ├── adamw.cuh ├── attention.cuh ├── cublas_common.h ├── cuda_common.h ├── cuda_utils.cuh ├── cudnn_att.cpp ├── cudnn_att.h ├── dataloader.h ├── encoder.cuh ├── fused_classifier.cuh ├── gelu.cuh ├── global_norm.cuh ├── layernorm.cuh ├── logger.h ├── matmul.cuh ├── mfu.h ├── outlier_detector.h ├── rand.h ├── sampler.h ├── schedulers.h ├── tokenizer.h ├── utils.h └── zero.cuh ├── llmcpp ├── CMakeLists.txt ├── README.md ├── cuda_profile_util.hpp ├── gpt.hpp ├── gpt2.hpp ├── gpt_optim.cpp ├── gpt_optim.cu ├── gpt_test.cpp ├── gpt_test.cu ├── nn.hpp ├── nn_test.cpp ├── nn_test.cu ├── optim.hpp ├── optim_test.cpp ├── tensor_types.hpp ├── tensor_util.hpp ├── test_eigen_cpu.cpp ├── test_eigen_gpu.cu ├── test_gpt2.cpp ├── train_gpt2.cpp └── train_gpt2.cu ├── profile_gpt2.cu ├── profile_gpt2cu.py ├── requirements.txt ├── scripts ├── README.md ├── multi_node │ ├── run_gpt2_124M_fs.sbatch │ ├── run_gpt2_124M_mpi.sh │ └── run_gpt2_124M_tcp.sbatch ├── pyrun_gpt2_124M.sh ├── run_gpt2_124M.sh ├── run_gpt2_1558M.sh ├── run_gpt2_350M.sh ├── run_gpt2_774M.sh └── run_gpt3_125M.sh ├── test_gpt2.c ├── test_gpt2.cu ├── test_gpt2_fp32.cu ├── train_gpt2.c ├── train_gpt2.cu ├── train_gpt2.py ├── train_gpt2_fp32.cu └── train_llama3.py /.github/workflows/ci_gpu.yml: -------------------------------------------------------------------------------- 1 | name: GPU Builds and Tests 2 | 3 | on: 4 | create: 5 | workflow_dispatch: 6 | push: 7 | branches: 8 | - master 9 | pull_request: 10 | branches: 11 | - master 12 | 13 | jobs: 14 | build-and-test-gpu: 15 | runs-on: ubicloud-gpu-standard-1-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Install OpenMP 22 | run: sudo apt-get update && sudo apt-get install -y libomp-dev 23 | 24 | - name: Install dependencies 25 | run: pip install -r requirements.txt 26 | 27 | - name: Run preprocessing 28 | run: python dev/data/tinyshakespeare.py 29 | 30 | - name: Train model 31 | run: python train_gpt2.py 32 | 33 | - name: Compile training and testing program 34 | run: make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu 35 | 36 | - name: Train model (With OpenMP) 37 | run: OMP_NUM_THREADS=8 ./train_gpt2cu 38 | 39 | - name: Train model (FP32) with gpt2_124M.bin 40 | run: | 41 | PRECISION=FP32 make train_gpt2cu 42 | ./train_gpt2cu -b 1 -t 64 -d 256 -l 0.0001 -v 200 -s 200 -a 1 -x 10 -r 0 -f 0 -e "gpt2_124M.bin" 43 | 44 | - name: Test for percent loss differential for FP32 45 | run: | 46 | PRECISION=FP32 make train_gpt2cu 47 | ./train_gpt2cu -b 1 -t 64 -d 256 -l 0.0001 -v 200 -s 200 -a 1 -x 10 -r 0 -f 0 -e "gpt2_124M.bin" > train_gpt2cu_fp32_precision.txt 48 | python dev/loss_checker_ci.py -f train_gpt2cu_fp32_precision.txt -s 20 -e 28 -a 5.0 49 | 50 | - name: Build FP32 precision 51 | run: PRECISION=FP32 make test_gpt2cu profile_gpt2cu 52 | 53 | - name: Run default 54 | run: ./test_gpt2cu 55 | 56 | - name: Run no recompute GeLU 57 | run: ./test_gpt2cu -r 0 58 | 59 | - name: Run recompute LN 60 | run: ./test_gpt2cu -r 2 61 | 62 | - name: Build BF16 precision 63 | run: PRECISION=BF16 make train_gpt2cu test_gpt2cu profile_gpt2cu 64 | 65 | - name: Run default 66 | run: ./test_gpt2cu 67 | 68 | - name: Run no recompute GeLU 69 | run: ./test_gpt2cu -r 0 70 | 71 | - name: Run no master weights 72 | run: ./test_gpt2cu -w 0 73 | 74 | - name: Run recompute LN 75 | run: ./test_gpt2cu -r 2 76 | 77 | - name: Train model fp32 (With OpenMP) 78 | run: OMP_NUM_THREADS=8 ./train_gpt2fp32cu 79 | 80 | - name: Execute testing program (With OpenMP) 81 | run: OMP_NUM_THREADS=8 ./test_gpt2cu 82 | 83 | - name: Execute testing program fp32 (With OpenMP) 84 | run: OMP_NUM_THREADS=8 ./test_gpt2fp32cu 85 | 86 | - name: Compile training and testing program without OpenMP 87 | run: NO_OMP=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu 88 | 89 | - name: Train model (No OpenMP) 90 | run: NO_OMP=1 ./train_gpt2cu 91 | 92 | - name: Train model fp32 (No OpenMP) 93 | run: NO_OMP=1 ./train_gpt2fp32cu 94 | 95 | - name: Execute testing program (No OpenMP) 96 | run: ./test_gpt2cu -b 32 97 | 98 | - name: Execute testing program fp32 (No OpenMP) 99 | run: ./test_gpt2fp32cu 100 | 101 | - name: Install cuDNN-frontend 102 | run: 103 | git clone https://github.com/NVIDIA/cudnn-frontend.git 104 | 105 | - name: Build with cuDNN 106 | run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu 107 | 108 | - name: Train model with cuDNN 109 | run: ./train_gpt2cu 110 | 111 | - name: Train model fp32 with cuDNN 112 | run: ./train_gpt2fp32cu 113 | 114 | - name: Execute testing program with cuDNN 115 | run: ./test_gpt2cu 116 | 117 | - name: Execute testing program fp32 with cuDNN 118 | run: ./test_gpt2fp32cu 119 | 120 | unit-tests-gpu: 121 | runs-on: ubicloud-gpu-standard-1-latest 122 | 123 | steps: 124 | - name: Checkout code 125 | uses: actions/checkout@v4 126 | 127 | - name: Test Device<->File IO 128 | run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io 129 | -------------------------------------------------------------------------------- /.github/workflows/ci_tests.yml: -------------------------------------------------------------------------------- 1 | name: Unit, Static and other Tests 2 | 3 | on: 4 | create: 5 | workflow_dispatch: 6 | push: 7 | branches: 8 | - master 9 | pull_request: 10 | branches: 11 | - master 12 | 13 | jobs: 14 | dataloader_test: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: test the dataloader without / with sanitize address 22 | run: | 23 | cd dev/test 24 | make PRECISION=BF16 test_dataloader 25 | ./test_dataloader 26 | make clean 27 | make PRECISION=BF16 TEST_CFLAGS="-fsanitize=address -fno-omit-frame-pointer" test_dataloader 28 | ./test_dataloader 29 | 30 | ptx_and_sass_files: 31 | runs-on: ubuntu-latest 32 | container: 33 | image: nvidia/cuda:12.4.1-devel-ubuntu22.04 34 | 35 | steps: 36 | - name: Checkout code 37 | uses: actions/checkout@v4 38 | 39 | - name: Install OpenMP and OpenMPI 40 | run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev 41 | 42 | - name: Generate ptx/sass files and upload them to persistent storage 43 | run: | 44 | mkdir -p dev/cuda/ptx_sass_logs 45 | make train_gpt2cu 46 | cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx 47 | cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass 48 | cd dev/cuda 49 | make -j all_ptx 50 | make -j all_sass 51 | cp *.ptx ptx_sass_logs/ 52 | cp *.sass ptx_sass_logs/ 53 | ls ptx_sass_logs/ 54 | 55 | - name: Generate ptx/sass files for A100 and upload them to persistent storage 56 | run: | 57 | mkdir -p dev/cuda/ptx_sass_logs_A100 58 | make train_gpt2cu GPU_COMPUTE_CAPABILITY=80 59 | cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx 60 | cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass 61 | cd dev/cuda 62 | make -j GPU_COMPUTE_CAPABILITY=80 all_ptx 63 | make -j GPU_COMPUTE_CAPABILITY=80 all_sass 64 | cp *.ptx ptx_sass_logs_A100/ 65 | cp *.sass ptx_sass_logs_A100/ 66 | ls ptx_sass_logs_A100/ 67 | 68 | - name: Generate ptx/sass files for H100 and upload them to persistent storage 69 | run: | 70 | mkdir -p dev/cuda/ptx_sass_logs_H100 71 | make train_gpt2cu GPU_COMPUTE_CAPABILITY=90 72 | cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx 73 | cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass 74 | cd dev/cuda 75 | make -j GPU_COMPUTE_CAPABILITY=90 all_ptx 76 | make -j GPU_COMPUTE_CAPABILITY=90 all_sass 77 | cp *.ptx ptx_sass_logs_H100/ 78 | cp *.sass ptx_sass_logs_H100/ 79 | ls ptx_sass_logs_H100/ 80 | 81 | - name: Upload ptx/sass files 82 | uses: actions/upload-artifact@v4 83 | with: 84 | name: ptx_sass_files 85 | path: dev/cuda/ptx_sass_logs/ 86 | retention-days: 30 # days to retain 87 | 88 | - name: Upload ptx/sass files for A100 89 | uses: actions/upload-artifact@v4 90 | with: 91 | name: ptx_sass_files_A100 92 | path: dev/cuda/ptx_sass_logs_A100/ 93 | retention-days: 30 # days to retain 94 | 95 | - name: Upload ptx/sass files for H100 96 | uses: actions/upload-artifact@v4 97 | with: 98 | name: ptx_sass_files_H100 99 | path: dev/cuda/ptx_sass_logs_H100/ 100 | retention-days: 30 # days to retain -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # dot files and such 2 | .vscode 3 | .venv 4 | 5 | # .bin files generated by Python 6 | *.bin 7 | 8 | # data directories 9 | dev/data/__pycache__/ 10 | dev/data/fineweb10B/ 11 | dev/data/hellaswag/ 12 | dev/data/mmlu/ 13 | dev/data/tinyshakespeare/ 14 | dev/data/tinystories/ 15 | 16 | # binaries 17 | test_gpt2 18 | test_gpt2cu 19 | test_gpt2fp32cu 20 | train_gpt2 21 | train_gpt2cu 22 | train_gpt2fp32cu 23 | profile_gpt2cu 24 | dev/cuda/*_forward 25 | dev/cuda/*_backward 26 | dev/cuda/classifier_fused 27 | dev/cuda/adamw 28 | dev/cuda/matmul_backward_bias 29 | dev/cuda/nccl_all_reduce 30 | dev/cuda/global_norm 31 | *.obj 32 | *.exe 33 | *.o 34 | 35 | # log files 36 | *.log 37 | 38 | # clion files 39 | .idea 40 | cmake-build-* 41 | build 42 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/abseil-cpp"] 2 | path = third_party/abseil-cpp 3 | url = https://github.com/abseil/abseil-cpp.git 4 | [submodule "third_party/eigen"] 5 | path = third_party/eigen 6 | url = https://gitlab.com/libeigen/eigen.git 7 | [submodule "third_party/googletest"] 8 | path = third_party/googletest 9 | url = https://github.com/google/googletest.git 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | project(llm.cpp LANGUAGES C CXX CUDA) 3 | 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(CMAKE_CUDA_STANDARD 17) 6 | set(BUILD_SHARED_LIBS OFF) 7 | # add_compile_options(-Ofast -march=native) 8 | # set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Ofast -march=native") 9 | # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Ofast -march=native") 10 | 11 | find_program(CCACHE_PROGRAM ccache) 12 | if (CCACHE_PROGRAM) 13 | set(CMAKE_C_COMPILER_LAUNCHER ccache) 14 | set(CMAKE_CXX_COMPILER_LAUNCHER ccache) 15 | set(CMAKE_CUDA_COMPILER_LAUNCHER ccache) 16 | endif () 17 | 18 | enable_testing() 19 | include_directories(.) 20 | 21 | # Abseil 22 | set(ABSL_PROPAGATE_CXX_STD ON) 23 | add_subdirectory(third_party/abseil-cpp) 24 | 25 | # GoogleTest 26 | add_subdirectory(third_party/googletest) 27 | 28 | # Eigen 29 | set(EIGEN3_INCLUDE_DIR third_party/eigen) 30 | add_definitions(-DEIGEN_DONT_PARALLELIZE) 31 | #add_definitions(-DEIGEN_DONT_VECTORIZE) 32 | add_definitions(-DEIGEN_USE_THREADS) 33 | include_directories(${EIGEN3_INCLUDE_DIR}) 34 | 35 | add_subdirectory(llmc) 36 | add_subdirectory(llmcpp) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm.cpp 2 | 项目 fork 自 karpathy 的 [llm.c](https://github.com/karpathy/llm.c),使用 C++(with Eigen) 来复现 GPT-2,支持 CPU/CUDA 计算。 3 | - 所有的计算部分都通过 Eigen Tensor 完成,所以同样一份代码通过简单地切换 Device 就可完成 CPU/CUDA 的计算 4 | - 这里实现的 GPT-2 与 PyTorch 版本是完全对齐的 5 | - 值得注意的是,CPU 版本比 PyTorch 快大约 20%,但是 GPU 版本比 PyTorch GPU 慢得多,主要原因是 Eigen 的 Tensor 不支持 BatchMatmul 6 | 7 | 8 | This repo is forked from karpathy's [llm.c](https://github.com/karpathy/llm.c), using C++ (with Eigen) to reproduce GPT-2. 9 | 10 | - All calculations are done through the Eigen Tensor Module, so the same code can be used for CPU/CUDA calculations by simply switching the Device. 11 | - Currently, this repo has reproduced GPT-2 and the results are completely aligned with the PyTorch version. 12 | - It is worth noting that CPU calculations are about 20% faster than PyTorch, while GPU calculations are still far behind PyTorch's GPU due to the difficulty of Eigen Tensor Module to support BatchMatmul. 13 | 14 | ## quick start (CPU) 15 | 16 | ```bash 17 | pip install -r requirements.txt 18 | python dev/data/tinyshakespeare.py 19 | python train_gpt2.py 20 | mkdir build && cd build 21 | cmake .. 22 | make train_gpt2_cpu 23 | cd ../ 24 | ./build/llmcpp/train_gpt2_cpu 25 | ``` 26 | 27 | The above lines 28 | - (1) download the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, 29 | tokenize it with the GPT-2 Tokenizer 30 | - (2) download and save the GPT-2 (124M) weights 31 | - (3) init from them in C++ and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. The output looks like this on my LMDE3 (Intel© Core™ i7-10700K CPU @ 3.80GHz × 8): 32 | 33 | ``` 34 | [GPT-2] 35 | max_seq_len: 1024 36 | vocab_size: 50257 37 | padded_vocab_size: 50304 38 | num_layers: 12 39 | num_heads: 12 40 | channels: 768 41 | num_parameters: 124475904(474 MB) 42 | train dataset num_batches: 1192 43 | val dataset num_batches: 128 44 | num_activations: 82723584(315 MB) 45 | val loss 5.325413 46 | step 0: train loss 5.356086 (took 786.515755 ms) 47 | step 1: train loss 4.300581 (took 677.340087 ms) 48 | step 2: train loss 4.623053 (took 674.843167 ms) 49 | step 3: train loss 4.599307 (took 673.189660 ms) 50 | ... (trunctated) ... 51 | step 39: train loss 3.972404 (took 749.386021 ms) 52 | val loss 4.017484 53 | generating: 54 | --- 55 | Requinetarius, 56 | Which; supreme, but 57 | Commands jest in vain for ever. 58 | 59 | <|endoftext|>Lady: 60 | No, heavens, 61 | I were not to haste 62 | To retire valorously and look nobly in the face, 63 | Before this 64 | UNHISILIUS UNDERDEINTS 65 | 66 | --- 67 | step 40: train loss 4.378605 (took 692.830391 ms) 68 | final 40 iters avg: 692.974 ms 69 | ``` 70 | 71 | ## quick start (1 GPU, fp32 only) 72 | ```bash 73 | mkdir build && cd build 74 | cmake .. 75 | make train_gpt2_gpu 76 | cd ../ 77 | ./build/llmcpp/train_gpt2_gpu 78 | ``` 79 | 80 | 81 | ## datasets 82 | 83 | The data files inside `/dev/data/(dataset).py` are responsible for downloading, tokenizing and saving the tokens to .bin files, readable easily from C. So for example when you run: 84 | 85 | ```bash 86 | python dev/data/tinyshakespeare.py 87 | ``` 88 | 89 | We download and tokenize the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. The output of this looks like this: 90 | 91 | ``` 92 | writing 32,768 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_val.bin 93 | writing 305,260 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_train.bin 94 | ``` 95 | 96 | The .bin files contain a short header (1024 bytes) and then a stream of tokens in uint16, indicating the token ids with the GPT-2 tokenizer. More datasets are available in `/dev/data`. 97 | 98 | ## test 99 | 100 | I am also attaching a simple unit test for making sure our C++ code agrees with the PyTorch code. On the CPU as an example, compile and run with: 101 | 102 | ```bash 103 | mkdir build && cd build 104 | cmake .. 105 | make test_gpt2_cpu 106 | cd ../ 107 | ./build/llmcpp/test_gpt2_cpu 108 | ``` 109 | 110 | This now loads the `gpt2_124M_debug_state.bin` file that gets written by train_gpt2.py, runs a forward pass, compares the logits and loss with the PyTorch reference implementation, then it does 10 iterations of training with Adam and makes sure the losses match PyTorch. 111 | This tests both the fp32 path and the mixed precision path. The test should pass and print `overall okay: 1`. 112 | 113 | 114 | ## license 115 | 116 | MIT 117 | -------------------------------------------------------------------------------- /dev/cuda/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for building dev/cuda kernels 2 | # Collects all the make commands in one file but each file also 3 | # has the compile and run commands in the header comments section. 4 | 5 | # Find nvcc (NVIDIA CUDA compiler) 6 | NVCC := $(shell which nvcc 2>/dev/null) 7 | ifeq ($(NVCC),) 8 | $(error nvcc not found.) 9 | endif 10 | 11 | ifneq ($(CI),true) # if not in CI, then use the GPU query 12 | ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY= 13 | GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too 14 | GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY)) 15 | endif 16 | endif 17 | 18 | # Compiler flags 19 | ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY= 20 | CFLAGS = -O3 --use_fast_math 21 | else 22 | CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] 23 | endif 24 | 25 | NVCCFLAGS = -lcublas -lcublasLt -std=c++17 26 | MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ 27 | 28 | # Default rule for our CUDA files 29 | %: %.cu 30 | $(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ 31 | 32 | # Build all targets 33 | TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute 34 | 35 | all: $(TARGETS) 36 | all_ptx: $(TARGETS:%=%.ptx) 37 | all_sass: $(TARGETS:%=%.sass) 38 | 39 | # Individual targets: forward pass 40 | attention_forward: attention_forward.cu 41 | classifier_fused: classifier_fused.cu 42 | crossentropy_forward: crossentropy_forward.cu 43 | encoder_forward: encoder_forward.cu 44 | gelu_forward: gelu_forward.cu 45 | layernorm_forward: layernorm_forward.cu 46 | fused_residual_forward: fused_residual_forward.cu 47 | residual_forward: residual_forward.cu 48 | softmax_forward: softmax_forward.cu 49 | trimat_forward: trimat_forward.cu 50 | # matmul fwd/bwd also uses OpenMP (optionally) and cuBLASLt libs 51 | matmul_forward: matmul_forward.cu 52 | $(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward 53 | 54 | # Individual targets: backward pass 55 | attention_backward: attention_backward.cu 56 | crossentropy_softmax_backward: crossentropy_softmax_backward.cu 57 | encoder_backward: encoder_backward.cu 58 | gelu_backward: gelu_backward.cu 59 | layernorm_backward: layernorm_backward.cu 60 | matmul_backward_bias: matmul_backward_bias.cu 61 | matmul_backward: matmul_backward.cu 62 | $(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward 63 | 64 | # Update kernels 65 | adamw: adamw.cu 66 | global_norm: global_norm.cu 67 | 68 | permute: permute.cu 69 | 70 | # NCCL communication kernels 71 | nccl_all_reduce: nccl_all_reduce.cu 72 | $(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce 73 | 74 | # Generate PTX using cuobjdump 75 | %.ptx: % 76 | cuobjdump --dump-ptx $< > $@ 77 | 78 | # Generate SASS using cuobjdump 79 | %.sass: % 80 | cuobjdump --dump-sass $< > $@ 81 | 82 | # Run all targets 83 | run_all: all 84 | @for target in $(TARGETS); do \ 85 | echo "\n========================================"; \ 86 | echo "Running $$target ..."; \ 87 | echo "========================================\n"; \ 88 | ./$$target; \ 89 | done 90 | 91 | # Clean up 92 | clean: 93 | rm -f $(TARGETS) *.ptx *.sass 94 | -------------------------------------------------------------------------------- /dev/cuda/README.md: -------------------------------------------------------------------------------- 1 | # dev/cuda 2 | 3 | This directory is scratch space for developing various versions of the needed CUDA kernels. Each file develops a kernel, and usually multiple versions of that kernel that could have different running times and of different code or time complexity. 4 | 5 | See the top of each file for how to compile and run the kernel. Alternatively, the commands are also all grouped in the `Makefile` in this directory for convenience. 6 | 7 | For example, we can look at the top of `layernorm_forward.cu` to build the forward pass kernels for the LayerNorm: 8 | 9 | ```bash 10 | nvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_forward.cu -o layernorm_forward 11 | ``` 12 | 13 | or simply 14 | 15 | ```bash 16 | make layernorm_forward 17 | ``` 18 | 19 | The comments at the top then document the different versions of this kernel available, usually these are in increasing complexity and decreasing running times. For example, inspecting the comments in the file on top, the most naive kernel we can then run as: 20 | 21 | ```bash 22 | ./layernorm_forward 1 23 | ``` 24 | 25 | You'll see that this first forwards the reference code on the CPU, then it runs kernel 1 on the GPU, compares the results to check for correctness, and then runs a number of configurations of this kernel (most often and most notably the block size), to time the kernel in these launch configurations. We can then run one of the faster kernels (kernel 4) instead: 26 | 27 | ```bash 28 | ./layernorm_forward 4 29 | ``` 30 | 31 | You'll see that this matches all the CPU results but runs much much faster. The typical process from here on is we copy paste the kernel that ran fastest, adjust it manually (e.g. to hardcode the best block size) and drop it into the training code file, e.g. `train_gpt2.cu`. 32 | 33 | To add a new version of a kernel, add the kernel to the corresponding file and adjust the docs. To add a new kernel, add the new file and adjust the Makefile. Run `make clean` to clean up binaries from your directory. 34 | 35 | If you do not have a GPU or is having trouble with CUDA dependencies, you can run the benchmarks on the [Modal platform](http://modal.com). For example, to run the benchmark for the attention forward pass on an A100 GPU with 80GB of memory, you can run the following command: 36 | 37 | ```bash 38 | GPU_MEM=80 modal run benchmark_on_modal.py --compile-command "nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas" --run-command "./attention_forward 1" 39 | ``` 40 | -------------------------------------------------------------------------------- /dev/cuda/benchmark_on_modal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for running benchmarks on the Modal platform. 3 | This is useful for folks who do not have access to expensive GPUs locally. 4 | Example usage for cuda kernels: 5 | GPU_MEM=80 modal run benchmark_on_modal.py \ 6 | --compile-command "nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas" \ 7 | --run-command "./attention_forward 1" 8 | OR if you want to use cuDNN etc. 9 | 10 | 11 | For training the gpt2 model with cuDNN use: 12 | GPU_MEM=80 modal run dev/cuda/benchmark_on_modal.py \ 13 | --compile-command "make train_gpt2cu USE_CUDNN=1" 14 | --run-command "./train_gpt2cu -i dev/data/tinyshakespeare/tiny_shakespeare_train.bin -j dev/data/tinyshakespeare/tiny_shakespeare_val.bin -v 250 -s 250 -g 144 -f shakespeare.log -b 4" 15 | 16 | 17 | For profiling using nsight system: 18 | GPU_MEM=80 modal run dev/cuda/benchmark_on_modal.py \ 19 | --compile-command "make train_gpt2cu USE_CUDNN=1" \ 20 | --run-command "nsys profile --cuda-graph-trace=graph --python-backtrace=cuda --cuda-memory-usage=true \ 21 | ./train_gpt2cu -i dev/data/tinyshakespeare/tiny_shakespeare_train.bin \ 22 | -j dev/data/tinyshakespeare/tiny_shakespeare_val.bin -v 250 -s 250 -g 144 -f shakespeare.log -b 4" 23 | 24 | For more nsys profiling specifics and command options, take a look at: https://docs.nvidia.com/nsight-systems/2024.2/UserGuide/ 25 | -> To profile the report using a GUI, download NVIDIA NSight System GUI version (this software can run on all OS, so you download it locally) 26 | 27 | NOTE: Currently there is a bug in the profiling using nsight system which produces a unrecognized GPU UUId error on the command line but it 28 | does not actually interfere with the model training and validation. The report (that you download) is still generated and can be viewed from Nsight Systems 29 | """ 30 | import subprocess 31 | import os 32 | import sys 33 | import datetime 34 | 35 | import modal 36 | from modal import Image, Stub 37 | GPU_NAME_TO_MODAL_CLASS_MAP = { 38 | "H100": modal.gpu.H100, 39 | "A100": modal.gpu.A100, 40 | "A10G": modal.gpu.A10G, 41 | } 42 | N_GPUS = int(os.environ.get("N_GPUS", 1)) 43 | GPU_MEM = int(os.environ.get("GPU_MEM", 40)) 44 | GPU_NAME = os.environ.get("GPU_NAME", "A100") 45 | GPU_CONFIG = GPU_NAME_TO_MODAL_CLASS_MAP[GPU_NAME](count=N_GPUS, size=str(GPU_MEM) + 'GB') 46 | 47 | APP_NAME = "llm.c benchmark run" 48 | 49 | image = ( 50 | Image.from_registry("totallyvyom/cuda-env:latest-2") 51 | .pip_install("huggingface_hub==0.20.3", "hf-transfer==0.1.5") 52 | .env( 53 | dict( 54 | HUGGINGFACE_HUB_CACHE="/pretrained", 55 | HF_HUB_ENABLE_HF_TRANSFER="1", 56 | TQDM_DISABLE="true", 57 | ) 58 | ) 59 | .run_commands( 60 | "wget -q https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-Linux-x86_64.sh", 61 | "bash cmake-3.28.1-Linux-x86_64.sh --skip-license --prefix=/usr/local", 62 | "rm cmake-3.28.1-Linux-x86_64.sh", 63 | "ln -s /usr/local/bin/cmake /usr/bin/cmake",) 64 | .run_commands( 65 | "apt-get install -y --allow-change-held-packages libcudnn8 libcudnn8-dev", 66 | "apt-get install -y openmpi-bin openmpi-doc libopenmpi-dev kmod sudo", 67 | "git clone https://github.com/NVIDIA/cudnn-frontend.git /root/cudnn-frontend", 68 | "cd /root/cudnn-frontend && mkdir build && cd build && cmake .. && make" 69 | ) 70 | .run_commands( 71 | "wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ 72 | mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \ 73 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \ 74 | add-apt-repository \"deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /\" && \ 75 | apt-get update" 76 | ).run_commands( 77 | "apt-get install -y nsight-systems-2023.3.3" 78 | ) 79 | ) 80 | 81 | stub = modal.App(APP_NAME) 82 | 83 | def execute_command(command: str): 84 | command_args = command.split(" ") 85 | print(f"{command_args = }") 86 | subprocess.run(command_args, stdout=sys.stdout, stderr=subprocess.STDOUT) 87 | 88 | @stub.function( 89 | gpu=GPU_CONFIG, 90 | image=image, 91 | allow_concurrent_inputs=4, 92 | container_idle_timeout=900, 93 | mounts=[modal.Mount.from_local_dir("./", remote_path="/root/")], 94 | # Instead of 'cuda-env' put your volume name that you create from 'modal volume create {volume-name}' 95 | # This enables the profiling reports to be saved on the volume that you can download by using: 96 | # 'modal volume get {volume-name} {/output_file_name} 97 | # For example right now, when profiling using this command "nsys profile --trace=cuda,nvtx --cuda-graph-trace=graph --python-backtrace=cuda --cuda-memory-usage=true" you would get your report 98 | # using in a directory in your volume, where the name contains the timestamp unique id. 99 | # This script will generate a "report1_{timestamp} folder in volume" 100 | # and you can download it with 'modal volume get {volume-name} report1_{timestamp} 101 | volumes={"/cuda-env": modal.Volume.from_name("cuda-env")}, 102 | ) 103 | def run_benchmark(compile_command: str, run_command: str): 104 | execute_command("pwd") 105 | execute_command("ls") 106 | execute_command(compile_command) 107 | execute_command(run_command) 108 | # Use this section if you want to profile using nsight system and install the reports on your volume to be locally downloaded 109 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 110 | 111 | execute_command("mkdir report1_" + timestamp) 112 | execute_command("mv /root/report1.nsys-rep /root/report1_" + timestamp + "/") 113 | execute_command("mv /root/report1.qdstrm /root/report1_" + timestamp + "/") 114 | execute_command("mv /root/report1_" + timestamp + "/" + " /cuda-env/") 115 | 116 | return None 117 | 118 | @stub.local_entrypoint() 119 | def inference_main(compile_command: str, run_command: str): 120 | results = run_benchmark.remote(compile_command, run_command) 121 | return results -------------------------------------------------------------------------------- /dev/cuda/crossentropy_forward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Kernels for crossentropy forward pass. 3 | 4 | Compile example: 5 | nvcc -O3 --use_fast_math -lcublas -lcublasLt crossentropy_forward.cu -o crossentropy_forward 6 | 7 | version 1 is a straight-forward port from CPU code to kernel, parallel over B,T 8 | ./crossentropy_forward 1 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | #include "common.h" 15 | 16 | // ---------------------------------------------------------------------------- 17 | // CPU code reference 18 | 19 | void crossentropy_forward_cpu(float* losses, 20 | const float* probs, const int* targets, 21 | int B, int T, int V) { 22 | // output: losses is (B,T) of the individual losses at each position 23 | // input: probs are (B,T,V) of the probabilities 24 | // input: targets is (B,T) of integers giving the correct index in logits 25 | for (int b = 0; b < B; b++) { 26 | for (int t = 0; t < T; t++) { 27 | // loss = -log(probs[target]) 28 | const float* probs_bt = probs + b * T * V + t * V; 29 | int ix = targets[b * T + t]; 30 | losses[b * T + t] = -logf(probs_bt[ix]); 31 | } 32 | } 33 | } 34 | 35 | // ---------------------------------------------------------------------------- 36 | // GPU kernels 37 | 38 | __global__ void crossentropy_forward_kernel1(float* losses, 39 | const float* probs, const int* targets, 40 | int B, int T, int V) { 41 | int i = blockIdx.x * blockDim.x + threadIdx.x; 42 | if (i < B * T) { 43 | int b = i / T; 44 | int t = i % T; 45 | const float* probs_bt = probs + b * T * V + t * V; 46 | int ix = targets[b * T + t]; 47 | losses[b * T + t] = -logf(probs_bt[ix]); 48 | } 49 | } 50 | 51 | // ---------------------------------------------------------------------------- 52 | // kernel launcher 53 | 54 | void crossentropy_forward1(float* losses, 55 | const float* probs, const int* targets, 56 | int B, int T, int V, 57 | const int block_size) { 58 | const int N = B * T; 59 | const int grid_size = ceil_div(N, block_size); 60 | crossentropy_forward_kernel1<<>>(losses, probs, targets, B, T, V); 61 | cudaCheck(cudaGetLastError()); 62 | } 63 | 64 | // kernel version dispatch 65 | void crossentropy_forward(int kernel_num, 66 | float* losses, 67 | const float* probs, const int* targets, 68 | int B, int T, int V, 69 | const int block_size) { 70 | switch (kernel_num) { 71 | case 1: 72 | crossentropy_forward1(losses, probs, targets, B, T, V, block_size); 73 | break; 74 | default: 75 | printf("Invalid kernel number\n"); 76 | exit(1); 77 | } 78 | } 79 | 80 | // ---------------------------------------------------------------------------- 81 | 82 | int main(int argc, char **argv) { 83 | srand(0); 84 | 85 | int B = 8; 86 | int T = 1024; 87 | int V = 50257; 88 | 89 | int deviceIdx = 0; 90 | cudaCheck(cudaSetDevice(deviceIdx)); 91 | 92 | // create host memory of random numbers 93 | float* out = (float*)malloc(B * T * sizeof(float)); 94 | float* probs = make_random_float_01(B * T * V); 95 | int* targets = make_random_int(B * T, V); 96 | 97 | // move to GPU 98 | float* d_out; 99 | float* d_probs; 100 | int* d_targets; 101 | cudaCheck(cudaMalloc(&d_out, B * T * sizeof(float))); 102 | cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float))); 103 | cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int))); 104 | cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice)); 105 | cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); 106 | 107 | // read kernel_num from command line 108 | int kernel_num = 1; 109 | if (argc > 1) { 110 | kernel_num = atoi(argv[1]); 111 | } 112 | printf("Using kernel %d\n", kernel_num); 113 | 114 | // first check the correctness of the kernel 115 | crossentropy_forward_cpu(out, probs, targets, B, T, V); 116 | // time the kernel at different block sizes 117 | int block_sizes[] = {32, 64, 128, 256, 512, 1024}; 118 | 119 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 120 | int block_size = block_sizes[j]; 121 | printf("Checking block size %d.\n", block_size); 122 | crossentropy_forward(kernel_num, d_out, d_probs, d_targets, B, T, V, block_size); 123 | validate_result(d_out, out, "out", B * T, 1e-5f); 124 | } 125 | 126 | printf("All results match. Starting benchmarks.\n\n"); 127 | 128 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 129 | int block_size = block_sizes[j]; 130 | 131 | int repeat_times = 1000; 132 | float elapsed_time = benchmark_kernel(repeat_times, crossentropy_forward, 133 | kernel_num, d_out, d_probs, d_targets, 134 | B, T, V, block_size); 135 | 136 | printf("block_size %4d | time %.4f ms | per token %.2f ns\n", block_size, elapsed_time, elapsed_time * 1'000'000 / (B*T)); 137 | } 138 | 139 | // free memory 140 | free(out); 141 | free(probs); 142 | free(targets); 143 | cudaCheck(cudaFree(d_out)); 144 | cudaCheck(cudaFree(d_probs)); 145 | cudaCheck(cudaFree(d_targets)); 146 | 147 | return 0; 148 | } -------------------------------------------------------------------------------- /dev/cuda/crossentropy_softmax_backward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Kernels for crossentropy forward pass. 3 | 4 | Compile example: 5 | nvcc -O3 --use_fast_math -lcublas -lcublasLt crossentropy_softmax_backward.cu -o crossentropy_softmax_backward 6 | 7 | version 1 is a straight-forward port from CPU code to kernel, parallel over B,T 8 | ./crossentropy_softmax_backward 1 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | #include "common.h" 15 | 16 | // ---------------------------------------------------------------------------- 17 | // CPU code reference 18 | 19 | void crossentropy_softmax_backward_cpu(float* dlogits, 20 | const float* dlosses, const float* probs, const int* targets, 21 | int B, int T, int V) { 22 | // backwards through both softmax and crossentropy 23 | for (int b = 0; b < B; b++) { 24 | for (int t = 0; t < T; t++) { 25 | float* dlogits_bt = dlogits + b * T * V + t * V; 26 | const float* probs_bt = probs + b * T * V + t * V; 27 | float dloss = dlosses[b * T + t]; 28 | int ix = targets[b * T + t]; 29 | for (int i = 0; i < V; i++) { 30 | float p = probs_bt[i]; 31 | float indicator = i == ix ? 1.0f : 0.0f; 32 | dlogits_bt[i] += (p - indicator) * dloss; 33 | } 34 | } 35 | } 36 | } 37 | 38 | // ---------------------------------------------------------------------------- 39 | // GPU kernels 40 | 41 | // naive kernel that just parallelizes over B,T,V 42 | __global__ void crossentropy_softmax_backward_kernel1(float* dlogits, 43 | const float* dlosses, const float* probs, const int* targets, 44 | int B, int T, int V) { 45 | int i = blockIdx.x * blockDim.x + threadIdx.x; 46 | if (i < B * T * V) { 47 | int b = i / (T * V); 48 | int t = (i / V) % T; 49 | int v = i % V; 50 | float* dlogits_bt = dlogits + b * T * V + t * V; 51 | const float* probs_bt = probs + b * T * V + t * V; 52 | float dloss = dlosses[b * T + t]; 53 | int ix = targets[b * T + t]; 54 | float p = probs_bt[v]; 55 | float indicator = v == ix ? 1.0f : 0.0f; 56 | dlogits_bt[v] += (p - indicator) * dloss; 57 | } 58 | } 59 | 60 | // ---------------------------------------------------------------------------- 61 | // kernel launcher 62 | 63 | void crossentropy_softmax_backward1(float* dlogits, 64 | const float* dlosses, const float* probs, const int* targets, 65 | int B, int T, int V, 66 | const int block_size) { 67 | const int N = B * T * V; 68 | const int grid_size = ceil_div(N, block_size); 69 | crossentropy_softmax_backward_kernel1<<>>(dlogits, dlosses, probs, targets, B, T, V); 70 | cudaCheck(cudaGetLastError()); 71 | } 72 | 73 | // kernel version dispatch 74 | void crossentropy_softmax_backward(int kernel_num, 75 | float* dlogits, 76 | const float* dlosses, const float* probs, const int* targets, 77 | int B, int T, int V, 78 | const int block_size) { 79 | switch (kernel_num) { 80 | case 1: 81 | crossentropy_softmax_backward1(dlogits, dlosses, probs, targets, B, T, V, block_size); 82 | break; 83 | default: 84 | printf("Invalid kernel number\n"); 85 | exit(1); 86 | } 87 | } 88 | 89 | // ---------------------------------------------------------------------------- 90 | 91 | int main(int argc, char **argv) { 92 | srand(0); 93 | 94 | int B = 8; 95 | int T = 1024; 96 | int V = 50257; 97 | 98 | int deviceIdx = 0; 99 | cudaCheck(cudaSetDevice(deviceIdx)); 100 | 101 | // create host memory of random numbers 102 | float* probs = make_random_float_01(B * T * V); 103 | int* targets = make_random_int(B * T, V); 104 | float* dlosses = make_random_float(B * T); 105 | float* dlogits = make_zeros_float(B * T * V); 106 | 107 | // move to GPU 108 | float* d_probs; 109 | int* d_targets; 110 | float* d_dlosses; 111 | float* d_dlogits; 112 | cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float))); 113 | cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int))); 114 | cudaCheck(cudaMalloc(&d_dlosses, B * T * sizeof(float))); 115 | cudaCheck(cudaMalloc(&d_dlogits, B * T * V * sizeof(float))); 116 | cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice)); 117 | cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); 118 | cudaCheck(cudaMemcpy(d_dlosses, dlosses, B * T * sizeof(float), cudaMemcpyHostToDevice)); 119 | 120 | // read kernel_num from command line 121 | int kernel_num = 1; 122 | if (argc > 1) { 123 | kernel_num = atoi(argv[1]); 124 | } 125 | printf("Using kernel %d\n", kernel_num); 126 | 127 | // first check the correctness of the kernel 128 | crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V); 129 | 130 | // time the kernel at different block sizes 131 | int block_sizes[] = {32, 64, 128, 256, 512, 1024}; 132 | 133 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 134 | int block_size = block_sizes[j]; 135 | cudaCheck(cudaMemset(d_dlogits, 0, B * T * V * sizeof(float))); 136 | printf("Checking block size %d.\n", block_size); 137 | crossentropy_softmax_backward(kernel_num, d_dlogits, d_dlosses, d_probs, d_targets, B, T, V, block_size); 138 | validate_result(d_dlogits, dlogits, "dlogits", B * T * V, 1e-5f); 139 | } 140 | 141 | printf("All results match. Starting benchmarks.\n\n"); 142 | 143 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 144 | int block_size = block_sizes[j]; 145 | 146 | int repeat_times = 100; 147 | float elapsed_time = benchmark_kernel(repeat_times, crossentropy_softmax_backward, 148 | kernel_num, d_dlogits, d_dlosses, d_probs, d_targets, 149 | B, T, V, block_size); 150 | 151 | printf("block_size %4d | time %.4f ms | per token %.2f µs\n", block_size, elapsed_time, elapsed_time * 1'000 / (B*T)); 152 | } 153 | 154 | // free memory 155 | free(probs); 156 | free(targets); 157 | free(dlosses); 158 | free(dlogits); 159 | cudaCheck(cudaFree(d_probs)); 160 | cudaCheck(cudaFree(d_targets)); 161 | cudaCheck(cudaFree(d_dlosses)); 162 | cudaCheck(cudaFree(d_dlogits)); 163 | 164 | return 0; 165 | } -------------------------------------------------------------------------------- /dev/cuda/gelu_forward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Kernels for gelu forward pass. 3 | 4 | Compile example: 5 | nvcc -O3 --use_fast_math -lcublas -lcublasLt gelu_forward.cu -o gelu_forward 6 | 7 | If encountering "error: identifier "M_PI" is undefined", add the following lines to the top of the file: 8 | 9 | #define _USE_MATH_DEFINES 10 | #include OR #include 11 | 12 | version 1 is naive CPU port 13 | ./gelu_forward 1 14 | 15 | version 2 is bfloat16 with the Packed128 data structure 16 | ./gelu_forward 2 17 | */ 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #define ENABLE_BF16 24 | #include "common.h" 25 | 26 | // ---------------------------------------------------------------------------- 27 | // CPU code reference 28 | 29 | #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) 30 | 31 | void gelu_forward_cpu(float* out, const float* inp, int N) { 32 | for (int i = 0; i < N; i++) { 33 | float x = inp[i]; 34 | float cube = 0.044715f * x * x * x; 35 | out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube))); 36 | } 37 | } 38 | 39 | // ---------------------------------------------------------------------------- 40 | // GPU kernels 41 | 42 | // elementwise ops are nice and ez 43 | __global__ void gelu_forward_kernel1(floatX* out, const floatX* inp, int N) { 44 | int i = blockIdx.x * blockDim.x + threadIdx.x; 45 | if (i < N) { 46 | float xi = inp[i]; 47 | float cube = 0.044715f * xi * xi * xi; 48 | out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))); 49 | } 50 | } 51 | 52 | // elementwise ops are nice and ez 53 | __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) { 54 | int i = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; 55 | if (i < N) { 56 | x128 packed_out; 57 | x128 packed_inp = load128cs(inp + i); // load and do not keep in cache 58 | for(int k = 0; k < packed_inp.size; ++k) { 59 | float xi = (float)packed_inp[k]; 60 | float cube = 0.044715f * xi * xi * xi; 61 | packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)))); 62 | } 63 | // store instead of storecs (without cache streaming) in case it is useful for the 64 | // data to be in the cache for the next operation after this GeLU 65 | store128(out + i, packed_out); 66 | } 67 | } 68 | 69 | // ---------------------------------------------------------------------------- 70 | // kernel launcher 71 | 72 | void gelu_forward1(floatX* out, const floatX* inp, int N, const int block_size) { 73 | const int grid_size = ceil_div(N, block_size); 74 | gelu_forward_kernel1<<>>(out, inp, N); 75 | cudaCheck(cudaGetLastError()); 76 | } 77 | 78 | void gelu_forward2(floatX* out, const floatX* inp, int N, const int block_size) { 79 | const int grid_size = ceil_div(N, block_size * x128::size); 80 | gelu_forward_kernel2<<>>(out, inp, N); 81 | cudaCheck(cudaGetLastError()); 82 | } 83 | 84 | // kernel version dispatch 85 | void gelu_forward(int kernel_num, 86 | floatX* out, 87 | const floatX* inp, 88 | int B, int T, int C, 89 | int block_size) { 90 | switch (kernel_num) { 91 | case 1: 92 | gelu_forward1(out, inp, B * T * C, block_size); 93 | break; 94 | case 2: 95 | gelu_forward2(out, inp, B * T * C, block_size); 96 | break; 97 | default: 98 | printf("Invalid kernel number\n"); 99 | exit(1); 100 | } 101 | } 102 | 103 | // ---------------------------------------------------------------------------- 104 | 105 | int main(int argc, const char **argv) { 106 | setup_main(); 107 | 108 | int B = 8; 109 | int T = 1024; 110 | int C = 768; 111 | 112 | // create host memory of random numbers 113 | float* out = (float*)malloc(B * T * C * sizeof(float)); 114 | float* inp = make_random_float(B * T * C); 115 | 116 | // read kernel_num from command line 117 | int kernel_num = 1; 118 | if (argc > 1) { 119 | kernel_num = atoi(argv[1]); 120 | } 121 | printf("Using kernel %d\n", kernel_num); 122 | 123 | // first check the correctness of the kernel 124 | gelu_forward_cpu(out, inp, B * T * C); 125 | 126 | // move to GPU 127 | floatX* d_out; 128 | floatX* d_inp; 129 | cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); 130 | cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); 131 | cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); 132 | 133 | // time the kernel at different block sizes 134 | int block_sizes[] = {32, 64, 128, 256, 512, 1024}; 135 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 136 | int block_size = block_sizes[j]; 137 | printf("Checking block size %d.\n", block_size); 138 | gelu_forward(kernel_num, d_out, d_inp, B, T, C, block_size); 139 | #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) 140 | float tol = 1e-5; 141 | #else 142 | float tol = 1e-2f; 143 | #endif 144 | validate_result(d_out, out, "out", B * T * C, tol); 145 | } 146 | 147 | printf("All results match. Starting benchmarks.\n\n"); 148 | 149 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 150 | int block_size = block_sizes[j]; 151 | 152 | int repeat_times = 1000; 153 | 154 | float elapsed_time = benchmark_kernel(repeat_times, gelu_forward, 155 | kernel_num, d_out, d_inp, 156 | B, T, C, block_size); 157 | 158 | // napkin math: estimate the memory bandwidth achieved 159 | // for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each 160 | // and e.g. A100 40GB PCIe is advertised at 1,555GB/s 161 | long memory_ops = B * T * C * 2 * (int)sizeof(floatX); 162 | float memory_bandwidth = memory_ops / elapsed_time / 1e6; 163 | 164 | printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); 165 | } 166 | 167 | // free memory 168 | free(out); 169 | free(inp); 170 | 171 | cudaCheck(cudaFree(d_out)); 172 | cudaCheck(cudaFree(d_inp)); 173 | return 0; 174 | } -------------------------------------------------------------------------------- /dev/cuda/residual_forward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Kernels for residual forward pass. 3 | 4 | Compile example: 5 | nvcc -O3 --use_fast_math -lcublas -lcublasLt residual_forward.cu -o residual_forward 6 | 7 | version 1 is naive port from CPU code to kernel 8 | ./residual_forward 1 9 | version 2 packs input into 128 bit memory reads 10 | ./residual_forward 2 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #define ENABLE_BF16 18 | #include "common.h" 19 | 20 | // ---------------------------------------------------------------------------- 21 | // CPU code reference lol 22 | 23 | void residual_forward_cpu(float* out, const float* inp1, const float* inp2, int N) { 24 | for (int i = 0; i < N; i++) { 25 | out[i] = inp1[i] + inp2[i]; 26 | } 27 | } 28 | 29 | // ---------------------------------------------------------------------------- 30 | // GPU kernels 31 | 32 | // elementwise ops are nice and ez 33 | __global__ void residual_forward_kernel1(floatX* out, const floatX* inp1, const floatX* inp2, int N) { 34 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 35 | if (idx < N) { 36 | out[idx] = (floatX)((float)inp1[idx] + (float)inp2[idx]); 37 | } 38 | } 39 | 40 | __global__ void residual_forward_kernel2(floatX* out, const floatX* inp1, const floatX* inp2, int N) { 41 | int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; 42 | if (idx < N) { 43 | x128 packed_out; 44 | x128 packed_inp1 = load128cs(inp1 + idx); 45 | x128 packed_inp2 = load128cs(inp2 + idx); 46 | for (int k = 0; k < packed_inp1.size; ++k) 47 | { 48 | packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]); 49 | } 50 | store128(out + idx, packed_out); 51 | } 52 | } 53 | 54 | // ---------------------------------------------------------------------------- 55 | // kernel launcher 56 | 57 | void residual_forward1(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) { 58 | const int grid_size = ceil_div(N, block_size); 59 | residual_forward_kernel1<<>>(out, inp1, inp2, N); 60 | cudaCheck(cudaGetLastError()); 61 | } 62 | 63 | void residual_forward2(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) { 64 | const int grid_size = ceil_div(N, (int)(block_size * x128::size)); 65 | residual_forward_kernel2<<>>(out, inp1, inp2, N); 66 | cudaCheck(cudaGetLastError()); 67 | } 68 | 69 | // kernel version dispatch 70 | void residual_forward(int kernel_num, 71 | floatX* out, 72 | const floatX* inp1, 73 | const floatX* inp2, 74 | int N, 75 | int block_size) { 76 | switch (kernel_num) { 77 | case 1: 78 | residual_forward1(out, inp1, inp2, N, block_size); 79 | break; 80 | case 2: 81 | residual_forward2(out, inp1, inp2, N, block_size); 82 | break; 83 | default: 84 | printf("Invalid kernel number\n"); 85 | exit(1); 86 | } 87 | } 88 | 89 | // ---------------------------------------------------------------------------- 90 | 91 | int main(int argc, char **argv) { 92 | setup_main(); 93 | 94 | int B = 8; 95 | int T = 1024; 96 | int C = 768; 97 | 98 | // create host memory of random numbers 99 | float* out = (float*)malloc(B * T * C * sizeof(float)); 100 | float* inp1 = make_random_float(B * T * C); 101 | float* inp2 = make_random_float(B * T * C); 102 | 103 | // move to GPU 104 | floatX* d_out; 105 | floatX* d_inp1; 106 | floatX* d_inp2; 107 | cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); 108 | cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX))); 109 | cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX))); 110 | cudaCheck(memcpy_convert(d_inp1, inp1, B * T * C)); 111 | cudaCheck(memcpy_convert(d_inp2, inp2, B * T * C)); 112 | 113 | // read kernel_num from command line 114 | int kernel_num = 1; 115 | if (argc > 1) { 116 | kernel_num = atoi(argv[1]); 117 | } 118 | printf("Using kernel %d\n", kernel_num); 119 | 120 | // first check the correctness of the kernel 121 | residual_forward_cpu(out, inp1, inp2, B * T * C); 122 | 123 | 124 | // time the kernel at different block sizes 125 | int block_sizes[] = {32, 64, 128, 256, 512, 1024}; 126 | 127 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 128 | int block_size = block_sizes[j]; 129 | printf("Checking block size %d.\n", block_size); 130 | residual_forward(kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size); 131 | #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) 132 | float tol = 1e-5; 133 | #else 134 | float tol = 1e-2f; 135 | #endif 136 | validate_result(d_out, out, "out", B * T * C, tol); 137 | } 138 | 139 | printf("All results match. Starting benchmarks.\n\n"); 140 | 141 | for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { 142 | int block_size = block_sizes[j]; 143 | 144 | int repeat_times = 1000; 145 | float elapsed_time = benchmark_kernel(repeat_times, residual_forward, 146 | kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size 147 | ); 148 | 149 | // napkin math: estimate the memory bandwidth achieved 150 | // for each (B,T,C) output element, we do 2 read and 1 write, 4 bytes each 151 | // and e.g. A100 40GB PCIe is advertised at 1,555GB/s 152 | long memory_ops = B * T * C * 3 * 4; 153 | float memory_bandwidth = memory_ops / elapsed_time / 1e6; 154 | 155 | printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); 156 | } 157 | 158 | // free memory 159 | free(out); 160 | free(inp1); 161 | free(inp2); 162 | cudaCheck(cudaFree(d_out)); 163 | cudaCheck(cudaFree(d_inp1)); 164 | cudaCheck(cudaFree(d_inp2)); 165 | 166 | return 0; 167 | } 168 | -------------------------------------------------------------------------------- /dev/data/README.md: -------------------------------------------------------------------------------- 1 | # dev/data organization 2 | 3 | The idea is that each dataset has a .py file here in the root of `dev/data`, and each dataset then creates a directory here, and writes and caches anything inside that directory. So for example: 4 | 5 | - running `python tinystories.py` will create a directory `tinystories` with its .bin files inside it 6 | - running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it 7 | 8 | And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these. 9 | -------------------------------------------------------------------------------- /dev/data/data_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities for the datasets 3 | """ 4 | 5 | import requests 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | 10 | def download_file(url: str, fname: str, chunk_size=1024): 11 | """Helper function to download a file from a given url""" 12 | resp = requests.get(url, stream=True) 13 | total = int(resp.headers.get("content-length", 0)) 14 | with open(fname, "wb") as file, tqdm( 15 | desc=fname, 16 | total=total, 17 | unit="iB", 18 | unit_scale=True, 19 | unit_divisor=1024, 20 | ) as bar: 21 | for data in resp.iter_content(chunk_size=chunk_size): 22 | size = file.write(data) 23 | bar.update(size) 24 | 25 | 26 | def write_datafile(filename, toks): 27 | """ 28 | Saves token data as a .bin file, for reading in C. 29 | - First comes a header with 256 int32s 30 | - The tokens follow, each as a uint16 31 | """ 32 | assert len(toks) < 2**31, "token count too large" # ~2.1B tokens 33 | # construct the header 34 | header = np.zeros(256, dtype=np.int32) 35 | header[0] = 20240520 # magic 36 | header[1] = 1 # version 37 | header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16) 38 | # construct the tokens numpy array, if not already 39 | if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16: 40 | # validate that no token exceeds a uint16 41 | maxtok = 2**16 42 | assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16" 43 | toks_np = np.array(toks, dtype=np.uint16) 44 | else: 45 | toks_np = toks 46 | # write to file 47 | print(f"writing {len(toks):,} tokens to {filename}") 48 | with open(filename, "wb") as f: 49 | f.write(header.tobytes()) 50 | f.write(toks_np.tobytes()) 51 | 52 | def write_evalfile(filename, datas): 53 | """ 54 | Saves eval data as a .bin file, for reading in C. 55 | Used for multiple-choice style evals, e.g. HellaSwag and MMLU 56 | - First comes a header with 256 int32s 57 | - The examples follow, each example is a stream of uint16_t: 58 | - delimiter of 2**16-1, i.e. 65,535 59 | - , bytes encoding this example, allowing efficient skip to next 60 | - , the index of the example in the dataset 61 | -