├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── docker └── Dockerfile ├── docker_launch.sh ├── megatron ├── __init__.py ├── arguments.py ├── checkpointing.py ├── config.py ├── fused_kernels │ ├── __init__.py │ ├── megatron_fused_kernels │ │ ├── __init__.py │ │ ├── compat.h │ │ ├── fused_weight_gradient_dense.cpp │ │ ├── fused_weight_gradient_dense_cuda.cu │ │ ├── helpers.cpp │ │ ├── layer_norm_cuda.cpp │ │ ├── layer_norm_cuda_kernel.cu │ │ ├── load_cpp_extensions.py │ │ ├── load_kernel.py │ │ ├── scaled_masked_softmax.cpp │ │ ├── scaled_masked_softmax.h │ │ ├── scaled_masked_softmax_cuda.cu │ │ ├── scaled_softmax.cpp │ │ ├── scaled_softmax_cuda.cu │ │ ├── scaled_upper_triang_masked_softmax.cpp │ │ ├── scaled_upper_triang_masked_softmax.h │ │ ├── scaled_upper_triang_masked_softmax_cuda.cu │ │ ├── tests │ │ │ └── __init__.py │ │ └── type_shim.h │ └── setup.py ├── global_vars.py ├── initialize.py ├── memory.py ├── microbatches.py ├── model │ ├── __init__.py │ ├── distributed.py │ ├── enums.py │ ├── fused_bias_gelu.py │ ├── fused_bias_sqrelu.py │ ├── fused_layer_norm.py │ ├── fused_softmax.py │ ├── gpt_model.py │ ├── language_model.py │ ├── module.py │ ├── positional_embeddings.py │ ├── transformer.py │ └── utils.py ├── mpu │ ├── __init__.py │ ├── communication.py │ ├── cross_entropy.py │ ├── cross_entropy_parallel.py │ ├── initialize.py │ ├── layers.py │ ├── mappings.py │ ├── random.py │ └── utils.py ├── text_generation │ ├── __init__.py │ ├── api.py │ ├── forward_step.py │ ├── generation.py │ ├── inference_params.py │ ├── sampling.py │ └── tokenization.py ├── text_generation_server.py ├── tokenizer │ ├── __init__.py │ └── tokenizer.py ├── training.py └── utils.py ├── prompts.md ├── requirements.txt ├── run_text_generation_server.py ├── run_text_generation_server.sh └── tool_use ├── __init__.py ├── experiment_pipeline └── __init__.py └── megatron └── megatron_utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | 8b_chat_model_release.tar 2 | 8b_base_model_release.tar 3 | 8b_chat_model_release/ 4 | 8b_base_model_release/ 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.env 2 | 3 | .ruff_cache 4 | .DS_Store 5 | 6 | # Byte-compiled / optimized / DLL files 7 | *__pycache__* 8 | .mypy_cache/ 9 | *.py[cod] 10 | *$py.class 11 | lm_cache 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | .next/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | env 37 | MANIFEST 38 | venv/ 39 | 40 | # PyCharm 41 | .idea/ 42 | 43 | # VSCode 44 | .vscode/ 45 | 46 | # Ignore swapfiles 47 | .*.swp 48 | .\#* 49 | 50 | # virtualenv 51 | .python-version 52 | 53 | # VSCode extension 54 | *.vsix 55 | 56 | # Environment variables 57 | .env 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Persimmon-8B User Guide 2 | ========== 3 | This repo contains inference code for [Persimmon-8B](https://www.adept.ai/blog/persimmon-8b), the new LLM from Adept. 4 | 5 | Downloading the Checkpoint 6 | -------- 7 | 8 | The model checkpoints are stored on our public OCI bucket and can be downloaded using `wget`. 9 | The base model is not fine-tuned and is released under an Apache 2.0 license. 10 | The chat model is fine-tuned and is released under a CC-BY-NC 4.0 license. 11 | 12 | Base: 13 | https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_base_model_release.tar 14 | md5sum: cd0320cba9efad9ccd18e9ec4d16ae1b 15 | 16 | Chat: 17 | https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_chat_model_release.tar 18 | md5sum: 663aeace07269c44e90f4e8bcd07f32a 19 | 20 | Untar the model into its own directory via `tar -xvf 8b_base_model_release.tar` or `tar -xvf 8b_chat_model_release.tar` 21 | 22 | The scripts are set up to expect the model folder to be placed within the code directory, but you can place it elsewhere and modify the scripts accordingly. 23 | 24 | Building Docker 25 | ----------- 26 | 27 | Build the docker that will include all the necessary dependencies (and then some!) using the included Dockerfile: 28 | 29 | ``` 30 | docker build -f docker/Dockerfile -t 'adeptdocker' . 31 | ``` 32 | 33 | Running Docker 34 | ---------- 35 | Ensure that the variable `MODEL_DIR` in `run_text_generation_server.sh` is set to the location of the model directory. By default it is set to `MODEL_DIR=8b_chat_model_release`, which is the default name for the chat model. (For the base model, change this line to `MODEL_DIR=8b_base_model_release`.) 36 | 37 | Running `sh docker_launch.sh` will start a model server that you can query via: 38 | 39 | ``` 40 | curl '
/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts": ["human: Hello, how are you?\n\nadept:"], "tokens_to_generate": 128, "top_p": 0.9, "random_seed": 1234, "logprobs": false}' 41 | ``` 42 | 43 | 44 | Notes 45 | ----- 46 | 47 | * The chat model is fine-tuned to expect inputs of the form: `human: {prompt}\n\nadept:`[^1]. To ensure best performance from this model, please use this format! You can see an example of this in the curl command above. To automatically wrap single-turn input prompts with this structure, you can modify the definition of `megatron/text_generation/api.py::generate_and_post_process` so that the default value for the argument `process_prompts_for_chat` is set to `True`. 48 | * We are releasing the model with tensor parallelism of 1. In this configuration, the model requires an 80GB GPU to run naively. 49 | It should be possible to fit the model on a 40GB card by removing the unused embeddings and reducing the maximum sequence length 50 | (at the top of `run_text_generation_server.py`). 51 | Quantization to 8-bit or lower would make also it fit with plenty of room to spare. 52 | * We included the `.vocab` file so you can browse the vocabulary in plain text - this file is otherwise unused. 53 | 54 | 55 | Citation 56 | -------- 57 | 58 | If you use this model in your work, please use the following BibTeX citation: 59 | ```bibtex 60 | @misc{persimmon-8b, 61 | author = {Elsen, Erich and Odena, Augustus and Nye, Maxwell and Ta\c{s}\i{}rlar, Sa\u{g}nak and Dao, Tri and Hawthorne, Curtis and Moparthi, Deepak and Somani, Arushi}, 62 | title = {Releasing {Persimmon-8B}}, 63 | url = {https://www.adept.ai/blog/persimmon-8b}, 64 | year = {2023} 65 | } 66 | ``` 67 | 68 | 69 | [^1]: Subsequent inputs should have the form `human: {prompt}\n\nadept: {output}\n\nhuman: {follow_up}\n\nadept:` 70 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM nvcr.io/nvidia/pytorch:23.04-py3 16 | 17 | RUN apt-get update -y && \ 18 | DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata build-essential devscripts debhelper fakeroot 19 | 20 | 21 | WORKDIR /opt/hpcx 22 | RUN rm -rf ucx && \ 23 | git clone --recursive https://github.com/openucx/ucx.git && \ 24 | pushd ucx && \ 25 | git fetch --all --tags && \ 26 | git checkout tags/v1.14.1 && \ 27 | ./autogen.sh && \ 28 | mkdir UCX_BUILD && \ 29 | ./contrib/configure-release-mt --prefix=/opt/hpcx/ucx/UCX_BUILD/ --with-cuda=/usr/local/cuda/ && \ 30 | make -j && \ 31 | make install && \ 32 | popd 33 | 34 | RUN rm -rf ucc && \ 35 | git clone --recursive https://github.com/openucx/ucc.git && \ 36 | pushd ucc && \ 37 | git fetch --all --tags && \ 38 | git checkout tags/v1.2.0 && \ 39 | ./autogen.sh && \ 40 | mkdir UCC_BUILD && \ 41 | ./configure --prefix=/opt/hpcx/ucc/UCC_BUILD --with-ucx=/opt/hpcx/ucx/UCX_BUILD/ --with-nccl=/usr --with-cuda=/usr/local/cuda/ --with-nvcc-gencode="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_90,code=compute_90"&& \ 42 | make -j && \ 43 | make install && \ 44 | popd 45 | 46 | ENV LD_LIBRARY_PATH=/opt/hpcx/ucx/UCX_BUILD/lib:/opt/hpcx/ucc/UCC_BUILD/lib:$LD_LIBRARY_PATH 47 | ENV UCX_HOME=/opt/hpcx/ucx/UCX_BUILD/ 48 | ENV UCC_HOME=/opt/hpcx/ucc/UCC_BUILD/ 49 | ENV WITH_CUDA=/usr/local/cuda 50 | 51 | WORKDIR /workspace 52 | 53 | # Install FlashAttention 54 | RUN pip install flash-attn==2.0.0.post1 55 | 56 | # Install rotary embedding, cross entropy, and FT't attention kernel 57 | # [2022-11-08] TD: Check out a specific version to make build reproducible 58 | RUN git clone https://github.com/HazyResearch/flash-attention \ 59 | && cd flash-attention && git checkout b8020d73c9e068665586989883083a4a5429a443 \ 60 | && cd csrc/rotary && pip install . && cd ../../ \ 61 | && cd csrc/xentropy && pip install . && cd ../../ \ 62 | && cd csrc/ft_attention && pip install . && cd ../../ \ 63 | && cd .. && rm -rf flash-attention 64 | 65 | 66 | COPY megatron/fused_kernels/ megatron/fused_kernels/ 67 | ENV PATH="${PATH}:/opt/hpcx/ompi/bin" LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/opt/hpcx/ompi/lib" 68 | RUN ldconfig /opt/hpcx/ompi/bin 69 | RUN TORCH_CUDA_ARCH_LIST="" cd megatron/fused_kernels/ && python setup.py install sdist && cd ../.. 70 | RUN cd /usr/local/lib/python3.8/dist-packages/megatron_fused_kernels-0.0.0-py3.8-linux-x86_64.egg/; mv *.so megatron_fused_kernels/; 71 | RUN rm -rf megatron/fused_kernels 72 | 73 | # Install apt-get dependencies for pip requirements. 74 | ENV DEBIAN_FRONTEND=noninteractive 75 | RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - && \ 76 | apt-get update -y && apt-get install -y nodejs iputils-ping \ 77 | wget ca-certificates tzdata zip locales \ 78 | && locale-gen en_US en_US.UTF-8 en_US.utf8 && dpkg-reconfigure --frontend noninteractive locales \ 79 | && npm install pm2 -g \ 80 | && pip install --upgrade pip setuptools \ 81 | && rm -rf /var/lib/apt/lists/* 82 | 83 | # Change locale for click 84 | ENV LANG=C.UTF-8 LANGUAGE=en_US.en LC_ALL=C.UTF-8 85 | 86 | # Install requirements & cleanup. 87 | COPY requirements.txt requirements.txt 88 | RUN pip install --upgrade pip setuptools && \ 89 | pip install -r requirements.txt && rm requirements.txt 90 | -------------------------------------------------------------------------------- /docker_launch.sh: -------------------------------------------------------------------------------- 1 | sudo docker run \ 2 | --rm \ 3 | --gpus all \ 4 | --device=/dev/infiniband \ 5 | --ipc=host \ 6 | --ulimit memlock=-1 \ 7 | --ulimit stack=67108864 \ 8 | --env PYTHONPATH="." \ 9 | -v $(pwd)/:/Adept_Inference/ \ 10 | --network=host \ 11 | --name=adept_inference \ 12 | adeptdocker \ 13 | bash -c "cd /Adept_Inference/; sh run_text_generation_server.sh"; 14 | -------------------------------------------------------------------------------- /megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | import torch 19 | 20 | from .global_vars import get_args 21 | from .global_vars import get_current_global_batch_size 22 | from .global_vars import get_num_microbatches 23 | from .global_vars import update_num_microbatches 24 | from .global_vars import get_tokenizer 25 | from .global_vars import get_tensorboard_writer 26 | from .global_vars import get_adlr_autoresume 27 | from .global_vars import get_timers 28 | from .global_vars import get_global_memory_buffer 29 | 30 | from .utils import print_rank_0, is_last_rank, print_rank_last 31 | from .initialize import initialize_megatron 32 | -------------------------------------------------------------------------------- /megatron/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Config objects. 17 | 18 | Provides a bridge between the flat `args` namespace and having different 19 | configs at any layer of the model. 20 | 21 | Config objects can be initialized manually, but to match current `args` 22 | behavior, they can be initialized using the `from_args` method. 23 | """ 24 | 25 | import argparse 26 | import dataclasses 27 | from typing import Type, TypeVar 28 | 29 | T = TypeVar("T", bound="Config") 30 | 31 | DISABLE_PREFIX = lambda: dataclasses.field(metadata={"no_from_args_prefix": True}) 32 | 33 | 34 | @dataclasses.dataclass 35 | class Config: 36 | @classmethod 37 | def from_args(cls: Type[T], args: argparse.Namespace, arg_prefix: str = None) -> T: 38 | """Initialize a Config object using an `args` object. 39 | 40 | Args: 41 | args: Parsed arguments from `megatron.get_args()`. The field names 42 | in the Config object will be populated by `args` entries of the 43 | same name. 44 | arg_prefix: If provided, will use this prefix when finding the 45 | field in `args`. For example, if the field name is `hidden_size` 46 | and arg_prefix is `blah_`, the field will be populated by the arg 47 | `blah_hidden_size`. 48 | However, if the field has the metadata field `no_from_args_prefix` 49 | set to True, the prefix will not be added. This is to support 50 | fields that are globally applicable. 51 | """ 52 | field_values = {} 53 | for field in dataclasses.fields(cls): 54 | if issubclass(field.type, Config): 55 | field_values[field.name] = field.type.from_args(args, arg_prefix) 56 | else: 57 | if arg_prefix: 58 | prefixed_field_name = arg_prefix + field.name 59 | if arg_prefix and not ( 60 | "no_from_args_prefix" in field.metadata and field.metadata["no_from_args_prefix"] 61 | ): 62 | arg_name = prefixed_field_name 63 | else: 64 | arg_name = field.name 65 | if prefixed_field_name in vars(args): 66 | raise ValueError( 67 | f"{field.name} has no_from_args_prefix set, but {prefixed_field_name} exists in args. This is likely a mistake." 68 | ) 69 | else: 70 | arg_name = field.name 71 | 72 | if arg_name not in vars(args): 73 | raise ValueError( 74 | f"{arg_name} not found in args when attempting to construct {cls.__name__} from args." 75 | ) 76 | field_values[field.name] = vars(args)[arg_name] 77 | return cls(**field_values) 78 | -------------------------------------------------------------------------------- /megatron/fused_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import signal 4 | import multiprocessing 5 | import time 6 | import importlib 7 | import sys 8 | from torch.utils import cpp_extension 9 | from megatron.fused_kernels.megatron_fused_kernels.load_kernel import ( 10 | KernelBuildInfo, 11 | ALL_BUILD_KERNELS, 12 | JIT_BUILD_KERNELS, 13 | ) 14 | 15 | 16 | class CompilationTimeoutError(Exception): 17 | pass 18 | 19 | 20 | def _load_kernel(kernel_build_info): 21 | kernel_build_info.load() 22 | 23 | 24 | def load(force_build_fused_kernels=False): 25 | """Load fused kernels.""" 26 | if force_build_fused_kernels: 27 | for kernel_build_info in ALL_BUILD_KERNELS.values(): 28 | _load_kernel(kernel_build_info) 29 | else: 30 | # Just comile the kernels that we need to JIT compile. 31 | for kernel_name in JIT_BUILD_KERNELS: 32 | kernel_build_info = ALL_BUILD_KERNELS[kernel_name] 33 | _load_kernel(kernel_build_info) 34 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/megatron/fused_kernels/megatron_fused_kernels/__init__.py -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/compat.h: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | /*This code is copied fron NVIDIA apex: 18 | * https://github.com/NVIDIA/apex 19 | * with minor changes. */ 20 | 21 | 22 | 23 | #ifndef TORCH_CHECK 24 | #define TORCH_CHECK AT_CHECK 25 | #endif 26 | 27 | #ifdef VERSION_GE_1_3 28 | #define DATA_PTR data_ptr 29 | #else 30 | #define DATA_PTR data 31 | #endif 32 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/fused_weight_gradient_dense.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "type_shim.h" 10 | 11 | 12 | template 13 | int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 14 | 15 | void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) { 16 | at::Tensor input_2d, d_output_2d; 17 | // input tensor: collapse to the first dim 18 | auto in_sizes = input.sizes(); 19 | if (input.dim() > 2) { 20 | input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); 21 | } else { 22 | input_2d = input; 23 | } 24 | // d_output tensor: collapse to the first dim 25 | auto d_out_sizes = d_output.sizes(); 26 | if (d_output.dim() > 2) { 27 | d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); 28 | } else { 29 | d_output_2d = d_output; 30 | } 31 | 32 | int hidden_dim = input_2d.size(0); 33 | int in_dim = input_2d.size(1); 34 | int out_dim = d_weight.size(0); 35 | 36 | DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32", 37 | int result = wgrad_gemm_accum_fp32_cuda( 38 | input_2d.data_ptr(), 39 | d_output_2d.data_ptr(), 40 | d_weight.data_ptr(), 41 | in_dim, 42 | hidden_dim, 43 | out_dim); 44 | ); 45 | } 46 | 47 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 48 | m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32"); 49 | } 50 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/fused_weight_gradient_dense_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | /* Includes, cuda */ 12 | #include 13 | #include 14 | 15 | 16 | // BF16 Tensor core wrapper around cublas GEMMEx 17 | cublasStatus_t gemmex_wrapper( 18 | cublasHandle_t handle, 19 | cublasOperation_t transa, 20 | cublasOperation_t transb, 21 | int m, 22 | int n, 23 | int k, 24 | const float* alpha, 25 | at::BFloat16* A, 26 | int lda, 27 | at::BFloat16* B, 28 | int ldb, 29 | const float* beta, 30 | float* C, 31 | int ldc) { 32 | return cublasGemmEx( 33 | handle, 34 | transa, 35 | transb, 36 | m, 37 | n, 38 | k, 39 | alpha, 40 | A, 41 | CUDA_R_16BF, 42 | lda, 43 | B, 44 | CUDA_R_16BF, 45 | ldb, 46 | beta, 47 | C, 48 | CUDA_R_32F, 49 | ldc, 50 | CUDA_R_32F, 51 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 52 | } 53 | 54 | // FP16 Tensor core wrapper around cublas GEMMEx 55 | cublasStatus_t gemmex_wrapper( 56 | cublasHandle_t handle, 57 | cublasOperation_t transa, 58 | cublasOperation_t transb, 59 | int m, 60 | int n, 61 | int k, 62 | const float* alpha, 63 | at::Half* A, 64 | int lda, 65 | at::Half* B, 66 | int ldb, 67 | const float* beta, 68 | float* C, 69 | int ldc) { 70 | return cublasGemmEx( 71 | handle, 72 | transa, 73 | transb, 74 | m, 75 | n, 76 | k, 77 | alpha, 78 | A, 79 | CUDA_R_16F, 80 | lda, 81 | B, 82 | CUDA_R_16F, 83 | ldb, 84 | beta, 85 | C, 86 | CUDA_R_32F, 87 | ldc, 88 | CUDA_R_32F, 89 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 90 | } 91 | 92 | // FP32 Tensor core wrapper around cublas GEMMEx 93 | cublasStatus_t gemmex_wrapper( 94 | cublasHandle_t handle, 95 | cublasOperation_t transa, 96 | cublasOperation_t transb, 97 | int m, 98 | int n, 99 | int k, 100 | const float* alpha, 101 | float* A, 102 | int lda, 103 | float* B, 104 | int ldb, 105 | const float* beta, 106 | float* C, 107 | int ldc) { 108 | return cublasGemmEx( 109 | handle, 110 | transa, 111 | transb, 112 | m, 113 | n, 114 | k, 115 | alpha, 116 | A, 117 | CUDA_R_32F, 118 | lda, 119 | B, 120 | CUDA_R_32F, 121 | ldb, 122 | beta, 123 | C, 124 | CUDA_R_32F, 125 | ldc, 126 | CUDA_R_32F, 127 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 128 | } 129 | 130 | template 131 | int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { 132 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); 133 | cudaStream_t stream; 134 | cublasGetStream(handle, &stream); 135 | const float alpha = 1.0; 136 | const float beta = 1.0; 137 | int status = 1; 138 | 139 | status = gemmex_wrapper( 140 | handle, 141 | CUBLAS_OP_N, 142 | CUBLAS_OP_T, 143 | in_dim, 144 | out_dim, 145 | hidden_dim, 146 | &alpha, 147 | input, 148 | in_dim, 149 | d_output, 150 | out_dim, 151 | &beta, 152 | d_weight, 153 | in_dim); 154 | return status; 155 | } 156 | 157 | template int wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 158 | template int wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 159 | template int wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 160 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/layer_norm_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | /*This code is copied fron NVIDIA apex: 18 | * https://github.com/NVIDIA/apex 19 | * with minor changes. */ 20 | 21 | #include 22 | #include 23 | #include 24 | #include "compat.h" 25 | 26 | namespace { 27 | 28 | void compute_n1_n2( 29 | at::Tensor input, 30 | at::IntArrayRef normalized_shape, 31 | int& n1, 32 | int& n2) { 33 | int idiff = input.ndimension() - normalized_shape.size(); 34 | n2 = 1; 35 | for (int i = 0; i < (int)normalized_shape.size(); ++i) { 36 | assert( input.sizes()[i+idiff] == normalized_shape[i] ); 37 | n2 *= normalized_shape[i]; 38 | } 39 | n1 = 1; 40 | for (int i = 0; i < idiff; ++i) { 41 | n1 *= input.sizes()[i]; 42 | } 43 | } 44 | 45 | void check_args( 46 | at::IntArrayRef normalized_shape, 47 | at::Tensor gamma, 48 | at::Tensor beta 49 | ) 50 | { 51 | TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); 52 | TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); 53 | } 54 | 55 | void check_args( 56 | at::Tensor input, 57 | at::IntArrayRef normalized_shape, 58 | int& n1, 59 | int& n2 60 | ) 61 | { 62 | int64_t normalized_ndim = normalized_shape.size(); 63 | 64 | if (normalized_ndim < 1) { 65 | std::stringstream ss; 66 | ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " 67 | << "containing at least one element, but got normalized_shape=" 68 | << normalized_shape; 69 | throw std::runtime_error(ss.str()); 70 | } 71 | 72 | auto input_shape = input.sizes(); 73 | auto input_ndim = input.dim(); 74 | 75 | if (input_ndim < normalized_ndim || 76 | !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { 77 | std::stringstream ss; 78 | ss << "Given normalized_shape=" << normalized_shape 79 | << ", expected input with shape [*"; 80 | for (auto size : normalized_shape) { 81 | ss << ", " << size; 82 | } 83 | ss << "], but got input of size" << input_shape; 84 | throw std::runtime_error(ss.str()); 85 | } 86 | 87 | compute_n1_n2(input,normalized_shape,n1,n2); 88 | } 89 | 90 | 91 | void check_args( 92 | at::Tensor input, 93 | at::IntArrayRef normalized_shape, 94 | at::Tensor gamma, 95 | at::Tensor beta, 96 | int& n1, 97 | int& n2 98 | ) 99 | { 100 | check_args(input,normalized_shape,n1,n2); 101 | check_args(normalized_shape,gamma,beta); 102 | } 103 | } 104 | 105 | void cuda_layer_norm( 106 | at::Tensor* output, 107 | at::Tensor* mean, 108 | at::Tensor* invvar, 109 | at::Tensor* input, 110 | int n1, 111 | int n2, 112 | at::IntArrayRef normalized_shape, 113 | at::Tensor* gamma, 114 | at::Tensor* beta, 115 | double epsilon); 116 | 117 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 118 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 119 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 120 | 121 | std::vector layer_norm_affine( 122 | at::Tensor input, 123 | at::IntArrayRef normalized_shape, 124 | at::Tensor gamma, 125 | at::Tensor beta, 126 | double epsilon) { 127 | 128 | CHECK_INPUT(input); 129 | CHECK_INPUT(gamma); 130 | CHECK_INPUT(beta); 131 | int n1, n2; 132 | check_args(input, normalized_shape, gamma, beta, n1, n2); 133 | 134 | at::Tensor output = at::empty_like( 135 | input, gamma.options().dtype(gamma.scalar_type())); 136 | at::Tensor mean = at::empty( 137 | {n1}, input.options().dtype(at::ScalarType::Float)); 138 | at::Tensor invvar = at::empty_like(mean); 139 | 140 | cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, 141 | normalized_shape, &gamma, &beta, epsilon); 142 | 143 | return {output, mean, invvar}; 144 | 145 | } 146 | 147 | 148 | void cuda_layer_norm_gradient( 149 | at::Tensor* dout, 150 | at::Tensor* mean, 151 | at::Tensor* invvar, 152 | at::Tensor* input, 153 | int n1, 154 | int n2, 155 | at::IntArrayRef normalized_shape, 156 | at::Tensor* gamma, 157 | at::Tensor* beta, 158 | double epsilon, 159 | at::Tensor* grad_input, 160 | at::Tensor* grad_gamma, 161 | at::Tensor* grad_beta 162 | ); 163 | 164 | std::vector layer_norm_gradient_affine( 165 | at::Tensor dout, 166 | at::Tensor mean, 167 | at::Tensor invvar, 168 | at::Tensor input, 169 | at::IntArrayRef normalized_shape, 170 | at::Tensor gamma, 171 | at::Tensor beta, 172 | double epsilon) { 173 | 174 | CHECK_INPUT(dout); 175 | CHECK_INPUT(mean); 176 | CHECK_INPUT(invvar); 177 | CHECK_INPUT(input); 178 | CHECK_INPUT(gamma); 179 | CHECK_INPUT(beta); 180 | int n1, n2; 181 | check_args(input, normalized_shape, gamma, beta, n1, n2); 182 | 183 | at::Tensor grad_input = at::empty_like(input); 184 | at::Tensor grad_gamma = at::empty_like(gamma); 185 | at::Tensor grad_beta = at::empty_like(beta); 186 | 187 | cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, 188 | normalized_shape, &gamma, &beta, epsilon, 189 | &grad_input, &grad_gamma, &grad_beta); 190 | 191 | return {grad_input, grad_gamma, grad_beta}; 192 | 193 | } 194 | 195 | 196 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 197 | m.def("forward_affine", &layer_norm_affine, 198 | "LayerNorm forward (CUDA)"); 199 | m.def("backward_affine", &layer_norm_gradient_affine, 200 | "LayerNorm backward (CUDA)"); 201 | } 202 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/load_cpp_extensions.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from pybind11.setup_helpers import Pybind11Extension 3 | 4 | srcpath = pathlib.Path(__file__).parent.absolute() 5 | 6 | 7 | def get_helpers_extension(): 8 | """Get the helpers pybind11 extension.""" 9 | return [ 10 | Pybind11Extension( 11 | name="helpers", 12 | sources=[str(srcpath / "helpers.cpp")], 13 | language="c++", 14 | ) 15 | ] 16 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/load_kernel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from dataclasses import dataclass 4 | from typing import List 5 | from torch.utils import cpp_extension 6 | from pathlib import Path 7 | 8 | 9 | def _create_build_dir(buildpath): 10 | try: 11 | os.mkdir(buildpath) 12 | except OSError: 13 | if not os.path.isdir(buildpath): 14 | print(f"Creation of the build directory {buildpath} failed") 15 | 16 | 17 | @dataclass 18 | class KernelBuildInfo: 19 | name: str 20 | sources: List[Path] 21 | build_directory: str 22 | extra_cflags: List[str] 23 | extra_cuda_cflags: List[str] 24 | verbose: bool 25 | 26 | def __init__(self, name, sources, build_directory, extra_cflags, extra_cuda_cflags, verbose): 27 | self.name = name 28 | self.extra_cflags = extra_cflags 29 | self.extra_cuda_cflags = extra_cuda_cflags 30 | self.verbose = verbose 31 | if not isinstance(build_directory, Path): 32 | build_directory = Path(build_directory) 33 | self.build_directory = build_directory 34 | for i, source in enumerate(sources): 35 | if not isinstance(source, Path): 36 | sources[i] = Path(source) 37 | self.sources = sources 38 | 39 | def __repr__(self): 40 | return f"KernelBuildInfo(name={self.name}, sources={self.sources}, build_directory={self.build_directory}, extra_cflags={self.extra_cflags}, extra_cuda_cflags={self.extra_cuda_cflags}, verbose={self.verbose})\n" 41 | 42 | def to_setup_cpp_extension(self) -> cpp_extension.CUDAExtension: 43 | sources = [str(source) for source in self.sources] 44 | return cpp_extension.CUDAExtension( 45 | name=self.name, 46 | sources=sources, 47 | extra_compile_args={ 48 | "cxx": self.extra_cflags, 49 | "nvcc": self.extra_cuda_cflags, 50 | }, 51 | is_python_module=True, 52 | ) 53 | 54 | def load(self): 55 | os.environ["TORCH_CUDA_ARCH_LIST"] = "" 56 | _create_build_dir(self.build_directory) 57 | _ = cpp_extension.load( 58 | name=self.name, 59 | sources=self.sources, 60 | build_directory=Path(self.build_directory), 61 | extra_cflags=self.extra_cflags, 62 | extra_cuda_cflags=self.extra_cuda_cflags, 63 | verbose=self.verbose, 64 | ) 65 | 66 | 67 | srcpath = pathlib.Path(__file__).parent.absolute() 68 | 69 | BUILD_DIR = srcpath / "build" 70 | 71 | BASE_CFLAGS = [ 72 | "-O3", 73 | "-llibtorch_python", 74 | ] 75 | BASE_CUDA_CFLAGS = [ 76 | "-O3", 77 | "--use_fast_math", 78 | "-gencode", 79 | "arch=compute_80,code=sm_80", 80 | ] 81 | 82 | BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS = BASE_CUDA_CFLAGS + [ 83 | "-U__CUDA_NO_HALF_OPERATORS__", 84 | "-U__CUDA_NO_HALF_CONVERSIONS__", 85 | "--expt-relaxed-constexpr", 86 | "--expt-extended-lambda", 87 | ] 88 | 89 | # These are the kernels that we need to build JIT as they have 90 | # some issues in installing with pip. Note: They are defined in the 91 | # ALL_BUILD_KERNELS dictionary below. 92 | JIT_BUILD_KERNELS = [] 93 | 94 | ALL_BUILD_KERNELS = { 95 | "scaled_upper_triang_masked_softmax": KernelBuildInfo( 96 | name="scaled_upper_triang_masked_softmax_cuda", 97 | sources=[ 98 | srcpath / "scaled_upper_triang_masked_softmax.cpp", 99 | srcpath / "scaled_upper_triang_masked_softmax_cuda.cu", 100 | ], 101 | build_directory=BUILD_DIR, 102 | extra_cflags=BASE_CFLAGS, 103 | extra_cuda_cflags=BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS, 104 | verbose=True, 105 | ), 106 | "scaled_masked_softmax_cuda": KernelBuildInfo( 107 | name="scaled_masked_softmax_cuda", 108 | sources=[srcpath / "scaled_masked_softmax.cpp", srcpath / "scaled_masked_softmax_cuda.cu"], 109 | build_directory=BUILD_DIR, 110 | extra_cflags=BASE_CFLAGS, 111 | extra_cuda_cflags=BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS, 112 | verbose=True, 113 | ), 114 | "scaled_softmax_cuda": KernelBuildInfo( 115 | name="scaled_softmax_cuda", 116 | sources=[srcpath / "scaled_softmax.cpp", srcpath / "scaled_softmax_cuda.cu"], 117 | build_directory=BUILD_DIR, 118 | extra_cflags=BASE_CFLAGS, 119 | extra_cuda_cflags=BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS, 120 | verbose=True, 121 | ), 122 | "fused_mix_prec_layer_norm_cuda": KernelBuildInfo( 123 | name="fused_mix_prec_layer_norm_cuda", 124 | sources=[srcpath / "layer_norm_cuda.cpp", srcpath / "layer_norm_cuda_kernel.cu"], 125 | build_directory=BUILD_DIR, 126 | extra_cflags=BASE_CFLAGS, 127 | extra_cuda_cflags=BASE_CUDA_CFLAGS + ["-maxrregcount=50"], 128 | verbose=True, 129 | ), 130 | "fused_dense_cuda": KernelBuildInfo( 131 | name="fused_dense_cuda", 132 | sources=[srcpath / "fused_weight_gradient_dense.cpp", srcpath / "fused_weight_gradient_dense_cuda.cu"], 133 | build_directory=BUILD_DIR, 134 | extra_cflags=BASE_CFLAGS, 135 | extra_cuda_cflags=BASE_CUDA_CFLAGS, 136 | verbose=True, 137 | ), 138 | } 139 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/scaled_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace multihead_attn { 22 | namespace fused_softmax { 23 | namespace scaled_masked_softmax { 24 | 25 | torch::Tensor fwd_cuda( 26 | torch::Tensor const& input, 27 | torch::Tensor const& mask, 28 | float scale_factor); 29 | 30 | torch::Tensor bwd_cuda( 31 | torch::Tensor const& output_grads, 32 | torch::Tensor const& softmax_results, 33 | float scale_factor); 34 | 35 | int get_batch_per_block_cuda( 36 | int query_seq_len, 37 | int key_seq_len, 38 | int batches, 39 | int attn_heads); 40 | 41 | torch::Tensor fwd( 42 | torch::Tensor const& input, 43 | torch::Tensor const& mask, 44 | float scale_factor) { 45 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 46 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 47 | (input.scalar_type() == at::ScalarType::BFloat16), 48 | "Only fp16 and bf16 are supported"); 49 | AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); 50 | 51 | return fwd_cuda(input, mask, scale_factor); 52 | } 53 | 54 | torch::Tensor bwd( 55 | torch::Tensor const& output_grads, 56 | torch::Tensor const& softmax_results, 57 | float scale_factor) { 58 | 59 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 60 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 61 | 62 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 63 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 64 | "Only fp16 and bf16 are supported"); 65 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 66 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 67 | "Only fp16 and bf16 are supported"); 68 | 69 | return bwd_cuda(output_grads, softmax_results, scale_factor); 70 | } 71 | 72 | int get_batch_per_block( 73 | int query_seq_len, 74 | int key_seq_len, 75 | int batches, 76 | int attn_heads) { 77 | return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); 78 | } 79 | 80 | } // end namespace scaled_masked_softmax 81 | } // end namespace fused_softmax 82 | } // end namespace multihead_attn 83 | 84 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 85 | m.def("forward", 86 | &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 87 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 88 | 89 | m.def("backward", 90 | &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, 91 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 92 | 93 | m.def("get_batch_per_block", 94 | &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, 95 | "Return Batch per block size." 96 | ); 97 | } 98 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/scaled_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "scaled_masked_softmax.h" 25 | #include "type_shim.h" 26 | 27 | namespace multihead_attn { 28 | namespace fused_softmax { 29 | namespace scaled_masked_softmax { 30 | 31 | int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ 32 | return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); 33 | } 34 | 35 | 36 | torch::Tensor fwd_cuda( 37 | torch::Tensor const& input, 38 | torch::Tensor const& mask, 39 | float scale_factor) 40 | { 41 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 42 | const int batches = input.size(0); 43 | const int pad_batches = mask.size(0); 44 | const int attn_heads = input.size(1); 45 | const int query_seq_len = input.size(2); 46 | const int key_seq_len = input.size(3); 47 | TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); 48 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 49 | TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); 50 | TORCH_INTERNAL_ASSERT(mask.size(1) == 1); 51 | TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); 52 | TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); 53 | 54 | // Output 55 | auto act_options = input.options().requires_grad(false); 56 | torch::Tensor softmax_results = 57 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 58 | 59 | // Softmax Intermediate Result Ptr 60 | void* input_ptr = static_cast(input.data_ptr()); 61 | void* mask_ptr = static_cast(mask.data_ptr()); 62 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 63 | 64 | DISPATCH_HALF_AND_BFLOAT( 65 | input.scalar_type(), 66 | "dispatch_scaled_masked_softmax_forward", 67 | dispatch_scaled_masked_softmax_forward( 68 | reinterpret_cast(softmax_results_ptr), 69 | reinterpret_cast(input_ptr), 70 | reinterpret_cast(mask_ptr), 71 | scale_factor, 72 | query_seq_len, 73 | key_seq_len, 74 | batches, 75 | attn_heads, 76 | pad_batches); 77 | ); 78 | return softmax_results; 79 | } 80 | 81 | torch::Tensor bwd_cuda( 82 | torch::Tensor const& output_grads_, 83 | torch::Tensor const& softmax_results_, 84 | float scale_factor) { 85 | 86 | auto output_grads = output_grads_.contiguous(); 87 | auto softmax_results = softmax_results_.contiguous(); 88 | 89 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 90 | const int batches = output_grads.size(0); 91 | const int attn_heads = output_grads.size(1); 92 | const int query_seq_len = output_grads.size(2); 93 | const int key_seq_len = output_grads.size(3); 94 | 95 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 96 | 97 | //Softmax Grad 98 | DISPATCH_HALF_AND_BFLOAT( 99 | output_grads_.scalar_type(), 100 | "dispatch_scaled_masked_softmax_backward", 101 | dispatch_scaled_masked_softmax_backward( 102 | reinterpret_cast(output_grads_ptr), 103 | reinterpret_cast(output_grads_ptr), 104 | reinterpret_cast(softmax_results.data_ptr()), 105 | scale_factor, 106 | query_seq_len, 107 | key_seq_len, 108 | batches, 109 | attn_heads); 110 | ); 111 | 112 | //backward pass is completely in-place 113 | return output_grads; 114 | } 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/scaled_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace multihead_attn { 22 | namespace fused_softmax { 23 | namespace scaled_softmax { 24 | 25 | torch::Tensor fwd_cuda( 26 | torch::Tensor const& input, 27 | float scale_factor); 28 | 29 | torch::Tensor bwd_cuda( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor); 33 | 34 | torch::Tensor fwd( 35 | torch::Tensor const& input, 36 | float scale_factor) { 37 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 38 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 39 | (input.scalar_type() == at::ScalarType::BFloat16), 40 | "Only fp16 and bf16 are supported"); 41 | 42 | return fwd_cuda(input, scale_factor); 43 | } 44 | 45 | torch::Tensor bwd( 46 | torch::Tensor const& output_grads, 47 | torch::Tensor const& softmax_results, 48 | float scale_factor) { 49 | 50 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 51 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 52 | 53 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 54 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 55 | "Only fp16 and bf16 are supported"); 56 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 57 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 58 | "Only fp16 and bf16 are supported"); 59 | 60 | return bwd_cuda(output_grads, softmax_results, scale_factor); 61 | } 62 | 63 | } // end namespace scaled_softmax 64 | } // end namespace fused_softmax 65 | } // end namespace multihead_attn 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("forward", 69 | &multihead_attn::fused_softmax::scaled_softmax::fwd, 70 | "Self Multihead Attention scaled, softmax -- Forward."); 71 | m.def("backward", 72 | &multihead_attn::fused_softmax::scaled_softmax::bwd, 73 | "Self Multihead Attention scaled, softmax -- Backward."); 74 | } 75 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/scaled_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "scaled_masked_softmax.h" 25 | #include "type_shim.h" 26 | 27 | namespace multihead_attn { 28 | namespace fused_softmax { 29 | namespace scaled_softmax { 30 | 31 | torch::Tensor fwd_cuda( 32 | torch::Tensor const& input, 33 | float scale_factor) 34 | { 35 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 36 | const int batches = input.size(0); 37 | const int attn_heads = input.size(1); 38 | const int query_seq_len = input.size(2); 39 | const int key_seq_len = input.size(3); 40 | TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); 41 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 42 | 43 | // Output 44 | auto act_options = input.options().requires_grad(false); 45 | torch::Tensor softmax_results = 46 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 47 | 48 | // Softmax Intermediate Result Ptr 49 | void* input_ptr = static_cast(input.data_ptr()); 50 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 51 | 52 | DISPATCH_HALF_AND_BFLOAT( 53 | input.scalar_type(), 54 | "dispatch_scaled_softmax_forward", 55 | dispatch_scaled_softmax_forward( 56 | reinterpret_cast(softmax_results_ptr), 57 | reinterpret_cast(input_ptr), 58 | scale_factor, 59 | query_seq_len, 60 | key_seq_len, 61 | batches, 62 | attn_heads); 63 | ); 64 | return softmax_results; 65 | } 66 | 67 | torch::Tensor bwd_cuda( 68 | torch::Tensor const& output_grads_, 69 | torch::Tensor const& softmax_results_, 70 | float scale_factor) { 71 | 72 | auto output_grads = output_grads_.contiguous(); 73 | auto softmax_results = softmax_results_.contiguous(); 74 | 75 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 76 | const int batches = output_grads.size(0); 77 | const int attn_heads = output_grads.size(1); 78 | const int query_seq_len = output_grads.size(2); 79 | const int key_seq_len = output_grads.size(3); 80 | 81 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 82 | 83 | //Softmax Grad 84 | DISPATCH_HALF_AND_BFLOAT( 85 | output_grads_.scalar_type(), 86 | "dispatch_scaled_masked_softmax_backward", 87 | dispatch_scaled_masked_softmax_backward( 88 | reinterpret_cast(output_grads_ptr), 89 | reinterpret_cast(output_grads_ptr), 90 | reinterpret_cast(softmax_results.data_ptr()), 91 | scale_factor, 92 | query_seq_len, 93 | key_seq_len, 94 | batches, 95 | attn_heads); 96 | ); 97 | 98 | //backward pass is completely in-place 99 | return output_grads; 100 | } 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/scaled_upper_triang_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace multihead_attn { 22 | namespace fused_softmax { 23 | namespace scaled_upper_triang_masked_softmax { 24 | 25 | torch::Tensor fwd_cuda( 26 | torch::Tensor const& input, 27 | float scale_factor); 28 | 29 | torch::Tensor bwd_cuda( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor); 33 | 34 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 35 | AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); 36 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 37 | (input.scalar_type() == at::ScalarType::BFloat16), 38 | "Only fp16 and bf16 are supported"); 39 | 40 | return fwd_cuda(input, scale_factor); 41 | } 42 | 43 | torch::Tensor bwd( 44 | torch::Tensor const& output_grads, 45 | torch::Tensor const& softmax_results, 46 | float scale_factor) { 47 | 48 | AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); 49 | AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); 50 | 51 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 52 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 53 | "Only fp16 and bf16 are supported"); 54 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 55 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 56 | "Only fp16 and bf16 are supported"); 57 | 58 | return bwd_cuda(output_grads, softmax_results, scale_factor); 59 | } 60 | 61 | } // end namespace scaled_upper_triang_masked_softmax 62 | } // end namespace fused_softmax 63 | } // end namespace multihead_attn 64 | 65 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 66 | m.def("forward", 67 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 68 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 69 | m.def("backward", 70 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, 71 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 72 | } 73 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "scaled_upper_triang_masked_softmax.h" 25 | #include "type_shim.h" 26 | 27 | namespace multihead_attn { 28 | namespace fused_softmax { 29 | namespace scaled_upper_triang_masked_softmax { 30 | 31 | torch::Tensor fwd_cuda( 32 | torch::Tensor const& input, 33 | float scale_factor) 34 | { 35 | // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 36 | const int attn_batches = input.size(0); 37 | const int seq_len = input.size(1); 38 | TORCH_INTERNAL_ASSERT(seq_len <= 8192); 39 | 40 | // Output 41 | auto act_options = input.options().requires_grad(false); 42 | torch::Tensor softmax_results = 43 | torch::empty({attn_batches, seq_len, seq_len}, act_options); 44 | 45 | // Softmax Intermediate Result Ptr 46 | void* input_ptr = static_cast(input.data_ptr()); 47 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 48 | 49 | DISPATCH_HALF_AND_BFLOAT( 50 | input.scalar_type(), 51 | "dispatch_scaled_upper_triang_masked_softmax_forward", 52 | dispatch_scaled_upper_triang_masked_softmax_forward( 53 | reinterpret_cast(softmax_results_ptr), 54 | reinterpret_cast(input_ptr), 55 | scale_factor, 56 | seq_len, 57 | seq_len, 58 | attn_batches); 59 | ); 60 | return softmax_results; 61 | } 62 | 63 | 64 | torch::Tensor bwd_cuda( 65 | torch::Tensor const& output_grads_, 66 | torch::Tensor const& softmax_results_, 67 | float scale_factor) { 68 | 69 | auto output_grads = output_grads_.contiguous(); 70 | auto softmax_results = softmax_results_.contiguous(); 71 | 72 | //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 73 | const int attn_batches = output_grads.size(0); 74 | const int seq_len = output_grads.size(1); 75 | TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); 76 | 77 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 78 | 79 | //Softmax Grad 80 | DISPATCH_HALF_AND_BFLOAT( 81 | output_grads_.scalar_type(), 82 | "dispatch_scaled_upper_triang_masked_softmax_backward", 83 | dispatch_scaled_upper_triang_masked_softmax_backward( 84 | reinterpret_cast(output_grads_ptr), 85 | reinterpret_cast(output_grads_ptr), 86 | reinterpret_cast(softmax_results.data_ptr()), 87 | scale_factor, 88 | seq_len, 89 | seq_len, 90 | attn_batches); 91 | ); 92 | 93 | //backward pass is completely in-place 94 | return output_grads; 95 | } 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/megatron/fused_kernels/megatron_fused_kernels/tests/__init__.py -------------------------------------------------------------------------------- /megatron/fused_kernels/megatron_fused_kernels/type_shim.h: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | 18 | #include 19 | #include "compat.h" 20 | 21 | 22 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 23 | switch(TYPE) \ 24 | { \ 25 | case at::ScalarType::Half: \ 26 | { \ 27 | using scalar_t = at::Half; \ 28 | __VA_ARGS__; \ 29 | break; \ 30 | } \ 31 | case at::ScalarType::BFloat16: \ 32 | { \ 33 | using scalar_t = at::BFloat16; \ 34 | __VA_ARGS__; \ 35 | break; \ 36 | } \ 37 | default: \ 38 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 39 | } 40 | 41 | 42 | #define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ 43 | switch(TYPE) \ 44 | { \ 45 | case at::ScalarType::Half: \ 46 | { \ 47 | using scalar_t = at::Half; \ 48 | __VA_ARGS__; \ 49 | break; \ 50 | } \ 51 | case at::ScalarType::BFloat16: \ 52 | { \ 53 | using scalar_t = at::BFloat16; \ 54 | __VA_ARGS__; \ 55 | break; \ 56 | } \ 57 | case at::ScalarType::Float: \ 58 | { \ 59 | using scalar_t = float; \ 60 | __VA_ARGS__; \ 61 | break; \ 62 | } \ 63 | default: \ 64 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 65 | } 66 | 67 | 68 | 69 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ 70 | switch(TYPEIN) \ 71 | { \ 72 | case at::ScalarType::Float: \ 73 | { \ 74 | using scalar_t_in = float; \ 75 | switch(TYPEOUT) \ 76 | { \ 77 | case at::ScalarType::Float: \ 78 | { \ 79 | using scalar_t_out = float; \ 80 | __VA_ARGS__; \ 81 | break; \ 82 | } \ 83 | case at::ScalarType::Half: \ 84 | { \ 85 | using scalar_t_out = at::Half; \ 86 | __VA_ARGS__; \ 87 | break; \ 88 | } \ 89 | case at::ScalarType::BFloat16: \ 90 | { \ 91 | using scalar_t_out = at::BFloat16; \ 92 | __VA_ARGS__; \ 93 | break; \ 94 | } \ 95 | default: \ 96 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ 97 | } \ 98 | break; \ 99 | } \ 100 | case at::ScalarType::Half: \ 101 | { \ 102 | using scalar_t_in = at::Half; \ 103 | using scalar_t_out = at::Half; \ 104 | __VA_ARGS__; \ 105 | break; \ 106 | } \ 107 | case at::ScalarType::BFloat16: \ 108 | { \ 109 | using scalar_t_in = at::BFloat16; \ 110 | using scalar_t_out = at::BFloat16; \ 111 | __VA_ARGS__; \ 112 | break; \ 113 | } \ 114 | default: \ 115 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ 116 | } 117 | -------------------------------------------------------------------------------- /megatron/fused_kernels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils import cpp_extension 3 | from megatron_fused_kernels.load_kernel import ALL_BUILD_KERNELS, JIT_BUILD_KERNELS 4 | from megatron_fused_kernels.load_cpp_extensions import get_helpers_extension 5 | 6 | 7 | def get_kernel_extensions(): 8 | # remove 'fused_dense_cuda' from the list of kernels to build 9 | # as it is not used in the current version of Megatron-LM 10 | # and it causes compilation errors when doing ninja installs. 11 | # We'll need to JIT compile it instead. 12 | kernels_to_build = [ 13 | build_info.to_setup_cpp_extension() for k, build_info in ALL_BUILD_KERNELS.items() if k not in JIT_BUILD_KERNELS 14 | ] 15 | return kernels_to_build 16 | 17 | 18 | setup( 19 | name="megatron_fused_kernels", 20 | packages=find_packages(exclude=("tests",)), 21 | packages_dir={"megatron_fused_kernels": "megatron_fused_kernels"}, 22 | ext_modules=get_kernel_extensions() + get_helpers_extension(), 23 | cmdclass={"build_ext": cpp_extension.BuildExtension}, 24 | zip_safe=False, 25 | author="Adept AI", 26 | author_email="", 27 | ) 28 | -------------------------------------------------------------------------------- /megatron/global_vars.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Megatron global variables.""" 20 | 21 | import functools 22 | import operator 23 | import os 24 | import sys 25 | import time 26 | from functools import reduce 27 | from pathlib import Path 28 | import requests 29 | import torch 30 | import yaml 31 | 32 | from megatron.tokenizer import build_tokenizer 33 | 34 | from .microbatches import build_num_microbatches_calculator 35 | 36 | _GLOBAL_ARGS = None 37 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None 38 | _GLOBAL_TOKENIZER = None 39 | _GLOBAL_TENSORBOARD_WRITER = None 40 | _GLOBAL_ADLR_AUTORESUME = None 41 | _GLOBAL_TIMERS = None 42 | 43 | _GLOBAL_MEMORY_BUFFER = None 44 | 45 | 46 | def get_args(): 47 | """Return arguments.""" 48 | _ensure_var_is_initialized(_GLOBAL_ARGS, "args") 49 | return _GLOBAL_ARGS 50 | 51 | 52 | def get_num_microbatches(): 53 | return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() 54 | 55 | 56 | def get_current_global_batch_size(): 57 | return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() 58 | 59 | 60 | def update_num_microbatches(consumed_samples, consistency_check=True): 61 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check) 62 | 63 | 64 | def get_tokenizer(): 65 | """Return tokenizer.""" 66 | _ensure_var_is_initialized(_GLOBAL_TOKENIZER, "tokenizer") 67 | return _GLOBAL_TOKENIZER 68 | 69 | 70 | def get_tensorboard_writer(): 71 | """Return tensorboard writer. It can be None so no need 72 | to check if it is initialized.""" 73 | return _GLOBAL_TENSORBOARD_WRITER 74 | 75 | 76 | def get_adlr_autoresume(): 77 | """ADLR autoresume object. It can be None so no need 78 | to check if it is initialized.""" 79 | return _GLOBAL_ADLR_AUTORESUME 80 | 81 | 82 | def get_timers(): 83 | """Return timers.""" 84 | _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") 85 | return _GLOBAL_TIMERS 86 | 87 | 88 | def get_global_memory_buffer(): 89 | _ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, "global memory buffer") 90 | return _GLOBAL_MEMORY_BUFFER 91 | 92 | 93 | def set_global_variables(args): 94 | """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" 95 | 96 | assert args is not None 97 | 98 | _ensure_var_is_not_initialized(_GLOBAL_ARGS, "args") 99 | set_args(args) 100 | 101 | _build_num_microbatches_calculator(args) 102 | if args.vocab_file or args.sp_model_file: 103 | _ = _build_tokenizer(args) 104 | _set_tensorboard_writer(args) 105 | _set_adlr_autoresume(args) 106 | _set_timers() 107 | _set_global_memory_buffer() 108 | 109 | 110 | def set_args(args): 111 | global _GLOBAL_ARGS 112 | _GLOBAL_ARGS = args 113 | 114 | 115 | def _build_num_microbatches_calculator(args): 116 | 117 | global _GLOBAL_NUM_MICROBATCHES_CALCULATOR 118 | _ensure_var_is_not_initialized( 119 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR, "num microbatches calculator" 120 | ) 121 | 122 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(args) 123 | 124 | 125 | def _build_tokenizer(args): 126 | """Initialize tokenizer.""" 127 | global _GLOBAL_TOKENIZER 128 | _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, "tokenizer") 129 | _GLOBAL_TOKENIZER = build_tokenizer(args) 130 | return _GLOBAL_TOKENIZER 131 | 132 | 133 | def rebuild_tokenizer(args): 134 | global _GLOBAL_TOKENIZER 135 | _GLOBAL_TOKENIZER = None 136 | return _build_tokenizer(args) 137 | 138 | 139 | def _set_tensorboard_writer(args): 140 | """Set tensorboard writer.""" 141 | global _GLOBAL_TENSORBOARD_WRITER 142 | _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer") 143 | 144 | if ( 145 | hasattr(args, "tensorboard_dir") 146 | and args.tensorboard_dir 147 | and args.rank == (args.world_size - 1) 148 | ): 149 | try: 150 | from torch.utils.tensorboard import ( 151 | SummaryWriter, 152 | ) # pylint: disable=import-outside-toplevel 153 | 154 | print("> setting tensorboard ...") 155 | _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( 156 | log_dir=args.tensorboard_dir, 157 | max_queue=args.tensorboard_queue_size, 158 | ) 159 | except ModuleNotFoundError: 160 | print( 161 | "WARNING: TensorBoard writing requested but is not " 162 | "available (are you using PyTorch 1.1.0 or later?), " 163 | "no TensorBoard logs will be written.", 164 | flush=True, 165 | ) 166 | 167 | 168 | def _set_adlr_autoresume(args): 169 | """Initialize ADLR autoresume.""" 170 | global _GLOBAL_ADLR_AUTORESUME 171 | _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, "adlr autoresume") 172 | 173 | if args.adlr_autoresume: 174 | if args.rank == 0: 175 | print("enabling autoresume ...", flush=True) 176 | sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) 177 | try: 178 | from userlib.auto_resume import ( 179 | AutoResume, 180 | ) # pylint: disable=import-outside-toplevel 181 | except BaseException: # pylint: disable=broad-except 182 | print("ADLR autoresume is not available, exiting ...") 183 | sys.exit() 184 | 185 | _GLOBAL_ADLR_AUTORESUME = AutoResume 186 | 187 | 188 | def _set_timers(): 189 | """Initialize timers.""" 190 | global _GLOBAL_TIMERS 191 | _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") 192 | _GLOBAL_TIMERS = Timers() 193 | 194 | 195 | def _set_global_memory_buffer(): 196 | """Initialize global buffer""" 197 | global _GLOBAL_MEMORY_BUFFER 198 | _ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, "global memory buffer") 199 | _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() 200 | 201 | 202 | def _ensure_var_is_initialized(var, name): 203 | """Make sure the input variable is not None.""" 204 | assert var is not None, f"{name} is not initialized." 205 | 206 | 207 | def _ensure_var_is_not_initialized(var, name): 208 | """Make sure the input variable is not None.""" 209 | assert var is None, f"{name} is already initialized." 210 | 211 | 212 | class _Timer: 213 | """Timer.""" 214 | 215 | def __init__(self, name): 216 | self.name_ = name 217 | self.elapsed_ = 0.0 218 | self.started_ = False 219 | self.start_time = time.time() 220 | 221 | def start(self): 222 | """Start the timer.""" 223 | # this import has to be here because of circular dependencies. 224 | from megatron.mpu import ( 225 | get_data_parallel_group, 226 | ) # pylint: disable=import-outside-toplevel 227 | 228 | assert not self.started_, "timer has already been started" 229 | torch.distributed.barrier(get_data_parallel_group()) 230 | torch.cuda.synchronize() 231 | self.start_time = time.time() 232 | self.started_ = True 233 | 234 | def stop(self): 235 | """Stop the timer.""" 236 | # this import has to be here because of circular dependencies. 237 | from megatron.mpu import ( 238 | get_data_parallel_group, 239 | ) # pylint: disable=import-outside-toplevel 240 | 241 | assert self.started_, "timer is not started" 242 | torch.distributed.barrier(get_data_parallel_group()) 243 | torch.cuda.synchronize() 244 | self.elapsed_ += time.time() - self.start_time 245 | self.started_ = False 246 | 247 | def reset(self): 248 | """Reset timer.""" 249 | self.elapsed_ = 0.0 250 | self.started_ = False 251 | 252 | def elapsed(self, reset=True): 253 | """Calculate the elapsed time.""" 254 | started = self.started_ 255 | # If the timing in progress, end it first. 256 | if self.started_: 257 | self.stop() 258 | # Get the elapsed time. 259 | elapsed = self.elapsed_ 260 | # Reset the elapsed time 261 | if reset: 262 | self.reset() 263 | # If timing was in progress, set it back. 264 | if started: 265 | self.start() 266 | return elapsed 267 | 268 | 269 | class Timers: 270 | """Group of timers.""" 271 | 272 | def __init__(self): 273 | self.timers = {} 274 | 275 | def __call__(self, name): 276 | if name not in self.timers: 277 | self.timers[name] = _Timer(name) 278 | return self.timers[name] 279 | 280 | def write(self, names, iteration, normalizer=1.0, reset=False): 281 | """Write timers to a tensorboard writer""" 282 | # currently when using add_scalars, 283 | # torch.utils.add_scalars makes each timer its own run, which 284 | # polutes the runs list, so we just add each as a scalar 285 | assert normalizer > 0.0 286 | for name in names: 287 | value = self.timers[name].elapsed(reset=reset) / normalizer 288 | key = f"timers/{name}-(s)" 289 | 290 | def log(self, names, normalizer=1.0, reset=True): 291 | """Log a group of timers.""" 292 | assert normalizer > 0.0 293 | string = "time (ms)" 294 | for name in names: 295 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer 296 | string += f" | {name}: {elapsed_time:.2f}" 297 | if torch.distributed.is_initialized(): 298 | if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): 299 | print(string, flush=True) 300 | else: 301 | print(string, flush=True) 302 | 303 | 304 | class GlobalMemoryBuffer: 305 | """Global buffer to avoid dynamic memory allocations. 306 | Caller should ensure that buffers of the same name 307 | are not used concurrently.""" 308 | 309 | def __init__(self): 310 | self.buffer = {} 311 | 312 | def get_tensor(self, tensor_shape, dtype, name): 313 | required_len = reduce(operator.mul, tensor_shape, 1) 314 | if ( 315 | self.buffer.get((name, dtype), None) is None 316 | or self.buffer[(name, dtype)].numel() < required_len 317 | ): 318 | self.buffer[(name, dtype)] = torch.empty( 319 | required_len, 320 | dtype=dtype, 321 | device=torch.cuda.current_device(), 322 | requires_grad=False, 323 | ) 324 | 325 | return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) 326 | -------------------------------------------------------------------------------- /megatron/memory.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | import torch 21 | 22 | 23 | # A dictionary of all the memory buffers allocated. 24 | _MEM_BUFFS = dict() 25 | 26 | 27 | def allocate_mem_buff(name, numel, dtype, track_usage): 28 | """Allocate a memory buffer.""" 29 | assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name) 30 | _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) 31 | return _MEM_BUFFS[name] 32 | 33 | 34 | class MemoryBuffer: 35 | """Contiguous memory buffer. 36 | Allocate a contiguous memory of type `dtype` and size `numel`. It is 37 | used to reduce memory fragmentation. 38 | 39 | Usage: After the allocation, the `_start` index is set tot the first 40 | index of the memory. A memory chunk starting from `_start` index 41 | can be `allocated` for an input tensor, with the elements of the 42 | tensor being coppied. The buffer can be reused by resetting the 43 | `_start` index. 44 | 45 | """ 46 | 47 | def __init__(self, name, numel, dtype, track_usage): 48 | if torch.distributed.get_rank() == 0: 49 | element_size = torch.tensor([], dtype=dtype).element_size() 50 | print( 51 | "> building the {} memory buffer with {} num elements " 52 | "and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024), 53 | flush=True, 54 | ) 55 | self.name = name 56 | self.numel = numel 57 | self.dtype = dtype 58 | self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) 59 | 60 | # Index tracking the start of the free memory. 61 | self._start = 0 62 | 63 | # Values used for tracking usage. 64 | self.track_usage = track_usage 65 | if self.track_usage: 66 | self.in_use_value = 0.0 67 | self.total_value = 0.0 68 | 69 | def reset(self): 70 | """Reset the buffer start index to the beginning of the buffer.""" 71 | self._start = 0 72 | 73 | def is_in_use(self): 74 | """Whether the current buffer hold on to any memory.""" 75 | return self._start > 0 76 | 77 | def numel_in_use(self): 78 | """Return number of elements in use.""" 79 | return self._start 80 | 81 | def add(self, tensor): 82 | """Allocate a chunk of memory from the buffer to tensor and copy 83 | the values.""" 84 | assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format( 85 | tensor.dtype, self.dtype 86 | ) 87 | # Number of elements of the input tensor. 88 | tensor_numel = torch.numel(tensor) 89 | new_start = self._start + tensor_numel 90 | assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format( 91 | tensor_numel, self.numel - self._start 92 | ) 93 | # New tensor is a view into the memory. 94 | new_tensor = self.data[self._start : new_start] 95 | self._start = new_start 96 | new_tensor = new_tensor.view(tensor.shape) 97 | new_tensor.copy_(tensor) 98 | # Return a pointer to the new tensor. 99 | return new_tensor 100 | 101 | def get_data(self): 102 | """Return the data currently in use.""" 103 | if self.track_usage: 104 | self.in_use_value += float(self._start) 105 | self.total_value += float(self.numel) 106 | return self.data[: self._start] 107 | 108 | def print_average_usage(self): 109 | """Print memory usage average over time. We would like this value 110 | to be as high as possible.""" 111 | assert self.track_usage, "You need to enable track usage." 112 | if torch.distributed.get_rank() == 0: 113 | print( 114 | " > usage of {} memory buffer: {:.2f} %".format( 115 | self.name, self.in_use_value * 100.0 / self.total_value 116 | ), 117 | flush=True, 118 | ) 119 | -------------------------------------------------------------------------------- /megatron/microbatches.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Megatron number of micro-batches calculators.""" 20 | 21 | from abc import ABC 22 | from abc import abstractmethod 23 | 24 | 25 | def build_num_microbatches_calculator(args): 26 | 27 | # Constant num micro-batches. 28 | if args.rampup_batch_size is None: 29 | num_microbatches_calculator = ConstantNumMicroBatches( 30 | args.global_batch_size, args.micro_batch_size, args.data_parallel_size 31 | ) 32 | if args.rank == 0: 33 | print( 34 | "setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True 35 | ) 36 | 37 | else: 38 | assert len(args.rampup_batch_size) == 3, ( 39 | "expected the following " 40 | "format: --rampup-batch-size " 41 | " " 42 | ) 43 | start_batch_size = int(args.rampup_batch_size[0]) 44 | batch_size_increment = int(args.rampup_batch_size[1]) 45 | ramup_samples = int(args.rampup_batch_size[2]) 46 | if args.rank == 0: 47 | print( 48 | "will use batch size rampup starting from global batch " 49 | "size {} to global batch size {} with batch size increments " 50 | "{} over {} samples.".format( 51 | start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples 52 | ), 53 | flush=True, 54 | ) 55 | num_microbatches_calculator = RampupBatchsizeNumMicroBatches( 56 | start_batch_size, 57 | batch_size_increment, 58 | ramup_samples, 59 | args.global_batch_size, 60 | args.micro_batch_size, 61 | args.data_parallel_size, 62 | ) 63 | 64 | return num_microbatches_calculator 65 | 66 | 67 | class NumMicroBatchesCalculator(ABC): 68 | def __init__(self): 69 | self.num_micro_batches = None 70 | self.current_global_batch_size = None 71 | 72 | def get(self): 73 | return self.num_micro_batches 74 | 75 | def get_current_global_batch_size(self): 76 | return self.current_global_batch_size 77 | 78 | @abstractmethod 79 | def update(self, consumed_samples, consistency_check): 80 | pass 81 | 82 | 83 | class ConstantNumMicroBatches(NumMicroBatchesCalculator): 84 | def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): 85 | micro_batch_times_data_parallel = micro_batch_size * data_parallel_size 86 | assert ( 87 | global_batch_size % micro_batch_times_data_parallel == 0 88 | ), "global batch size ({}) is not divisible by micro batch size ({})" " times data parallel size ({})".format( 89 | global_batch_size, micro_batch_size, data_parallel_size 90 | ) 91 | self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel 92 | assert self.num_micro_batches >= 1 93 | self.current_global_batch_size = global_batch_size 94 | 95 | def update(self, consumed_samples, consistency_check): 96 | pass 97 | 98 | 99 | class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): 100 | def __init__( 101 | self, 102 | start_batch_size, 103 | batch_size_increment, 104 | ramup_samples, 105 | global_batch_size, 106 | micro_batch_size, 107 | data_parallel_size, 108 | ): 109 | """Batch size ramp up. 110 | Over 111 | steps = (global-batch-size - start-batch-size) / batch_size_increment 112 | increment batch size from start-batch-size to global-batch-size using 113 | rampup-samples / steps 114 | samples. 115 | Arguments: 116 | start_batch_size: global batch size to start with 117 | batch_size_increment: global batch size increments 118 | ramup_samples: number of samples to use ramp up global 119 | batch size from `start_batch_size` to `global_batch_size` 120 | global_batch_size: global batch size post rampup 121 | micro_batch_size: micro batch size 122 | data_parallel_size: data parallel size. 123 | """ 124 | 125 | self.micro_batch_size = micro_batch_size 126 | self.data_parallel_size = data_parallel_size 127 | self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size 128 | assert self.micro_batch_times_data_parallel_size > 0 129 | 130 | assert start_batch_size > 0 131 | self.start_batch_size = start_batch_size 132 | 133 | assert global_batch_size > 0 134 | self.global_batch_size = global_batch_size 135 | diff_batch_size = self.global_batch_size - self.start_batch_size 136 | assert diff_batch_size >= 0 137 | assert batch_size_increment > 0 138 | self.batch_size_increment = batch_size_increment 139 | assert diff_batch_size % batch_size_increment == 0, ( 140 | "expected " 141 | "global batch size interval ({}) to be divisible by global batch " 142 | "size increment ({})".format(diff_batch_size, batch_size_increment) 143 | ) 144 | 145 | num_increments = diff_batch_size // self.batch_size_increment 146 | self.ramup_samples = ramup_samples 147 | assert self.ramup_samples >= 0 148 | self.rampup_samples_per_increment = self.ramup_samples / num_increments 149 | 150 | # Initialize number of microbatches. 151 | self.update(0, False) 152 | 153 | def update(self, consumed_samples, consistency_check): 154 | 155 | if consumed_samples > self.ramup_samples: 156 | self.current_global_batch_size = self.global_batch_size 157 | else: 158 | steps = int(consumed_samples / self.rampup_samples_per_increment) 159 | self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment 160 | assert self.current_global_batch_size <= self.global_batch_size 161 | 162 | if consistency_check: 163 | assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, ( 164 | "current global " 165 | "batch size ({}) is not divisible by micro-batch-size ({}) times" 166 | "data parallel size ({})".format( 167 | self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size 168 | ) 169 | ) 170 | self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size 171 | -------------------------------------------------------------------------------- /megatron/model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm 17 | 18 | from .distributed import DistributedDataParallel 19 | from .gpt_model import GPTModel 20 | from .language_model import get_language_model 21 | from .module import Float16Module 22 | from .enums import ModelType 23 | -------------------------------------------------------------------------------- /megatron/model/enums.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import enum 20 | 21 | 22 | class ModelType(enum.Enum): 23 | encoder_or_decoder = 1 24 | encoder_and_decoder = 2 25 | 26 | 27 | class LayerType(enum.Enum): 28 | encoder = 1 29 | decoder = 2 30 | 31 | 32 | class AttnType(enum.Enum): 33 | self_attn = 1 34 | cross_attn = 2 35 | 36 | 37 | class AttnMaskType(enum.Enum): 38 | padding = 1 39 | causal = 2 40 | -------------------------------------------------------------------------------- /megatron/model/fused_bias_gelu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | ###### BIAS GELU FUSION/ NO AUTOGRAD ################ 20 | # 1/sqrt(2*pi)-> 0.3989423 21 | # 1/sqrt(2) -> 0.70710678 22 | # sqrt(2/pi) -> 0.79788456 23 | # this function is tanh approximation of gelu 24 | # actual gelu is: 25 | # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) 26 | 27 | 28 | @torch.jit.script 29 | def bias_gelu(bias, y): 30 | x = bias + y 31 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 32 | 33 | 34 | # gradient of tanh approximation of gelu 35 | # gradient of actual gelu is: 36 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) 37 | @torch.jit.script 38 | def bias_gelu_back(g, bias, y): 39 | x = bias + y 40 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 41 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 42 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) 43 | return ff * g 44 | 45 | 46 | class GeLUFunction(torch.autograd.Function): 47 | @staticmethod 48 | # bias is an optional argument 49 | def forward(ctx, input, bias): 50 | ctx.save_for_backward(input, bias) 51 | return bias_gelu(bias, input) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, bias = ctx.saved_tensors 56 | tmp = bias_gelu_back(grad_output, bias, input) 57 | return tmp, tmp 58 | 59 | 60 | bias_gelu_impl = GeLUFunction.apply 61 | -------------------------------------------------------------------------------- /megatron/model/fused_bias_sqrelu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | 4 | import torch 5 | 6 | 7 | ###### BIAS SQRELU FUSION/ NO AUTOGRAD ################ 8 | 9 | 10 | @torch.jit.script 11 | def bias_sqrelu(bias, y): 12 | x = bias + y 13 | relud_x = torch.relu(x) 14 | return relud_x * relud_x 15 | 16 | 17 | @torch.jit.script 18 | def bias_sqrelu_back(g, bias, y): 19 | x = bias + y 20 | return g * 2 * torch.relu(x) 21 | 22 | 23 | class SqReLUFunction(torch.autograd.Function): 24 | @staticmethod 25 | # bias is an optional argument 26 | def forward(ctx, input, bias): 27 | ctx.save_for_backward(input, bias) 28 | return bias_sqrelu(bias, input) 29 | 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | input, bias = ctx.saved_tensors 33 | tmp = bias_sqrelu_back(grad_output, bias, input) 34 | return tmp, tmp 35 | 36 | 37 | bias_sqrelu_impl = SqReLUFunction.apply 38 | -------------------------------------------------------------------------------- /megatron/model/fused_layer_norm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This code is copied fron NVIDIA apex: 17 | https://github.com/NVIDIA/apex 18 | with some changes. """ 19 | 20 | import numbers 21 | import torch 22 | from torch.nn.parameter import Parameter 23 | from torch.nn import init 24 | import importlib 25 | 26 | from megatron.mpu import make_viewless_tensor 27 | from megatron.mpu import set_tensor_model_parallel_attributes 28 | 29 | try: 30 | from apex.contrib.layer_norm.layer_norm import FastLayerNormFN 31 | 32 | HAVE_PERSIST_LAYER_NORM = True 33 | except: 34 | HAVE_PERSIST_LAYER_NORM = False 35 | 36 | global fused_mix_prec_layer_norm_cuda 37 | fused_mix_prec_layer_norm_cuda = None 38 | 39 | 40 | class FusedLayerNormAffineFunction(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, weight, bias, normalized_shape, eps): 43 | 44 | ctx.normalized_shape = normalized_shape 45 | ctx.eps = eps 46 | input_ = input.contiguous() 47 | weight_ = weight.contiguous() 48 | bias_ = bias.contiguous() 49 | output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( 50 | input_, ctx.normalized_shape, weight_, bias_, ctx.eps 51 | ) 52 | ctx.save_for_backward(input_, weight_, bias_, mean, invvar) 53 | 54 | return output 55 | 56 | @staticmethod 57 | def backward(ctx, grad_output): 58 | 59 | input_, weight_, bias_, mean, invvar = ctx.saved_tensors 60 | grad_input = grad_weight = grad_bias = None 61 | grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( 62 | grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps 63 | ) 64 | 65 | return grad_input, grad_weight, grad_bias, None, None 66 | 67 | 68 | class MixedFusedLayerNorm(torch.nn.Module): 69 | def __init__( 70 | self, 71 | normalized_shape, 72 | eps=1e-5, 73 | no_persist_layer_norm=True, 74 | sequence_parallel=False, 75 | is_tensor_parallel_unique=False, 76 | reduce_tensor_parallel_grads=False, 77 | ): 78 | super(MixedFusedLayerNorm, self).__init__() 79 | 80 | global fused_mix_prec_layer_norm_cuda 81 | fused_mix_prec_layer_norm_cuda = importlib.import_module( 82 | "megatron_fused_kernels.fused_mix_prec_layer_norm_cuda" 83 | ) 84 | 85 | # List of hiddens sizes supported in the persistent layer norm kernel 86 | # If the hidden size is not supported, fall back to the non-persistent 87 | # kernel. 88 | persist_ln_hidden_sizes = [ 89 | 1024, 90 | 1536, 91 | 2048, 92 | 2304, 93 | 3072, 94 | 3840, 95 | 4096, 96 | 5120, 97 | 6144, 98 | 8192, 99 | 10240, 100 | 12288, 101 | 12800, 102 | 15360, 103 | 16384, 104 | 18432, 105 | 20480, 106 | 24576, 107 | 25600, 108 | 30720, 109 | 32768, 110 | 40960, 111 | 49152, 112 | 65536, 113 | ] 114 | if normalized_shape not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM: 115 | no_persist_layer_norm = True 116 | 117 | if isinstance(normalized_shape, numbers.Integral): 118 | normalized_shape = (normalized_shape,) 119 | self.normalized_shape = torch.Size(normalized_shape) 120 | self.eps = eps 121 | 122 | self.weight = Parameter(torch.Tensor(*normalized_shape)) 123 | set_tensor_model_parallel_attributes( 124 | self.weight, 125 | is_tensor_parallel_unique=is_tensor_parallel_unique, 126 | reduce_tensor_parallel_grads=reduce_tensor_parallel_grads, 127 | ) 128 | 129 | self.bias = Parameter(torch.Tensor(*normalized_shape)) 130 | set_tensor_model_parallel_attributes( 131 | self.bias, 132 | is_tensor_parallel_unique=is_tensor_parallel_unique, 133 | reduce_tensor_parallel_grads=reduce_tensor_parallel_grads, 134 | ) 135 | 136 | self.reset_parameters() 137 | self.no_persist_layer_norm = no_persist_layer_norm 138 | self.sequence_parallel = sequence_parallel 139 | 140 | # set sequence parallelism flag on weight and bias parameters 141 | setattr(self.weight, "sequence_parallel", self.sequence_parallel) 142 | setattr(self.bias, "sequence_parallel", self.sequence_parallel) 143 | 144 | def reset_parameters(self): 145 | 146 | init.ones_(self.weight) 147 | init.zeros_(self.bias) 148 | 149 | def forward(self, input): 150 | # Doing this shape check here is important because the two kernels below don't handle shapes the same way, 151 | # so we need to check that we're passing them what we expect. 152 | assert len(self.normalized_shape) == 1 153 | if input.shape[-1] != self.normalized_shape[0]: 154 | raise ValueError(f"Expected input shape ending in {self.normalized_shape=}, but got {input.shape=}") 155 | 156 | if self.no_persist_layer_norm: 157 | return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) 158 | else: 159 | output = FastLayerNormFN.apply(input, self.weight, self.bias, self.eps) 160 | 161 | # Apex's fast layer norm function outputs a 'view' tensor (i.e., has 162 | # a populated '_base' field). This will result in schedule.py's 163 | # deallocate_output_tensor() throwing an error, so a viewless tensor is 164 | # created to prevent this. 165 | output = make_viewless_tensor(inp=output, requires_grad=input.requires_grad, keep_graph=True) 166 | 167 | return output 168 | -------------------------------------------------------------------------------- /megatron/model/fused_softmax.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | import torch 21 | import torch.nn as nn 22 | from megatron.model.enums import AttnMaskType 23 | 24 | 25 | class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): 26 | """ 27 | Fused operation which performs following three operations in sequence 28 | 1. Scale the tensor. 29 | 2. Apply upper triangular mask (typically used in gpt models). 30 | 3. Perform softmax. 31 | """ 32 | 33 | @staticmethod 34 | def forward(ctx, inputs, scale): 35 | from megatron_fused_kernels import scaled_upper_triang_masked_softmax_cuda 36 | 37 | scale_t = torch.tensor([scale]) 38 | softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( 39 | inputs, scale_t[0] 40 | ) 41 | 42 | ctx.save_for_backward(softmax_results, scale_t) 43 | return softmax_results 44 | 45 | @staticmethod 46 | def backward(ctx, output_grads): 47 | from megatron_fused_kernels import scaled_upper_triang_masked_softmax_cuda 48 | 49 | softmax_results, scale_t = ctx.saved_tensors 50 | input_grads = scaled_upper_triang_masked_softmax_cuda.backward( 51 | output_grads, softmax_results, scale_t[0] 52 | ) 53 | 54 | return input_grads, None 55 | 56 | 57 | class ScaledSoftmax(torch.autograd.Function): 58 | """ 59 | Fused operation which performs following two operations in sequence 60 | 1. Scale the tensor. 61 | 2. Perform softmax. 62 | """ 63 | 64 | @staticmethod 65 | def forward(ctx, inputs, scale): 66 | from megatron_fused_kernels import scaled_softmax_cuda 67 | 68 | scale_t = torch.tensor([scale]) 69 | 70 | softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0]) 71 | ctx.save_for_backward(softmax_results, scale_t) 72 | return softmax_results 73 | 74 | @staticmethod 75 | def backward(ctx, output_grads): 76 | from megatron_fused_kernels import scaled_softmax_cuda 77 | 78 | softmax_results, scale_t = ctx.saved_tensors 79 | 80 | input_grads = scaled_softmax_cuda.backward( 81 | output_grads, softmax_results, scale_t[0] 82 | ) 83 | return input_grads, None, None 84 | 85 | 86 | class FusedScaleSoftmax(nn.Module): 87 | """ 88 | fused operation: scaling + mask + softmax 89 | 90 | Arguments: 91 | input_in_fp16: flag to indicate if input in fp16 data format. 92 | input_in_bf16: flag to indicate if input in bf16 data format. 93 | attn_mask_type: attention mask type (pad or causal) 94 | scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion 95 | softmax_in_fp32: if true, softmax in performed at fp32 precision. 96 | scale: scaling factor used in input tensor scaling. 97 | """ 98 | 99 | def __init__( 100 | self, 101 | input_in_fp16, 102 | input_in_bf16, 103 | attn_mask_type, 104 | scaled_masked_softmax_fusion, 105 | softmax_in_fp32, 106 | scale, 107 | ): 108 | super(FusedScaleSoftmax, self).__init__() 109 | self.input_in_fp16 = input_in_fp16 110 | self.input_in_bf16 = input_in_bf16 111 | assert not ( 112 | self.input_in_fp16 and self.input_in_bf16 113 | ), "both fp16 and bf16 flags cannot be active at the same time." 114 | self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 115 | self.attn_mask_type = attn_mask_type 116 | self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion 117 | self.softmax_in_fp32 = softmax_in_fp32 118 | self.scale = scale 119 | 120 | assert ( 121 | self.scale is None or softmax_in_fp32 122 | ), "softmax should be in fp32 when scaled" 123 | 124 | def forward(self, input): 125 | # [b, np, sq, sk] 126 | assert input.dim() == 4 127 | 128 | if self.is_kernel_available(*input.size()): 129 | return self.forward_fused_softmax(input) 130 | else: 131 | return self.forward_torch_softmax(input) 132 | 133 | def is_kernel_available(self, b, np, sq, sk): 134 | attn_batches = b * np 135 | if ( 136 | self.scaled_masked_softmax_fusion # user want to fuse 137 | and self.input_in_float16 # input must be fp16 138 | and 16 < sk <= 8192 # sk must be 16 ~ 8192 139 | and sq % 4 == 0 # sq must be divisor of 4 140 | and sk % 4 == 0 # sk must be divisor of 4 141 | and attn_batches % 4 == 0 # np * b must be divisor of 4 142 | ): 143 | if 0 <= sk <= 8192: 144 | batch_per_block = self.get_batch_per_block(sq, sk, b, np) 145 | 146 | if self.attn_mask_type == AttnMaskType.causal: 147 | if attn_batches % batch_per_block == 0: 148 | return True 149 | else: 150 | if sq % batch_per_block == 0 and sk % batch_per_block == 0: 151 | return True 152 | return False 153 | 154 | def forward_fused_softmax(self, input): 155 | b, np, sq, sk = input.size() 156 | scale = self.scale if self.scale is not None else 1.0 157 | 158 | if self.attn_mask_type == AttnMaskType.causal: 159 | assert sq == sk, "causal mask is only for self attention" 160 | 161 | # input is 3D tensor (attn_batches, sq, sk) 162 | input = input.view(-1, sq, sk) 163 | probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) 164 | return probs.view(b, np, sq, sk) 165 | else: 166 | # input is 4D tensor (b, np, sq, sk) 167 | return ScaledSoftmax.apply(input, scale) 168 | 169 | def forward_torch_softmax(self, input): 170 | b, np, sq, sk = input.size() 171 | if self.input_in_float16 and self.softmax_in_fp32: 172 | input = input.float() 173 | 174 | if self.scale is not None: 175 | input = input * self.scale 176 | 177 | if self.attn_mask_type == AttnMaskType.causal: 178 | assert sq == sk, "causal mask is only for self attention" 179 | attention_mask = torch.tril( 180 | torch.ones((1, input.shape[-1], input.shape[-1]), device=input.device) 181 | ).view(1, 1, input.shape[-1], input.shape[-1]) 182 | # invert mask so that 0 -> 1 and 1 -> 0 183 | attention_mask = attention_mask < 0.5 184 | 185 | # assign very low logit to masked entries 186 | input.masked_fill_(attention_mask, -10000.0) 187 | 188 | probs = torch.nn.Softmax(dim=-1)(input) 189 | 190 | if self.input_in_float16 and self.softmax_in_fp32: 191 | if self.input_in_fp16: 192 | probs = probs.half() 193 | else: 194 | probs = probs.bfloat16() 195 | 196 | return probs 197 | 198 | @staticmethod 199 | def get_batch_per_block(sq, sk, b, np): 200 | from megatron_fused_kernels import scaled_masked_softmax_cuda 201 | 202 | return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) 203 | -------------------------------------------------------------------------------- /megatron/model/gpt_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright (c) 2023 ADEPT AI LABS INC. 4 | # This file is based on code by the authors denoted below and has been modified from its original version. 5 | # 6 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """GPT-2 model.""" 21 | 22 | import torch 23 | 24 | from megatron import get_args 25 | from megatron import mpu 26 | from .module import MegatronModule 27 | 28 | from .enums import AttnMaskType 29 | from .language_model import parallel_lm_logits 30 | from .language_model import get_language_model 31 | from .utils import init_method_normal 32 | from .utils import scaled_init_method_normal 33 | 34 | try: 35 | from megatron.mpu.cross_entropy_parallel import CrossEntropyLossParallel 36 | from einops import rearrange 37 | except ImportError: 38 | CrossEntropyLossParallel = None 39 | rearrange = None 40 | 41 | 42 | def post_language_model_processing( 43 | lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy, use_fast_cross_entropy 44 | ): 45 | 46 | # Output. Format [s b h] 47 | output = parallel_lm_logits(lm_output, logit_weights, parallel_output) 48 | if labels is None: 49 | # [s b h] => [b s h] 50 | return output.transpose(0, 1).contiguous() 51 | else: 52 | # [b s] => [s b] 53 | labels = labels.transpose(0, 1).contiguous() 54 | if use_fast_cross_entropy: 55 | loss_fn = CrossEntropyLossParallel(reduction="none", inplace_backward=True) 56 | loss = rearrange( 57 | loss_fn(rearrange(output, "s b ... -> (s b) ..."), rearrange(labels, "s b ... -> (s b) ...")), 58 | "(s b) ... -> s b ...", 59 | s=output.shape[0], 60 | ) 61 | elif fp16_lm_cross_entropy: 62 | assert output.dtype == torch.half 63 | loss = mpu.vocab_parallel_cross_entropy(output, labels) 64 | else: 65 | loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) 66 | 67 | # [s b] => [b, s] 68 | loss = loss.transpose(0, 1).contiguous() 69 | 70 | argmax = mpu.layers.argmax_parallel(output) 71 | accuracy = labels == argmax 72 | 73 | accuracy = accuracy.transpose(0, 1).contiguous() 74 | argmax = argmax.transpose(0, 1).contiguous() 75 | 76 | return {"argmax": argmax, "loss": loss, "accuracy": accuracy} 77 | 78 | 79 | class GPTModel(MegatronModule): 80 | """GPT-2 Language model.""" 81 | 82 | def __init__( 83 | self, 84 | num_tokentypes=0, 85 | parallel_output=True, 86 | pre_process=True, 87 | post_process=True, 88 | continuous_embed_input_size=None, 89 | ): 90 | args = get_args() 91 | super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings) 92 | self.parallel_output = parallel_output 93 | self.pre_process = pre_process 94 | self.post_process = post_process 95 | self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy 96 | self.share_word_embeddings = not args.untie_embeddings 97 | self.use_fast_cross_entropy = args.use_fast_cross_entropy 98 | if self.use_fast_cross_entropy: 99 | if CrossEntropyLossParallel is None: 100 | raise ImportError("xentropy CUDA extension is not installed") 101 | if rearrange is None: 102 | raise ImportError("einops is not installed") 103 | 104 | self.language_model, self._language_model_key = get_language_model( 105 | num_tokentypes=num_tokentypes, 106 | add_pooler=False, 107 | encoder_attn_mask_type=AttnMaskType.causal, 108 | init_method=init_method_normal(args.init_method_std), 109 | scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), 110 | pre_process=self.pre_process, 111 | post_process=self.post_process, 112 | continuous_embed_input_size=continuous_embed_input_size, 113 | ) 114 | 115 | self.initialize_word_embeddings(init_method_normal) 116 | 117 | def set_input_tensor(self, input_tensor): 118 | """See megatron.model.transformer.set_input_tensor()""" 119 | self.language_model.set_input_tensor(input_tensor) 120 | 121 | def forward( 122 | self, 123 | input_ids, 124 | position_ids, 125 | labels=None, 126 | tokentype_ids=None, 127 | inference_params=None, 128 | lm_logits_mask=None, 129 | ): 130 | lm_output = self.language_model(input_ids, position_ids, inference_params=inference_params) 131 | 132 | if lm_logits_mask is not None: 133 | # lm_logits_mask has shape [s] while lm_output has shape [s, b, h]. 134 | lm_output = lm_output[lm_logits_mask, :, :] 135 | 136 | del tokentype_ids 137 | if self.post_process: 138 | return post_language_model_processing( 139 | lm_output, 140 | labels, 141 | self.word_embeddings_weight(), 142 | self.parallel_output, 143 | self.fp16_lm_cross_entropy, 144 | use_fast_cross_entropy=self.use_fast_cross_entropy, 145 | ) 146 | else: 147 | return lm_output 148 | 149 | def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): 150 | 151 | state_dict_ = {} 152 | state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint( 153 | destination, prefix, keep_vars 154 | ) 155 | # Save word_embeddings. 156 | # When it is the last stage of pipelining or if we are not sharing word 157 | # embeddings, then we need to save the variable to the state dict. 158 | if self.post_process and (not self.pre_process or not self.share_word_embeddings): 159 | state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict( 160 | destination, prefix, keep_vars 161 | ) 162 | return state_dict_ 163 | 164 | def load_state_dict(self, state_dict, strict=True): 165 | """Customized load.""" 166 | 167 | # Load word_embeddings. 168 | if (self.post_process and not self.pre_process) or not self.share_word_embeddings: 169 | if self._word_embeddings_for_head_key in state_dict: 170 | self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict) 171 | if self._language_model_key in state_dict: 172 | state_dict = state_dict[self._language_model_key] 173 | self.language_model.load_state_dict(state_dict, strict=strict) 174 | -------------------------------------------------------------------------------- /megatron/model/positional_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Positional embeddings modules.""" 15 | import torch 16 | import math 17 | 18 | 19 | class RotaryEmbedding(torch.nn.Module): 20 | def __init__(self, dim, base=10000, precision=torch.bfloat16): 21 | super().__init__() 22 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 23 | self.register_buffer("inv_freq", inv_freq) 24 | self.seq_len_cached = None 25 | self.cos_cached = None 26 | self.sin_cached = None 27 | self.precision = precision 28 | 29 | def forward(self, x, seq_dim=1, seq_len=None): 30 | if seq_len is None: 31 | seq_len = x.shape[seq_dim] 32 | if seq_len != self.seq_len_cached: 33 | self.seq_len_cached = seq_len 34 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 35 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 36 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 37 | if self.precision == torch.bfloat16: 38 | emb = emb.float() 39 | self.cos_cached = emb.cos()[:, None, None, :] 40 | self.sin_cached = emb.sin()[:, None, None, :] 41 | if self.precision == torch.bfloat16: 42 | self.cos_cached = self.cos_cached.bfloat16() 43 | self.sin_cached = self.sin_cached.bfloat16() 44 | return self.cos_cached, self.sin_cached 45 | 46 | 47 | def rotate_half(x): 48 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 49 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 50 | 51 | 52 | @torch.jit.script 53 | def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): 54 | cos, sin = ( 55 | cos[offset : q.shape[0] + offset, ...], 56 | sin[offset : q.shape[0] + offset, ...], 57 | ) 58 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 59 | 60 | 61 | def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 62 | cos, sin = ( 63 | cos[offset : q.shape[0] + offset, ...], 64 | sin[offset : q.shape[0] + offset, ...], 65 | ) 66 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 67 | -------------------------------------------------------------------------------- /megatron/model/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Utilities for models.""" 20 | 21 | import math 22 | import socket 23 | 24 | import torch 25 | from numpy.random import default_rng, SeedSequence 26 | import multiprocessing 27 | import concurrent.futures 28 | import numpy as np 29 | from typing import Sequence, List 30 | import functools 31 | 32 | from megatron import get_args 33 | 34 | from megatron import mpu 35 | 36 | 37 | def init_method_normal(sigma): 38 | """Init method based on N(0, sigma).""" 39 | 40 | def fill(rng: np.random.Generator, shape: Sequence[int]): 41 | n_threads = 16 42 | n_elements = np.prod(shape) 43 | values = np.empty(n_elements) 44 | assert n_elements % n_threads == 0, "Number of elements must be a multiple of number of threads!" 45 | step = np.ceil(n_elements / n_threads).astype(np.int_) 46 | 47 | # TODO(erich): hacky way to get new generator seeds, we should use spawn as soon as we have 48 | # numpy 1.25 and not this!!! 49 | seeds = rng.integers(0, 1 << 63, n_threads) 50 | 51 | _random_generators = [np.random.default_rng(s) for s in seeds] 52 | 53 | executor = concurrent.futures.ThreadPoolExecutor(n_threads) 54 | 55 | def _fill(generator, out, first, last): 56 | out[first:last] = generator.normal(loc=0.0, scale=sigma, size=step) 57 | 58 | futures = {} 59 | for i in range(n_threads): 60 | args = (_fill, _random_generators[i], values, i * step, (i + 1) * step) 61 | futures[executor.submit(*args)] = i 62 | concurrent.futures.wait(futures) 63 | return np.reshape(values, shape) 64 | 65 | return fill 66 | 67 | 68 | # TODO(erich): combine these two functions to reduce code duplication 69 | def scaled_init_method_normal(sigma, num_layers): 70 | """Init method based on N(0, sigma/sqrt(2*num_layers).""" 71 | std = sigma / math.sqrt(2.0 * num_layers) 72 | 73 | def fill(rng: np.random.Generator, shape: Sequence[int]): 74 | n_threads = 16 75 | n_elements = np.prod(shape) 76 | values = np.empty(n_elements) 77 | assert n_elements % n_threads == 0, "Number of elements must be a multiple of number of threads!" 78 | step = np.ceil(n_elements / n_threads).astype(np.int_) 79 | 80 | # TODO(erich): hacky way to get new generator seeds, we should use spawn as soon as we have 81 | # numpy 1.25 and not this!!! 82 | seeds = rng.integers(0, 1 << 63, n_threads) 83 | 84 | _random_generators = [np.random.default_rng(s) for s in seeds] 85 | executor = concurrent.futures.ThreadPoolExecutor(n_threads) 86 | 87 | def _fill(generator, out, first, last): 88 | out[first:last] = generator.normal(loc=0.0, scale=std, size=step) 89 | 90 | futures = {} 91 | for i in range(n_threads): 92 | args = (_fill, _random_generators[i], values, i * step, (i + 1) * step) 93 | futures[executor.submit(*args)] = i 94 | concurrent.futures.wait(futures) 95 | return np.reshape(values, shape) 96 | 97 | return fill 98 | 99 | 100 | def check_shapes(tensor, expected_shapes=None, expected_ndim=None, expected_dtype=None): 101 | """Check that the passed-in `tensor` has the expected shape and ndim. Should an expected dim be None, it is not checked. 102 | 103 | Args: 104 | tensor: The tensor to shape-check. 105 | expected_shapes: The prefix of the shapes to expect in `tensor`. 106 | expected_ndim: The expected number of dimensions in `tensor`. 107 | expected_dtype: Expected dtype. 108 | """ 109 | assert tensor is not None, "tensor is None" 110 | 111 | if expected_shapes is None and expected_ndim is None and expected_dtype is None: 112 | return 113 | 114 | if expected_dtype is not None: 115 | assert tensor.dtype == expected_dtype 116 | 117 | if expected_ndim is not None: 118 | assert tensor.ndim == expected_ndim, f"Unexpected ndims detected. Got: {tensor.ndim} expected: {expected_ndim}" 119 | 120 | expected_shapes = list(expected_shapes or []) 121 | 122 | if expected_ndim is not None: 123 | assert ( 124 | len(expected_shapes) <= tensor.ndim 125 | ), f"Asking to check too many shapes. Got: {expected_shapes} expected: <= {tensor.ndim}" 126 | 127 | t_shapes = list(tensor.shape[: len(expected_shapes)]) 128 | for i, e in enumerate(expected_shapes): 129 | if e is None: 130 | t_shapes[i] = None 131 | assert t_shapes == expected_shapes, f"Mismatch detected. Got: {t_shapes} expected: {expected_shapes}" 132 | 133 | 134 | def get_linear_layer(rows, columns, init_method): 135 | """Simple linear layer with weight initialization.""" 136 | layer = torch.nn.Linear(rows, columns) 137 | if get_args().perform_initialization: 138 | init_method(layer.weight) 139 | with torch.no_grad(): 140 | layer.bias.zero_() 141 | return layer 142 | 143 | 144 | @torch.jit.script 145 | def gelu_impl(x): 146 | """OpenAI's gelu implementation.""" 147 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) 148 | 149 | 150 | def openai_gelu(x): 151 | return gelu_impl(x) 152 | 153 | 154 | # This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter 155 | @torch.jit.script 156 | def erf_gelu(x): 157 | return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) 158 | 159 | 160 | def print_named_parameters(model): 161 | """Print a summary of the parameters in a model.""" 162 | prefix = "" 163 | if torch.distributed.is_initialized(): 164 | # Print on only the first data parallel rank, but on all tensor/pipeline parallel ranks. 165 | should_print = mpu.get_data_parallel_rank() == 0 166 | if mpu.get_tensor_model_parallel_world_size() > 1: 167 | prefix = f"tensor-rank: {mpu.get_tensor_model_parallel_rank()} | " 168 | if mpu.get_pipeline_model_parallel_world_size() > 1: 169 | prefix = f"pipeline-rank: {mpu.get_pipeline_model_parallel_rank()} | " 170 | else: 171 | should_print = get_args().rank == 0 172 | 173 | if not should_print: 174 | return 175 | 176 | print(f"{prefix} > {type(model).__name__} parameters: ", flush=True) 177 | for name, param in model.named_parameters(): 178 | if mpu.param_is_tensor_parallel_unique(param): 179 | print(f"{prefix}{name=}, {param.shape=}, norm={torch.norm(param.data.float()).item()}", flush=True) 180 | 181 | 182 | def sync_data_parallel_replicated_parameters(a, b): 183 | pass 184 | 185 | 186 | def sync_tensor_parallel_replicated_parameters(a, b): 187 | pass 188 | 189 | 190 | def _sync_replicated_parameters(a, b, c, d, e): 191 | 192 | def _sync(p): 193 | pass 194 | 195 | 196 | class ReplicationMismatchError(Exception): 197 | pass 198 | 199 | 200 | class TensorParallelReplicationMismatchError(ReplicationMismatchError): 201 | pass 202 | 203 | 204 | class DataParallelReplicationMismatchError(ReplicationMismatchError): 205 | pass 206 | 207 | 208 | def validate_replicated_parameters(m, o): 209 | 210 | def collect_and_validate(a, b, c, d, prefix): 211 | pass 212 | 213 | 214 | def _validate_replicated_parameters( 215 | *, a, b, c, d, prefix="" 216 | ): 217 | pass 218 | -------------------------------------------------------------------------------- /megatron/mpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Model parallel utility interface.""" 20 | 21 | from .cross_entropy import vocab_parallel_cross_entropy 22 | 23 | from .initialize import is_unitialized 24 | from .initialize import destroy_model_parallel 25 | from .initialize import force_communicator_creation 26 | from .initialize import get_data_parallel_group 27 | from .initialize import get_data_parallel_rank 28 | from .initialize import get_data_parallel_world_size 29 | from .initialize import get_embedding_group 30 | from .initialize import get_position_embedding_group 31 | from .initialize import get_model_parallel_group 32 | from .initialize import get_tensor_model_parallel_group 33 | from .initialize import get_pipeline_model_parallel_group 34 | from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank 35 | from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank 36 | from .initialize import is_pipeline_first_stage, is_pipeline_last_stage 37 | from .initialize import is_rank_in_embedding_group 38 | from .initialize import is_rank_in_position_embedding_group 39 | from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split 40 | from .initialize import is_pipeline_stage_at_split 41 | from .initialize import get_num_layers 42 | from .initialize import get_tensor_model_parallel_src_rank 43 | from .initialize import get_data_parallel_src_rank 44 | from .initialize import get_pipeline_model_parallel_first_rank 45 | from .initialize import get_pipeline_model_parallel_last_rank 46 | from .initialize import get_pipeline_model_parallel_next_rank 47 | from .initialize import get_pipeline_model_parallel_prev_rank 48 | from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size 49 | from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size 50 | from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank 51 | from .initialize import initialize_model_parallel 52 | from .initialize import model_parallel_is_initialized 53 | 54 | from .layers import LinearWithGradAccumulationAndAsyncCommunication 55 | from .layers import ColumnParallelLinear 56 | from .layers import ParallelLinear 57 | from .layers import RowParallelLinear 58 | from .layers import VocabParallelEmbedding 59 | from .layers import ( 60 | set_tensor_model_parallel_attributes, 61 | set_defaults_if_not_set_tensor_model_parallel_attributes, 62 | copy_tensor_model_parallel_attributes, 63 | param_is_tensor_parallel_unique, 64 | param_is_tensor_parallel_replicated, 65 | ) 66 | 67 | from .mappings import reduce_backward_from_tensor_model_parallel_region 68 | from .mappings import reduce_forward_from_tensor_model_parallel_region 69 | from .mappings import scatter_to_tensor_model_parallel_region 70 | from .mappings import gather_from_tensor_model_parallel_region 71 | from .mappings import scatter_to_sequence_parallel_region 72 | from .mappings import gather_from_sequence_parallel_region 73 | from .mappings import reduce_scatter_to_sequence_parallel_region 74 | 75 | from .random import checkpoint 76 | from .random import get_cuda_rng_tracker 77 | from .random import model_parallel_cuda_manual_seed 78 | from .random import gather_split_1d_tensor 79 | from .random import split_tensor_into_1d_equal_chunks 80 | from .random import make_viewless_tensor 81 | from .random import assert_viewless_tensor 82 | from .random import safely_set_viewless_tensor_data 83 | 84 | from .utils import divide 85 | from .utils import split_tensor_along_last_dim 86 | -------------------------------------------------------------------------------- /megatron/mpu/communication.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Communications utilities.""" 20 | 21 | 22 | from typing import Sequence, Union 23 | import collections 24 | import torch 25 | 26 | from megatron import mpu 27 | from megatron.model.utils import check_shapes 28 | 29 | 30 | # TODO: use functions from megatron/p2p 31 | def recv_from_prev_pipeline_rank_(recv_buffer=None): 32 | """Receive from previous pipeline stage and update the 33 | input buffer inplace.""" 34 | if not mpu.is_pipeline_first_stage(): 35 | assert recv_buffer is not None 36 | recv_prev_op = torch.distributed.P2POp( 37 | torch.distributed.irecv, recv_buffer, mpu.get_pipeline_model_parallel_prev_rank() 38 | ) 39 | reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) 40 | for req in reqs: 41 | req.wait() 42 | # To protect against race condition when using batch_isend_irecv(). 43 | torch.cuda.synchronize() 44 | 45 | 46 | # TODO: use functions from megatron/p2p 47 | def send_to_next_pipeline_rank(tensor=None): 48 | """Send output to the next pipeline stage.""" 49 | if not mpu.is_pipeline_last_stage(): 50 | assert tensor is not None 51 | send_next_op = torch.distributed.P2POp( 52 | torch.distributed.isend, tensor, mpu.get_pipeline_model_parallel_next_rank() 53 | ) 54 | reqs = torch.distributed.batch_isend_irecv([send_next_op]) 55 | for req in reqs: 56 | req.wait() 57 | # To protect against race condition when using batch_isend_irecv(). 58 | torch.cuda.synchronize() 59 | 60 | 61 | def _is_cuda(tensor): 62 | """Check if a tensor is not none and is cuda.""" 63 | assert tensor is not None 64 | assert tensor.is_cuda 65 | 66 | 67 | def _is_cuda_contiguous(tensor): 68 | """Check if a tensor is not none, is cuda, and is contiguous.""" 69 | _is_cuda(tensor) 70 | assert tensor.is_contiguous() 71 | 72 | 73 | def _check_shapes(tensor: torch.Tensor, size: Union[int, Sequence[int]], dtype: torch.dtype): 74 | """Check that the sending and receiver tensors will be compatible.""" 75 | # This logic is required because torch.empty will promote a size of int to [int] and there are lots 76 | # of callers that rely on this convention. 77 | shape = size if isinstance(size, collections.abc.Sequence) else [size] 78 | 79 | check_shapes(tensor, expected_shapes=shape, expected_ndim=len(shape), expected_dtype=dtype) 80 | 81 | 82 | def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): 83 | """Broadcast a tensor from last pipeline stage to all ranks.""" 84 | 85 | is_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=True) 86 | # If first stage and last state are the same, then there is no 87 | # pipeline parallelism and no need to communicate. 88 | if mpu.is_pipeline_first_stage() and is_last_stage: 89 | return tensor 90 | 91 | if is_last_stage: 92 | _is_cuda_contiguous(tensor) 93 | _check_shapes(tensor, size, dtype) 94 | else: 95 | tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) 96 | # Get the group and corresponding source rank. 97 | src = mpu.get_pipeline_model_parallel_last_rank() 98 | group = mpu.get_pipeline_model_parallel_group() 99 | torch.distributed.broadcast(tensor, src, group) 100 | 101 | return tensor 102 | 103 | 104 | def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): 105 | """Broadcast tensor values from last stage into the first stage.""" 106 | 107 | is_last_stage = mpu.is_pipeline_last_stage() 108 | is_first_stage = mpu.is_pipeline_first_stage() 109 | # If first stage and last state are the same, then there is no 110 | # pipeline parallelism and no need to communicate. 111 | if is_first_stage and is_last_stage: 112 | return tensor 113 | # Only first and last stage pipeline stages need to be involved. 114 | if is_last_stage or is_first_stage: 115 | if is_last_stage: 116 | _is_cuda_contiguous(tensor) 117 | _check_shapes(tensor, size, dtype) 118 | else: 119 | tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) 120 | src = mpu.get_pipeline_model_parallel_last_rank() 121 | group = mpu.get_embedding_group() 122 | # Broadcast from last stage into the first stage. 123 | torch.distributed.broadcast(tensor, src, group) 124 | else: 125 | tensor = None 126 | 127 | return tensor 128 | 129 | 130 | def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): 131 | """Copy tensor values from last stage into the first stage. 132 | Note that the input tensor is updated in place.""" 133 | 134 | is_last_stage = mpu.is_pipeline_last_stage() 135 | is_first_stage = mpu.is_pipeline_first_stage() 136 | # If first stage and last state are the same, then there is no 137 | # pipeline parallelism and no need to communicate. 138 | if is_first_stage and is_last_stage: 139 | return 140 | # Only first and last stage pipeline stages need to be involved. 141 | if is_last_stage or is_first_stage: 142 | _is_cuda(tensor) 143 | is_contiguous = tensor.is_contiguous() 144 | src = mpu.get_pipeline_model_parallel_last_rank() 145 | group = mpu.get_embedding_group() 146 | if is_contiguous: 147 | tensor_ = tensor 148 | else: 149 | if is_last_stage: 150 | tensor_ = tensor.contiguous() 151 | else: 152 | tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) 153 | # Broadcast from last stage into the first stage. 154 | torch.distributed.broadcast(tensor_, src, group) 155 | # Update the first stage tensor 156 | if is_first_stage and not is_contiguous: 157 | tensor[...] = tensor_ 158 | 159 | 160 | def broadcast_tensor(size, dtype, tensor=None, rank=0, device=0): 161 | """Given size and type of a tensor on all ranks and the tensor value 162 | only on a specific rank, broadcast from that rank to all other ranks. 163 | Args: 164 | size: size of the tensor 165 | dtype: type of the tensor 166 | tensor: tensor to be broadcasted 167 | rank: primary rank for broadcasting 168 | device: device of the tensor. If not set to None, then we use cuda.current_device(). 169 | Default is 0, since we use cuda.current_device() to get the device. 170 | """ 171 | if device is not None: 172 | device = torch.cuda.current_device() 173 | if torch.distributed.get_rank() == rank: 174 | if device is not None: 175 | _is_cuda_contiguous(tensor) 176 | _check_shapes(tensor, size, dtype) 177 | else: 178 | tensor = torch.empty(size, dtype=dtype, device=device) 179 | torch.distributed.broadcast(tensor, rank) 180 | return tensor 181 | 182 | 183 | def broadcast_list(size, dtype, list_values=None, rank=0, device=0): 184 | """Broadcast a list of values with a given type. 185 | Args: 186 | size: size of the list 187 | dtype: dtype of the list 188 | list_values: list of values to be broadcasted 189 | rank: primary rank for broadcasting 190 | device: device of the tensor. If not set to None, then we use cuda.current_device(). 191 | Default is 0, since we use cuda.current_device() to get the device. 192 | """ 193 | tensor = None 194 | if device is not None: 195 | device = torch.cuda.current_device() 196 | if torch.distributed.get_rank() == rank: 197 | tensor = torch.tensor(list_values, dtype=dtype, device=device) 198 | return broadcast_tensor(size, dtype, tensor=tensor, rank=rank, device=device) 199 | 200 | 201 | def broadcast_int_list(size, int_list=None, rank=0, device=0): 202 | """Broadcast a list of interger values. 203 | Args: 204 | size: size of the list 205 | int_list: list of values to be broadcasted 206 | rank: primary rank for broadcasting 207 | device: device of the tensor. If not set to None, then we use cuda.current_device(). 208 | Default is 0, since we use cuda.current_device() to get the device. 209 | """ 210 | if device is not None: 211 | device = torch.cuda.current_device() 212 | return broadcast_list(size, torch.int64, list_values=int_list, rank=rank, device=device) 213 | 214 | 215 | def broadcast_float_list(size, float_list=None, rank=0, device=0): 216 | """Broadcast a list of float values. 217 | Args: 218 | size: size of the list 219 | float_list: list of values to be broadcasted 220 | rank: primary rank for broadcasting 221 | device: device of the tensor. If not set to None, then we use cuda.current_device(). 222 | Default is 0, since we use cuda.current_device() to get the device. 223 | """ 224 | if device is not None: 225 | device = torch.cuda.current_device() 226 | return broadcast_list(size, torch.float32, list_values=float_list, rank=rank, device=device) 227 | -------------------------------------------------------------------------------- /megatron/mpu/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | import torch 21 | 22 | from .initialize import ( 23 | get_tensor_model_parallel_group, 24 | get_tensor_model_parallel_rank, 25 | get_tensor_model_parallel_world_size, 26 | ) 27 | from .utils import VocabUtility 28 | 29 | 30 | class _VocabParallelCrossEntropy(torch.autograd.Function): 31 | @staticmethod 32 | def forward(ctx, vocab_parallel_logits, target): 33 | 34 | # Maximum value along vocab dimension across all GPUs. 35 | logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] 36 | torch.distributed.all_reduce( 37 | logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() 38 | ) 39 | # Subtract the maximum value. 40 | vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) 41 | 42 | # Get the partition's vocab indecies 43 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size 44 | partition_vocab_size = vocab_parallel_logits.size()[-1] 45 | rank = get_tensor_model_parallel_rank() 46 | world_size = get_tensor_model_parallel_world_size() 47 | vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) 48 | 49 | # Create a mask of valid vocab ids (1 means it needs to be masked). 50 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index) 51 | masked_target = target.clone() - vocab_start_index 52 | masked_target[target_mask] = 0 53 | 54 | # Get predicted-logits = logits[target]. 55 | # For Simplicity, we convert logits to a 2-D tensor with size 56 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. 57 | logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) 58 | masked_target_1d = masked_target.view(-1) 59 | arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) 60 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] 61 | predicted_logits_1d = predicted_logits_1d.clone().contiguous() 62 | predicted_logits = predicted_logits_1d.view_as(target) 63 | predicted_logits[target_mask] = 0.0 64 | # All reduce is needed to get the chunks from other GPUs. 65 | torch.distributed.all_reduce( 66 | predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() 67 | ) 68 | 69 | # Sum of exponential of logits along vocab dimension across all GPUs. 70 | exp_logits = vocab_parallel_logits 71 | torch.exp(vocab_parallel_logits, out=exp_logits) 72 | sum_exp_logits = exp_logits.sum(dim=-1) 73 | torch.distributed.all_reduce( 74 | sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() 75 | ) 76 | 77 | # Loss = log(sum(exp(logits))) - predicted-logit. 78 | loss = torch.log(sum_exp_logits) - predicted_logits 79 | 80 | # Store softmax, target-mask and masked-target for backward pass. 81 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) 82 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) 83 | 84 | return loss 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | 89 | # Retreive tensors from the forward path. 90 | softmax, target_mask, masked_target_1d = ctx.saved_tensors 91 | 92 | # All the inputs have softmax as thier gradient. 93 | grad_input = softmax 94 | # For simplicity, work with the 2D gradient. 95 | partition_vocab_size = softmax.size()[-1] 96 | grad_2d = grad_input.view(-1, partition_vocab_size) 97 | 98 | # Add the gradient from matching classes. 99 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) 100 | grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() 101 | 102 | # Finally elementwise multiplication with the output gradients. 103 | grad_input.mul_(grad_output.unsqueeze(dim=-1)) 104 | 105 | return grad_input, None 106 | 107 | 108 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target): 109 | """Helper function for the cross entropy.""" 110 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) 111 | -------------------------------------------------------------------------------- /megatron/mpu/cross_entropy_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 ADEPT AI LABS INC. 2 | # This file is based on code by the authors denoted below and has been modified from its original version. 3 | # 4 | # Taken from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/losses/cross_entropy_parallel.py 5 | # But we import from Megatron instead of from Apex 6 | # HazyResearch is licensed under BSD 3.0 7 | 8 | # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py 9 | # But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and 10 | # the losses we can get the global loss. There's no need to do it step by step 11 | # (compute local max, exchange, compute exp, compute local sum, exchange, etc.) 12 | import torch 13 | import torch.nn as nn 14 | 15 | import xentropy_cuda_lib 16 | 17 | from .initialize import get_tensor_model_parallel_group 18 | from .initialize import get_tensor_model_parallel_rank 19 | from .initialize import get_tensor_model_parallel_world_size 20 | from .utils import VocabUtility 21 | 22 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for 23 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent 24 | # version of PyTorch. The following 4 lines are for backward compatibility with 25 | # older PyTorch. 26 | if "all_gather_into_tensor" not in dir(torch.distributed): 27 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 28 | if "reduce_scatter_tensor" not in dir(torch.distributed): 29 | torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base 30 | 31 | 32 | class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): 33 | @staticmethod 34 | def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False): 35 | """ 36 | logits_parallel: (batch, vocab_size / world_size) 37 | labels: (batch,) 38 | """ 39 | batch, partition_vocab_size = logits_parallel.shape 40 | assert labels.shape == (batch,) 41 | rank = get_tensor_model_parallel_rank() 42 | world_size = get_tensor_model_parallel_world_size() 43 | 44 | if world_size == 1: 45 | losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing) 46 | losses.masked_fill_(labels == ignored_index, 0) 47 | labels_local = labels 48 | else: 49 | vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( 50 | partition_vocab_size, get_tensor_model_parallel_rank(), get_tensor_model_parallel_world_size() 51 | ) 52 | 53 | # Create a mask of valid vocab ids (1 means it needs to be masked). 54 | labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) 55 | ignored_mask = labels == ignored_index 56 | labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) 57 | masked_labels = labels_local.clone() 58 | masked_labels[labels_mask] = ignored_index 59 | 60 | losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) 61 | assert lse_local.shape == (batch,) 62 | assert losses.shape == (batch,) 63 | losses.masked_fill_(masked_labels == ignored_index, 0) 64 | 65 | lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, device=lse_local.device) 66 | handle_lse = torch.distributed.all_gather_into_tensor( 67 | lse_allgather, lse_local.contiguous(), group=get_tensor_model_parallel_group(), async_op=True 68 | ) 69 | handle_losses = torch.distributed.all_reduce( 70 | losses, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=True 71 | ) 72 | handle_lse.wait() 73 | lse = torch.logsumexp(lse_allgather, dim=0) 74 | # The losses are going to be lse_local - predicted_logit, we just have to subtract 75 | # the lse_local and add the lse (global). 76 | rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode="floor") 77 | lse_local = lse_allgather[rank_per_sample, torch.arange(batch, device=lse_allgather.device)] 78 | 79 | handle_losses.wait() 80 | losses += lse - lse_local 81 | losses.masked_fill_(ignored_mask, 0) 82 | 83 | ctx.save_for_backward(logits_parallel, lse, labels_local) 84 | ctx.smoothing = smoothing 85 | ctx.ignored_index = ignored_index 86 | ctx.inplace_backward = inplace_backward 87 | return losses 88 | 89 | @staticmethod 90 | def backward(ctx, grad_loss): 91 | logits_parallel, lse, labels = ctx.saved_tensors 92 | if not grad_loss.is_contiguous(): 93 | grad_loss = grad_loss.contiguous() 94 | grad_loss.masked_fill_(labels == ctx.ignored_index, 0) 95 | grad_logits = xentropy_cuda_lib.backward( 96 | grad_loss, logits_parallel, lse, labels, ctx.smoothing, ctx.inplace_backward 97 | ) 98 | return grad_logits, None, None, None, None, None 99 | 100 | 101 | class CrossEntropyLossParallel(nn.Module): 102 | def __init__(self, ignore_index=-100, reduction="mean", label_smoothing=0.0, inplace_backward=False): 103 | super().__init__() 104 | if reduction not in ["mean", "none"]: 105 | raise NotImplementedError("Only support reduction = 'mean' or 'none'") 106 | self.ignore_index = ignore_index 107 | self.reduction = reduction 108 | self.label_smoothing = label_smoothing 109 | self.inplace_backward = inplace_backward 110 | 111 | def forward(self, input, target): 112 | assert input.is_cuda and target.is_cuda 113 | # SoftmaxCrossEntropyLoss implicitly casts to float 114 | loss = SoftmaxCrossEntropyLossParallelFn.apply( 115 | input, target, self.label_smoothing, self.ignore_index, self.inplace_backward 116 | ) 117 | if self.reduction == "mean": 118 | return loss.sum() / (target != self.ignore_index).sum() 119 | else: 120 | return loss 121 | -------------------------------------------------------------------------------- /megatron/mpu/mappings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | 21 | from .initialize import ( 22 | get_tensor_model_parallel_group, 23 | get_tensor_model_parallel_rank, 24 | get_tensor_model_parallel_world_size, 25 | ) 26 | from .utils import split_tensor_along_last_dim 27 | 28 | 29 | def _reduce(input_): 30 | """All-reduce the input tensor across model parallel group.""" 31 | 32 | # Bypass the function if we are using only 1 GPU. 33 | if get_tensor_model_parallel_world_size() == 1: 34 | return input_ 35 | 36 | # All-reduce, requires contiguous tensor. 37 | input_ = input_.contiguous() 38 | torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) 39 | 40 | return input_ 41 | 42 | 43 | def _split_along_last_dim(input_): 44 | """Split the tensor along its last dimension and keep the 45 | corresponding slice.""" 46 | 47 | world_size = get_tensor_model_parallel_world_size() 48 | # Bypass the function if we are using only 1 GPU. 49 | if world_size == 1: 50 | return input_ 51 | 52 | # Split along last dimension. 53 | input_list = split_tensor_along_last_dim(input_, world_size) 54 | 55 | # Note: torch.split does not create contiguous tensors by default. 56 | rank = get_tensor_model_parallel_rank() 57 | output = input_list[rank].contiguous() 58 | 59 | return output 60 | 61 | 62 | def _split_along_first_dim(input_): 63 | """Split the tensor along its first dimension and keep the 64 | corresponding slice.""" 65 | 66 | world_size = get_tensor_model_parallel_world_size() 67 | # Bypass the function if we are using only 1 GPU. 68 | if world_size == 1: 69 | return input_ 70 | 71 | # Split along first dimension. 72 | dim_size = input_.size()[0] 73 | assert ( 74 | dim_size % world_size == 0 75 | ), f"First dimension of the tensor should be divisible by tensor parallel size. {dim_size} % {world_size} != 0" 76 | local_dim_size = dim_size // world_size 77 | rank = get_tensor_model_parallel_rank() 78 | dim_offset = rank * local_dim_size 79 | 80 | output = input_[dim_offset : dim_offset + local_dim_size].contiguous() 81 | 82 | return output 83 | 84 | 85 | def _gather_along_last_dim(input_): 86 | """Gather tensors and concatinate along the last dimension.""" 87 | 88 | world_size = get_tensor_model_parallel_world_size() 89 | # Bypass the function if we are using only 1 GPU. 90 | if world_size == 1: 91 | return input_ 92 | 93 | # Size and dimension. 94 | last_dim = input_.dim() - 1 95 | rank = get_tensor_model_parallel_rank() 96 | 97 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 98 | tensor_list[rank] = input_ 99 | torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) 100 | 101 | # Note: torch.cat already creates a contiguous tensor. 102 | output = torch.cat(tensor_list, dim=last_dim).contiguous() 103 | 104 | return output 105 | 106 | 107 | def _gather_along_first_dim(input_): 108 | """Gather tensors and concatinate along the first dimension.""" 109 | 110 | world_size = get_tensor_model_parallel_world_size() 111 | # Bypass the function if we are using only 1 GPU. 112 | if world_size == 1: 113 | return input_ 114 | 115 | dim_size = list(input_.size()) 116 | dim_size[0] = dim_size[0] * world_size 117 | 118 | output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) 119 | torch.distributed._all_gather_base(output, input_.contiguous(), group=get_tensor_model_parallel_group()) 120 | 121 | return output 122 | 123 | 124 | def _reduce_scatter_along_first_dim(input_): 125 | """Reduce-scatter the input tensor across model parallel group.""" 126 | world_size = get_tensor_model_parallel_world_size() 127 | # Bypass the function if we are using only 1 GPU. 128 | if world_size == 1: 129 | return input_ 130 | 131 | dim_size = list(input_.size()) 132 | assert dim_size[0] % world_size == 0, "First dimension of the tensor should be divisible by tensor parallel size" 133 | 134 | dim_size[0] = dim_size[0] // world_size 135 | 136 | output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) 137 | torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=get_tensor_model_parallel_group()) 138 | return output 139 | 140 | 141 | class _ReduceBackwardFromTensorModelParallelRegion(torch.autograd.Function): 142 | """All-reduce the backward pass input from the model parallel region.""" 143 | 144 | @staticmethod 145 | def symbolic(graph, input_): 146 | return input_ 147 | 148 | @staticmethod 149 | def forward(ctx, input_): 150 | return input_ 151 | 152 | @staticmethod 153 | def backward(ctx, grad_output): 154 | return _reduce(grad_output) 155 | 156 | 157 | class _ReduceForwardFromTensorModelParallelRegion(torch.autograd.Function): 158 | """All-reduce the input from the model parallel region.""" 159 | 160 | @staticmethod 161 | def symbolic(graph, input_): 162 | return _reduce(input_) 163 | 164 | @staticmethod 165 | def forward(ctx, input_): 166 | return _reduce(input_) 167 | 168 | @staticmethod 169 | def backward(ctx, grad_output): 170 | return grad_output 171 | 172 | 173 | class _ScatterToModelParallelRegion(torch.autograd.Function): 174 | """Split the input and keep only the corresponding chuck to the rank.""" 175 | 176 | @staticmethod 177 | def symbolic(graph, input_): 178 | return _split_along_last_dim(input_) 179 | 180 | @staticmethod 181 | def forward(ctx, input_): 182 | return _split_along_last_dim(input_) 183 | 184 | @staticmethod 185 | def backward(ctx, grad_output): 186 | return _gather_along_last_dim(grad_output) 187 | 188 | 189 | class _GatherFromModelParallelRegion(torch.autograd.Function): 190 | """Gather the input from model parallel region and concatinate.""" 191 | 192 | @staticmethod 193 | def symbolic(graph, input_): 194 | return _gather_along_last_dim(input_) 195 | 196 | @staticmethod 197 | def forward(ctx, input_): 198 | return _gather_along_last_dim(input_) 199 | 200 | @staticmethod 201 | def backward(ctx, grad_output): 202 | return _split_along_last_dim(grad_output) 203 | 204 | 205 | class _ScatterToSequenceParallelRegion(torch.autograd.Function): 206 | """Split the input and keep only the corresponding chuck to the rank.""" 207 | 208 | @staticmethod 209 | def symbolic(graph, input_): 210 | return _split_along_first_dim(input_) 211 | 212 | @staticmethod 213 | def forward(ctx, input_): 214 | return _split_along_first_dim(input_) 215 | 216 | @staticmethod 217 | def backward(ctx, grad_output): 218 | return _gather_along_first_dim(grad_output) 219 | 220 | 221 | class _GatherFromSequenceParallelRegion(torch.autograd.Function): 222 | """Gather the input from sequence parallel region and concatinate.""" 223 | 224 | @staticmethod 225 | def symbolic(graph, input_, tensor_parallel_output_grad=True): 226 | return _gather_along_first_dim(input_) 227 | 228 | @staticmethod 229 | def forward(ctx, input_, tensor_parallel_output_grad=True): 230 | ctx.tensor_parallel_output_grad = tensor_parallel_output_grad 231 | return _gather_along_first_dim(input_) 232 | 233 | @staticmethod 234 | def backward(ctx, grad_output): 235 | tensor_parallel_output_grad = ctx.tensor_parallel_output_grad 236 | 237 | # If the computation graph after the gather operation is 238 | # in the tensor parallel mode, output gradients need to reduce 239 | # scattered and whereas if the computation is duplicated, 240 | # output gradients need to be scattered. 241 | if tensor_parallel_output_grad: 242 | return _reduce_scatter_along_first_dim(grad_output), None 243 | else: 244 | return _split_along_first_dim(grad_output), None 245 | 246 | 247 | class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): 248 | """Reduce scatter the input from the model parallel region.""" 249 | 250 | @staticmethod 251 | def symbolic(graph, input_): 252 | return _reduce_scatter_along_first_dim(input_) 253 | 254 | @staticmethod 255 | def forward(ctx, input_): 256 | return _reduce_scatter_along_first_dim(input_) 257 | 258 | @staticmethod 259 | def backward(ctx, grad_output): 260 | return _gather_along_first_dim(grad_output) 261 | 262 | 263 | # ----------------- 264 | # Helper functions. 265 | # ----------------- 266 | 267 | 268 | def reduce_backward_from_tensor_model_parallel_region(input_): 269 | return _ReduceBackwardFromTensorModelParallelRegion.apply(input_) 270 | 271 | 272 | def reduce_forward_from_tensor_model_parallel_region(input_): 273 | return _ReduceForwardFromTensorModelParallelRegion.apply(input_) 274 | 275 | 276 | def scatter_to_tensor_model_parallel_region(input_): 277 | return _ScatterToModelParallelRegion.apply(input_) 278 | 279 | 280 | def gather_from_tensor_model_parallel_region(input_): 281 | return _GatherFromModelParallelRegion.apply(input_) 282 | 283 | 284 | def scatter_to_sequence_parallel_region(input_): 285 | return _ScatterToSequenceParallelRegion.apply(input_) 286 | 287 | 288 | def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): 289 | return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) 290 | 291 | 292 | def reduce_scatter_to_sequence_parallel_region(input_): 293 | return _ReduceScatterToSequenceParallelRegion.apply(input_) 294 | -------------------------------------------------------------------------------- /megatron/mpu/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | import torch 21 | 22 | 23 | def ensure_divisibility(numerator, denominator): 24 | """Ensure that numerator is divisible by the denominator.""" 25 | assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) 26 | 27 | 28 | def divide(numerator, denominator): 29 | """Ensure that numerator is divisible by the denominator and return 30 | the division value.""" 31 | ensure_divisibility(numerator, denominator) 32 | return numerator // denominator 33 | 34 | 35 | def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): 36 | """Split a tensor along its last dimension. 37 | Arguments: 38 | tensor: input tensor. 39 | num_partitions: number of partitions to split the tensor 40 | contiguous_split_chunks: If True, make each chunk contiguous 41 | in memory. 42 | """ 43 | # Get the size and dimension. 44 | last_dim = tensor.dim() - 1 45 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 46 | # Split. 47 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 48 | # Note: torch.split does not create contiguous tensors by default. 49 | if contiguous_split_chunks: 50 | return tuple(chunk.contiguous() for chunk in tensor_list) 51 | 52 | return tensor_list 53 | 54 | 55 | class VocabUtility: 56 | """Split the vocabulary into `world_size` chunks amd return the 57 | first and last index of the vocabulary belonging to the `rank` 58 | partition: Note that indecies in [fist, last)""" 59 | 60 | @staticmethod 61 | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): 62 | index_f = rank * per_partition_vocab_size 63 | index_l = index_f + per_partition_vocab_size 64 | return index_f, index_l 65 | 66 | @staticmethod 67 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): 68 | per_partition_vocab_size = divide(global_vocab_size, world_size) 69 | return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) 70 | -------------------------------------------------------------------------------- /megatron/text_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | from .api import generate, generate_and_post_process 21 | -------------------------------------------------------------------------------- /megatron/text_generation/api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Inference API.""" 20 | 21 | 22 | import base64 23 | from typing import List, Optional, Tuple, Any, Dict 24 | 25 | import numpy as np 26 | import torch 27 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 28 | 29 | from megatron.model.module import MegatronModule 30 | from megatron import mpu 31 | from megatron.model import DistributedDataParallel as LocalDDP 32 | from megatron.model import Float16Module 33 | 34 | from megatron.tokenizer.tokenizer import AbstractTokenizer 35 | from megatron.mpu.communication import broadcast_float_list, broadcast_int_list 36 | from megatron.text_generation.generation import ( 37 | generate_tokens_probs_and_return_on_first_stage, 38 | score_and_return_on_first_stage, 39 | ) 40 | from megatron.text_generation.inference_params import InferenceParams 41 | from megatron.text_generation.tokenization import ( 42 | convert_generations_to_human_readable_tokens, 43 | tokenize_prompts, 44 | ) 45 | from megatron.utils import unwrap_model 46 | import numpy.typing as npt 47 | 48 | def preprocess_prompts(prompts: Optional[List[List[str]]]) -> Optional[List[List[str]]]: 49 | """ 50 | Accepts a list of list of subprompts, returns a list of list of processed subprompts. 51 | """ 52 | if prompts is None: 53 | return None 54 | processed_prompts = [] 55 | for prompt in prompts: 56 | processed_subprompts = [] 57 | for subprompt in prompt: 58 | processed_subprompts.append(f"human: {subprompt.strip()}\n\nadept:") 59 | processed_prompts.append(processed_subprompts) 60 | return processed_prompts 61 | 62 | 63 | def generate_and_post_process( 64 | model: MegatronModule, 65 | params_dtype: torch.dtype, 66 | max_position_embeddings: int, 67 | termination_id: int, 68 | tokenizer: AbstractTokenizer, 69 | prompts: Optional[List[List[str]]] = None, 70 | max_tokens_to_generate: int = 0, 71 | inference_params: Optional[InferenceParams] = None, 72 | return_output_log_probs: bool = False, 73 | return_all_log_probs: bool = False, 74 | log_prob_tokens: Optional[torch.Tensor] = None, 75 | top_k_sampling: int = 0, 76 | top_p_sampling: float = 0.0, 77 | temperature: float = 1.0, 78 | add_BOS: bool = False, 79 | random_seed: int = -1, 80 | process_prompts_for_chat: bool = False 81 | ) -> Optional[ 82 | Tuple[ 83 | List[str], 84 | Optional[Any], 85 | List[List[int]], 86 | List[str], 87 | Optional[str], 88 | Optional[Any], 89 | List[List[str]], 90 | Optional[Any], 91 | ] 92 | ]: 93 | """Run inference and post-process outputs, i.e., detokenize, 94 | move to cpu and convert to list. 95 | 96 | prompts: a list of list of strings, where each element in the outer list represents a single sample that 97 | is an item in the batch. A single sample is represented as a list of strings. 98 | """ 99 | # Pre-process the prompts. 100 | if process_prompts_for_chat: 101 | prompts = preprocess_prompts(prompts) 102 | 103 | # Main inference. 104 | outputs = generate( 105 | model, 106 | max_position_embeddings=max_position_embeddings, 107 | params_dtype=params_dtype, 108 | termination_id=termination_id, 109 | tokenizer=tokenizer, 110 | prompts=prompts, 111 | max_tokens_to_generate=max_tokens_to_generate, 112 | inference_params=inference_params, 113 | return_output_log_probs=return_output_log_probs, 114 | return_all_log_probs=return_all_log_probs, 115 | log_prob_tokens=log_prob_tokens, 116 | top_k_sampling=top_k_sampling, 117 | top_p_sampling=top_p_sampling, 118 | temperature=temperature, 119 | add_BOS=add_BOS, 120 | random_seed=random_seed, 121 | ) 122 | 123 | # Only post-process on first stage. 124 | if mpu.is_pipeline_first_stage(): 125 | 126 | all_tokens = outputs["tokens"].cpu().numpy().tolist() 127 | 128 | raw_prompts_plus_generations = [tokenizer.detokenize(ts) for ts in all_tokens] 129 | processed_generations = [ 130 | tokenizer.detokenize(toks) for toks in outputs.get("generated_tokens", []) 131 | ] 132 | 133 | processed_generated_tokens = [ 134 | tokenizer.encode(g) for g in processed_generations 135 | ] 136 | 137 | human_readable_tokens = convert_generations_to_human_readable_tokens( 138 | processed_generated_tokens, "" 139 | ) 140 | 141 | output_log_probs = None 142 | if return_output_log_probs: 143 | output_log_probs = outputs["output_log_probs"].cpu().numpy().tolist() 144 | 145 | all_log_probs = None 146 | if return_all_log_probs: 147 | all_log_probs = outputs["all_log_probs"].cpu().numpy().tolist() 148 | 149 | return ( 150 | raw_prompts_plus_generations, 151 | output_log_probs, 152 | all_tokens, 153 | processed_generations, 154 | all_log_probs, 155 | human_readable_tokens, 156 | ) 157 | 158 | return None 159 | 160 | 161 | def generate( 162 | model: MegatronModule, 163 | max_position_embeddings: int, 164 | params_dtype: torch.dtype, 165 | termination_id: int, 166 | tokenizer: AbstractTokenizer, 167 | prompts: Optional[List[List[str]]] = None, 168 | max_tokens_to_generate: int = 0, 169 | inference_params: Optional[InferenceParams] = None, 170 | return_output_log_probs: bool = False, 171 | return_all_log_probs: bool = False, 172 | log_prob_tokens: Optional[torch.Tensor] = None, 173 | top_k_sampling: int = 0, 174 | top_p_sampling: float = 0.0, 175 | temperature: float = 1.0, 176 | add_BOS: bool = False, 177 | random_seed: int = -1, 178 | ) -> Dict[str, Any]: 179 | """Given prompts and input parameters, run inference and return: 180 | tokens: prompts plus the generated tokens. 181 | lengths: length of the prompt + generations. Note that we can 182 | discard tokens in the tokens tensor that are after the 183 | corresponding length. 184 | output_log_probs: log probs of the tokens. 185 | all_log_probs: log probs of all the vocab (or just log_prob_tokens if provided) 186 | tokens for each generated token position. 187 | """ 188 | num_log_prob_tokens = 0 if log_prob_tokens is None else len(log_prob_tokens) 189 | 190 | # Make sure input params are avaialble to all ranks. 191 | values = [ 192 | max_tokens_to_generate, 193 | return_output_log_probs, 194 | top_k_sampling, 195 | top_p_sampling, 196 | temperature, 197 | add_BOS, 198 | random_seed, 199 | return_all_log_probs, 200 | num_log_prob_tokens, 201 | ] 202 | values_float_tensor = broadcast_float_list(len(values), float_list=values) 203 | max_tokens_to_generate = int(values_float_tensor[0].item()) 204 | return_output_log_probs = bool(values_float_tensor[1].item()) 205 | top_k_sampling = int(values_float_tensor[2].item()) 206 | top_p_sampling = values_float_tensor[3].item() 207 | temperature = values_float_tensor[4].item() 208 | add_BOS = bool(values_float_tensor[5].item()) 209 | random_seed = int(values_float_tensor[6].item()) 210 | return_all_log_probs = bool(values_float_tensor[7].item()) 211 | num_log_prob_tokens = int(values_float_tensor[8].item()) 212 | 213 | if return_all_log_probs and num_log_prob_tokens > 0: 214 | # Do another broadcast for the log_prob_tokens. 215 | log_prob_tokens = broadcast_int_list( 216 | num_log_prob_tokens, int_list=log_prob_tokens 217 | ) 218 | else: 219 | log_prob_tokens = None 220 | 221 | if random_seed != -1: 222 | torch.random.manual_seed(random_seed) 223 | 224 | # Tokenize prompts and get the batch. 225 | # Note that these tensors are broadcasted to all ranks. 226 | if torch.distributed.get_rank() == 0: 227 | assert prompts is not None 228 | 229 | context_tokens_tensor, context_length_tensor = tokenize_prompts( 230 | prompts=prompts, 231 | max_tokens_to_generate=max_tokens_to_generate, 232 | max_position_embeddings=max_position_embeddings, 233 | add_BOS=add_BOS, 234 | ) 235 | 236 | batch_size = context_tokens_tensor.shape[0] 237 | num_sub_sequences = context_tokens_tensor.shape[1] 238 | 239 | assert num_sub_sequences == 1 240 | # Remove subsequence dim 241 | context_length_tensor = torch.squeeze(context_length_tensor, dim=1) 242 | context_tokens_tensor = torch.squeeze(context_tokens_tensor, dim=1) 243 | 244 | if max_tokens_to_generate == 0: 245 | assert inference_params is not None 246 | return score_and_return_on_first_stage( 247 | model, 248 | context_tokens_tensor, 249 | context_length_tensor, 250 | inference_params, 251 | max_position_embeddings=max_position_embeddings, 252 | ) 253 | 254 | # Main inference function. 255 | # Note that the outputs are available on the first stage. 256 | assert inference_params is not None 257 | # Added termination_id to support the case that we want to terminate the 258 | # generation once that id is generated. 259 | 260 | outputs = generate_tokens_probs_and_return_on_first_stage( 261 | model, 262 | context_tokens_tensor, 263 | context_length_tensor, 264 | inference_params=inference_params, 265 | max_position_embeddings=max_position_embeddings, 266 | termination_id=termination_id, 267 | vocab_size=tokenizer.vocab_size, 268 | return_output_log_probs=return_output_log_probs, 269 | log_prob_tokens=log_prob_tokens, 270 | return_all_log_probs=return_all_log_probs, 271 | top_k=top_k_sampling, 272 | top_p=top_p_sampling, 273 | temperature=temperature, 274 | ) 275 | # Now we can figure out what actually got generated and return that too 276 | generated_tokens = [] 277 | context_lengths = context_length_tensor.detach().cpu().numpy().tolist() 278 | contexts = context_tokens_tensor.detach().cpu().numpy().tolist() 279 | lengths_cpu = outputs["lengths"].detach().cpu().numpy().tolist() 280 | 281 | for b in range(batch_size): 282 | assert lengths_cpu[b] > context_lengths[b] 283 | gen = contexts[b][context_lengths[b] : lengths_cpu[b]] 284 | generated_tokens.append(gen) 285 | 286 | outputs["generated_tokens"] = generated_tokens 287 | 288 | return outputs 289 | -------------------------------------------------------------------------------- /megatron/text_generation/forward_step.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Forward step utilities.""" 20 | 21 | from typing import Optional, Dict, Any, List 22 | from collections.abc import Iterable 23 | 24 | from torch import dtype 25 | import torch 26 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 27 | 28 | from megatron import get_args, mpu 29 | from megatron.text_generation.inference_params import InferenceParams 30 | from megatron.model import DistributedDataParallel as LocalDDP 31 | from megatron.model import Float16Module 32 | from megatron.model.module import MegatronModule 33 | from megatron.model.gpt_model import GPTModel 34 | from megatron.utils import unwrap_model 35 | 36 | from megatron.mpu.communication import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank 37 | 38 | 39 | class ForwardStep: 40 | """Forward step function with all the communications. 41 | We use a class here to hide the inference parameters 42 | from the outside caller.""" 43 | 44 | def __init__( 45 | self, 46 | model: MegatronModule, 47 | max_batch_size: int, 48 | max_sequence_len: int, 49 | inference_params: InferenceParams, 50 | ): 51 | """Set values so we don't need to do it multiple times.""" 52 | # Make sure model is in eval mode. 53 | assert not isinstance(model, Iterable), "interleaving schedule is not supported for inference" 54 | model.eval() 55 | self.model = model 56 | # Initialize inference parameters. 57 | self.inference_params = inference_params 58 | # Pipelining arguments. 59 | args = get_args() 60 | self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size > 1 61 | # Threshold of pipelining. 62 | self.pipelining_batch_x_seqlen = args.inference_batch_times_seqlen_threshold 63 | 64 | def __call__( 65 | self, 66 | tokens: torch.Tensor, 67 | position_ids: torch.Tensor, 68 | lm_logits_mask: Optional[torch.Tensor] = None, 69 | ) -> Optional[Dict[str, Any]]: 70 | """Invocation of the forward methods. Note that self.inference_params 71 | is being modified by the forward step.""" 72 | # Pipelining case. 73 | if self.pipeline_size_larger_than_one: 74 | current_batch_x_seqlen = tokens.size(0) * tokens.size(1) 75 | if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: 76 | raise ValueError("We deleted _with_pipelining_forward_step") 77 | 78 | return _no_pipelining_forward_step( 79 | self.model, 80 | tokens, 81 | position_ids, 82 | self.inference_params, 83 | lm_logits_mask=lm_logits_mask, 84 | ) 85 | 86 | 87 | def _get_recv_buffer_dtype(args: Any) -> dtype: 88 | """Receive happens between the layers.""" 89 | if args.fp32_residual_connection: 90 | return torch.float 91 | return args.params_dtype 92 | 93 | 94 | def _allocate_recv_buffer(batch_size: int, sequence_length: int) -> torch.Tensor: 95 | """Receive happens between the layers with size [s, b, h].""" 96 | if mpu.is_pipeline_first_stage(): 97 | return None 98 | args = get_args() 99 | recv_size = (sequence_length, batch_size, args.hidden_size) 100 | return torch.empty(recv_size, dtype=_get_recv_buffer_dtype(args), device=torch.cuda.current_device()) 101 | 102 | 103 | def _model_forward_step( 104 | model: MegatronModule, 105 | tokens: torch.Tensor, 106 | position_ids: torch.Tensor, 107 | inference_params: InferenceParams, 108 | lm_logits_mask: Optional[torch.Tensor] = None, 109 | ) -> Dict[str, Any]: 110 | # Run a simple forward pass. 111 | unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) 112 | outputs: Dict[str, Any] = {} 113 | if isinstance(unwrapped_model, GPTModel): 114 | outputs = model(tokens, position_ids, lm_logits_mask=lm_logits_mask, inference_params=inference_params) 115 | else: 116 | assert False, "Unknown model type!" + str(type(unwrapped_model)) 117 | 118 | if not isinstance(outputs, dict): 119 | outputs = {"logits": outputs} 120 | 121 | return outputs 122 | 123 | 124 | def _forward_step_helper( 125 | model: MegatronModule, 126 | tokens: torch.Tensor, 127 | position_ids: torch.Tensor, 128 | inference_params: InferenceParams, 129 | recv_buffer: Optional[torch.Tensor] = None, 130 | lm_logits_mask: Optional[torch.Tensor] = None, 131 | ) -> Dict[str, Any]: 132 | """Single forward step. Update the allocate memory flag so 133 | only the first time the memory is allocated.""" 134 | batch_size = tokens.size(0) 135 | sequence_length = tokens.size(1) 136 | if recv_buffer is None: 137 | recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) 138 | 139 | # Receive from previous stage. 140 | recv_from_prev_pipeline_rank_(recv_buffer) 141 | 142 | # Forward pass through the model. 143 | model.set_input_tensor(recv_buffer) 144 | outputs = _model_forward_step( 145 | model, 146 | tokens, 147 | position_ids, 148 | inference_params=inference_params, 149 | lm_logits_mask=lm_logits_mask, 150 | ) 151 | 152 | # Send output to the next stage. 153 | send_to_next_pipeline_rank(outputs) 154 | 155 | return outputs 156 | 157 | 158 | def _no_pipelining_forward_step( 159 | model: MegatronModule, 160 | tokens: torch.Tensor, 161 | position_ids: torch.Tensor, 162 | inference_params: InferenceParams, 163 | recv_buffer: Optional[torch.Tensor] = None, 164 | lm_logits_mask: Optional[torch.Tensor] = None, 165 | ) -> Optional[Dict[str, Any]]: 166 | """If recv_buffer is none, we will allocate one on the fly.""" 167 | # Run a simple forward pass. 168 | outputs = _forward_step_helper( 169 | model, 170 | tokens, 171 | position_ids, 172 | inference_params, 173 | recv_buffer=recv_buffer, 174 | lm_logits_mask=lm_logits_mask, 175 | ) 176 | # Update the sequence length offset. 177 | if inference_params is not None: 178 | inference_params.sequence_len_offset += tokens.size(1) 179 | 180 | lm_output = None 181 | if mpu.is_pipeline_last_stage(): 182 | lm_output = outputs 183 | 184 | return lm_output 185 | -------------------------------------------------------------------------------- /megatron/text_generation/inference_params.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Tuple, Any 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class InferenceParams: 8 | """Inference parameters that are passed to the main model in order 9 | to efficienly calculate and store the context during inference.""" 10 | 11 | def __init__( 12 | self, 13 | max_batch_size: int, 14 | max_sequence_len: int, 15 | lengths_per_sample: Optional[Tensor] = None, 16 | fused_ft_kernel: bool = False, 17 | ) -> None: 18 | # fused_ft_kernel: whether to use FasterTransformer fused single-query attention kernel. 19 | """Note that offsets are set to zero and we always set the 20 | flag to allocate memory. After the first call, make sure to 21 | set this flag to False.""" 22 | self.max_sequence_len = max_sequence_len 23 | self.max_batch_size = max_batch_size 24 | self.sequence_len_offset = 0 25 | self.batch_size_offset = 0 26 | self.key_value_memory_dict: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} 27 | # This is used incase of an encoder-decoder model. The encoder output can be cached 28 | # and reused for multiple decoder forward passes. 29 | self.encoder_hidden_state: Dict[Any, Any] = {} 30 | self.return_encoder_hidden_state = False 31 | self.fused_ft_kernel = fused_ft_kernel 32 | self.lengths_per_sample: Tensor = lengths_per_sample 33 | # Raise import error at initialization time instead of the 1st generation time. 34 | if fused_ft_kernel: 35 | try: 36 | # pylint: disable=import-outside-toplevel,unused-import 37 | import ft_attention # type: ignore 38 | except ImportError as exc: 39 | raise ImportError("Please install ft_attention from the FlashAttention repo.") from exc 40 | 41 | def reset(self) -> None: 42 | self.sequence_len_offset = 0 43 | self.batch_size_offset = 0 44 | 45 | def swap_key_value_dict(self, batch_idx: Any) -> None: 46 | "swap between batches" 47 | if len(self.key_value_memory_dict) == 0: 48 | raise ValueError("should not swap when dict in empty") 49 | 50 | for layer_number in self.key_value_memory_dict.keys(): # pylint: disable=consider-using-dict-items 51 | inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] 52 | assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same 53 | new_inference_key_memory = inference_key_memory[:, batch_idx] 54 | new_inference_value_memory = inference_value_memory[:, batch_idx] 55 | self.key_value_memory_dict[layer_number] = (new_inference_key_memory, new_inference_value_memory) 56 | -------------------------------------------------------------------------------- /megatron/text_generation/sampling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Sampling utilities. 20 | Part of this code is inspired by: 21 | - https://github.com/ari-holtzman/degen/blob/master/gen.py 22 | - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html 23 | """ 24 | 25 | 26 | from typing import Optional 27 | import torch 28 | 29 | 30 | def modify_logits_for_top_k_filtering(logits: torch.Tensor, top_k: int) -> None: 31 | """Set the logits for none top-k values to -inf.""" 32 | 33 | filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] 34 | logits.masked_fill_(filter_, float("-Inf")) 35 | 36 | 37 | def modify_logits_for_top_p_filtering(logits: torch.Tensor, top_p: float) -> None: 38 | """Set the logits for none top-p values to -inf.""" 39 | 40 | # First sort and calculate cumulative sum of probabilities. 41 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 42 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 43 | 44 | # Filteration based on the cumulative sum. 45 | filter_ = cumulative_probs > top_p 46 | # This shift by 1 is weird and I cannot justify it. This existed 47 | # in the original implementation: 48 | # https://github.com/ari-holtzman/degen/blob/master/gen.py 49 | # and I guess it is needed so keeping it for now. 50 | filter_[:, 1:] = filter_[:, :-1].clone() 51 | # Make sure we at least have one token to select from. 52 | filter_[..., 0] = 0 53 | 54 | # Fill in the filtered part 55 | filter_ = filter_.scatter(1, sorted_indices, filter_) 56 | logits.masked_fill_(filter_, float("-Inf")) 57 | 58 | 59 | def sample( 60 | logits: torch.Tensor, top_k: int = 0, top_p: float = 0.0, temperature: float = 1.0, vocab_size: Optional[int] = None 61 | ) -> torch.Tensor: 62 | """Sample and generate a token. 63 | Note: logits has the dimension [b, v] where b is the batch size 64 | and v is the vocabulary size. 65 | If vocab_size is provided, we will make sure the sample that is 66 | generated is in [0, vocab-size). This will avoid out of vocabulary 67 | generations due to padding. 68 | """ 69 | 70 | # Check logits for consistency. 71 | assert logits.ndim == 2, "expected the logits to be of [b, v] shape." 72 | assert logits.type() == "torch.cuda.FloatTensor", f"input logits should be floats. Got: {logits.type()}" 73 | 74 | # Greedy is just simple argmax. 75 | if top_k == 1: 76 | assert top_p == 0.0, "cannot set both greedy and top-p samplings." 77 | samples = torch.argmax(logits, dim=-1) 78 | 79 | # Top-k or top-p sampling. 80 | else: 81 | # Clone so we do not modify the inputs, 82 | logits = logits.clone() 83 | # Apply temperature in place. 84 | if temperature != 1.0: 85 | logits.div_(temperature) 86 | 87 | if top_k > 1: 88 | assert top_p == 0.0, "cannot set both top-k and top-p samplings." 89 | assert top_k <= logits.size(1), "top-k is larger than logit size." 90 | if vocab_size: 91 | assert top_k < vocab_size, "top-k is larger than vocab size." 92 | modify_logits_for_top_k_filtering(logits, top_k) 93 | 94 | elif top_p > 0.0: 95 | assert top_p <= 1.0, "top-p should be in (0, 1]." 96 | modify_logits_for_top_p_filtering(logits, top_p) 97 | 98 | # After filtering, we need to recalculate the distribution. 99 | probs = logits.softmax(dim=-1) 100 | samples = torch.multinomial(probs, num_samples=1).view(-1) 101 | 102 | # If vocab size is provided, make sure the samples are in 103 | # in the range [0, vocab-size). 104 | if vocab_size: 105 | samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) 106 | 107 | return samples 108 | -------------------------------------------------------------------------------- /megatron/text_generation/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Tokenization utilities.""" 20 | 21 | import re 22 | from typing import List, Tuple, Optional, Any, NamedTuple, Dict 23 | import numpy as np 24 | 25 | import torch 26 | 27 | from megatron import get_args, get_tokenizer 28 | from megatron.tokenizer import AbstractTokenizer 29 | from megatron.mpu.communication import broadcast_int_list, broadcast_tensor 30 | 31 | TEXT_REPR_BBOX_OPEN = "" 32 | TEXT_REPR_BBOX_CLOSE = "" 33 | TEXT_REPR_POINT_OPEN = "" 34 | TEXT_REPR_POINT_CLOSE = "" 35 | 36 | 37 | def convert_generations_to_human_readable_tokens( 38 | generations: List[List[int]], bos_token: str 39 | ) -> List[List[str]]: 40 | """Convert the list of integers that a model outputs into a human-readable list of tokens. 41 | Args: 42 | generations: One list per batch, each of which contains a list of integers to detokenize. 43 | bos_token: The BOS token that we are using. 44 | Return: 45 | A list of lists. 46 | """ 47 | new_generations = [] 48 | tokenizer = get_tokenizer() 49 | 50 | for generation in generations: 51 | tokens: List[str] = [] 52 | for i, int_token in enumerate(generation): 53 | token = tokenizer.inv_vocab[int_token] 54 | # convert underscore into an empty string when it is first. 55 | if token[0] == "▁" and (i == 0 or tokens[i - 1] == bos_token): 56 | token = token[1:] 57 | # continue processing normally. 58 | token = re.sub("▁", " ", token) 59 | tokens.append(token) 60 | new_generations.append(tokens) 61 | return new_generations 62 | 63 | 64 | # ====================================================================== TOKENIZATION # 65 | 66 | 67 | def tokenize_prompts( 68 | prompts: Optional[List[List[str]]], 69 | max_tokens_to_generate: int, 70 | max_position_embeddings: int, 71 | add_BOS: bool, 72 | rank: int = 0, 73 | ) -> Tuple[torch.Tensor, torch.Tensor]: 74 | """Tokenize prompts and make them avaiable on all ranks.""" 75 | assert add_BOS is not None 76 | 77 | # On all ranks set to None so we can pass them to functions 78 | sizes_list = None 79 | prompts_tokens_cuda_long_tensor = None 80 | prompts_length_cuda_long_tensor = None 81 | 82 | # On the specified rank, build the above. 83 | if torch.distributed.get_rank() == rank: 84 | assert prompts is not None 85 | assert max_tokens_to_generate is not None 86 | # Tensor of tokens padded and their unpadded length. 87 | ( 88 | prompts_tokens_cuda_long_tensor, 89 | prompts_length_cuda_long_tensor, 90 | ) = _tokenize_prompts_and_batch( 91 | prompts, 92 | max_tokens_to_generate, 93 | max_position_embeddings, 94 | add_BOS, 95 | ) 96 | # We need the sizes of these tensors for the broadcast 97 | sizes_list = [ 98 | prompts_tokens_cuda_long_tensor.size(0), # Batch size 99 | prompts_tokens_cuda_long_tensor.size(1), # Num subsequences 100 | prompts_tokens_cuda_long_tensor.size(2), # Sequence length 101 | ] 102 | # First, broadcast the sizes. 103 | sizes_tensor = broadcast_int_list(3, int_list=sizes_list, rank=rank) 104 | 105 | # Now that we have the sizes, we can broadcast the tokens 106 | # and length tensors. 107 | sizes = sizes_tensor.tolist() 108 | prompts_tokens_cuda_long_tensor = broadcast_tensor( 109 | sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank 110 | ) 111 | prompts_length_cuda_long_tensor = broadcast_tensor( 112 | sizes[:2], torch.int64, tensor=prompts_length_cuda_long_tensor, rank=rank 113 | ) 114 | 115 | return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor 116 | 117 | 118 | def _tokenize_prompts_and_batch( 119 | prompts: List[List[str]], 120 | max_tokens_to_generate: int, 121 | max_position_embeddings: int, 122 | add_BOS: bool, # Same issue with types as above 123 | ) -> Tuple[torch.Tensor, torch.Tensor]: 124 | """ 125 | Given a set of prompts and number of tokens to generate: 126 | - tokenize prompts 127 | - set the sequence length to be the max of length of prompts 128 | plus the number of tokens we would like to generate 129 | - pad all the sequences to this length so we can convert them 130 | into a 3D tensor. 131 | """ 132 | args = get_args() 133 | # Tokenize all the prompts. 134 | tokenizer = get_tokenizer() 135 | 136 | transformed_prompt_tokens = [ 137 | [tokenizer.tokenize(prompt) for prompt in prompt_seq] for prompt_seq in prompts 138 | ] 139 | 140 | if add_BOS: 141 | if args.add_bos_prompt_token is not None: 142 | bos_token = tokenizer.vocab[args.add_bos_prompt_token] 143 | else: 144 | bos_token = tokenizer.eod 145 | prompts_tokens = [ 146 | [[bos_token] + x for x in prompt_seq] 147 | for prompt_seq in transformed_prompt_tokens 148 | ] 149 | else: 150 | prompts_tokens = transformed_prompt_tokens 151 | 152 | # Now we have a list of list of tokens which each list has a different 153 | # size. We want to extend this list to: 154 | # - incorporate the tokens that need to be generated 155 | # - make all the sequences equal length. 156 | # Get the prompts length. 157 | 158 | prompts_length = [ 159 | [len(x) for x in prompts_tokens_seq] for prompts_tokens_seq in prompts_tokens 160 | ] 161 | # Get the max prompts length. 162 | max_prompt_len: int = np.max(prompts_length) 163 | # Number of tokens in the each sample of the batch. 164 | samples_length = min( 165 | max_prompt_len + max_tokens_to_generate, max_position_embeddings 166 | ) 167 | if ( 168 | max_prompt_len + max_tokens_to_generate > max_position_embeddings 169 | and torch.distributed.get_rank() == 0 170 | ): 171 | print( 172 | f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}", 173 | f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.", 174 | ) 175 | # Now update the list of list to be of the same size: samples_length. 176 | for prompt_tokens_seq, prompts_length_seq in zip(prompts_tokens, prompts_length): 177 | for prompt_tokens, prompt_length in zip(prompt_tokens_seq, prompts_length_seq): 178 | if len(prompt_tokens) > samples_length: 179 | raise ValueError( 180 | "Length of subsequence prompt exceeds sequence length." 181 | ) 182 | padding_size = samples_length - prompt_length 183 | prompt_tokens.extend([tokenizer.eod] * padding_size) 184 | 185 | # Now we are in a structured format, we can convert to tensors. 186 | prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) 187 | prompts_length_tensor = torch.cuda.LongTensor(prompts_length) 188 | 189 | return prompts_tokens_tensor, prompts_length_tensor 190 | -------------------------------------------------------------------------------- /megatron/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from .tokenizer import build_tokenizer, AbstractTokenizer 18 | -------------------------------------------------------------------------------- /megatron/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Megatron tokenizers.""" 17 | 18 | from abc import ABCMeta 19 | from abc import abstractmethod 20 | 21 | from typing import Dict, Any, List, Optional 22 | import sentencepiece as spm # type: ignore 23 | 24 | 25 | class AbstractTokenizer(metaclass=ABCMeta): 26 | """Abstract class for tokenizer.""" 27 | 28 | def __init__(self, name: str) -> None: 29 | self.name = name 30 | super().__init__() 31 | 32 | @property 33 | @abstractmethod 34 | def vocab_size(self) -> int: 35 | """Number of distinct tokens in the vocabulary.""" 36 | pass 37 | 38 | @property 39 | @abstractmethod 40 | def vocab(self) -> Dict[str, int]: 41 | """Dictionary from vocab text token to id token.""" 42 | pass 43 | 44 | @property 45 | @abstractmethod 46 | def inv_vocab(self) -> Dict[int, str]: 47 | """Dictionary from vocab id token to text token.""" 48 | pass 49 | 50 | @abstractmethod 51 | def tokenize(self, text: str) -> List[int]: 52 | """Tokenize the text.""" 53 | pass 54 | 55 | def encode(self, text: str) -> List[int]: 56 | """Encode the text.""" 57 | return self.tokenize(text) 58 | 59 | def __call__(self, text: str) -> List[int]: 60 | """Syntactic sugar for tokenize.""" 61 | return self.tokenize(text) 62 | 63 | @abstractmethod 64 | def detokenize(self, token_ids: List[int]) -> str: 65 | """Transform tokens back to a string.""" 66 | pass 67 | 68 | def decode(self, ids: List[int]) -> str: 69 | """Decode the ids.""" 70 | return self.detokenize(ids) 71 | 72 | @property 73 | def cls(self) -> int: 74 | raise NotImplementedError(f"CLS is not provided for {self.name} tokenizer") 75 | 76 | @property 77 | def sep(self) -> int: 78 | raise NotImplementedError(f"SEP is not provided for {self.name} tokenizer") 79 | 80 | @property 81 | def pad(self) -> int: 82 | raise NotImplementedError(f"PAD is not provided for {self.name} tokenizer") 83 | 84 | @property 85 | def eod(self) -> int: 86 | raise NotImplementedError(f"EOD is not provided for {self.name} tokenizer") 87 | 88 | @property 89 | def eod_token_text(self) -> str: 90 | """The EOD token string.""" 91 | return self.decode([self.eod]) 92 | 93 | @property 94 | def mask(self) -> int: 95 | """Get the mask token.""" 96 | raise NotImplementedError(f"MASK is not provided for {self.name} tokenizer") 97 | 98 | @property 99 | def pad_token(self) -> str: 100 | """Get the pad token.""" 101 | raise NotImplementedError 102 | 103 | @property 104 | def max_len_single_sentence(self) -> int: 105 | """Get the max length of a single sentence.""" 106 | raise NotImplementedError 107 | 108 | def __len__(self) -> int: 109 | """Get the length of the tokenizer.""" 110 | raise NotImplementedError 111 | 112 | @property 113 | def eos_token_id(self) -> int: 114 | """Get the id of the EOS token.""" 115 | raise NotImplementedError 116 | 117 | @property 118 | def eos_token(self) -> str: 119 | """Get the EOS token.""" 120 | raise NotImplementedError 121 | 122 | 123 | class _SentencePieceTokenizer(AbstractTokenizer): 124 | """Sentece piece tokenizer.""" 125 | 126 | def __init__(self, model_file: str) -> None: 127 | name = "Sentence Piece" 128 | super().__init__(name) 129 | 130 | # no-member is firing here but the code works fine! 131 | # pylint: disable=no-member 132 | self._tokenizer = spm.SentencePieceProcessor() 133 | self._tokenizer.load(model_file) 134 | self._vocab_size = self._tokenizer.vocab_size() 135 | self._tokens = [self._tokenizer.id_to_piece(t) for t in range(self.vocab_size)] 136 | self._vocab = {t: i for i, t in enumerate(self._tokens)} 137 | self._eod_id = None 138 | # look for end of document id 139 | for idx, token in enumerate(self._tokens): 140 | if token == "|ENDOFTEXT|": 141 | self._eod_id = idx 142 | break 143 | if self._eod_id is None: 144 | self._eod_id = self._tokenizer.eos_id() 145 | assert self._eod_id is not None 146 | 147 | @property 148 | def tokenizer(self) -> spm.SentencePieceProcessor: 149 | return self._tokenizer 150 | 151 | @property 152 | def vocab_size(self) -> int: 153 | return int(self._vocab_size) 154 | 155 | @property 156 | def vocab(self) -> Dict[str, int]: 157 | return self._vocab 158 | 159 | @property 160 | def inv_vocab(self): # type: ignore 161 | return self._tokens 162 | 163 | def tokenize(self, text: str): # type: ignore 164 | # pylint: disable=bare-except, no-member 165 | try: 166 | tokenized = self._tokenizer.encode_as_ids(text) 167 | except: 168 | tokenized = None 169 | return tokenized 170 | 171 | def pieces(self, text: str) -> List[str]: 172 | # pylint: disable=no-member 173 | pieces: List[str] = self._tokenizer.encode_as_pieces(text) 174 | return pieces 175 | 176 | def detokenize(self, token_ids: List[int]) -> str: 177 | # pylint: disable=no-member 178 | return self._tokenizer.decode_ids(token_ids) # type: ignore 179 | 180 | @property 181 | def eod(self) -> int: 182 | return self._eod_id # type: ignore 183 | 184 | @property 185 | def eos_token_id(self) -> int: 186 | """Id of the end of sentence token in the vocabulary.""" 187 | eos_id: int = self._eod_id # type: ignore 188 | return eos_id 189 | 190 | 191 | def megatron_initialize_tokenizer( 192 | tokenizer_type: str, 193 | sp_model_file: Optional[str] = None, 194 | ) -> AbstractTokenizer: 195 | """Initialize the tokenizer.""" 196 | if tokenizer_type == "SentencePiece": 197 | assert sp_model_file is not None 198 | tokenizer = _SentencePieceTokenizer(sp_model_file) 199 | else: 200 | raise NotImplementedError(f"{tokenizer_type} tokenizer is not implemented.") 201 | return tokenizer 202 | 203 | 204 | def _vocab_size_with_padding(orig_vocab_size: int, args: Any) -> int: 205 | """Pad vocab size so it is divisible by model parallel size and 206 | still having GPU friendly size.""" 207 | 208 | after = orig_vocab_size 209 | multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size 210 | while (after % multiple) != 0: 211 | after += 1 212 | if args.rank == 0: 213 | print( 214 | f" > padded vocab (sz: {orig_vocab_size}) w/ {after - orig_vocab_size} dummy toks (new sz: {after})", 215 | flush=True, 216 | ) 217 | return after 218 | 219 | 220 | def build_tokenizer(args: Any) -> AbstractTokenizer: 221 | """Initialize tokenizer.""" 222 | if args.rank == 0: 223 | print( 224 | f"> building {args.tokenizer_type} tokenizer ...", 225 | flush=True, 226 | ) 227 | 228 | # Select and instantiate the tokenizer. 229 | if args.tokenizer_type == "SentencePiece": 230 | assert args.sp_model_file is not None 231 | tokenizer = _SentencePieceTokenizer(f"{args.sp_model_file}") 232 | else: 233 | raise NotImplementedError( 234 | f"{args.tokenizer_type} tokenizer is not implemented." 235 | ) 236 | 237 | # Add vocab size. 238 | args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) 239 | 240 | return tokenizer 241 | -------------------------------------------------------------------------------- /megatron/training.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Pretrain utilities.""" 20 | 21 | import math 22 | import sys 23 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 24 | 25 | import torch 26 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 27 | 28 | from megatron import ( 29 | get_args, 30 | mpu, 31 | ) 32 | from megatron.checkpointing import ( 33 | load_state_dicts_and_update_args, 34 | update_model_and_optim_from_loaded_data, 35 | ) 36 | from megatron.initialize import initialize_megatron 37 | from megatron.model import DistributedDataParallel as LocalDDP 38 | from megatron.model import Float16Module, ModelType 39 | 40 | from megatron.utils import ( 41 | unwrap_model, 42 | ) 43 | 44 | 45 | def get_model( 46 | model_provider_func, 47 | model_type=ModelType.encoder_or_decoder, 48 | wrap_with_ddp=True, 49 | ): 50 | """Build the model.""" 51 | args = get_args() 52 | args.model_type = model_type 53 | 54 | # Build model. 55 | if ( 56 | mpu.get_pipeline_model_parallel_world_size() > 1 57 | and args.virtual_pipeline_model_parallel_size is not None 58 | ): 59 | assert ( 60 | model_type != ModelType.encoder_and_decoder 61 | ), "Interleaved schedule not supported for model with both encoder and decoder" 62 | model = [] 63 | for i in range(args.virtual_pipeline_model_parallel_size): 64 | mpu.set_virtual_pipeline_model_parallel_rank(i) 65 | # Set pre_process and post_process only after virtual rank is set. 66 | pre_process = mpu.is_pipeline_first_stage() 67 | post_process = mpu.is_pipeline_last_stage() 68 | this_model = model_provider_func( 69 | pre_process=pre_process, post_process=post_process 70 | ) 71 | this_model.model_type = model_type 72 | model.append(this_model) 73 | else: 74 | pre_process = mpu.is_pipeline_first_stage() 75 | post_process = mpu.is_pipeline_last_stage() 76 | add_encoder = True 77 | add_decoder = True 78 | if model_type == ModelType.encoder_and_decoder: 79 | if mpu.get_pipeline_model_parallel_world_size() > 1: 80 | assert ( 81 | args.pipeline_model_parallel_split_rank is not None 82 | ), "Split rank needs to be specified for model with both encoder and decoder" 83 | rank = mpu.get_pipeline_model_parallel_rank() 84 | split_rank = args.pipeline_model_parallel_split_rank 85 | world_size = mpu.get_pipeline_model_parallel_world_size() 86 | pre_process = rank == 0 or rank == split_rank 87 | post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1)) 88 | add_encoder = mpu.is_pipeline_stage_before_split() 89 | add_decoder = mpu.is_pipeline_stage_after_split() 90 | model = model_provider_func( 91 | pre_process=pre_process, 92 | post_process=post_process, 93 | add_encoder=add_encoder, 94 | add_decoder=add_decoder, 95 | ) 96 | else: 97 | model = model_provider_func( 98 | pre_process=pre_process, post_process=post_process 99 | ) 100 | model.model_type = model_type 101 | 102 | if not isinstance(model, list): 103 | model = [model] 104 | 105 | # Set tensor model parallel attributes if not set. 106 | # Only parameters that are already tensor model parallel have these 107 | # attributes set for them. We should make sure the default attributes 108 | # are set for all params so the optimizer can use them. 109 | for model_module in model: 110 | for param in model_module.parameters(): 111 | mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) 112 | 113 | # Print number of parameters. 114 | if mpu.get_data_parallel_rank() == 0: 115 | print( 116 | " > number of parameters on (tensor, pipeline) " 117 | "model parallel rank ({}, {}): {}".format( 118 | mpu.get_tensor_model_parallel_rank(), 119 | mpu.get_pipeline_model_parallel_rank(), 120 | sum( 121 | [ 122 | sum([p.nelement() for p in model_module.parameters()]) 123 | for model_module in model 124 | ] 125 | ), 126 | ), 127 | flush=True, 128 | ) 129 | 130 | # GPU allocation. 131 | for model_module in model: 132 | model_module.cuda(torch.cuda.current_device()) 133 | 134 | # Fp16 conversion. 135 | if args.fp16 or args.bf16: 136 | model = [Float16Module(model_module, args) for model_module in model] 137 | 138 | if wrap_with_ddp: 139 | if args.DDP_impl == "torch": 140 | i = torch.cuda.current_device() 141 | model = [ 142 | torchDDP( 143 | model_module, 144 | device_ids=[i], 145 | output_device=i, 146 | process_group=mpu.get_data_parallel_group(), 147 | ) 148 | for model_module in model 149 | ] 150 | 151 | elif args.DDP_impl == "local": 152 | model = [ 153 | LocalDDP( 154 | model_module, 155 | args.accumulate_allreduce_grads_in_fp32, 156 | args.use_contiguous_buffers_in_local_ddp, 157 | ) 158 | for model_module in model 159 | ] 160 | # broad cast params from data parallel src rank to other data parallel ranks 161 | if args.data_parallel_random_init: 162 | for model_module in model: 163 | model_module.broadcast_params() 164 | else: 165 | raise NotImplementedError( 166 | "Unknown DDP implementation specified: " 167 | "{}. Exiting.".format(args.DDP_impl) 168 | ) 169 | 170 | return model 171 | -------------------------------------------------------------------------------- /prompts.md: -------------------------------------------------------------------------------- 1 | # Prompt Examples 2 | 3 | ## MMLU 4 | You are a bot designed to answer questions by choosing between A, B, C, and D. Only one of these answers from the provided choices is correct. Reply with the letter of the correct answer. 5 | Question: Alfred and Ben don't know each other but are each considering asking the lovely Charlene to the school prom. The probability that at least one of them will ask her is 0.72. The probability that they both ask her is 0.18. The probability that Alfred asks her is 0.6. What is the probability that Ben asks Charlene to the prom? 6 | A: 0.78 7 | B: 0.3 8 | C: 0.24 9 | D: 0.48 10 | Answer: B 11 | 12 | Question: A telephone survey of 400 registered voters showed that 256 had not yet made up their minds 1 month before the election. How sure can we be that between 60% and 68% of the electorate were still undecided at that time? 13 | A: 2.4% 14 | B: 8.0% 15 | C: 64.0% 16 | D: 90.4% 17 | Answer: 18 | 19 | 20 | You are a bot designed to answer questions by choosing between A, B, C, and D. Only one of these answers from the provided choices is correct. Reply with the letter of the correct answer. 21 | Question: The procedure involving repeated presentation of a stimulus to the client until the attractiveness of that stimulus is reduced is best described as 22 | A: stimulus satiation 23 | B: response-prevention 24 | C: flooding 25 | D: implosion 26 | Answer: A 27 | 28 | Question: If, during a postexamination discussion with parents, a psychologist establishes that a child’s new pediatrician is apparently unaware of the child's history of brain damage. which is very important in understanding the problem situation, the psychologist should 29 | A: tell the parents that he/she will inform the pediatrician 30 | B: urge the parents to grant him/her permission to inform the pediatrician 31 | C: cell the parents char be/she is legally obligated to inform the pediatrician 32 | D: cell the parents that it is their responsibility to inform the pediatrician 33 | Answer: 34 | 35 | ## Winogrande 36 | You are a bot that is responsible for answering fill-in-the-blank questions. You are provided with a sentence and two possible fill-in-the-blank options. Your task is to return 1 or 2 as the answer. 37 | Q: The store had 80 platters but only 2 bowls left in stock because the _ were in high demand.. 1 = platters and 2 = bowls. Answer: 38 | 2 39 | 40 | Q: The smell in the kitchen of the home is unbearable, while the laundry room smells fine. The _ must have been cleaned longer ago.. 1 = kitchen and 2 = laundry room. Answer: 41 | 42 | 43 | 44 | You are a bot that is responsible for answering fill-in-the-blank questions. You are provided with a sentence and two possible fill-in-the-blank options. Your task is to return 1 or 2 as the answer. 45 | Q: John painted the pole red close to the color of the wall and painted the frame white and now the _ is similar.. 1 = frame and 2 = pole. Answer: 46 | 2 47 | 48 | Q: Benjamin has a spouse and Kyle is single after being divorced, so _ is celebrating their independence this year.. 1 = Benjamin and 2 = Kyle. Answer: 49 | 50 | ## Arc Easy/Challenge 51 | You are a bot designed to answer multiple choice questions. Given a series of options, you answer by choosing A, B, C, or D. Some examples are below. 52 | Question: Which process uses carbon from the air to make food for plants? 53 | A: growth 54 | B: respiration 55 | C: decomposition 56 | D: photosynthesis 57 | Answer: D 58 | 59 | Question: An object composed mainly of ice is orbiting the Sun in an elliptical path. This object is most likely 60 | A: a planet. 61 | B: an asteroid. 62 | C: a meteor. 63 | D: a comet. 64 | Answer: 65 | 66 | 67 | 68 | You are a bot designed to answer multiple choice questions. Given a series of options, you answer by choosing A, B, C, or D. Some examples are below. 69 | Question: Decomposers are important in the food chain because they 70 | A: produce their own food using light from the Sun. 71 | B: stop the flow of energy from one organism to another. 72 | C: break down dead organisms and recycle nutrients into the soil. 73 | D: are microscopic and other organisms cannot consume them. 74 | Answer: C 75 | 76 | Question: A wire is wrapped around a metal nail and connected to a battery. If the battery is active, the nail will 77 | A: vibrate. 78 | B: create sound. 79 | C: produce heat. 80 | D: become magnetic. 81 | Answer: 82 | 83 | ## Humaneval 84 | ``` 85 | You are a code-writing bot. Given a function signature, and a docstring, complete the program body. Some examples are given below. 86 | def sort_array(arr): 87 | """ 88 | In this Kata, you have to sort an array of non-negative integers according to 89 | number of ones in their binary representation in ascending order. 90 | For similar number of ones, sort based on decimal value. 91 | 92 | It must be implemented like this: 93 | >>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] 94 | >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2] 95 | >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4] 96 | """ 97 | return sorted(sorted(arr), key=lambda x: bin(x)[2:].count('1')) 98 | 99 | 100 | def prime_fib(n: int): 101 | """ 102 | prime_fib returns n-th number that is a Fibonacci number and it's also prime. 103 | >>> prime_fib(1) 104 | 2 105 | >>> prime_fib(2) 106 | 3 107 | >>> prime_fib(3) 108 | 5 109 | >>> prime_fib(4) 110 | 13 111 | >>> prime_fib(5) 112 | 89 113 | """ 114 | ``` 115 | 116 | ``` 117 | # You are a code-writing bot. Given a function signature, and a docstring, complete the program body. Some examples are given below. 118 | def unique(l: list): 119 | """Return sorted unique elements in a list 120 | >>> unique([5, 3, 5, 2, 3, 3, 9, 0, 123]) 121 | [0, 2, 3, 5, 9, 123] 122 | """ 123 | return sorted(list(set(l))) 124 | 125 | def total_match(lst1, lst2): 126 | ''' 127 | Write a function that accepts two lists of strings and returns the list that has 128 | total number of chars in the all strings of the list less than the other list. 129 | 130 | if the two lists have the same number of chars, return the first list. 131 | 132 | Examples 133 | total_match([], []) ➞ [] 134 | total_match(['hi', 'admin'], ['hI', 'Hi']) ➞ ['hI', 'Hi'] 135 | total_match(['hi', 'admin'], ['hi', 'hi', 'admin', 'project']) ➞ ['hi', 'admin'] 136 | total_match(['hi', 'admin'], ['hI', 'hi', 'hi']) ➞ ['hI', 'hi', 'hi'] 137 | total_match(['4'], ['1', '2', '3', '4', '5']) ➞ ['4'] 138 | ''' 139 | ``` 140 | 141 | ## TriviaQA 142 | 143 | You are a trivia answering bot designed to answer questions. You are given a question and are supposed to output an answer in 1-3 words. Some examples are below. 144 | Q: To what RAF base, near Wooton Bassett village, were the bodies of servicemen killed in Afghanistan formerly transported 145 | A: LYNEHAM 146 | 147 | Q: What star sign is Jamie Lee Curtis? 148 | A: 149 | 150 | 151 | You are a trivia answering bot designed to answer questions. You are given a question and are supposed to output an answer in 1-3 words. Some examples are below. 152 | Q: Pre restraining order(s), who did People magazine name as their first "Sexiest Man Alive", in 1985? 153 | A: Mel Gibson 154 | 155 | Q: What id the name given to the study of birds? 156 | A: -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # If you change this file, you need to regenerate the docker image! 2 | # To do so, run: adept infra create-dev-docker 3 | black==22.8.0 4 | celery-types==0.14.0 5 | click==8.0.4 6 | codenamize==1.2.3 7 | dacite==1.8.1 8 | dateparser==1.1.8 9 | einops 10 | fabric==3.0.0 11 | fire==0.4.0 12 | Flask==2.2.2 13 | Flask-AutoIndex==0.6.6 14 | Flask-Cors==3.0.10 15 | Flask-Login==0.6.2 16 | Flask-RESTful==0.3.9 17 | flatten-json==0.1.13 18 | gitpython 19 | hypothesis==4.50.8 20 | ipython==8.8.0 21 | jieba==0.42.1 22 | json-parser==1.2.0 23 | ldap3==2.9.1 24 | lm-dataformat==0.0.19 25 | loguru==0.6.0 26 | mock==4.0.3 27 | mypy==0.991 28 | nagisa==0.2.7 29 | nltk==3.7 30 | numexpr==2.7.2 31 | numpy>=1.18,<1.24 32 | parallel-ssh==2.12.0 33 | parameterized==0.8.1 34 | pip 35 | pipdeptree==2.5.2 36 | pre-commit 37 | prompt-toolkit==3.0.29 38 | pydantic>=1.10.2 39 | pykka 40 | pylint==2.13.5 41 | pytablewriter==0.58.0 42 | pytest==7.2.1 43 | pytest-celery==0.0.0 44 | pytest-cov==4.0.0 45 | pytest-forked==1.6.0 46 | pytest-randomly==3.12.0 47 | pytest-subtests==0.8.0 48 | pytest-timeout 49 | pytest-xdist==2.5.0 50 | python-dotenv==1.0.0 51 | python-hostlist==1.22 52 | python-snappy==0.6.1 53 | pytz==2023.3 54 | pyyaml==6.0.1 55 | pyzstd==0.15.0 56 | requirements-parser==0.5.0 57 | rtree 58 | ruamel.yaml==0.17.21 59 | sacrebleu==1.5.0 60 | scikit-learn 61 | sentencepiece==0.1.96 62 | sqlitedict==1.6.0 63 | sqlparse==0.4.2 64 | statsd==4.0.1 65 | strsimpy 66 | tabulate==0.9.0 67 | terminaltables==3.1.0 68 | tldextract==3.4.1 69 | types-bleach==5.0.3.1 70 | types-click 71 | types-Flask-Cors==3.0.10.5 72 | types-invoke 73 | types-paramiko==2.12.0.3 74 | types-Pillow 75 | types-protobuf==4.22.0.0 76 | types-PyYAML==6.0.11 77 | types-redis==4.3.21 78 | types-requests==2.28.4 79 | typing-extensions==4.6.3 80 | tzdata==2023.3 81 | unzip==1.0.0 82 | urllib3==1.26.15 83 | warcio==1.7.4 84 | wonderwords==2.2.0 85 | zstd==1.5.2.6 86 | -------------------------------------------------------------------------------- /run_text_generation_server.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2023 ADEPT AI LABS INC. 3 | # This file is based on code by the authors denoted below and has been modified from its original version. 4 | # 5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """Sample Generate GPT""" 20 | import os 21 | import sys 22 | from typing import Any, Optional, Tuple 23 | 24 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 25 | import torch 26 | 27 | from megatron import get_args, get_tokenizer, mpu 28 | from megatron.checkpointing import load_checkpoint 29 | from megatron.initialize import initialize_megatron 30 | from megatron.model import GPTModel 31 | from megatron.model.module import MegatronModule 32 | from megatron.model.utils import print_named_parameters 33 | from megatron.text_generation.api import generate_and_post_process 34 | from megatron.text_generation.inference_params import InferenceParams 35 | from megatron.text_generation_server import (MegatronServer, 36 | add_text_generate_args, 37 | setup_model) 38 | from megatron.training import get_model 39 | 40 | MAX_BATCH_SIZE = 1 # You can increase this depending on your desired max sequence length and GPU memory 41 | MAX_SEQLEN = 16 * 1024 42 | 43 | 44 | def model_provider( 45 | pre_process: bool=True, 46 | post_process: bool=True 47 | ) -> MegatronModule: 48 | """Build the model.""" 49 | 50 | args = get_args() 51 | if args.model_architecture == "GPTModel": 52 | model : MegatronModule = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) 53 | else: 54 | raise ValueError(f"Unsupported model type: {args.model_architecture}") 55 | print_named_parameters(model) 56 | return model 57 | 58 | 59 | def initialize_model_from_args() -> Tuple[Any, Optional[InferenceParams]]: 60 | # Needed for tensor parallel inference with CUDA graph 61 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" 62 | os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0" 63 | 64 | initialize_megatron( 65 | extra_args_provider=add_text_generate_args, 66 | args_defaults={ 67 | "tokenizer_type": "GPT2BPETokenizer", 68 | "no_load_rng": True, 69 | "no_load_optim": True, 70 | "inference_max_batch_size": MAX_BATCH_SIZE, 71 | "inference_max_seqlen": MAX_SEQLEN, 72 | }, 73 | ) 74 | 75 | args = get_args() 76 | if not args.fused_ft_kernel: 77 | args.use_cuda_graph = False # CUDA graph requires fused FT kernel 78 | if hasattr(args, "iteration"): 79 | args.curr_iteration = args.iteration 80 | print("curr_iteration", args.curr_iteration) 81 | if args.num_layers_per_virtual_pipeline_stage is not None: 82 | print("Interleaved pipeline schedule is not yet supported for text generation.") 83 | exit() 84 | # Set up model and load checkpoint 85 | model = get_model(model_provider, wrap_with_ddp=False) 86 | 87 | if args.load is not None: 88 | _ = load_checkpoint(model, None, None) 89 | 90 | assert len(model) == 1, "Above condition should have caught this" 91 | model = model[0] 92 | 93 | args.model_architecture = "GPTModel" 94 | inference_params = setup_model( 95 | model, 96 | args.model_architecture, 97 | args.use_inference_kv_cache, 98 | args.fused_ft_kernel, 99 | args.use_cuda_graph, 100 | args.inference_max_batch_size, 101 | args.inference_max_seqlen, 102 | ) 103 | 104 | return model, inference_params 105 | 106 | 107 | if __name__ == "__main__": 108 | model, inference_params = initialize_model_from_args() 109 | 110 | args = get_args() 111 | tokenizer = get_tokenizer() 112 | if hasattr(args, "eos_id"): 113 | termination_id = args.eos_id 114 | else: 115 | termination_id = tokenizer.eod 116 | if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: 117 | assert inference_params is not None 118 | server = MegatronServer( 119 | model=model, 120 | inference_params=inference_params, 121 | params_dtype=args.params_dtype, 122 | max_position_embeddings=args.max_position_embeddings, 123 | termination_id=termination_id, 124 | tokenizer=tokenizer, 125 | port=args.port, 126 | ) 127 | server.run("0.0.0.0") 128 | 129 | while True: 130 | if inference_params is not None: 131 | inference_params.reset() 132 | choice = torch.cuda.LongTensor(1) 133 | torch.distributed.broadcast(choice, 0) 134 | if choice[0].item() == 0: 135 | try: 136 | assert inference_params is not None 137 | generate_and_post_process( 138 | model=model, 139 | params_dtype=args.params_dtype, 140 | max_position_embeddings=args.max_position_embeddings, 141 | termination_id=termination_id, 142 | tokenizer=tokenizer, 143 | inference_params=inference_params) 144 | except ValueError as ve: 145 | pass 146 | -------------------------------------------------------------------------------- /run_text_generation_server.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | MODEL_DIR="8b_chat_model_release" 15 | torchrun --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr $(hostname) --master_port 5003 run_text_generation_server.py --no-load-rng --no-load-optim --no-initialization --top_p 0.9 --port 6001 --micro-batch-size 1 --load ${MODEL_DIR} --use-flash-attn --sp-model-file ${MODEL_DIR}/adept_vocab.model 16 | -------------------------------------------------------------------------------- /tool_use/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/tool_use/__init__.py -------------------------------------------------------------------------------- /tool_use/experiment_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/tool_use/experiment_pipeline/__init__.py -------------------------------------------------------------------------------- /tool_use/megatron/megatron_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for megatron tool use.""" 2 | 3 | from enum import Enum 4 | 5 | 6 | class ModelBackend(Enum): 7 | """Model backend.""" 8 | 9 | MEGATRON = "megatron" 10 | --------------------------------------------------------------------------------