├── .dockerignore
├── .gitignore
├── LICENSE
├── README.md
├── docker
└── Dockerfile
├── docker_launch.sh
├── megatron
├── __init__.py
├── arguments.py
├── checkpointing.py
├── config.py
├── fused_kernels
│ ├── __init__.py
│ ├── megatron_fused_kernels
│ │ ├── __init__.py
│ │ ├── compat.h
│ │ ├── fused_weight_gradient_dense.cpp
│ │ ├── fused_weight_gradient_dense_cuda.cu
│ │ ├── helpers.cpp
│ │ ├── layer_norm_cuda.cpp
│ │ ├── layer_norm_cuda_kernel.cu
│ │ ├── load_cpp_extensions.py
│ │ ├── load_kernel.py
│ │ ├── scaled_masked_softmax.cpp
│ │ ├── scaled_masked_softmax.h
│ │ ├── scaled_masked_softmax_cuda.cu
│ │ ├── scaled_softmax.cpp
│ │ ├── scaled_softmax_cuda.cu
│ │ ├── scaled_upper_triang_masked_softmax.cpp
│ │ ├── scaled_upper_triang_masked_softmax.h
│ │ ├── scaled_upper_triang_masked_softmax_cuda.cu
│ │ ├── tests
│ │ │ └── __init__.py
│ │ └── type_shim.h
│ └── setup.py
├── global_vars.py
├── initialize.py
├── memory.py
├── microbatches.py
├── model
│ ├── __init__.py
│ ├── distributed.py
│ ├── enums.py
│ ├── fused_bias_gelu.py
│ ├── fused_bias_sqrelu.py
│ ├── fused_layer_norm.py
│ ├── fused_softmax.py
│ ├── gpt_model.py
│ ├── language_model.py
│ ├── module.py
│ ├── positional_embeddings.py
│ ├── transformer.py
│ └── utils.py
├── mpu
│ ├── __init__.py
│ ├── communication.py
│ ├── cross_entropy.py
│ ├── cross_entropy_parallel.py
│ ├── initialize.py
│ ├── layers.py
│ ├── mappings.py
│ ├── random.py
│ └── utils.py
├── text_generation
│ ├── __init__.py
│ ├── api.py
│ ├── forward_step.py
│ ├── generation.py
│ ├── inference_params.py
│ ├── sampling.py
│ └── tokenization.py
├── text_generation_server.py
├── tokenizer
│ ├── __init__.py
│ └── tokenizer.py
├── training.py
└── utils.py
├── prompts.md
├── requirements.txt
├── run_text_generation_server.py
├── run_text_generation_server.sh
└── tool_use
├── __init__.py
├── experiment_pipeline
└── __init__.py
└── megatron
└── megatron_utils.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | 8b_chat_model_release.tar
2 | 8b_base_model_release.tar
3 | 8b_chat_model_release/
4 | 8b_base_model_release/
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.env
2 |
3 | .ruff_cache
4 | .DS_Store
5 |
6 | # Byte-compiled / optimized / DLL files
7 | *__pycache__*
8 | .mypy_cache/
9 | *.py[cod]
10 | *$py.class
11 | lm_cache
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | .next/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | pip-wheel-metadata/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | env
37 | MANIFEST
38 | venv/
39 |
40 | # PyCharm
41 | .idea/
42 |
43 | # VSCode
44 | .vscode/
45 |
46 | # Ignore swapfiles
47 | .*.swp
48 | .\#*
49 |
50 | # virtualenv
51 | .python-version
52 |
53 | # VSCode extension
54 | *.vsix
55 |
56 | # Environment variables
57 | .env
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Persimmon-8B User Guide
2 | ==========
3 | This repo contains inference code for [Persimmon-8B](https://www.adept.ai/blog/persimmon-8b), the new LLM from Adept.
4 |
5 | Downloading the Checkpoint
6 | --------
7 |
8 | The model checkpoints are stored on our public OCI bucket and can be downloaded using `wget`.
9 | The base model is not fine-tuned and is released under an Apache 2.0 license.
10 | The chat model is fine-tuned and is released under a CC-BY-NC 4.0 license.
11 |
12 | Base:
13 | https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_base_model_release.tar
14 | md5sum: cd0320cba9efad9ccd18e9ec4d16ae1b
15 |
16 | Chat:
17 | https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_chat_model_release.tar
18 | md5sum: 663aeace07269c44e90f4e8bcd07f32a
19 |
20 | Untar the model into its own directory via `tar -xvf 8b_base_model_release.tar` or `tar -xvf 8b_chat_model_release.tar`
21 |
22 | The scripts are set up to expect the model folder to be placed within the code directory, but you can place it elsewhere and modify the scripts accordingly.
23 |
24 | Building Docker
25 | -----------
26 |
27 | Build the docker that will include all the necessary dependencies (and then some!) using the included Dockerfile:
28 |
29 | ```
30 | docker build -f docker/Dockerfile -t 'adeptdocker' .
31 | ```
32 |
33 | Running Docker
34 | ----------
35 | Ensure that the variable `MODEL_DIR` in `run_text_generation_server.sh` is set to the location of the model directory. By default it is set to `MODEL_DIR=8b_chat_model_release`, which is the default name for the chat model. (For the base model, change this line to `MODEL_DIR=8b_base_model_release`.)
36 |
37 | Running `sh docker_launch.sh` will start a model server that you can query via:
38 |
39 | ```
40 | curl '
/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts": ["human: Hello, how are you?\n\nadept:"], "tokens_to_generate": 128, "top_p": 0.9, "random_seed": 1234, "logprobs": false}'
41 | ```
42 |
43 |
44 | Notes
45 | -----
46 |
47 | * The chat model is fine-tuned to expect inputs of the form: `human: {prompt}\n\nadept:`[^1]. To ensure best performance from this model, please use this format! You can see an example of this in the curl command above. To automatically wrap single-turn input prompts with this structure, you can modify the definition of `megatron/text_generation/api.py::generate_and_post_process` so that the default value for the argument `process_prompts_for_chat` is set to `True`.
48 | * We are releasing the model with tensor parallelism of 1. In this configuration, the model requires an 80GB GPU to run naively.
49 | It should be possible to fit the model on a 40GB card by removing the unused embeddings and reducing the maximum sequence length
50 | (at the top of `run_text_generation_server.py`).
51 | Quantization to 8-bit or lower would make also it fit with plenty of room to spare.
52 | * We included the `.vocab` file so you can browse the vocabulary in plain text - this file is otherwise unused.
53 |
54 |
55 | Citation
56 | --------
57 |
58 | If you use this model in your work, please use the following BibTeX citation:
59 | ```bibtex
60 | @misc{persimmon-8b,
61 | author = {Elsen, Erich and Odena, Augustus and Nye, Maxwell and Ta\c{s}\i{}rlar, Sa\u{g}nak and Dao, Tri and Hawthorne, Curtis and Moparthi, Deepak and Somani, Arushi},
62 | title = {Releasing {Persimmon-8B}},
63 | url = {https://www.adept.ai/blog/persimmon-8b},
64 | year = {2023}
65 | }
66 | ```
67 |
68 |
69 | [^1]: Subsequent inputs should have the form `human: {prompt}\n\nadept: {output}\n\nhuman: {follow_up}\n\nadept:`
70 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | FROM nvcr.io/nvidia/pytorch:23.04-py3
16 |
17 | RUN apt-get update -y && \
18 | DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata build-essential devscripts debhelper fakeroot
19 |
20 |
21 | WORKDIR /opt/hpcx
22 | RUN rm -rf ucx && \
23 | git clone --recursive https://github.com/openucx/ucx.git && \
24 | pushd ucx && \
25 | git fetch --all --tags && \
26 | git checkout tags/v1.14.1 && \
27 | ./autogen.sh && \
28 | mkdir UCX_BUILD && \
29 | ./contrib/configure-release-mt --prefix=/opt/hpcx/ucx/UCX_BUILD/ --with-cuda=/usr/local/cuda/ && \
30 | make -j && \
31 | make install && \
32 | popd
33 |
34 | RUN rm -rf ucc && \
35 | git clone --recursive https://github.com/openucx/ucc.git && \
36 | pushd ucc && \
37 | git fetch --all --tags && \
38 | git checkout tags/v1.2.0 && \
39 | ./autogen.sh && \
40 | mkdir UCC_BUILD && \
41 | ./configure --prefix=/opt/hpcx/ucc/UCC_BUILD --with-ucx=/opt/hpcx/ucx/UCX_BUILD/ --with-nccl=/usr --with-cuda=/usr/local/cuda/ --with-nvcc-gencode="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_90,code=compute_90"&& \
42 | make -j && \
43 | make install && \
44 | popd
45 |
46 | ENV LD_LIBRARY_PATH=/opt/hpcx/ucx/UCX_BUILD/lib:/opt/hpcx/ucc/UCC_BUILD/lib:$LD_LIBRARY_PATH
47 | ENV UCX_HOME=/opt/hpcx/ucx/UCX_BUILD/
48 | ENV UCC_HOME=/opt/hpcx/ucc/UCC_BUILD/
49 | ENV WITH_CUDA=/usr/local/cuda
50 |
51 | WORKDIR /workspace
52 |
53 | # Install FlashAttention
54 | RUN pip install flash-attn==2.0.0.post1
55 |
56 | # Install rotary embedding, cross entropy, and FT't attention kernel
57 | # [2022-11-08] TD: Check out a specific version to make build reproducible
58 | RUN git clone https://github.com/HazyResearch/flash-attention \
59 | && cd flash-attention && git checkout b8020d73c9e068665586989883083a4a5429a443 \
60 | && cd csrc/rotary && pip install . && cd ../../ \
61 | && cd csrc/xentropy && pip install . && cd ../../ \
62 | && cd csrc/ft_attention && pip install . && cd ../../ \
63 | && cd .. && rm -rf flash-attention
64 |
65 |
66 | COPY megatron/fused_kernels/ megatron/fused_kernels/
67 | ENV PATH="${PATH}:/opt/hpcx/ompi/bin" LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/opt/hpcx/ompi/lib"
68 | RUN ldconfig /opt/hpcx/ompi/bin
69 | RUN TORCH_CUDA_ARCH_LIST="" cd megatron/fused_kernels/ && python setup.py install sdist && cd ../..
70 | RUN cd /usr/local/lib/python3.8/dist-packages/megatron_fused_kernels-0.0.0-py3.8-linux-x86_64.egg/; mv *.so megatron_fused_kernels/;
71 | RUN rm -rf megatron/fused_kernels
72 |
73 | # Install apt-get dependencies for pip requirements.
74 | ENV DEBIAN_FRONTEND=noninteractive
75 | RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - && \
76 | apt-get update -y && apt-get install -y nodejs iputils-ping \
77 | wget ca-certificates tzdata zip locales \
78 | && locale-gen en_US en_US.UTF-8 en_US.utf8 && dpkg-reconfigure --frontend noninteractive locales \
79 | && npm install pm2 -g \
80 | && pip install --upgrade pip setuptools \
81 | && rm -rf /var/lib/apt/lists/*
82 |
83 | # Change locale for click
84 | ENV LANG=C.UTF-8 LANGUAGE=en_US.en LC_ALL=C.UTF-8
85 |
86 | # Install requirements & cleanup.
87 | COPY requirements.txt requirements.txt
88 | RUN pip install --upgrade pip setuptools && \
89 | pip install -r requirements.txt && rm requirements.txt
90 |
--------------------------------------------------------------------------------
/docker_launch.sh:
--------------------------------------------------------------------------------
1 | sudo docker run \
2 | --rm \
3 | --gpus all \
4 | --device=/dev/infiniband \
5 | --ipc=host \
6 | --ulimit memlock=-1 \
7 | --ulimit stack=67108864 \
8 | --env PYTHONPATH="." \
9 | -v $(pwd)/:/Adept_Inference/ \
10 | --network=host \
11 | --name=adept_inference \
12 | adeptdocker \
13 | bash -c "cd /Adept_Inference/; sh run_text_generation_server.sh";
14 |
--------------------------------------------------------------------------------
/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | import torch
19 |
20 | from .global_vars import get_args
21 | from .global_vars import get_current_global_batch_size
22 | from .global_vars import get_num_microbatches
23 | from .global_vars import update_num_microbatches
24 | from .global_vars import get_tokenizer
25 | from .global_vars import get_tensorboard_writer
26 | from .global_vars import get_adlr_autoresume
27 | from .global_vars import get_timers
28 | from .global_vars import get_global_memory_buffer
29 |
30 | from .utils import print_rank_0, is_last_rank, print_rank_last
31 | from .initialize import initialize_megatron
32 |
--------------------------------------------------------------------------------
/megatron/config.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Config objects.
17 |
18 | Provides a bridge between the flat `args` namespace and having different
19 | configs at any layer of the model.
20 |
21 | Config objects can be initialized manually, but to match current `args`
22 | behavior, they can be initialized using the `from_args` method.
23 | """
24 |
25 | import argparse
26 | import dataclasses
27 | from typing import Type, TypeVar
28 |
29 | T = TypeVar("T", bound="Config")
30 |
31 | DISABLE_PREFIX = lambda: dataclasses.field(metadata={"no_from_args_prefix": True})
32 |
33 |
34 | @dataclasses.dataclass
35 | class Config:
36 | @classmethod
37 | def from_args(cls: Type[T], args: argparse.Namespace, arg_prefix: str = None) -> T:
38 | """Initialize a Config object using an `args` object.
39 |
40 | Args:
41 | args: Parsed arguments from `megatron.get_args()`. The field names
42 | in the Config object will be populated by `args` entries of the
43 | same name.
44 | arg_prefix: If provided, will use this prefix when finding the
45 | field in `args`. For example, if the field name is `hidden_size`
46 | and arg_prefix is `blah_`, the field will be populated by the arg
47 | `blah_hidden_size`.
48 | However, if the field has the metadata field `no_from_args_prefix`
49 | set to True, the prefix will not be added. This is to support
50 | fields that are globally applicable.
51 | """
52 | field_values = {}
53 | for field in dataclasses.fields(cls):
54 | if issubclass(field.type, Config):
55 | field_values[field.name] = field.type.from_args(args, arg_prefix)
56 | else:
57 | if arg_prefix:
58 | prefixed_field_name = arg_prefix + field.name
59 | if arg_prefix and not (
60 | "no_from_args_prefix" in field.metadata and field.metadata["no_from_args_prefix"]
61 | ):
62 | arg_name = prefixed_field_name
63 | else:
64 | arg_name = field.name
65 | if prefixed_field_name in vars(args):
66 | raise ValueError(
67 | f"{field.name} has no_from_args_prefix set, but {prefixed_field_name} exists in args. This is likely a mistake."
68 | )
69 | else:
70 | arg_name = field.name
71 |
72 | if arg_name not in vars(args):
73 | raise ValueError(
74 | f"{arg_name} not found in args when attempting to construct {cls.__name__} from args."
75 | )
76 | field_values[field.name] = vars(args)[arg_name]
77 | return cls(**field_values)
78 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import signal
4 | import multiprocessing
5 | import time
6 | import importlib
7 | import sys
8 | from torch.utils import cpp_extension
9 | from megatron.fused_kernels.megatron_fused_kernels.load_kernel import (
10 | KernelBuildInfo,
11 | ALL_BUILD_KERNELS,
12 | JIT_BUILD_KERNELS,
13 | )
14 |
15 |
16 | class CompilationTimeoutError(Exception):
17 | pass
18 |
19 |
20 | def _load_kernel(kernel_build_info):
21 | kernel_build_info.load()
22 |
23 |
24 | def load(force_build_fused_kernels=False):
25 | """Load fused kernels."""
26 | if force_build_fused_kernels:
27 | for kernel_build_info in ALL_BUILD_KERNELS.values():
28 | _load_kernel(kernel_build_info)
29 | else:
30 | # Just comile the kernels that we need to JIT compile.
31 | for kernel_name in JIT_BUILD_KERNELS:
32 | kernel_build_info = ALL_BUILD_KERNELS[kernel_name]
33 | _load_kernel(kernel_build_info)
34 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/megatron/fused_kernels/megatron_fused_kernels/__init__.py
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/compat.h:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | /*This code is copied fron NVIDIA apex:
18 | * https://github.com/NVIDIA/apex
19 | * with minor changes. */
20 |
21 |
22 |
23 | #ifndef TORCH_CHECK
24 | #define TORCH_CHECK AT_CHECK
25 | #endif
26 |
27 | #ifdef VERSION_GE_1_3
28 | #define DATA_PTR data_ptr
29 | #else
30 | #define DATA_PTR data
31 | #endif
32 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/fused_weight_gradient_dense.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
2 |
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | #include "type_shim.h"
10 |
11 |
12 | template
13 | int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
14 |
15 | void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
16 | at::Tensor input_2d, d_output_2d;
17 | // input tensor: collapse to the first dim
18 | auto in_sizes = input.sizes();
19 | if (input.dim() > 2) {
20 | input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
21 | } else {
22 | input_2d = input;
23 | }
24 | // d_output tensor: collapse to the first dim
25 | auto d_out_sizes = d_output.sizes();
26 | if (d_output.dim() > 2) {
27 | d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
28 | } else {
29 | d_output_2d = d_output;
30 | }
31 |
32 | int hidden_dim = input_2d.size(0);
33 | int in_dim = input_2d.size(1);
34 | int out_dim = d_weight.size(0);
35 |
36 | DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
37 | int result = wgrad_gemm_accum_fp32_cuda(
38 | input_2d.data_ptr(),
39 | d_output_2d.data_ptr(),
40 | d_weight.data_ptr(),
41 | in_dim,
42 | hidden_dim,
43 | out_dim);
44 | );
45 | }
46 |
47 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
48 | m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
49 | }
50 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/fused_weight_gradient_dense_cuda.cu:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | /* Includes, cuda */
12 | #include
13 | #include
14 |
15 |
16 | // BF16 Tensor core wrapper around cublas GEMMEx
17 | cublasStatus_t gemmex_wrapper(
18 | cublasHandle_t handle,
19 | cublasOperation_t transa,
20 | cublasOperation_t transb,
21 | int m,
22 | int n,
23 | int k,
24 | const float* alpha,
25 | at::BFloat16* A,
26 | int lda,
27 | at::BFloat16* B,
28 | int ldb,
29 | const float* beta,
30 | float* C,
31 | int ldc) {
32 | return cublasGemmEx(
33 | handle,
34 | transa,
35 | transb,
36 | m,
37 | n,
38 | k,
39 | alpha,
40 | A,
41 | CUDA_R_16BF,
42 | lda,
43 | B,
44 | CUDA_R_16BF,
45 | ldb,
46 | beta,
47 | C,
48 | CUDA_R_32F,
49 | ldc,
50 | CUDA_R_32F,
51 | CUBLAS_GEMM_DEFAULT_TENSOR_OP);
52 | }
53 |
54 | // FP16 Tensor core wrapper around cublas GEMMEx
55 | cublasStatus_t gemmex_wrapper(
56 | cublasHandle_t handle,
57 | cublasOperation_t transa,
58 | cublasOperation_t transb,
59 | int m,
60 | int n,
61 | int k,
62 | const float* alpha,
63 | at::Half* A,
64 | int lda,
65 | at::Half* B,
66 | int ldb,
67 | const float* beta,
68 | float* C,
69 | int ldc) {
70 | return cublasGemmEx(
71 | handle,
72 | transa,
73 | transb,
74 | m,
75 | n,
76 | k,
77 | alpha,
78 | A,
79 | CUDA_R_16F,
80 | lda,
81 | B,
82 | CUDA_R_16F,
83 | ldb,
84 | beta,
85 | C,
86 | CUDA_R_32F,
87 | ldc,
88 | CUDA_R_32F,
89 | CUBLAS_GEMM_DEFAULT_TENSOR_OP);
90 | }
91 |
92 | // FP32 Tensor core wrapper around cublas GEMMEx
93 | cublasStatus_t gemmex_wrapper(
94 | cublasHandle_t handle,
95 | cublasOperation_t transa,
96 | cublasOperation_t transb,
97 | int m,
98 | int n,
99 | int k,
100 | const float* alpha,
101 | float* A,
102 | int lda,
103 | float* B,
104 | int ldb,
105 | const float* beta,
106 | float* C,
107 | int ldc) {
108 | return cublasGemmEx(
109 | handle,
110 | transa,
111 | transb,
112 | m,
113 | n,
114 | k,
115 | alpha,
116 | A,
117 | CUDA_R_32F,
118 | lda,
119 | B,
120 | CUDA_R_32F,
121 | ldb,
122 | beta,
123 | C,
124 | CUDA_R_32F,
125 | ldc,
126 | CUDA_R_32F,
127 | CUBLAS_GEMM_DEFAULT_TENSOR_OP);
128 | }
129 |
130 | template
131 | int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
132 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
133 | cudaStream_t stream;
134 | cublasGetStream(handle, &stream);
135 | const float alpha = 1.0;
136 | const float beta = 1.0;
137 | int status = 1;
138 |
139 | status = gemmex_wrapper(
140 | handle,
141 | CUBLAS_OP_N,
142 | CUBLAS_OP_T,
143 | in_dim,
144 | out_dim,
145 | hidden_dim,
146 | &alpha,
147 | input,
148 | in_dim,
149 | d_output,
150 | out_dim,
151 | &beta,
152 | d_weight,
153 | in_dim);
154 | return status;
155 | }
156 |
157 | template int wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
158 | template int wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
159 | template int wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
160 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/layer_norm_cuda.cpp:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | /*This code is copied fron NVIDIA apex:
18 | * https://github.com/NVIDIA/apex
19 | * with minor changes. */
20 |
21 | #include
22 | #include
23 | #include
24 | #include "compat.h"
25 |
26 | namespace {
27 |
28 | void compute_n1_n2(
29 | at::Tensor input,
30 | at::IntArrayRef normalized_shape,
31 | int& n1,
32 | int& n2) {
33 | int idiff = input.ndimension() - normalized_shape.size();
34 | n2 = 1;
35 | for (int i = 0; i < (int)normalized_shape.size(); ++i) {
36 | assert( input.sizes()[i+idiff] == normalized_shape[i] );
37 | n2 *= normalized_shape[i];
38 | }
39 | n1 = 1;
40 | for (int i = 0; i < idiff; ++i) {
41 | n1 *= input.sizes()[i];
42 | }
43 | }
44 |
45 | void check_args(
46 | at::IntArrayRef normalized_shape,
47 | at::Tensor gamma,
48 | at::Tensor beta
49 | )
50 | {
51 | TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
52 | TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
53 | }
54 |
55 | void check_args(
56 | at::Tensor input,
57 | at::IntArrayRef normalized_shape,
58 | int& n1,
59 | int& n2
60 | )
61 | {
62 | int64_t normalized_ndim = normalized_shape.size();
63 |
64 | if (normalized_ndim < 1) {
65 | std::stringstream ss;
66 | ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
67 | << "containing at least one element, but got normalized_shape="
68 | << normalized_shape;
69 | throw std::runtime_error(ss.str());
70 | }
71 |
72 | auto input_shape = input.sizes();
73 | auto input_ndim = input.dim();
74 |
75 | if (input_ndim < normalized_ndim ||
76 | !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
77 | std::stringstream ss;
78 | ss << "Given normalized_shape=" << normalized_shape
79 | << ", expected input with shape [*";
80 | for (auto size : normalized_shape) {
81 | ss << ", " << size;
82 | }
83 | ss << "], but got input of size" << input_shape;
84 | throw std::runtime_error(ss.str());
85 | }
86 |
87 | compute_n1_n2(input,normalized_shape,n1,n2);
88 | }
89 |
90 |
91 | void check_args(
92 | at::Tensor input,
93 | at::IntArrayRef normalized_shape,
94 | at::Tensor gamma,
95 | at::Tensor beta,
96 | int& n1,
97 | int& n2
98 | )
99 | {
100 | check_args(input,normalized_shape,n1,n2);
101 | check_args(normalized_shape,gamma,beta);
102 | }
103 | }
104 |
105 | void cuda_layer_norm(
106 | at::Tensor* output,
107 | at::Tensor* mean,
108 | at::Tensor* invvar,
109 | at::Tensor* input,
110 | int n1,
111 | int n2,
112 | at::IntArrayRef normalized_shape,
113 | at::Tensor* gamma,
114 | at::Tensor* beta,
115 | double epsilon);
116 |
117 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
118 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
119 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
120 |
121 | std::vector layer_norm_affine(
122 | at::Tensor input,
123 | at::IntArrayRef normalized_shape,
124 | at::Tensor gamma,
125 | at::Tensor beta,
126 | double epsilon) {
127 |
128 | CHECK_INPUT(input);
129 | CHECK_INPUT(gamma);
130 | CHECK_INPUT(beta);
131 | int n1, n2;
132 | check_args(input, normalized_shape, gamma, beta, n1, n2);
133 |
134 | at::Tensor output = at::empty_like(
135 | input, gamma.options().dtype(gamma.scalar_type()));
136 | at::Tensor mean = at::empty(
137 | {n1}, input.options().dtype(at::ScalarType::Float));
138 | at::Tensor invvar = at::empty_like(mean);
139 |
140 | cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
141 | normalized_shape, &gamma, &beta, epsilon);
142 |
143 | return {output, mean, invvar};
144 |
145 | }
146 |
147 |
148 | void cuda_layer_norm_gradient(
149 | at::Tensor* dout,
150 | at::Tensor* mean,
151 | at::Tensor* invvar,
152 | at::Tensor* input,
153 | int n1,
154 | int n2,
155 | at::IntArrayRef normalized_shape,
156 | at::Tensor* gamma,
157 | at::Tensor* beta,
158 | double epsilon,
159 | at::Tensor* grad_input,
160 | at::Tensor* grad_gamma,
161 | at::Tensor* grad_beta
162 | );
163 |
164 | std::vector layer_norm_gradient_affine(
165 | at::Tensor dout,
166 | at::Tensor mean,
167 | at::Tensor invvar,
168 | at::Tensor input,
169 | at::IntArrayRef normalized_shape,
170 | at::Tensor gamma,
171 | at::Tensor beta,
172 | double epsilon) {
173 |
174 | CHECK_INPUT(dout);
175 | CHECK_INPUT(mean);
176 | CHECK_INPUT(invvar);
177 | CHECK_INPUT(input);
178 | CHECK_INPUT(gamma);
179 | CHECK_INPUT(beta);
180 | int n1, n2;
181 | check_args(input, normalized_shape, gamma, beta, n1, n2);
182 |
183 | at::Tensor grad_input = at::empty_like(input);
184 | at::Tensor grad_gamma = at::empty_like(gamma);
185 | at::Tensor grad_beta = at::empty_like(beta);
186 |
187 | cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
188 | normalized_shape, &gamma, &beta, epsilon,
189 | &grad_input, &grad_gamma, &grad_beta);
190 |
191 | return {grad_input, grad_gamma, grad_beta};
192 |
193 | }
194 |
195 |
196 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
197 | m.def("forward_affine", &layer_norm_affine,
198 | "LayerNorm forward (CUDA)");
199 | m.def("backward_affine", &layer_norm_gradient_affine,
200 | "LayerNorm backward (CUDA)");
201 | }
202 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/load_cpp_extensions.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from pybind11.setup_helpers import Pybind11Extension
3 |
4 | srcpath = pathlib.Path(__file__).parent.absolute()
5 |
6 |
7 | def get_helpers_extension():
8 | """Get the helpers pybind11 extension."""
9 | return [
10 | Pybind11Extension(
11 | name="helpers",
12 | sources=[str(srcpath / "helpers.cpp")],
13 | language="c++",
14 | )
15 | ]
16 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/load_kernel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from dataclasses import dataclass
4 | from typing import List
5 | from torch.utils import cpp_extension
6 | from pathlib import Path
7 |
8 |
9 | def _create_build_dir(buildpath):
10 | try:
11 | os.mkdir(buildpath)
12 | except OSError:
13 | if not os.path.isdir(buildpath):
14 | print(f"Creation of the build directory {buildpath} failed")
15 |
16 |
17 | @dataclass
18 | class KernelBuildInfo:
19 | name: str
20 | sources: List[Path]
21 | build_directory: str
22 | extra_cflags: List[str]
23 | extra_cuda_cflags: List[str]
24 | verbose: bool
25 |
26 | def __init__(self, name, sources, build_directory, extra_cflags, extra_cuda_cflags, verbose):
27 | self.name = name
28 | self.extra_cflags = extra_cflags
29 | self.extra_cuda_cflags = extra_cuda_cflags
30 | self.verbose = verbose
31 | if not isinstance(build_directory, Path):
32 | build_directory = Path(build_directory)
33 | self.build_directory = build_directory
34 | for i, source in enumerate(sources):
35 | if not isinstance(source, Path):
36 | sources[i] = Path(source)
37 | self.sources = sources
38 |
39 | def __repr__(self):
40 | return f"KernelBuildInfo(name={self.name}, sources={self.sources}, build_directory={self.build_directory}, extra_cflags={self.extra_cflags}, extra_cuda_cflags={self.extra_cuda_cflags}, verbose={self.verbose})\n"
41 |
42 | def to_setup_cpp_extension(self) -> cpp_extension.CUDAExtension:
43 | sources = [str(source) for source in self.sources]
44 | return cpp_extension.CUDAExtension(
45 | name=self.name,
46 | sources=sources,
47 | extra_compile_args={
48 | "cxx": self.extra_cflags,
49 | "nvcc": self.extra_cuda_cflags,
50 | },
51 | is_python_module=True,
52 | )
53 |
54 | def load(self):
55 | os.environ["TORCH_CUDA_ARCH_LIST"] = ""
56 | _create_build_dir(self.build_directory)
57 | _ = cpp_extension.load(
58 | name=self.name,
59 | sources=self.sources,
60 | build_directory=Path(self.build_directory),
61 | extra_cflags=self.extra_cflags,
62 | extra_cuda_cflags=self.extra_cuda_cflags,
63 | verbose=self.verbose,
64 | )
65 |
66 |
67 | srcpath = pathlib.Path(__file__).parent.absolute()
68 |
69 | BUILD_DIR = srcpath / "build"
70 |
71 | BASE_CFLAGS = [
72 | "-O3",
73 | "-llibtorch_python",
74 | ]
75 | BASE_CUDA_CFLAGS = [
76 | "-O3",
77 | "--use_fast_math",
78 | "-gencode",
79 | "arch=compute_80,code=sm_80",
80 | ]
81 |
82 | BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS = BASE_CUDA_CFLAGS + [
83 | "-U__CUDA_NO_HALF_OPERATORS__",
84 | "-U__CUDA_NO_HALF_CONVERSIONS__",
85 | "--expt-relaxed-constexpr",
86 | "--expt-extended-lambda",
87 | ]
88 |
89 | # These are the kernels that we need to build JIT as they have
90 | # some issues in installing with pip. Note: They are defined in the
91 | # ALL_BUILD_KERNELS dictionary below.
92 | JIT_BUILD_KERNELS = []
93 |
94 | ALL_BUILD_KERNELS = {
95 | "scaled_upper_triang_masked_softmax": KernelBuildInfo(
96 | name="scaled_upper_triang_masked_softmax_cuda",
97 | sources=[
98 | srcpath / "scaled_upper_triang_masked_softmax.cpp",
99 | srcpath / "scaled_upper_triang_masked_softmax_cuda.cu",
100 | ],
101 | build_directory=BUILD_DIR,
102 | extra_cflags=BASE_CFLAGS,
103 | extra_cuda_cflags=BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS,
104 | verbose=True,
105 | ),
106 | "scaled_masked_softmax_cuda": KernelBuildInfo(
107 | name="scaled_masked_softmax_cuda",
108 | sources=[srcpath / "scaled_masked_softmax.cpp", srcpath / "scaled_masked_softmax_cuda.cu"],
109 | build_directory=BUILD_DIR,
110 | extra_cflags=BASE_CFLAGS,
111 | extra_cuda_cflags=BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS,
112 | verbose=True,
113 | ),
114 | "scaled_softmax_cuda": KernelBuildInfo(
115 | name="scaled_softmax_cuda",
116 | sources=[srcpath / "scaled_softmax.cpp", srcpath / "scaled_softmax_cuda.cu"],
117 | build_directory=BUILD_DIR,
118 | extra_cflags=BASE_CFLAGS,
119 | extra_cuda_cflags=BASE_MASKED_SOFTMAX_FUSION_CUDA_CFLAGS,
120 | verbose=True,
121 | ),
122 | "fused_mix_prec_layer_norm_cuda": KernelBuildInfo(
123 | name="fused_mix_prec_layer_norm_cuda",
124 | sources=[srcpath / "layer_norm_cuda.cpp", srcpath / "layer_norm_cuda_kernel.cu"],
125 | build_directory=BUILD_DIR,
126 | extra_cflags=BASE_CFLAGS,
127 | extra_cuda_cflags=BASE_CUDA_CFLAGS + ["-maxrregcount=50"],
128 | verbose=True,
129 | ),
130 | "fused_dense_cuda": KernelBuildInfo(
131 | name="fused_dense_cuda",
132 | sources=[srcpath / "fused_weight_gradient_dense.cpp", srcpath / "fused_weight_gradient_dense_cuda.cu"],
133 | build_directory=BUILD_DIR,
134 | extra_cflags=BASE_CFLAGS,
135 | extra_cuda_cflags=BASE_CUDA_CFLAGS,
136 | verbose=True,
137 | ),
138 | }
139 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/scaled_masked_softmax.cpp:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | namespace multihead_attn {
22 | namespace fused_softmax {
23 | namespace scaled_masked_softmax {
24 |
25 | torch::Tensor fwd_cuda(
26 | torch::Tensor const& input,
27 | torch::Tensor const& mask,
28 | float scale_factor);
29 |
30 | torch::Tensor bwd_cuda(
31 | torch::Tensor const& output_grads,
32 | torch::Tensor const& softmax_results,
33 | float scale_factor);
34 |
35 | int get_batch_per_block_cuda(
36 | int query_seq_len,
37 | int key_seq_len,
38 | int batches,
39 | int attn_heads);
40 |
41 | torch::Tensor fwd(
42 | torch::Tensor const& input,
43 | torch::Tensor const& mask,
44 | float scale_factor) {
45 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
46 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
47 | (input.scalar_type() == at::ScalarType::BFloat16),
48 | "Only fp16 and bf16 are supported");
49 | AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
50 |
51 | return fwd_cuda(input, mask, scale_factor);
52 | }
53 |
54 | torch::Tensor bwd(
55 | torch::Tensor const& output_grads,
56 | torch::Tensor const& softmax_results,
57 | float scale_factor) {
58 |
59 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
60 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
61 |
62 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
63 | (output_grads.scalar_type() == at::ScalarType::BFloat16),
64 | "Only fp16 and bf16 are supported");
65 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
66 | (softmax_results.scalar_type() == at::ScalarType::BFloat16),
67 | "Only fp16 and bf16 are supported");
68 |
69 | return bwd_cuda(output_grads, softmax_results, scale_factor);
70 | }
71 |
72 | int get_batch_per_block(
73 | int query_seq_len,
74 | int key_seq_len,
75 | int batches,
76 | int attn_heads) {
77 | return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
78 | }
79 |
80 | } // end namespace scaled_masked_softmax
81 | } // end namespace fused_softmax
82 | } // end namespace multihead_attn
83 |
84 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
85 | m.def("forward",
86 | &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
87 | "Self Multihead Attention scaled, time masked softmax -- Forward.");
88 |
89 | m.def("backward",
90 | &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
91 | "Self Multihead Attention scaled, time masked softmax -- Backward.");
92 |
93 | m.def("get_batch_per_block",
94 | &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
95 | "Return Batch per block size."
96 | );
97 | }
98 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/scaled_masked_softmax_cuda.cu:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include "scaled_masked_softmax.h"
25 | #include "type_shim.h"
26 |
27 | namespace multihead_attn {
28 | namespace fused_softmax {
29 | namespace scaled_masked_softmax {
30 |
31 | int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
32 | return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
33 | }
34 |
35 |
36 | torch::Tensor fwd_cuda(
37 | torch::Tensor const& input,
38 | torch::Tensor const& mask,
39 | float scale_factor)
40 | {
41 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
42 | const int batches = input.size(0);
43 | const int pad_batches = mask.size(0);
44 | const int attn_heads = input.size(1);
45 | const int query_seq_len = input.size(2);
46 | const int key_seq_len = input.size(3);
47 | TORCH_INTERNAL_ASSERT(key_seq_len <= 8192);
48 | TORCH_INTERNAL_ASSERT(query_seq_len > 1);
49 | TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
50 | TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
51 | TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
52 | TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
53 |
54 | // Output
55 | auto act_options = input.options().requires_grad(false);
56 | torch::Tensor softmax_results =
57 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
58 |
59 | // Softmax Intermediate Result Ptr
60 | void* input_ptr = static_cast(input.data_ptr());
61 | void* mask_ptr = static_cast(mask.data_ptr());
62 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr());
63 |
64 | DISPATCH_HALF_AND_BFLOAT(
65 | input.scalar_type(),
66 | "dispatch_scaled_masked_softmax_forward",
67 | dispatch_scaled_masked_softmax_forward(
68 | reinterpret_cast(softmax_results_ptr),
69 | reinterpret_cast(input_ptr),
70 | reinterpret_cast(mask_ptr),
71 | scale_factor,
72 | query_seq_len,
73 | key_seq_len,
74 | batches,
75 | attn_heads,
76 | pad_batches);
77 | );
78 | return softmax_results;
79 | }
80 |
81 | torch::Tensor bwd_cuda(
82 | torch::Tensor const& output_grads_,
83 | torch::Tensor const& softmax_results_,
84 | float scale_factor) {
85 |
86 | auto output_grads = output_grads_.contiguous();
87 | auto softmax_results = softmax_results_.contiguous();
88 |
89 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
90 | const int batches = output_grads.size(0);
91 | const int attn_heads = output_grads.size(1);
92 | const int query_seq_len = output_grads.size(2);
93 | const int key_seq_len = output_grads.size(3);
94 |
95 | void* output_grads_ptr = static_cast(output_grads.data_ptr());
96 |
97 | //Softmax Grad
98 | DISPATCH_HALF_AND_BFLOAT(
99 | output_grads_.scalar_type(),
100 | "dispatch_scaled_masked_softmax_backward",
101 | dispatch_scaled_masked_softmax_backward(
102 | reinterpret_cast(output_grads_ptr),
103 | reinterpret_cast(output_grads_ptr),
104 | reinterpret_cast(softmax_results.data_ptr()),
105 | scale_factor,
106 | query_seq_len,
107 | key_seq_len,
108 | batches,
109 | attn_heads);
110 | );
111 |
112 | //backward pass is completely in-place
113 | return output_grads;
114 | }
115 | }
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/scaled_softmax.cpp:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | namespace multihead_attn {
22 | namespace fused_softmax {
23 | namespace scaled_softmax {
24 |
25 | torch::Tensor fwd_cuda(
26 | torch::Tensor const& input,
27 | float scale_factor);
28 |
29 | torch::Tensor bwd_cuda(
30 | torch::Tensor const& output_grads,
31 | torch::Tensor const& softmax_results,
32 | float scale_factor);
33 |
34 | torch::Tensor fwd(
35 | torch::Tensor const& input,
36 | float scale_factor) {
37 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
38 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
39 | (input.scalar_type() == at::ScalarType::BFloat16),
40 | "Only fp16 and bf16 are supported");
41 |
42 | return fwd_cuda(input, scale_factor);
43 | }
44 |
45 | torch::Tensor bwd(
46 | torch::Tensor const& output_grads,
47 | torch::Tensor const& softmax_results,
48 | float scale_factor) {
49 |
50 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
51 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
52 |
53 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
54 | (output_grads.scalar_type() == at::ScalarType::BFloat16),
55 | "Only fp16 and bf16 are supported");
56 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
57 | (softmax_results.scalar_type() == at::ScalarType::BFloat16),
58 | "Only fp16 and bf16 are supported");
59 |
60 | return bwd_cuda(output_grads, softmax_results, scale_factor);
61 | }
62 |
63 | } // end namespace scaled_softmax
64 | } // end namespace fused_softmax
65 | } // end namespace multihead_attn
66 |
67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68 | m.def("forward",
69 | &multihead_attn::fused_softmax::scaled_softmax::fwd,
70 | "Self Multihead Attention scaled, softmax -- Forward.");
71 | m.def("backward",
72 | &multihead_attn::fused_softmax::scaled_softmax::bwd,
73 | "Self Multihead Attention scaled, softmax -- Backward.");
74 | }
75 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/scaled_softmax_cuda.cu:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include "scaled_masked_softmax.h"
25 | #include "type_shim.h"
26 |
27 | namespace multihead_attn {
28 | namespace fused_softmax {
29 | namespace scaled_softmax {
30 |
31 | torch::Tensor fwd_cuda(
32 | torch::Tensor const& input,
33 | float scale_factor)
34 | {
35 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
36 | const int batches = input.size(0);
37 | const int attn_heads = input.size(1);
38 | const int query_seq_len = input.size(2);
39 | const int key_seq_len = input.size(3);
40 | TORCH_INTERNAL_ASSERT(key_seq_len <= 8192);
41 | TORCH_INTERNAL_ASSERT(query_seq_len > 1);
42 |
43 | // Output
44 | auto act_options = input.options().requires_grad(false);
45 | torch::Tensor softmax_results =
46 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
47 |
48 | // Softmax Intermediate Result Ptr
49 | void* input_ptr = static_cast(input.data_ptr());
50 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr());
51 |
52 | DISPATCH_HALF_AND_BFLOAT(
53 | input.scalar_type(),
54 | "dispatch_scaled_softmax_forward",
55 | dispatch_scaled_softmax_forward(
56 | reinterpret_cast(softmax_results_ptr),
57 | reinterpret_cast(input_ptr),
58 | scale_factor,
59 | query_seq_len,
60 | key_seq_len,
61 | batches,
62 | attn_heads);
63 | );
64 | return softmax_results;
65 | }
66 |
67 | torch::Tensor bwd_cuda(
68 | torch::Tensor const& output_grads_,
69 | torch::Tensor const& softmax_results_,
70 | float scale_factor) {
71 |
72 | auto output_grads = output_grads_.contiguous();
73 | auto softmax_results = softmax_results_.contiguous();
74 |
75 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
76 | const int batches = output_grads.size(0);
77 | const int attn_heads = output_grads.size(1);
78 | const int query_seq_len = output_grads.size(2);
79 | const int key_seq_len = output_grads.size(3);
80 |
81 | void* output_grads_ptr = static_cast(output_grads.data_ptr());
82 |
83 | //Softmax Grad
84 | DISPATCH_HALF_AND_BFLOAT(
85 | output_grads_.scalar_type(),
86 | "dispatch_scaled_masked_softmax_backward",
87 | dispatch_scaled_masked_softmax_backward(
88 | reinterpret_cast(output_grads_ptr),
89 | reinterpret_cast(output_grads_ptr),
90 | reinterpret_cast(softmax_results.data_ptr()),
91 | scale_factor,
92 | query_seq_len,
93 | key_seq_len,
94 | batches,
95 | attn_heads);
96 | );
97 |
98 | //backward pass is completely in-place
99 | return output_grads;
100 | }
101 | }
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/scaled_upper_triang_masked_softmax.cpp:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | namespace multihead_attn {
22 | namespace fused_softmax {
23 | namespace scaled_upper_triang_masked_softmax {
24 |
25 | torch::Tensor fwd_cuda(
26 | torch::Tensor const& input,
27 | float scale_factor);
28 |
29 | torch::Tensor bwd_cuda(
30 | torch::Tensor const& output_grads,
31 | torch::Tensor const& softmax_results,
32 | float scale_factor);
33 |
34 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
35 | AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
36 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
37 | (input.scalar_type() == at::ScalarType::BFloat16),
38 | "Only fp16 and bf16 are supported");
39 |
40 | return fwd_cuda(input, scale_factor);
41 | }
42 |
43 | torch::Tensor bwd(
44 | torch::Tensor const& output_grads,
45 | torch::Tensor const& softmax_results,
46 | float scale_factor) {
47 |
48 | AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
49 | AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
50 |
51 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
52 | (output_grads.scalar_type() == at::ScalarType::BFloat16),
53 | "Only fp16 and bf16 are supported");
54 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
55 | (softmax_results.scalar_type() == at::ScalarType::BFloat16),
56 | "Only fp16 and bf16 are supported");
57 |
58 | return bwd_cuda(output_grads, softmax_results, scale_factor);
59 | }
60 |
61 | } // end namespace scaled_upper_triang_masked_softmax
62 | } // end namespace fused_softmax
63 | } // end namespace multihead_attn
64 |
65 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
66 | m.def("forward",
67 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
68 | "Self Multihead Attention scaled, time masked softmax -- Forward.");
69 | m.def("backward",
70 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
71 | "Self Multihead Attention scaled, time masked softmax -- Backward.");
72 | }
73 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include "scaled_upper_triang_masked_softmax.h"
25 | #include "type_shim.h"
26 |
27 | namespace multihead_attn {
28 | namespace fused_softmax {
29 | namespace scaled_upper_triang_masked_softmax {
30 |
31 | torch::Tensor fwd_cuda(
32 | torch::Tensor const& input,
33 | float scale_factor)
34 | {
35 | // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
36 | const int attn_batches = input.size(0);
37 | const int seq_len = input.size(1);
38 | TORCH_INTERNAL_ASSERT(seq_len <= 8192);
39 |
40 | // Output
41 | auto act_options = input.options().requires_grad(false);
42 | torch::Tensor softmax_results =
43 | torch::empty({attn_batches, seq_len, seq_len}, act_options);
44 |
45 | // Softmax Intermediate Result Ptr
46 | void* input_ptr = static_cast(input.data_ptr());
47 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr());
48 |
49 | DISPATCH_HALF_AND_BFLOAT(
50 | input.scalar_type(),
51 | "dispatch_scaled_upper_triang_masked_softmax_forward",
52 | dispatch_scaled_upper_triang_masked_softmax_forward(
53 | reinterpret_cast(softmax_results_ptr),
54 | reinterpret_cast(input_ptr),
55 | scale_factor,
56 | seq_len,
57 | seq_len,
58 | attn_batches);
59 | );
60 | return softmax_results;
61 | }
62 |
63 |
64 | torch::Tensor bwd_cuda(
65 | torch::Tensor const& output_grads_,
66 | torch::Tensor const& softmax_results_,
67 | float scale_factor) {
68 |
69 | auto output_grads = output_grads_.contiguous();
70 | auto softmax_results = softmax_results_.contiguous();
71 |
72 | //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
73 | const int attn_batches = output_grads.size(0);
74 | const int seq_len = output_grads.size(1);
75 | TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
76 |
77 | void* output_grads_ptr = static_cast(output_grads.data_ptr());
78 |
79 | //Softmax Grad
80 | DISPATCH_HALF_AND_BFLOAT(
81 | output_grads_.scalar_type(),
82 | "dispatch_scaled_upper_triang_masked_softmax_backward",
83 | dispatch_scaled_upper_triang_masked_softmax_backward(
84 | reinterpret_cast(output_grads_ptr),
85 | reinterpret_cast(output_grads_ptr),
86 | reinterpret_cast(softmax_results.data_ptr()),
87 | scale_factor,
88 | seq_len,
89 | seq_len,
90 | attn_batches);
91 | );
92 |
93 | //backward pass is completely in-place
94 | return output_grads;
95 | }
96 | }
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/megatron/fused_kernels/megatron_fused_kernels/tests/__init__.py
--------------------------------------------------------------------------------
/megatron/fused_kernels/megatron_fused_kernels/type_shim.h:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 |
18 | #include
19 | #include "compat.h"
20 |
21 |
22 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
23 | switch(TYPE) \
24 | { \
25 | case at::ScalarType::Half: \
26 | { \
27 | using scalar_t = at::Half; \
28 | __VA_ARGS__; \
29 | break; \
30 | } \
31 | case at::ScalarType::BFloat16: \
32 | { \
33 | using scalar_t = at::BFloat16; \
34 | __VA_ARGS__; \
35 | break; \
36 | } \
37 | default: \
38 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
39 | }
40 |
41 |
42 | #define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
43 | switch(TYPE) \
44 | { \
45 | case at::ScalarType::Half: \
46 | { \
47 | using scalar_t = at::Half; \
48 | __VA_ARGS__; \
49 | break; \
50 | } \
51 | case at::ScalarType::BFloat16: \
52 | { \
53 | using scalar_t = at::BFloat16; \
54 | __VA_ARGS__; \
55 | break; \
56 | } \
57 | case at::ScalarType::Float: \
58 | { \
59 | using scalar_t = float; \
60 | __VA_ARGS__; \
61 | break; \
62 | } \
63 | default: \
64 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
65 | }
66 |
67 |
68 |
69 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
70 | switch(TYPEIN) \
71 | { \
72 | case at::ScalarType::Float: \
73 | { \
74 | using scalar_t_in = float; \
75 | switch(TYPEOUT) \
76 | { \
77 | case at::ScalarType::Float: \
78 | { \
79 | using scalar_t_out = float; \
80 | __VA_ARGS__; \
81 | break; \
82 | } \
83 | case at::ScalarType::Half: \
84 | { \
85 | using scalar_t_out = at::Half; \
86 | __VA_ARGS__; \
87 | break; \
88 | } \
89 | case at::ScalarType::BFloat16: \
90 | { \
91 | using scalar_t_out = at::BFloat16; \
92 | __VA_ARGS__; \
93 | break; \
94 | } \
95 | default: \
96 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
97 | } \
98 | break; \
99 | } \
100 | case at::ScalarType::Half: \
101 | { \
102 | using scalar_t_in = at::Half; \
103 | using scalar_t_out = at::Half; \
104 | __VA_ARGS__; \
105 | break; \
106 | } \
107 | case at::ScalarType::BFloat16: \
108 | { \
109 | using scalar_t_in = at::BFloat16; \
110 | using scalar_t_out = at::BFloat16; \
111 | __VA_ARGS__; \
112 | break; \
113 | } \
114 | default: \
115 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
116 | }
117 |
--------------------------------------------------------------------------------
/megatron/fused_kernels/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from torch.utils import cpp_extension
3 | from megatron_fused_kernels.load_kernel import ALL_BUILD_KERNELS, JIT_BUILD_KERNELS
4 | from megatron_fused_kernels.load_cpp_extensions import get_helpers_extension
5 |
6 |
7 | def get_kernel_extensions():
8 | # remove 'fused_dense_cuda' from the list of kernels to build
9 | # as it is not used in the current version of Megatron-LM
10 | # and it causes compilation errors when doing ninja installs.
11 | # We'll need to JIT compile it instead.
12 | kernels_to_build = [
13 | build_info.to_setup_cpp_extension() for k, build_info in ALL_BUILD_KERNELS.items() if k not in JIT_BUILD_KERNELS
14 | ]
15 | return kernels_to_build
16 |
17 |
18 | setup(
19 | name="megatron_fused_kernels",
20 | packages=find_packages(exclude=("tests",)),
21 | packages_dir={"megatron_fused_kernels": "megatron_fused_kernels"},
22 | ext_modules=get_kernel_extensions() + get_helpers_extension(),
23 | cmdclass={"build_ext": cpp_extension.BuildExtension},
24 | zip_safe=False,
25 | author="Adept AI",
26 | author_email="",
27 | )
28 |
--------------------------------------------------------------------------------
/megatron/global_vars.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Megatron global variables."""
20 |
21 | import functools
22 | import operator
23 | import os
24 | import sys
25 | import time
26 | from functools import reduce
27 | from pathlib import Path
28 | import requests
29 | import torch
30 | import yaml
31 |
32 | from megatron.tokenizer import build_tokenizer
33 |
34 | from .microbatches import build_num_microbatches_calculator
35 |
36 | _GLOBAL_ARGS = None
37 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
38 | _GLOBAL_TOKENIZER = None
39 | _GLOBAL_TENSORBOARD_WRITER = None
40 | _GLOBAL_ADLR_AUTORESUME = None
41 | _GLOBAL_TIMERS = None
42 |
43 | _GLOBAL_MEMORY_BUFFER = None
44 |
45 |
46 | def get_args():
47 | """Return arguments."""
48 | _ensure_var_is_initialized(_GLOBAL_ARGS, "args")
49 | return _GLOBAL_ARGS
50 |
51 |
52 | def get_num_microbatches():
53 | return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
54 |
55 |
56 | def get_current_global_batch_size():
57 | return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
58 |
59 |
60 | def update_num_microbatches(consumed_samples, consistency_check=True):
61 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
62 |
63 |
64 | def get_tokenizer():
65 | """Return tokenizer."""
66 | _ensure_var_is_initialized(_GLOBAL_TOKENIZER, "tokenizer")
67 | return _GLOBAL_TOKENIZER
68 |
69 |
70 | def get_tensorboard_writer():
71 | """Return tensorboard writer. It can be None so no need
72 | to check if it is initialized."""
73 | return _GLOBAL_TENSORBOARD_WRITER
74 |
75 |
76 | def get_adlr_autoresume():
77 | """ADLR autoresume object. It can be None so no need
78 | to check if it is initialized."""
79 | return _GLOBAL_ADLR_AUTORESUME
80 |
81 |
82 | def get_timers():
83 | """Return timers."""
84 | _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers")
85 | return _GLOBAL_TIMERS
86 |
87 |
88 | def get_global_memory_buffer():
89 | _ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, "global memory buffer")
90 | return _GLOBAL_MEMORY_BUFFER
91 |
92 |
93 | def set_global_variables(args):
94 | """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
95 |
96 | assert args is not None
97 |
98 | _ensure_var_is_not_initialized(_GLOBAL_ARGS, "args")
99 | set_args(args)
100 |
101 | _build_num_microbatches_calculator(args)
102 | if args.vocab_file or args.sp_model_file:
103 | _ = _build_tokenizer(args)
104 | _set_tensorboard_writer(args)
105 | _set_adlr_autoresume(args)
106 | _set_timers()
107 | _set_global_memory_buffer()
108 |
109 |
110 | def set_args(args):
111 | global _GLOBAL_ARGS
112 | _GLOBAL_ARGS = args
113 |
114 |
115 | def _build_num_microbatches_calculator(args):
116 |
117 | global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
118 | _ensure_var_is_not_initialized(
119 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR, "num microbatches calculator"
120 | )
121 |
122 | _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(args)
123 |
124 |
125 | def _build_tokenizer(args):
126 | """Initialize tokenizer."""
127 | global _GLOBAL_TOKENIZER
128 | _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, "tokenizer")
129 | _GLOBAL_TOKENIZER = build_tokenizer(args)
130 | return _GLOBAL_TOKENIZER
131 |
132 |
133 | def rebuild_tokenizer(args):
134 | global _GLOBAL_TOKENIZER
135 | _GLOBAL_TOKENIZER = None
136 | return _build_tokenizer(args)
137 |
138 |
139 | def _set_tensorboard_writer(args):
140 | """Set tensorboard writer."""
141 | global _GLOBAL_TENSORBOARD_WRITER
142 | _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer")
143 |
144 | if (
145 | hasattr(args, "tensorboard_dir")
146 | and args.tensorboard_dir
147 | and args.rank == (args.world_size - 1)
148 | ):
149 | try:
150 | from torch.utils.tensorboard import (
151 | SummaryWriter,
152 | ) # pylint: disable=import-outside-toplevel
153 |
154 | print("> setting tensorboard ...")
155 | _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
156 | log_dir=args.tensorboard_dir,
157 | max_queue=args.tensorboard_queue_size,
158 | )
159 | except ModuleNotFoundError:
160 | print(
161 | "WARNING: TensorBoard writing requested but is not "
162 | "available (are you using PyTorch 1.1.0 or later?), "
163 | "no TensorBoard logs will be written.",
164 | flush=True,
165 | )
166 |
167 |
168 | def _set_adlr_autoresume(args):
169 | """Initialize ADLR autoresume."""
170 | global _GLOBAL_ADLR_AUTORESUME
171 | _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, "adlr autoresume")
172 |
173 | if args.adlr_autoresume:
174 | if args.rank == 0:
175 | print("enabling autoresume ...", flush=True)
176 | sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
177 | try:
178 | from userlib.auto_resume import (
179 | AutoResume,
180 | ) # pylint: disable=import-outside-toplevel
181 | except BaseException: # pylint: disable=broad-except
182 | print("ADLR autoresume is not available, exiting ...")
183 | sys.exit()
184 |
185 | _GLOBAL_ADLR_AUTORESUME = AutoResume
186 |
187 |
188 | def _set_timers():
189 | """Initialize timers."""
190 | global _GLOBAL_TIMERS
191 | _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers")
192 | _GLOBAL_TIMERS = Timers()
193 |
194 |
195 | def _set_global_memory_buffer():
196 | """Initialize global buffer"""
197 | global _GLOBAL_MEMORY_BUFFER
198 | _ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, "global memory buffer")
199 | _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
200 |
201 |
202 | def _ensure_var_is_initialized(var, name):
203 | """Make sure the input variable is not None."""
204 | assert var is not None, f"{name} is not initialized."
205 |
206 |
207 | def _ensure_var_is_not_initialized(var, name):
208 | """Make sure the input variable is not None."""
209 | assert var is None, f"{name} is already initialized."
210 |
211 |
212 | class _Timer:
213 | """Timer."""
214 |
215 | def __init__(self, name):
216 | self.name_ = name
217 | self.elapsed_ = 0.0
218 | self.started_ = False
219 | self.start_time = time.time()
220 |
221 | def start(self):
222 | """Start the timer."""
223 | # this import has to be here because of circular dependencies.
224 | from megatron.mpu import (
225 | get_data_parallel_group,
226 | ) # pylint: disable=import-outside-toplevel
227 |
228 | assert not self.started_, "timer has already been started"
229 | torch.distributed.barrier(get_data_parallel_group())
230 | torch.cuda.synchronize()
231 | self.start_time = time.time()
232 | self.started_ = True
233 |
234 | def stop(self):
235 | """Stop the timer."""
236 | # this import has to be here because of circular dependencies.
237 | from megatron.mpu import (
238 | get_data_parallel_group,
239 | ) # pylint: disable=import-outside-toplevel
240 |
241 | assert self.started_, "timer is not started"
242 | torch.distributed.barrier(get_data_parallel_group())
243 | torch.cuda.synchronize()
244 | self.elapsed_ += time.time() - self.start_time
245 | self.started_ = False
246 |
247 | def reset(self):
248 | """Reset timer."""
249 | self.elapsed_ = 0.0
250 | self.started_ = False
251 |
252 | def elapsed(self, reset=True):
253 | """Calculate the elapsed time."""
254 | started = self.started_
255 | # If the timing in progress, end it first.
256 | if self.started_:
257 | self.stop()
258 | # Get the elapsed time.
259 | elapsed = self.elapsed_
260 | # Reset the elapsed time
261 | if reset:
262 | self.reset()
263 | # If timing was in progress, set it back.
264 | if started:
265 | self.start()
266 | return elapsed
267 |
268 |
269 | class Timers:
270 | """Group of timers."""
271 |
272 | def __init__(self):
273 | self.timers = {}
274 |
275 | def __call__(self, name):
276 | if name not in self.timers:
277 | self.timers[name] = _Timer(name)
278 | return self.timers[name]
279 |
280 | def write(self, names, iteration, normalizer=1.0, reset=False):
281 | """Write timers to a tensorboard writer"""
282 | # currently when using add_scalars,
283 | # torch.utils.add_scalars makes each timer its own run, which
284 | # polutes the runs list, so we just add each as a scalar
285 | assert normalizer > 0.0
286 | for name in names:
287 | value = self.timers[name].elapsed(reset=reset) / normalizer
288 | key = f"timers/{name}-(s)"
289 |
290 | def log(self, names, normalizer=1.0, reset=True):
291 | """Log a group of timers."""
292 | assert normalizer > 0.0
293 | string = "time (ms)"
294 | for name in names:
295 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
296 | string += f" | {name}: {elapsed_time:.2f}"
297 | if torch.distributed.is_initialized():
298 | if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
299 | print(string, flush=True)
300 | else:
301 | print(string, flush=True)
302 |
303 |
304 | class GlobalMemoryBuffer:
305 | """Global buffer to avoid dynamic memory allocations.
306 | Caller should ensure that buffers of the same name
307 | are not used concurrently."""
308 |
309 | def __init__(self):
310 | self.buffer = {}
311 |
312 | def get_tensor(self, tensor_shape, dtype, name):
313 | required_len = reduce(operator.mul, tensor_shape, 1)
314 | if (
315 | self.buffer.get((name, dtype), None) is None
316 | or self.buffer[(name, dtype)].numel() < required_len
317 | ):
318 | self.buffer[(name, dtype)] = torch.empty(
319 | required_len,
320 | dtype=dtype,
321 | device=torch.cuda.current_device(),
322 | requires_grad=False,
323 | )
324 |
325 | return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
326 |
--------------------------------------------------------------------------------
/megatron/memory.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | import torch
21 |
22 |
23 | # A dictionary of all the memory buffers allocated.
24 | _MEM_BUFFS = dict()
25 |
26 |
27 | def allocate_mem_buff(name, numel, dtype, track_usage):
28 | """Allocate a memory buffer."""
29 | assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name)
30 | _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
31 | return _MEM_BUFFS[name]
32 |
33 |
34 | class MemoryBuffer:
35 | """Contiguous memory buffer.
36 | Allocate a contiguous memory of type `dtype` and size `numel`. It is
37 | used to reduce memory fragmentation.
38 |
39 | Usage: After the allocation, the `_start` index is set tot the first
40 | index of the memory. A memory chunk starting from `_start` index
41 | can be `allocated` for an input tensor, with the elements of the
42 | tensor being coppied. The buffer can be reused by resetting the
43 | `_start` index.
44 |
45 | """
46 |
47 | def __init__(self, name, numel, dtype, track_usage):
48 | if torch.distributed.get_rank() == 0:
49 | element_size = torch.tensor([], dtype=dtype).element_size()
50 | print(
51 | "> building the {} memory buffer with {} num elements "
52 | "and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024),
53 | flush=True,
54 | )
55 | self.name = name
56 | self.numel = numel
57 | self.dtype = dtype
58 | self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False)
59 |
60 | # Index tracking the start of the free memory.
61 | self._start = 0
62 |
63 | # Values used for tracking usage.
64 | self.track_usage = track_usage
65 | if self.track_usage:
66 | self.in_use_value = 0.0
67 | self.total_value = 0.0
68 |
69 | def reset(self):
70 | """Reset the buffer start index to the beginning of the buffer."""
71 | self._start = 0
72 |
73 | def is_in_use(self):
74 | """Whether the current buffer hold on to any memory."""
75 | return self._start > 0
76 |
77 | def numel_in_use(self):
78 | """Return number of elements in use."""
79 | return self._start
80 |
81 | def add(self, tensor):
82 | """Allocate a chunk of memory from the buffer to tensor and copy
83 | the values."""
84 | assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format(
85 | tensor.dtype, self.dtype
86 | )
87 | # Number of elements of the input tensor.
88 | tensor_numel = torch.numel(tensor)
89 | new_start = self._start + tensor_numel
90 | assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format(
91 | tensor_numel, self.numel - self._start
92 | )
93 | # New tensor is a view into the memory.
94 | new_tensor = self.data[self._start : new_start]
95 | self._start = new_start
96 | new_tensor = new_tensor.view(tensor.shape)
97 | new_tensor.copy_(tensor)
98 | # Return a pointer to the new tensor.
99 | return new_tensor
100 |
101 | def get_data(self):
102 | """Return the data currently in use."""
103 | if self.track_usage:
104 | self.in_use_value += float(self._start)
105 | self.total_value += float(self.numel)
106 | return self.data[: self._start]
107 |
108 | def print_average_usage(self):
109 | """Print memory usage average over time. We would like this value
110 | to be as high as possible."""
111 | assert self.track_usage, "You need to enable track usage."
112 | if torch.distributed.get_rank() == 0:
113 | print(
114 | " > usage of {} memory buffer: {:.2f} %".format(
115 | self.name, self.in_use_value * 100.0 / self.total_value
116 | ),
117 | flush=True,
118 | )
119 |
--------------------------------------------------------------------------------
/megatron/microbatches.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Megatron number of micro-batches calculators."""
20 |
21 | from abc import ABC
22 | from abc import abstractmethod
23 |
24 |
25 | def build_num_microbatches_calculator(args):
26 |
27 | # Constant num micro-batches.
28 | if args.rampup_batch_size is None:
29 | num_microbatches_calculator = ConstantNumMicroBatches(
30 | args.global_batch_size, args.micro_batch_size, args.data_parallel_size
31 | )
32 | if args.rank == 0:
33 | print(
34 | "setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
35 | )
36 |
37 | else:
38 | assert len(args.rampup_batch_size) == 3, (
39 | "expected the following "
40 | "format: --rampup-batch-size "
41 | " "
42 | )
43 | start_batch_size = int(args.rampup_batch_size[0])
44 | batch_size_increment = int(args.rampup_batch_size[1])
45 | ramup_samples = int(args.rampup_batch_size[2])
46 | if args.rank == 0:
47 | print(
48 | "will use batch size rampup starting from global batch "
49 | "size {} to global batch size {} with batch size increments "
50 | "{} over {} samples.".format(
51 | start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples
52 | ),
53 | flush=True,
54 | )
55 | num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
56 | start_batch_size,
57 | batch_size_increment,
58 | ramup_samples,
59 | args.global_batch_size,
60 | args.micro_batch_size,
61 | args.data_parallel_size,
62 | )
63 |
64 | return num_microbatches_calculator
65 |
66 |
67 | class NumMicroBatchesCalculator(ABC):
68 | def __init__(self):
69 | self.num_micro_batches = None
70 | self.current_global_batch_size = None
71 |
72 | def get(self):
73 | return self.num_micro_batches
74 |
75 | def get_current_global_batch_size(self):
76 | return self.current_global_batch_size
77 |
78 | @abstractmethod
79 | def update(self, consumed_samples, consistency_check):
80 | pass
81 |
82 |
83 | class ConstantNumMicroBatches(NumMicroBatchesCalculator):
84 | def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
85 | micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
86 | assert (
87 | global_batch_size % micro_batch_times_data_parallel == 0
88 | ), "global batch size ({}) is not divisible by micro batch size ({})" " times data parallel size ({})".format(
89 | global_batch_size, micro_batch_size, data_parallel_size
90 | )
91 | self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
92 | assert self.num_micro_batches >= 1
93 | self.current_global_batch_size = global_batch_size
94 |
95 | def update(self, consumed_samples, consistency_check):
96 | pass
97 |
98 |
99 | class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
100 | def __init__(
101 | self,
102 | start_batch_size,
103 | batch_size_increment,
104 | ramup_samples,
105 | global_batch_size,
106 | micro_batch_size,
107 | data_parallel_size,
108 | ):
109 | """Batch size ramp up.
110 | Over
111 | steps = (global-batch-size - start-batch-size) / batch_size_increment
112 | increment batch size from start-batch-size to global-batch-size using
113 | rampup-samples / steps
114 | samples.
115 | Arguments:
116 | start_batch_size: global batch size to start with
117 | batch_size_increment: global batch size increments
118 | ramup_samples: number of samples to use ramp up global
119 | batch size from `start_batch_size` to `global_batch_size`
120 | global_batch_size: global batch size post rampup
121 | micro_batch_size: micro batch size
122 | data_parallel_size: data parallel size.
123 | """
124 |
125 | self.micro_batch_size = micro_batch_size
126 | self.data_parallel_size = data_parallel_size
127 | self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
128 | assert self.micro_batch_times_data_parallel_size > 0
129 |
130 | assert start_batch_size > 0
131 | self.start_batch_size = start_batch_size
132 |
133 | assert global_batch_size > 0
134 | self.global_batch_size = global_batch_size
135 | diff_batch_size = self.global_batch_size - self.start_batch_size
136 | assert diff_batch_size >= 0
137 | assert batch_size_increment > 0
138 | self.batch_size_increment = batch_size_increment
139 | assert diff_batch_size % batch_size_increment == 0, (
140 | "expected "
141 | "global batch size interval ({}) to be divisible by global batch "
142 | "size increment ({})".format(diff_batch_size, batch_size_increment)
143 | )
144 |
145 | num_increments = diff_batch_size // self.batch_size_increment
146 | self.ramup_samples = ramup_samples
147 | assert self.ramup_samples >= 0
148 | self.rampup_samples_per_increment = self.ramup_samples / num_increments
149 |
150 | # Initialize number of microbatches.
151 | self.update(0, False)
152 |
153 | def update(self, consumed_samples, consistency_check):
154 |
155 | if consumed_samples > self.ramup_samples:
156 | self.current_global_batch_size = self.global_batch_size
157 | else:
158 | steps = int(consumed_samples / self.rampup_samples_per_increment)
159 | self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment
160 | assert self.current_global_batch_size <= self.global_batch_size
161 |
162 | if consistency_check:
163 | assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, (
164 | "current global "
165 | "batch size ({}) is not divisible by micro-batch-size ({}) times"
166 | "data parallel size ({})".format(
167 | self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size
168 | )
169 | )
170 | self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size
171 |
--------------------------------------------------------------------------------
/megatron/model/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
17 |
18 | from .distributed import DistributedDataParallel
19 | from .gpt_model import GPTModel
20 | from .language_model import get_language_model
21 | from .module import Float16Module
22 | from .enums import ModelType
23 |
--------------------------------------------------------------------------------
/megatron/model/enums.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | import enum
20 |
21 |
22 | class ModelType(enum.Enum):
23 | encoder_or_decoder = 1
24 | encoder_and_decoder = 2
25 |
26 |
27 | class LayerType(enum.Enum):
28 | encoder = 1
29 | decoder = 2
30 |
31 |
32 | class AttnType(enum.Enum):
33 | self_attn = 1
34 | cross_attn = 2
35 |
36 |
37 | class AttnMaskType(enum.Enum):
38 | padding = 1
39 | causal = 2
40 |
--------------------------------------------------------------------------------
/megatron/model/fused_bias_gelu.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 |
19 | ###### BIAS GELU FUSION/ NO AUTOGRAD ################
20 | # 1/sqrt(2*pi)-> 0.3989423
21 | # 1/sqrt(2) -> 0.70710678
22 | # sqrt(2/pi) -> 0.79788456
23 | # this function is tanh approximation of gelu
24 | # actual gelu is:
25 | # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
26 |
27 |
28 | @torch.jit.script
29 | def bias_gelu(bias, y):
30 | x = bias + y
31 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
32 |
33 |
34 | # gradient of tanh approximation of gelu
35 | # gradient of actual gelu is:
36 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
37 | @torch.jit.script
38 | def bias_gelu_back(g, bias, y):
39 | x = bias + y
40 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
41 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
42 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
43 | return ff * g
44 |
45 |
46 | class GeLUFunction(torch.autograd.Function):
47 | @staticmethod
48 | # bias is an optional argument
49 | def forward(ctx, input, bias):
50 | ctx.save_for_backward(input, bias)
51 | return bias_gelu(bias, input)
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, bias = ctx.saved_tensors
56 | tmp = bias_gelu_back(grad_output, bias, input)
57 | return tmp, tmp
58 |
59 |
60 | bias_gelu_impl = GeLUFunction.apply
61 |
--------------------------------------------------------------------------------
/megatron/model/fused_bias_sqrelu.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 |
4 | import torch
5 |
6 |
7 | ###### BIAS SQRELU FUSION/ NO AUTOGRAD ################
8 |
9 |
10 | @torch.jit.script
11 | def bias_sqrelu(bias, y):
12 | x = bias + y
13 | relud_x = torch.relu(x)
14 | return relud_x * relud_x
15 |
16 |
17 | @torch.jit.script
18 | def bias_sqrelu_back(g, bias, y):
19 | x = bias + y
20 | return g * 2 * torch.relu(x)
21 |
22 |
23 | class SqReLUFunction(torch.autograd.Function):
24 | @staticmethod
25 | # bias is an optional argument
26 | def forward(ctx, input, bias):
27 | ctx.save_for_backward(input, bias)
28 | return bias_sqrelu(bias, input)
29 |
30 | @staticmethod
31 | def backward(ctx, grad_output):
32 | input, bias = ctx.saved_tensors
33 | tmp = bias_sqrelu_back(grad_output, bias, input)
34 | return tmp, tmp
35 |
36 |
37 | bias_sqrelu_impl = SqReLUFunction.apply
38 |
--------------------------------------------------------------------------------
/megatron/model/fused_layer_norm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This code is copied fron NVIDIA apex:
17 | https://github.com/NVIDIA/apex
18 | with some changes. """
19 |
20 | import numbers
21 | import torch
22 | from torch.nn.parameter import Parameter
23 | from torch.nn import init
24 | import importlib
25 |
26 | from megatron.mpu import make_viewless_tensor
27 | from megatron.mpu import set_tensor_model_parallel_attributes
28 |
29 | try:
30 | from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
31 |
32 | HAVE_PERSIST_LAYER_NORM = True
33 | except:
34 | HAVE_PERSIST_LAYER_NORM = False
35 |
36 | global fused_mix_prec_layer_norm_cuda
37 | fused_mix_prec_layer_norm_cuda = None
38 |
39 |
40 | class FusedLayerNormAffineFunction(torch.autograd.Function):
41 | @staticmethod
42 | def forward(ctx, input, weight, bias, normalized_shape, eps):
43 |
44 | ctx.normalized_shape = normalized_shape
45 | ctx.eps = eps
46 | input_ = input.contiguous()
47 | weight_ = weight.contiguous()
48 | bias_ = bias.contiguous()
49 | output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
50 | input_, ctx.normalized_shape, weight_, bias_, ctx.eps
51 | )
52 | ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
53 |
54 | return output
55 |
56 | @staticmethod
57 | def backward(ctx, grad_output):
58 |
59 | input_, weight_, bias_, mean, invvar = ctx.saved_tensors
60 | grad_input = grad_weight = grad_bias = None
61 | grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine(
62 | grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
63 | )
64 |
65 | return grad_input, grad_weight, grad_bias, None, None
66 |
67 |
68 | class MixedFusedLayerNorm(torch.nn.Module):
69 | def __init__(
70 | self,
71 | normalized_shape,
72 | eps=1e-5,
73 | no_persist_layer_norm=True,
74 | sequence_parallel=False,
75 | is_tensor_parallel_unique=False,
76 | reduce_tensor_parallel_grads=False,
77 | ):
78 | super(MixedFusedLayerNorm, self).__init__()
79 |
80 | global fused_mix_prec_layer_norm_cuda
81 | fused_mix_prec_layer_norm_cuda = importlib.import_module(
82 | "megatron_fused_kernels.fused_mix_prec_layer_norm_cuda"
83 | )
84 |
85 | # List of hiddens sizes supported in the persistent layer norm kernel
86 | # If the hidden size is not supported, fall back to the non-persistent
87 | # kernel.
88 | persist_ln_hidden_sizes = [
89 | 1024,
90 | 1536,
91 | 2048,
92 | 2304,
93 | 3072,
94 | 3840,
95 | 4096,
96 | 5120,
97 | 6144,
98 | 8192,
99 | 10240,
100 | 12288,
101 | 12800,
102 | 15360,
103 | 16384,
104 | 18432,
105 | 20480,
106 | 24576,
107 | 25600,
108 | 30720,
109 | 32768,
110 | 40960,
111 | 49152,
112 | 65536,
113 | ]
114 | if normalized_shape not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
115 | no_persist_layer_norm = True
116 |
117 | if isinstance(normalized_shape, numbers.Integral):
118 | normalized_shape = (normalized_shape,)
119 | self.normalized_shape = torch.Size(normalized_shape)
120 | self.eps = eps
121 |
122 | self.weight = Parameter(torch.Tensor(*normalized_shape))
123 | set_tensor_model_parallel_attributes(
124 | self.weight,
125 | is_tensor_parallel_unique=is_tensor_parallel_unique,
126 | reduce_tensor_parallel_grads=reduce_tensor_parallel_grads,
127 | )
128 |
129 | self.bias = Parameter(torch.Tensor(*normalized_shape))
130 | set_tensor_model_parallel_attributes(
131 | self.bias,
132 | is_tensor_parallel_unique=is_tensor_parallel_unique,
133 | reduce_tensor_parallel_grads=reduce_tensor_parallel_grads,
134 | )
135 |
136 | self.reset_parameters()
137 | self.no_persist_layer_norm = no_persist_layer_norm
138 | self.sequence_parallel = sequence_parallel
139 |
140 | # set sequence parallelism flag on weight and bias parameters
141 | setattr(self.weight, "sequence_parallel", self.sequence_parallel)
142 | setattr(self.bias, "sequence_parallel", self.sequence_parallel)
143 |
144 | def reset_parameters(self):
145 |
146 | init.ones_(self.weight)
147 | init.zeros_(self.bias)
148 |
149 | def forward(self, input):
150 | # Doing this shape check here is important because the two kernels below don't handle shapes the same way,
151 | # so we need to check that we're passing them what we expect.
152 | assert len(self.normalized_shape) == 1
153 | if input.shape[-1] != self.normalized_shape[0]:
154 | raise ValueError(f"Expected input shape ending in {self.normalized_shape=}, but got {input.shape=}")
155 |
156 | if self.no_persist_layer_norm:
157 | return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
158 | else:
159 | output = FastLayerNormFN.apply(input, self.weight, self.bias, self.eps)
160 |
161 | # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
162 | # a populated '_base' field). This will result in schedule.py's
163 | # deallocate_output_tensor() throwing an error, so a viewless tensor is
164 | # created to prevent this.
165 | output = make_viewless_tensor(inp=output, requires_grad=input.requires_grad, keep_graph=True)
166 |
167 | return output
168 |
--------------------------------------------------------------------------------
/megatron/model/fused_softmax.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | import torch
21 | import torch.nn as nn
22 | from megatron.model.enums import AttnMaskType
23 |
24 |
25 | class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
26 | """
27 | Fused operation which performs following three operations in sequence
28 | 1. Scale the tensor.
29 | 2. Apply upper triangular mask (typically used in gpt models).
30 | 3. Perform softmax.
31 | """
32 |
33 | @staticmethod
34 | def forward(ctx, inputs, scale):
35 | from megatron_fused_kernels import scaled_upper_triang_masked_softmax_cuda
36 |
37 | scale_t = torch.tensor([scale])
38 | softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
39 | inputs, scale_t[0]
40 | )
41 |
42 | ctx.save_for_backward(softmax_results, scale_t)
43 | return softmax_results
44 |
45 | @staticmethod
46 | def backward(ctx, output_grads):
47 | from megatron_fused_kernels import scaled_upper_triang_masked_softmax_cuda
48 |
49 | softmax_results, scale_t = ctx.saved_tensors
50 | input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
51 | output_grads, softmax_results, scale_t[0]
52 | )
53 |
54 | return input_grads, None
55 |
56 |
57 | class ScaledSoftmax(torch.autograd.Function):
58 | """
59 | Fused operation which performs following two operations in sequence
60 | 1. Scale the tensor.
61 | 2. Perform softmax.
62 | """
63 |
64 | @staticmethod
65 | def forward(ctx, inputs, scale):
66 | from megatron_fused_kernels import scaled_softmax_cuda
67 |
68 | scale_t = torch.tensor([scale])
69 |
70 | softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
71 | ctx.save_for_backward(softmax_results, scale_t)
72 | return softmax_results
73 |
74 | @staticmethod
75 | def backward(ctx, output_grads):
76 | from megatron_fused_kernels import scaled_softmax_cuda
77 |
78 | softmax_results, scale_t = ctx.saved_tensors
79 |
80 | input_grads = scaled_softmax_cuda.backward(
81 | output_grads, softmax_results, scale_t[0]
82 | )
83 | return input_grads, None, None
84 |
85 |
86 | class FusedScaleSoftmax(nn.Module):
87 | """
88 | fused operation: scaling + mask + softmax
89 |
90 | Arguments:
91 | input_in_fp16: flag to indicate if input in fp16 data format.
92 | input_in_bf16: flag to indicate if input in bf16 data format.
93 | attn_mask_type: attention mask type (pad or causal)
94 | scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
95 | softmax_in_fp32: if true, softmax in performed at fp32 precision.
96 | scale: scaling factor used in input tensor scaling.
97 | """
98 |
99 | def __init__(
100 | self,
101 | input_in_fp16,
102 | input_in_bf16,
103 | attn_mask_type,
104 | scaled_masked_softmax_fusion,
105 | softmax_in_fp32,
106 | scale,
107 | ):
108 | super(FusedScaleSoftmax, self).__init__()
109 | self.input_in_fp16 = input_in_fp16
110 | self.input_in_bf16 = input_in_bf16
111 | assert not (
112 | self.input_in_fp16 and self.input_in_bf16
113 | ), "both fp16 and bf16 flags cannot be active at the same time."
114 | self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
115 | self.attn_mask_type = attn_mask_type
116 | self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
117 | self.softmax_in_fp32 = softmax_in_fp32
118 | self.scale = scale
119 |
120 | assert (
121 | self.scale is None or softmax_in_fp32
122 | ), "softmax should be in fp32 when scaled"
123 |
124 | def forward(self, input):
125 | # [b, np, sq, sk]
126 | assert input.dim() == 4
127 |
128 | if self.is_kernel_available(*input.size()):
129 | return self.forward_fused_softmax(input)
130 | else:
131 | return self.forward_torch_softmax(input)
132 |
133 | def is_kernel_available(self, b, np, sq, sk):
134 | attn_batches = b * np
135 | if (
136 | self.scaled_masked_softmax_fusion # user want to fuse
137 | and self.input_in_float16 # input must be fp16
138 | and 16 < sk <= 8192 # sk must be 16 ~ 8192
139 | and sq % 4 == 0 # sq must be divisor of 4
140 | and sk % 4 == 0 # sk must be divisor of 4
141 | and attn_batches % 4 == 0 # np * b must be divisor of 4
142 | ):
143 | if 0 <= sk <= 8192:
144 | batch_per_block = self.get_batch_per_block(sq, sk, b, np)
145 |
146 | if self.attn_mask_type == AttnMaskType.causal:
147 | if attn_batches % batch_per_block == 0:
148 | return True
149 | else:
150 | if sq % batch_per_block == 0 and sk % batch_per_block == 0:
151 | return True
152 | return False
153 |
154 | def forward_fused_softmax(self, input):
155 | b, np, sq, sk = input.size()
156 | scale = self.scale if self.scale is not None else 1.0
157 |
158 | if self.attn_mask_type == AttnMaskType.causal:
159 | assert sq == sk, "causal mask is only for self attention"
160 |
161 | # input is 3D tensor (attn_batches, sq, sk)
162 | input = input.view(-1, sq, sk)
163 | probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
164 | return probs.view(b, np, sq, sk)
165 | else:
166 | # input is 4D tensor (b, np, sq, sk)
167 | return ScaledSoftmax.apply(input, scale)
168 |
169 | def forward_torch_softmax(self, input):
170 | b, np, sq, sk = input.size()
171 | if self.input_in_float16 and self.softmax_in_fp32:
172 | input = input.float()
173 |
174 | if self.scale is not None:
175 | input = input * self.scale
176 |
177 | if self.attn_mask_type == AttnMaskType.causal:
178 | assert sq == sk, "causal mask is only for self attention"
179 | attention_mask = torch.tril(
180 | torch.ones((1, input.shape[-1], input.shape[-1]), device=input.device)
181 | ).view(1, 1, input.shape[-1], input.shape[-1])
182 | # invert mask so that 0 -> 1 and 1 -> 0
183 | attention_mask = attention_mask < 0.5
184 |
185 | # assign very low logit to masked entries
186 | input.masked_fill_(attention_mask, -10000.0)
187 |
188 | probs = torch.nn.Softmax(dim=-1)(input)
189 |
190 | if self.input_in_float16 and self.softmax_in_fp32:
191 | if self.input_in_fp16:
192 | probs = probs.half()
193 | else:
194 | probs = probs.bfloat16()
195 |
196 | return probs
197 |
198 | @staticmethod
199 | def get_batch_per_block(sq, sk, b, np):
200 | from megatron_fused_kernels import scaled_masked_softmax_cuda
201 |
202 | return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
203 |
--------------------------------------------------------------------------------
/megatron/model/gpt_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | #
3 | # Copyright (c) 2023 ADEPT AI LABS INC.
4 | # This file is based on code by the authors denoted below and has been modified from its original version.
5 | #
6 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | """GPT-2 model."""
21 |
22 | import torch
23 |
24 | from megatron import get_args
25 | from megatron import mpu
26 | from .module import MegatronModule
27 |
28 | from .enums import AttnMaskType
29 | from .language_model import parallel_lm_logits
30 | from .language_model import get_language_model
31 | from .utils import init_method_normal
32 | from .utils import scaled_init_method_normal
33 |
34 | try:
35 | from megatron.mpu.cross_entropy_parallel import CrossEntropyLossParallel
36 | from einops import rearrange
37 | except ImportError:
38 | CrossEntropyLossParallel = None
39 | rearrange = None
40 |
41 |
42 | def post_language_model_processing(
43 | lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy, use_fast_cross_entropy
44 | ):
45 |
46 | # Output. Format [s b h]
47 | output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
48 | if labels is None:
49 | # [s b h] => [b s h]
50 | return output.transpose(0, 1).contiguous()
51 | else:
52 | # [b s] => [s b]
53 | labels = labels.transpose(0, 1).contiguous()
54 | if use_fast_cross_entropy:
55 | loss_fn = CrossEntropyLossParallel(reduction="none", inplace_backward=True)
56 | loss = rearrange(
57 | loss_fn(rearrange(output, "s b ... -> (s b) ..."), rearrange(labels, "s b ... -> (s b) ...")),
58 | "(s b) ... -> s b ...",
59 | s=output.shape[0],
60 | )
61 | elif fp16_lm_cross_entropy:
62 | assert output.dtype == torch.half
63 | loss = mpu.vocab_parallel_cross_entropy(output, labels)
64 | else:
65 | loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
66 |
67 | # [s b] => [b, s]
68 | loss = loss.transpose(0, 1).contiguous()
69 |
70 | argmax = mpu.layers.argmax_parallel(output)
71 | accuracy = labels == argmax
72 |
73 | accuracy = accuracy.transpose(0, 1).contiguous()
74 | argmax = argmax.transpose(0, 1).contiguous()
75 |
76 | return {"argmax": argmax, "loss": loss, "accuracy": accuracy}
77 |
78 |
79 | class GPTModel(MegatronModule):
80 | """GPT-2 Language model."""
81 |
82 | def __init__(
83 | self,
84 | num_tokentypes=0,
85 | parallel_output=True,
86 | pre_process=True,
87 | post_process=True,
88 | continuous_embed_input_size=None,
89 | ):
90 | args = get_args()
91 | super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings)
92 | self.parallel_output = parallel_output
93 | self.pre_process = pre_process
94 | self.post_process = post_process
95 | self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
96 | self.share_word_embeddings = not args.untie_embeddings
97 | self.use_fast_cross_entropy = args.use_fast_cross_entropy
98 | if self.use_fast_cross_entropy:
99 | if CrossEntropyLossParallel is None:
100 | raise ImportError("xentropy CUDA extension is not installed")
101 | if rearrange is None:
102 | raise ImportError("einops is not installed")
103 |
104 | self.language_model, self._language_model_key = get_language_model(
105 | num_tokentypes=num_tokentypes,
106 | add_pooler=False,
107 | encoder_attn_mask_type=AttnMaskType.causal,
108 | init_method=init_method_normal(args.init_method_std),
109 | scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers),
110 | pre_process=self.pre_process,
111 | post_process=self.post_process,
112 | continuous_embed_input_size=continuous_embed_input_size,
113 | )
114 |
115 | self.initialize_word_embeddings(init_method_normal)
116 |
117 | def set_input_tensor(self, input_tensor):
118 | """See megatron.model.transformer.set_input_tensor()"""
119 | self.language_model.set_input_tensor(input_tensor)
120 |
121 | def forward(
122 | self,
123 | input_ids,
124 | position_ids,
125 | labels=None,
126 | tokentype_ids=None,
127 | inference_params=None,
128 | lm_logits_mask=None,
129 | ):
130 | lm_output = self.language_model(input_ids, position_ids, inference_params=inference_params)
131 |
132 | if lm_logits_mask is not None:
133 | # lm_logits_mask has shape [s] while lm_output has shape [s, b, h].
134 | lm_output = lm_output[lm_logits_mask, :, :]
135 |
136 | del tokentype_ids
137 | if self.post_process:
138 | return post_language_model_processing(
139 | lm_output,
140 | labels,
141 | self.word_embeddings_weight(),
142 | self.parallel_output,
143 | self.fp16_lm_cross_entropy,
144 | use_fast_cross_entropy=self.use_fast_cross_entropy,
145 | )
146 | else:
147 | return lm_output
148 |
149 | def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
150 |
151 | state_dict_ = {}
152 | state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
153 | destination, prefix, keep_vars
154 | )
155 | # Save word_embeddings.
156 | # When it is the last stage of pipelining or if we are not sharing word
157 | # embeddings, then we need to save the variable to the state dict.
158 | if self.post_process and (not self.pre_process or not self.share_word_embeddings):
159 | state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
160 | destination, prefix, keep_vars
161 | )
162 | return state_dict_
163 |
164 | def load_state_dict(self, state_dict, strict=True):
165 | """Customized load."""
166 |
167 | # Load word_embeddings.
168 | if (self.post_process and not self.pre_process) or not self.share_word_embeddings:
169 | if self._word_embeddings_for_head_key in state_dict:
170 | self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
171 | if self._language_model_key in state_dict:
172 | state_dict = state_dict[self._language_model_key]
173 | self.language_model.load_state_dict(state_dict, strict=strict)
174 |
--------------------------------------------------------------------------------
/megatron/model/positional_embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Positional embeddings modules."""
15 | import torch
16 | import math
17 |
18 |
19 | class RotaryEmbedding(torch.nn.Module):
20 | def __init__(self, dim, base=10000, precision=torch.bfloat16):
21 | super().__init__()
22 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
23 | self.register_buffer("inv_freq", inv_freq)
24 | self.seq_len_cached = None
25 | self.cos_cached = None
26 | self.sin_cached = None
27 | self.precision = precision
28 |
29 | def forward(self, x, seq_dim=1, seq_len=None):
30 | if seq_len is None:
31 | seq_len = x.shape[seq_dim]
32 | if seq_len != self.seq_len_cached:
33 | self.seq_len_cached = seq_len
34 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
35 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
36 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
37 | if self.precision == torch.bfloat16:
38 | emb = emb.float()
39 | self.cos_cached = emb.cos()[:, None, None, :]
40 | self.sin_cached = emb.sin()[:, None, None, :]
41 | if self.precision == torch.bfloat16:
42 | self.cos_cached = self.cos_cached.bfloat16()
43 | self.sin_cached = self.sin_cached.bfloat16()
44 | return self.cos_cached, self.sin_cached
45 |
46 |
47 | def rotate_half(x):
48 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
49 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
50 |
51 |
52 | @torch.jit.script
53 | def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
54 | cos, sin = (
55 | cos[offset : q.shape[0] + offset, ...],
56 | sin[offset : q.shape[0] + offset, ...],
57 | )
58 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
59 |
60 |
61 | def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
62 | cos, sin = (
63 | cos[offset : q.shape[0] + offset, ...],
64 | sin[offset : q.shape[0] + offset, ...],
65 | )
66 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
67 |
--------------------------------------------------------------------------------
/megatron/model/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Utilities for models."""
20 |
21 | import math
22 | import socket
23 |
24 | import torch
25 | from numpy.random import default_rng, SeedSequence
26 | import multiprocessing
27 | import concurrent.futures
28 | import numpy as np
29 | from typing import Sequence, List
30 | import functools
31 |
32 | from megatron import get_args
33 |
34 | from megatron import mpu
35 |
36 |
37 | def init_method_normal(sigma):
38 | """Init method based on N(0, sigma)."""
39 |
40 | def fill(rng: np.random.Generator, shape: Sequence[int]):
41 | n_threads = 16
42 | n_elements = np.prod(shape)
43 | values = np.empty(n_elements)
44 | assert n_elements % n_threads == 0, "Number of elements must be a multiple of number of threads!"
45 | step = np.ceil(n_elements / n_threads).astype(np.int_)
46 |
47 | # TODO(erich): hacky way to get new generator seeds, we should use spawn as soon as we have
48 | # numpy 1.25 and not this!!!
49 | seeds = rng.integers(0, 1 << 63, n_threads)
50 |
51 | _random_generators = [np.random.default_rng(s) for s in seeds]
52 |
53 | executor = concurrent.futures.ThreadPoolExecutor(n_threads)
54 |
55 | def _fill(generator, out, first, last):
56 | out[first:last] = generator.normal(loc=0.0, scale=sigma, size=step)
57 |
58 | futures = {}
59 | for i in range(n_threads):
60 | args = (_fill, _random_generators[i], values, i * step, (i + 1) * step)
61 | futures[executor.submit(*args)] = i
62 | concurrent.futures.wait(futures)
63 | return np.reshape(values, shape)
64 |
65 | return fill
66 |
67 |
68 | # TODO(erich): combine these two functions to reduce code duplication
69 | def scaled_init_method_normal(sigma, num_layers):
70 | """Init method based on N(0, sigma/sqrt(2*num_layers)."""
71 | std = sigma / math.sqrt(2.0 * num_layers)
72 |
73 | def fill(rng: np.random.Generator, shape: Sequence[int]):
74 | n_threads = 16
75 | n_elements = np.prod(shape)
76 | values = np.empty(n_elements)
77 | assert n_elements % n_threads == 0, "Number of elements must be a multiple of number of threads!"
78 | step = np.ceil(n_elements / n_threads).astype(np.int_)
79 |
80 | # TODO(erich): hacky way to get new generator seeds, we should use spawn as soon as we have
81 | # numpy 1.25 and not this!!!
82 | seeds = rng.integers(0, 1 << 63, n_threads)
83 |
84 | _random_generators = [np.random.default_rng(s) for s in seeds]
85 | executor = concurrent.futures.ThreadPoolExecutor(n_threads)
86 |
87 | def _fill(generator, out, first, last):
88 | out[first:last] = generator.normal(loc=0.0, scale=std, size=step)
89 |
90 | futures = {}
91 | for i in range(n_threads):
92 | args = (_fill, _random_generators[i], values, i * step, (i + 1) * step)
93 | futures[executor.submit(*args)] = i
94 | concurrent.futures.wait(futures)
95 | return np.reshape(values, shape)
96 |
97 | return fill
98 |
99 |
100 | def check_shapes(tensor, expected_shapes=None, expected_ndim=None, expected_dtype=None):
101 | """Check that the passed-in `tensor` has the expected shape and ndim. Should an expected dim be None, it is not checked.
102 |
103 | Args:
104 | tensor: The tensor to shape-check.
105 | expected_shapes: The prefix of the shapes to expect in `tensor`.
106 | expected_ndim: The expected number of dimensions in `tensor`.
107 | expected_dtype: Expected dtype.
108 | """
109 | assert tensor is not None, "tensor is None"
110 |
111 | if expected_shapes is None and expected_ndim is None and expected_dtype is None:
112 | return
113 |
114 | if expected_dtype is not None:
115 | assert tensor.dtype == expected_dtype
116 |
117 | if expected_ndim is not None:
118 | assert tensor.ndim == expected_ndim, f"Unexpected ndims detected. Got: {tensor.ndim} expected: {expected_ndim}"
119 |
120 | expected_shapes = list(expected_shapes or [])
121 |
122 | if expected_ndim is not None:
123 | assert (
124 | len(expected_shapes) <= tensor.ndim
125 | ), f"Asking to check too many shapes. Got: {expected_shapes} expected: <= {tensor.ndim}"
126 |
127 | t_shapes = list(tensor.shape[: len(expected_shapes)])
128 | for i, e in enumerate(expected_shapes):
129 | if e is None:
130 | t_shapes[i] = None
131 | assert t_shapes == expected_shapes, f"Mismatch detected. Got: {t_shapes} expected: {expected_shapes}"
132 |
133 |
134 | def get_linear_layer(rows, columns, init_method):
135 | """Simple linear layer with weight initialization."""
136 | layer = torch.nn.Linear(rows, columns)
137 | if get_args().perform_initialization:
138 | init_method(layer.weight)
139 | with torch.no_grad():
140 | layer.bias.zero_()
141 | return layer
142 |
143 |
144 | @torch.jit.script
145 | def gelu_impl(x):
146 | """OpenAI's gelu implementation."""
147 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
148 |
149 |
150 | def openai_gelu(x):
151 | return gelu_impl(x)
152 |
153 |
154 | # This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
155 | @torch.jit.script
156 | def erf_gelu(x):
157 | return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
158 |
159 |
160 | def print_named_parameters(model):
161 | """Print a summary of the parameters in a model."""
162 | prefix = ""
163 | if torch.distributed.is_initialized():
164 | # Print on only the first data parallel rank, but on all tensor/pipeline parallel ranks.
165 | should_print = mpu.get_data_parallel_rank() == 0
166 | if mpu.get_tensor_model_parallel_world_size() > 1:
167 | prefix = f"tensor-rank: {mpu.get_tensor_model_parallel_rank()} | "
168 | if mpu.get_pipeline_model_parallel_world_size() > 1:
169 | prefix = f"pipeline-rank: {mpu.get_pipeline_model_parallel_rank()} | "
170 | else:
171 | should_print = get_args().rank == 0
172 |
173 | if not should_print:
174 | return
175 |
176 | print(f"{prefix} > {type(model).__name__} parameters: ", flush=True)
177 | for name, param in model.named_parameters():
178 | if mpu.param_is_tensor_parallel_unique(param):
179 | print(f"{prefix}{name=}, {param.shape=}, norm={torch.norm(param.data.float()).item()}", flush=True)
180 |
181 |
182 | def sync_data_parallel_replicated_parameters(a, b):
183 | pass
184 |
185 |
186 | def sync_tensor_parallel_replicated_parameters(a, b):
187 | pass
188 |
189 |
190 | def _sync_replicated_parameters(a, b, c, d, e):
191 |
192 | def _sync(p):
193 | pass
194 |
195 |
196 | class ReplicationMismatchError(Exception):
197 | pass
198 |
199 |
200 | class TensorParallelReplicationMismatchError(ReplicationMismatchError):
201 | pass
202 |
203 |
204 | class DataParallelReplicationMismatchError(ReplicationMismatchError):
205 | pass
206 |
207 |
208 | def validate_replicated_parameters(m, o):
209 |
210 | def collect_and_validate(a, b, c, d, prefix):
211 | pass
212 |
213 |
214 | def _validate_replicated_parameters(
215 | *, a, b, c, d, prefix=""
216 | ):
217 | pass
218 |
--------------------------------------------------------------------------------
/megatron/mpu/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Model parallel utility interface."""
20 |
21 | from .cross_entropy import vocab_parallel_cross_entropy
22 |
23 | from .initialize import is_unitialized
24 | from .initialize import destroy_model_parallel
25 | from .initialize import force_communicator_creation
26 | from .initialize import get_data_parallel_group
27 | from .initialize import get_data_parallel_rank
28 | from .initialize import get_data_parallel_world_size
29 | from .initialize import get_embedding_group
30 | from .initialize import get_position_embedding_group
31 | from .initialize import get_model_parallel_group
32 | from .initialize import get_tensor_model_parallel_group
33 | from .initialize import get_pipeline_model_parallel_group
34 | from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
35 | from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
36 | from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
37 | from .initialize import is_rank_in_embedding_group
38 | from .initialize import is_rank_in_position_embedding_group
39 | from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
40 | from .initialize import is_pipeline_stage_at_split
41 | from .initialize import get_num_layers
42 | from .initialize import get_tensor_model_parallel_src_rank
43 | from .initialize import get_data_parallel_src_rank
44 | from .initialize import get_pipeline_model_parallel_first_rank
45 | from .initialize import get_pipeline_model_parallel_last_rank
46 | from .initialize import get_pipeline_model_parallel_next_rank
47 | from .initialize import get_pipeline_model_parallel_prev_rank
48 | from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
49 | from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
50 | from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
51 | from .initialize import initialize_model_parallel
52 | from .initialize import model_parallel_is_initialized
53 |
54 | from .layers import LinearWithGradAccumulationAndAsyncCommunication
55 | from .layers import ColumnParallelLinear
56 | from .layers import ParallelLinear
57 | from .layers import RowParallelLinear
58 | from .layers import VocabParallelEmbedding
59 | from .layers import (
60 | set_tensor_model_parallel_attributes,
61 | set_defaults_if_not_set_tensor_model_parallel_attributes,
62 | copy_tensor_model_parallel_attributes,
63 | param_is_tensor_parallel_unique,
64 | param_is_tensor_parallel_replicated,
65 | )
66 |
67 | from .mappings import reduce_backward_from_tensor_model_parallel_region
68 | from .mappings import reduce_forward_from_tensor_model_parallel_region
69 | from .mappings import scatter_to_tensor_model_parallel_region
70 | from .mappings import gather_from_tensor_model_parallel_region
71 | from .mappings import scatter_to_sequence_parallel_region
72 | from .mappings import gather_from_sequence_parallel_region
73 | from .mappings import reduce_scatter_to_sequence_parallel_region
74 |
75 | from .random import checkpoint
76 | from .random import get_cuda_rng_tracker
77 | from .random import model_parallel_cuda_manual_seed
78 | from .random import gather_split_1d_tensor
79 | from .random import split_tensor_into_1d_equal_chunks
80 | from .random import make_viewless_tensor
81 | from .random import assert_viewless_tensor
82 | from .random import safely_set_viewless_tensor_data
83 |
84 | from .utils import divide
85 | from .utils import split_tensor_along_last_dim
86 |
--------------------------------------------------------------------------------
/megatron/mpu/communication.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Communications utilities."""
20 |
21 |
22 | from typing import Sequence, Union
23 | import collections
24 | import torch
25 |
26 | from megatron import mpu
27 | from megatron.model.utils import check_shapes
28 |
29 |
30 | # TODO: use functions from megatron/p2p
31 | def recv_from_prev_pipeline_rank_(recv_buffer=None):
32 | """Receive from previous pipeline stage and update the
33 | input buffer inplace."""
34 | if not mpu.is_pipeline_first_stage():
35 | assert recv_buffer is not None
36 | recv_prev_op = torch.distributed.P2POp(
37 | torch.distributed.irecv, recv_buffer, mpu.get_pipeline_model_parallel_prev_rank()
38 | )
39 | reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
40 | for req in reqs:
41 | req.wait()
42 | # To protect against race condition when using batch_isend_irecv().
43 | torch.cuda.synchronize()
44 |
45 |
46 | # TODO: use functions from megatron/p2p
47 | def send_to_next_pipeline_rank(tensor=None):
48 | """Send output to the next pipeline stage."""
49 | if not mpu.is_pipeline_last_stage():
50 | assert tensor is not None
51 | send_next_op = torch.distributed.P2POp(
52 | torch.distributed.isend, tensor, mpu.get_pipeline_model_parallel_next_rank()
53 | )
54 | reqs = torch.distributed.batch_isend_irecv([send_next_op])
55 | for req in reqs:
56 | req.wait()
57 | # To protect against race condition when using batch_isend_irecv().
58 | torch.cuda.synchronize()
59 |
60 |
61 | def _is_cuda(tensor):
62 | """Check if a tensor is not none and is cuda."""
63 | assert tensor is not None
64 | assert tensor.is_cuda
65 |
66 |
67 | def _is_cuda_contiguous(tensor):
68 | """Check if a tensor is not none, is cuda, and is contiguous."""
69 | _is_cuda(tensor)
70 | assert tensor.is_contiguous()
71 |
72 |
73 | def _check_shapes(tensor: torch.Tensor, size: Union[int, Sequence[int]], dtype: torch.dtype):
74 | """Check that the sending and receiver tensors will be compatible."""
75 | # This logic is required because torch.empty will promote a size of int to [int] and there are lots
76 | # of callers that rely on this convention.
77 | shape = size if isinstance(size, collections.abc.Sequence) else [size]
78 |
79 | check_shapes(tensor, expected_shapes=shape, expected_ndim=len(shape), expected_dtype=dtype)
80 |
81 |
82 | def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
83 | """Broadcast a tensor from last pipeline stage to all ranks."""
84 |
85 | is_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=True)
86 | # If first stage and last state are the same, then there is no
87 | # pipeline parallelism and no need to communicate.
88 | if mpu.is_pipeline_first_stage() and is_last_stage:
89 | return tensor
90 |
91 | if is_last_stage:
92 | _is_cuda_contiguous(tensor)
93 | _check_shapes(tensor, size, dtype)
94 | else:
95 | tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device())
96 | # Get the group and corresponding source rank.
97 | src = mpu.get_pipeline_model_parallel_last_rank()
98 | group = mpu.get_pipeline_model_parallel_group()
99 | torch.distributed.broadcast(tensor, src, group)
100 |
101 | return tensor
102 |
103 |
104 | def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
105 | """Broadcast tensor values from last stage into the first stage."""
106 |
107 | is_last_stage = mpu.is_pipeline_last_stage()
108 | is_first_stage = mpu.is_pipeline_first_stage()
109 | # If first stage and last state are the same, then there is no
110 | # pipeline parallelism and no need to communicate.
111 | if is_first_stage and is_last_stage:
112 | return tensor
113 | # Only first and last stage pipeline stages need to be involved.
114 | if is_last_stage or is_first_stage:
115 | if is_last_stage:
116 | _is_cuda_contiguous(tensor)
117 | _check_shapes(tensor, size, dtype)
118 | else:
119 | tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device())
120 | src = mpu.get_pipeline_model_parallel_last_rank()
121 | group = mpu.get_embedding_group()
122 | # Broadcast from last stage into the first stage.
123 | torch.distributed.broadcast(tensor, src, group)
124 | else:
125 | tensor = None
126 |
127 | return tensor
128 |
129 |
130 | def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
131 | """Copy tensor values from last stage into the first stage.
132 | Note that the input tensor is updated in place."""
133 |
134 | is_last_stage = mpu.is_pipeline_last_stage()
135 | is_first_stage = mpu.is_pipeline_first_stage()
136 | # If first stage and last state are the same, then there is no
137 | # pipeline parallelism and no need to communicate.
138 | if is_first_stage and is_last_stage:
139 | return
140 | # Only first and last stage pipeline stages need to be involved.
141 | if is_last_stage or is_first_stage:
142 | _is_cuda(tensor)
143 | is_contiguous = tensor.is_contiguous()
144 | src = mpu.get_pipeline_model_parallel_last_rank()
145 | group = mpu.get_embedding_group()
146 | if is_contiguous:
147 | tensor_ = tensor
148 | else:
149 | if is_last_stage:
150 | tensor_ = tensor.contiguous()
151 | else:
152 | tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device())
153 | # Broadcast from last stage into the first stage.
154 | torch.distributed.broadcast(tensor_, src, group)
155 | # Update the first stage tensor
156 | if is_first_stage and not is_contiguous:
157 | tensor[...] = tensor_
158 |
159 |
160 | def broadcast_tensor(size, dtype, tensor=None, rank=0, device=0):
161 | """Given size and type of a tensor on all ranks and the tensor value
162 | only on a specific rank, broadcast from that rank to all other ranks.
163 | Args:
164 | size: size of the tensor
165 | dtype: type of the tensor
166 | tensor: tensor to be broadcasted
167 | rank: primary rank for broadcasting
168 | device: device of the tensor. If not set to None, then we use cuda.current_device().
169 | Default is 0, since we use cuda.current_device() to get the device.
170 | """
171 | if device is not None:
172 | device = torch.cuda.current_device()
173 | if torch.distributed.get_rank() == rank:
174 | if device is not None:
175 | _is_cuda_contiguous(tensor)
176 | _check_shapes(tensor, size, dtype)
177 | else:
178 | tensor = torch.empty(size, dtype=dtype, device=device)
179 | torch.distributed.broadcast(tensor, rank)
180 | return tensor
181 |
182 |
183 | def broadcast_list(size, dtype, list_values=None, rank=0, device=0):
184 | """Broadcast a list of values with a given type.
185 | Args:
186 | size: size of the list
187 | dtype: dtype of the list
188 | list_values: list of values to be broadcasted
189 | rank: primary rank for broadcasting
190 | device: device of the tensor. If not set to None, then we use cuda.current_device().
191 | Default is 0, since we use cuda.current_device() to get the device.
192 | """
193 | tensor = None
194 | if device is not None:
195 | device = torch.cuda.current_device()
196 | if torch.distributed.get_rank() == rank:
197 | tensor = torch.tensor(list_values, dtype=dtype, device=device)
198 | return broadcast_tensor(size, dtype, tensor=tensor, rank=rank, device=device)
199 |
200 |
201 | def broadcast_int_list(size, int_list=None, rank=0, device=0):
202 | """Broadcast a list of interger values.
203 | Args:
204 | size: size of the list
205 | int_list: list of values to be broadcasted
206 | rank: primary rank for broadcasting
207 | device: device of the tensor. If not set to None, then we use cuda.current_device().
208 | Default is 0, since we use cuda.current_device() to get the device.
209 | """
210 | if device is not None:
211 | device = torch.cuda.current_device()
212 | return broadcast_list(size, torch.int64, list_values=int_list, rank=rank, device=device)
213 |
214 |
215 | def broadcast_float_list(size, float_list=None, rank=0, device=0):
216 | """Broadcast a list of float values.
217 | Args:
218 | size: size of the list
219 | float_list: list of values to be broadcasted
220 | rank: primary rank for broadcasting
221 | device: device of the tensor. If not set to None, then we use cuda.current_device().
222 | Default is 0, since we use cuda.current_device() to get the device.
223 | """
224 | if device is not None:
225 | device = torch.cuda.current_device()
226 | return broadcast_list(size, torch.float32, list_values=float_list, rank=rank, device=device)
227 |
--------------------------------------------------------------------------------
/megatron/mpu/cross_entropy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | import torch
21 |
22 | from .initialize import (
23 | get_tensor_model_parallel_group,
24 | get_tensor_model_parallel_rank,
25 | get_tensor_model_parallel_world_size,
26 | )
27 | from .utils import VocabUtility
28 |
29 |
30 | class _VocabParallelCrossEntropy(torch.autograd.Function):
31 | @staticmethod
32 | def forward(ctx, vocab_parallel_logits, target):
33 |
34 | # Maximum value along vocab dimension across all GPUs.
35 | logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
36 | torch.distributed.all_reduce(
37 | logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
38 | )
39 | # Subtract the maximum value.
40 | vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
41 |
42 | # Get the partition's vocab indecies
43 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
44 | partition_vocab_size = vocab_parallel_logits.size()[-1]
45 | rank = get_tensor_model_parallel_rank()
46 | world_size = get_tensor_model_parallel_world_size()
47 | vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
48 |
49 | # Create a mask of valid vocab ids (1 means it needs to be masked).
50 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
51 | masked_target = target.clone() - vocab_start_index
52 | masked_target[target_mask] = 0
53 |
54 | # Get predicted-logits = logits[target].
55 | # For Simplicity, we convert logits to a 2-D tensor with size
56 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
57 | logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
58 | masked_target_1d = masked_target.view(-1)
59 | arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
60 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
61 | predicted_logits_1d = predicted_logits_1d.clone().contiguous()
62 | predicted_logits = predicted_logits_1d.view_as(target)
63 | predicted_logits[target_mask] = 0.0
64 | # All reduce is needed to get the chunks from other GPUs.
65 | torch.distributed.all_reduce(
66 | predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()
67 | )
68 |
69 | # Sum of exponential of logits along vocab dimension across all GPUs.
70 | exp_logits = vocab_parallel_logits
71 | torch.exp(vocab_parallel_logits, out=exp_logits)
72 | sum_exp_logits = exp_logits.sum(dim=-1)
73 | torch.distributed.all_reduce(
74 | sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()
75 | )
76 |
77 | # Loss = log(sum(exp(logits))) - predicted-logit.
78 | loss = torch.log(sum_exp_logits) - predicted_logits
79 |
80 | # Store softmax, target-mask and masked-target for backward pass.
81 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
82 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
83 |
84 | return loss
85 |
86 | @staticmethod
87 | def backward(ctx, grad_output):
88 |
89 | # Retreive tensors from the forward path.
90 | softmax, target_mask, masked_target_1d = ctx.saved_tensors
91 |
92 | # All the inputs have softmax as thier gradient.
93 | grad_input = softmax
94 | # For simplicity, work with the 2D gradient.
95 | partition_vocab_size = softmax.size()[-1]
96 | grad_2d = grad_input.view(-1, partition_vocab_size)
97 |
98 | # Add the gradient from matching classes.
99 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
100 | grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
101 |
102 | # Finally elementwise multiplication with the output gradients.
103 | grad_input.mul_(grad_output.unsqueeze(dim=-1))
104 |
105 | return grad_input, None
106 |
107 |
108 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
109 | """Helper function for the cross entropy."""
110 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
111 |
--------------------------------------------------------------------------------
/megatron/mpu/cross_entropy_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 ADEPT AI LABS INC.
2 | # This file is based on code by the authors denoted below and has been modified from its original version.
3 | #
4 | # Taken from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/losses/cross_entropy_parallel.py
5 | # But we import from Megatron instead of from Apex
6 | # HazyResearch is licensed under BSD 3.0
7 |
8 | # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
9 | # But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
10 | # the losses we can get the global loss. There's no need to do it step by step
11 | # (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
12 | import torch
13 | import torch.nn as nn
14 |
15 | import xentropy_cuda_lib
16 |
17 | from .initialize import get_tensor_model_parallel_group
18 | from .initialize import get_tensor_model_parallel_rank
19 | from .initialize import get_tensor_model_parallel_world_size
20 | from .utils import VocabUtility
21 |
22 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
23 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
24 | # version of PyTorch. The following 4 lines are for backward compatibility with
25 | # older PyTorch.
26 | if "all_gather_into_tensor" not in dir(torch.distributed):
27 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
28 | if "reduce_scatter_tensor" not in dir(torch.distributed):
29 | torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
30 |
31 |
32 | class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
33 | @staticmethod
34 | def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False):
35 | """
36 | logits_parallel: (batch, vocab_size / world_size)
37 | labels: (batch,)
38 | """
39 | batch, partition_vocab_size = logits_parallel.shape
40 | assert labels.shape == (batch,)
41 | rank = get_tensor_model_parallel_rank()
42 | world_size = get_tensor_model_parallel_world_size()
43 |
44 | if world_size == 1:
45 | losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
46 | losses.masked_fill_(labels == ignored_index, 0)
47 | labels_local = labels
48 | else:
49 | vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
50 | partition_vocab_size, get_tensor_model_parallel_rank(), get_tensor_model_parallel_world_size()
51 | )
52 |
53 | # Create a mask of valid vocab ids (1 means it needs to be masked).
54 | labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
55 | ignored_mask = labels == ignored_index
56 | labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
57 | masked_labels = labels_local.clone()
58 | masked_labels[labels_mask] = ignored_index
59 |
60 | losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
61 | assert lse_local.shape == (batch,)
62 | assert losses.shape == (batch,)
63 | losses.masked_fill_(masked_labels == ignored_index, 0)
64 |
65 | lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, device=lse_local.device)
66 | handle_lse = torch.distributed.all_gather_into_tensor(
67 | lse_allgather, lse_local.contiguous(), group=get_tensor_model_parallel_group(), async_op=True
68 | )
69 | handle_losses = torch.distributed.all_reduce(
70 | losses, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), async_op=True
71 | )
72 | handle_lse.wait()
73 | lse = torch.logsumexp(lse_allgather, dim=0)
74 | # The losses are going to be lse_local - predicted_logit, we just have to subtract
75 | # the lse_local and add the lse (global).
76 | rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode="floor")
77 | lse_local = lse_allgather[rank_per_sample, torch.arange(batch, device=lse_allgather.device)]
78 |
79 | handle_losses.wait()
80 | losses += lse - lse_local
81 | losses.masked_fill_(ignored_mask, 0)
82 |
83 | ctx.save_for_backward(logits_parallel, lse, labels_local)
84 | ctx.smoothing = smoothing
85 | ctx.ignored_index = ignored_index
86 | ctx.inplace_backward = inplace_backward
87 | return losses
88 |
89 | @staticmethod
90 | def backward(ctx, grad_loss):
91 | logits_parallel, lse, labels = ctx.saved_tensors
92 | if not grad_loss.is_contiguous():
93 | grad_loss = grad_loss.contiguous()
94 | grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
95 | grad_logits = xentropy_cuda_lib.backward(
96 | grad_loss, logits_parallel, lse, labels, ctx.smoothing, ctx.inplace_backward
97 | )
98 | return grad_logits, None, None, None, None, None
99 |
100 |
101 | class CrossEntropyLossParallel(nn.Module):
102 | def __init__(self, ignore_index=-100, reduction="mean", label_smoothing=0.0, inplace_backward=False):
103 | super().__init__()
104 | if reduction not in ["mean", "none"]:
105 | raise NotImplementedError("Only support reduction = 'mean' or 'none'")
106 | self.ignore_index = ignore_index
107 | self.reduction = reduction
108 | self.label_smoothing = label_smoothing
109 | self.inplace_backward = inplace_backward
110 |
111 | def forward(self, input, target):
112 | assert input.is_cuda and target.is_cuda
113 | # SoftmaxCrossEntropyLoss implicitly casts to float
114 | loss = SoftmaxCrossEntropyLossParallelFn.apply(
115 | input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
116 | )
117 | if self.reduction == "mean":
118 | return loss.sum() / (target != self.ignore_index).sum()
119 | else:
120 | return loss
121 |
--------------------------------------------------------------------------------
/megatron/mpu/mappings.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | import torch
20 |
21 | from .initialize import (
22 | get_tensor_model_parallel_group,
23 | get_tensor_model_parallel_rank,
24 | get_tensor_model_parallel_world_size,
25 | )
26 | from .utils import split_tensor_along_last_dim
27 |
28 |
29 | def _reduce(input_):
30 | """All-reduce the input tensor across model parallel group."""
31 |
32 | # Bypass the function if we are using only 1 GPU.
33 | if get_tensor_model_parallel_world_size() == 1:
34 | return input_
35 |
36 | # All-reduce, requires contiguous tensor.
37 | input_ = input_.contiguous()
38 | torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
39 |
40 | return input_
41 |
42 |
43 | def _split_along_last_dim(input_):
44 | """Split the tensor along its last dimension and keep the
45 | corresponding slice."""
46 |
47 | world_size = get_tensor_model_parallel_world_size()
48 | # Bypass the function if we are using only 1 GPU.
49 | if world_size == 1:
50 | return input_
51 |
52 | # Split along last dimension.
53 | input_list = split_tensor_along_last_dim(input_, world_size)
54 |
55 | # Note: torch.split does not create contiguous tensors by default.
56 | rank = get_tensor_model_parallel_rank()
57 | output = input_list[rank].contiguous()
58 |
59 | return output
60 |
61 |
62 | def _split_along_first_dim(input_):
63 | """Split the tensor along its first dimension and keep the
64 | corresponding slice."""
65 |
66 | world_size = get_tensor_model_parallel_world_size()
67 | # Bypass the function if we are using only 1 GPU.
68 | if world_size == 1:
69 | return input_
70 |
71 | # Split along first dimension.
72 | dim_size = input_.size()[0]
73 | assert (
74 | dim_size % world_size == 0
75 | ), f"First dimension of the tensor should be divisible by tensor parallel size. {dim_size} % {world_size} != 0"
76 | local_dim_size = dim_size // world_size
77 | rank = get_tensor_model_parallel_rank()
78 | dim_offset = rank * local_dim_size
79 |
80 | output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
81 |
82 | return output
83 |
84 |
85 | def _gather_along_last_dim(input_):
86 | """Gather tensors and concatinate along the last dimension."""
87 |
88 | world_size = get_tensor_model_parallel_world_size()
89 | # Bypass the function if we are using only 1 GPU.
90 | if world_size == 1:
91 | return input_
92 |
93 | # Size and dimension.
94 | last_dim = input_.dim() - 1
95 | rank = get_tensor_model_parallel_rank()
96 |
97 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
98 | tensor_list[rank] = input_
99 | torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
100 |
101 | # Note: torch.cat already creates a contiguous tensor.
102 | output = torch.cat(tensor_list, dim=last_dim).contiguous()
103 |
104 | return output
105 |
106 |
107 | def _gather_along_first_dim(input_):
108 | """Gather tensors and concatinate along the first dimension."""
109 |
110 | world_size = get_tensor_model_parallel_world_size()
111 | # Bypass the function if we are using only 1 GPU.
112 | if world_size == 1:
113 | return input_
114 |
115 | dim_size = list(input_.size())
116 | dim_size[0] = dim_size[0] * world_size
117 |
118 | output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
119 | torch.distributed._all_gather_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
120 |
121 | return output
122 |
123 |
124 | def _reduce_scatter_along_first_dim(input_):
125 | """Reduce-scatter the input tensor across model parallel group."""
126 | world_size = get_tensor_model_parallel_world_size()
127 | # Bypass the function if we are using only 1 GPU.
128 | if world_size == 1:
129 | return input_
130 |
131 | dim_size = list(input_.size())
132 | assert dim_size[0] % world_size == 0, "First dimension of the tensor should be divisible by tensor parallel size"
133 |
134 | dim_size[0] = dim_size[0] // world_size
135 |
136 | output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
137 | torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=get_tensor_model_parallel_group())
138 | return output
139 |
140 |
141 | class _ReduceBackwardFromTensorModelParallelRegion(torch.autograd.Function):
142 | """All-reduce the backward pass input from the model parallel region."""
143 |
144 | @staticmethod
145 | def symbolic(graph, input_):
146 | return input_
147 |
148 | @staticmethod
149 | def forward(ctx, input_):
150 | return input_
151 |
152 | @staticmethod
153 | def backward(ctx, grad_output):
154 | return _reduce(grad_output)
155 |
156 |
157 | class _ReduceForwardFromTensorModelParallelRegion(torch.autograd.Function):
158 | """All-reduce the input from the model parallel region."""
159 |
160 | @staticmethod
161 | def symbolic(graph, input_):
162 | return _reduce(input_)
163 |
164 | @staticmethod
165 | def forward(ctx, input_):
166 | return _reduce(input_)
167 |
168 | @staticmethod
169 | def backward(ctx, grad_output):
170 | return grad_output
171 |
172 |
173 | class _ScatterToModelParallelRegion(torch.autograd.Function):
174 | """Split the input and keep only the corresponding chuck to the rank."""
175 |
176 | @staticmethod
177 | def symbolic(graph, input_):
178 | return _split_along_last_dim(input_)
179 |
180 | @staticmethod
181 | def forward(ctx, input_):
182 | return _split_along_last_dim(input_)
183 |
184 | @staticmethod
185 | def backward(ctx, grad_output):
186 | return _gather_along_last_dim(grad_output)
187 |
188 |
189 | class _GatherFromModelParallelRegion(torch.autograd.Function):
190 | """Gather the input from model parallel region and concatinate."""
191 |
192 | @staticmethod
193 | def symbolic(graph, input_):
194 | return _gather_along_last_dim(input_)
195 |
196 | @staticmethod
197 | def forward(ctx, input_):
198 | return _gather_along_last_dim(input_)
199 |
200 | @staticmethod
201 | def backward(ctx, grad_output):
202 | return _split_along_last_dim(grad_output)
203 |
204 |
205 | class _ScatterToSequenceParallelRegion(torch.autograd.Function):
206 | """Split the input and keep only the corresponding chuck to the rank."""
207 |
208 | @staticmethod
209 | def symbolic(graph, input_):
210 | return _split_along_first_dim(input_)
211 |
212 | @staticmethod
213 | def forward(ctx, input_):
214 | return _split_along_first_dim(input_)
215 |
216 | @staticmethod
217 | def backward(ctx, grad_output):
218 | return _gather_along_first_dim(grad_output)
219 |
220 |
221 | class _GatherFromSequenceParallelRegion(torch.autograd.Function):
222 | """Gather the input from sequence parallel region and concatinate."""
223 |
224 | @staticmethod
225 | def symbolic(graph, input_, tensor_parallel_output_grad=True):
226 | return _gather_along_first_dim(input_)
227 |
228 | @staticmethod
229 | def forward(ctx, input_, tensor_parallel_output_grad=True):
230 | ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
231 | return _gather_along_first_dim(input_)
232 |
233 | @staticmethod
234 | def backward(ctx, grad_output):
235 | tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
236 |
237 | # If the computation graph after the gather operation is
238 | # in the tensor parallel mode, output gradients need to reduce
239 | # scattered and whereas if the computation is duplicated,
240 | # output gradients need to be scattered.
241 | if tensor_parallel_output_grad:
242 | return _reduce_scatter_along_first_dim(grad_output), None
243 | else:
244 | return _split_along_first_dim(grad_output), None
245 |
246 |
247 | class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
248 | """Reduce scatter the input from the model parallel region."""
249 |
250 | @staticmethod
251 | def symbolic(graph, input_):
252 | return _reduce_scatter_along_first_dim(input_)
253 |
254 | @staticmethod
255 | def forward(ctx, input_):
256 | return _reduce_scatter_along_first_dim(input_)
257 |
258 | @staticmethod
259 | def backward(ctx, grad_output):
260 | return _gather_along_first_dim(grad_output)
261 |
262 |
263 | # -----------------
264 | # Helper functions.
265 | # -----------------
266 |
267 |
268 | def reduce_backward_from_tensor_model_parallel_region(input_):
269 | return _ReduceBackwardFromTensorModelParallelRegion.apply(input_)
270 |
271 |
272 | def reduce_forward_from_tensor_model_parallel_region(input_):
273 | return _ReduceForwardFromTensorModelParallelRegion.apply(input_)
274 |
275 |
276 | def scatter_to_tensor_model_parallel_region(input_):
277 | return _ScatterToModelParallelRegion.apply(input_)
278 |
279 |
280 | def gather_from_tensor_model_parallel_region(input_):
281 | return _GatherFromModelParallelRegion.apply(input_)
282 |
283 |
284 | def scatter_to_sequence_parallel_region(input_):
285 | return _ScatterToSequenceParallelRegion.apply(input_)
286 |
287 |
288 | def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
289 | return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
290 |
291 |
292 | def reduce_scatter_to_sequence_parallel_region(input_):
293 | return _ReduceScatterToSequenceParallelRegion.apply(input_)
294 |
--------------------------------------------------------------------------------
/megatron/mpu/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | import torch
21 |
22 |
23 | def ensure_divisibility(numerator, denominator):
24 | """Ensure that numerator is divisible by the denominator."""
25 | assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
26 |
27 |
28 | def divide(numerator, denominator):
29 | """Ensure that numerator is divisible by the denominator and return
30 | the division value."""
31 | ensure_divisibility(numerator, denominator)
32 | return numerator // denominator
33 |
34 |
35 | def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
36 | """Split a tensor along its last dimension.
37 | Arguments:
38 | tensor: input tensor.
39 | num_partitions: number of partitions to split the tensor
40 | contiguous_split_chunks: If True, make each chunk contiguous
41 | in memory.
42 | """
43 | # Get the size and dimension.
44 | last_dim = tensor.dim() - 1
45 | last_dim_size = divide(tensor.size()[last_dim], num_partitions)
46 | # Split.
47 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
48 | # Note: torch.split does not create contiguous tensors by default.
49 | if contiguous_split_chunks:
50 | return tuple(chunk.contiguous() for chunk in tensor_list)
51 |
52 | return tensor_list
53 |
54 |
55 | class VocabUtility:
56 | """Split the vocabulary into `world_size` chunks amd return the
57 | first and last index of the vocabulary belonging to the `rank`
58 | partition: Note that indecies in [fist, last)"""
59 |
60 | @staticmethod
61 | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
62 | index_f = rank * per_partition_vocab_size
63 | index_l = index_f + per_partition_vocab_size
64 | return index_f, index_l
65 |
66 | @staticmethod
67 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
68 | per_partition_vocab_size = divide(global_vocab_size, world_size)
69 | return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
70 |
--------------------------------------------------------------------------------
/megatron/text_generation/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | from .api import generate, generate_and_post_process
21 |
--------------------------------------------------------------------------------
/megatron/text_generation/api.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Inference API."""
20 |
21 |
22 | import base64
23 | from typing import List, Optional, Tuple, Any, Dict
24 |
25 | import numpy as np
26 | import torch
27 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
28 |
29 | from megatron.model.module import MegatronModule
30 | from megatron import mpu
31 | from megatron.model import DistributedDataParallel as LocalDDP
32 | from megatron.model import Float16Module
33 |
34 | from megatron.tokenizer.tokenizer import AbstractTokenizer
35 | from megatron.mpu.communication import broadcast_float_list, broadcast_int_list
36 | from megatron.text_generation.generation import (
37 | generate_tokens_probs_and_return_on_first_stage,
38 | score_and_return_on_first_stage,
39 | )
40 | from megatron.text_generation.inference_params import InferenceParams
41 | from megatron.text_generation.tokenization import (
42 | convert_generations_to_human_readable_tokens,
43 | tokenize_prompts,
44 | )
45 | from megatron.utils import unwrap_model
46 | import numpy.typing as npt
47 |
48 | def preprocess_prompts(prompts: Optional[List[List[str]]]) -> Optional[List[List[str]]]:
49 | """
50 | Accepts a list of list of subprompts, returns a list of list of processed subprompts.
51 | """
52 | if prompts is None:
53 | return None
54 | processed_prompts = []
55 | for prompt in prompts:
56 | processed_subprompts = []
57 | for subprompt in prompt:
58 | processed_subprompts.append(f"human: {subprompt.strip()}\n\nadept:")
59 | processed_prompts.append(processed_subprompts)
60 | return processed_prompts
61 |
62 |
63 | def generate_and_post_process(
64 | model: MegatronModule,
65 | params_dtype: torch.dtype,
66 | max_position_embeddings: int,
67 | termination_id: int,
68 | tokenizer: AbstractTokenizer,
69 | prompts: Optional[List[List[str]]] = None,
70 | max_tokens_to_generate: int = 0,
71 | inference_params: Optional[InferenceParams] = None,
72 | return_output_log_probs: bool = False,
73 | return_all_log_probs: bool = False,
74 | log_prob_tokens: Optional[torch.Tensor] = None,
75 | top_k_sampling: int = 0,
76 | top_p_sampling: float = 0.0,
77 | temperature: float = 1.0,
78 | add_BOS: bool = False,
79 | random_seed: int = -1,
80 | process_prompts_for_chat: bool = False
81 | ) -> Optional[
82 | Tuple[
83 | List[str],
84 | Optional[Any],
85 | List[List[int]],
86 | List[str],
87 | Optional[str],
88 | Optional[Any],
89 | List[List[str]],
90 | Optional[Any],
91 | ]
92 | ]:
93 | """Run inference and post-process outputs, i.e., detokenize,
94 | move to cpu and convert to list.
95 |
96 | prompts: a list of list of strings, where each element in the outer list represents a single sample that
97 | is an item in the batch. A single sample is represented as a list of strings.
98 | """
99 | # Pre-process the prompts.
100 | if process_prompts_for_chat:
101 | prompts = preprocess_prompts(prompts)
102 |
103 | # Main inference.
104 | outputs = generate(
105 | model,
106 | max_position_embeddings=max_position_embeddings,
107 | params_dtype=params_dtype,
108 | termination_id=termination_id,
109 | tokenizer=tokenizer,
110 | prompts=prompts,
111 | max_tokens_to_generate=max_tokens_to_generate,
112 | inference_params=inference_params,
113 | return_output_log_probs=return_output_log_probs,
114 | return_all_log_probs=return_all_log_probs,
115 | log_prob_tokens=log_prob_tokens,
116 | top_k_sampling=top_k_sampling,
117 | top_p_sampling=top_p_sampling,
118 | temperature=temperature,
119 | add_BOS=add_BOS,
120 | random_seed=random_seed,
121 | )
122 |
123 | # Only post-process on first stage.
124 | if mpu.is_pipeline_first_stage():
125 |
126 | all_tokens = outputs["tokens"].cpu().numpy().tolist()
127 |
128 | raw_prompts_plus_generations = [tokenizer.detokenize(ts) for ts in all_tokens]
129 | processed_generations = [
130 | tokenizer.detokenize(toks) for toks in outputs.get("generated_tokens", [])
131 | ]
132 |
133 | processed_generated_tokens = [
134 | tokenizer.encode(g) for g in processed_generations
135 | ]
136 |
137 | human_readable_tokens = convert_generations_to_human_readable_tokens(
138 | processed_generated_tokens, ""
139 | )
140 |
141 | output_log_probs = None
142 | if return_output_log_probs:
143 | output_log_probs = outputs["output_log_probs"].cpu().numpy().tolist()
144 |
145 | all_log_probs = None
146 | if return_all_log_probs:
147 | all_log_probs = outputs["all_log_probs"].cpu().numpy().tolist()
148 |
149 | return (
150 | raw_prompts_plus_generations,
151 | output_log_probs,
152 | all_tokens,
153 | processed_generations,
154 | all_log_probs,
155 | human_readable_tokens,
156 | )
157 |
158 | return None
159 |
160 |
161 | def generate(
162 | model: MegatronModule,
163 | max_position_embeddings: int,
164 | params_dtype: torch.dtype,
165 | termination_id: int,
166 | tokenizer: AbstractTokenizer,
167 | prompts: Optional[List[List[str]]] = None,
168 | max_tokens_to_generate: int = 0,
169 | inference_params: Optional[InferenceParams] = None,
170 | return_output_log_probs: bool = False,
171 | return_all_log_probs: bool = False,
172 | log_prob_tokens: Optional[torch.Tensor] = None,
173 | top_k_sampling: int = 0,
174 | top_p_sampling: float = 0.0,
175 | temperature: float = 1.0,
176 | add_BOS: bool = False,
177 | random_seed: int = -1,
178 | ) -> Dict[str, Any]:
179 | """Given prompts and input parameters, run inference and return:
180 | tokens: prompts plus the generated tokens.
181 | lengths: length of the prompt + generations. Note that we can
182 | discard tokens in the tokens tensor that are after the
183 | corresponding length.
184 | output_log_probs: log probs of the tokens.
185 | all_log_probs: log probs of all the vocab (or just log_prob_tokens if provided)
186 | tokens for each generated token position.
187 | """
188 | num_log_prob_tokens = 0 if log_prob_tokens is None else len(log_prob_tokens)
189 |
190 | # Make sure input params are avaialble to all ranks.
191 | values = [
192 | max_tokens_to_generate,
193 | return_output_log_probs,
194 | top_k_sampling,
195 | top_p_sampling,
196 | temperature,
197 | add_BOS,
198 | random_seed,
199 | return_all_log_probs,
200 | num_log_prob_tokens,
201 | ]
202 | values_float_tensor = broadcast_float_list(len(values), float_list=values)
203 | max_tokens_to_generate = int(values_float_tensor[0].item())
204 | return_output_log_probs = bool(values_float_tensor[1].item())
205 | top_k_sampling = int(values_float_tensor[2].item())
206 | top_p_sampling = values_float_tensor[3].item()
207 | temperature = values_float_tensor[4].item()
208 | add_BOS = bool(values_float_tensor[5].item())
209 | random_seed = int(values_float_tensor[6].item())
210 | return_all_log_probs = bool(values_float_tensor[7].item())
211 | num_log_prob_tokens = int(values_float_tensor[8].item())
212 |
213 | if return_all_log_probs and num_log_prob_tokens > 0:
214 | # Do another broadcast for the log_prob_tokens.
215 | log_prob_tokens = broadcast_int_list(
216 | num_log_prob_tokens, int_list=log_prob_tokens
217 | )
218 | else:
219 | log_prob_tokens = None
220 |
221 | if random_seed != -1:
222 | torch.random.manual_seed(random_seed)
223 |
224 | # Tokenize prompts and get the batch.
225 | # Note that these tensors are broadcasted to all ranks.
226 | if torch.distributed.get_rank() == 0:
227 | assert prompts is not None
228 |
229 | context_tokens_tensor, context_length_tensor = tokenize_prompts(
230 | prompts=prompts,
231 | max_tokens_to_generate=max_tokens_to_generate,
232 | max_position_embeddings=max_position_embeddings,
233 | add_BOS=add_BOS,
234 | )
235 |
236 | batch_size = context_tokens_tensor.shape[0]
237 | num_sub_sequences = context_tokens_tensor.shape[1]
238 |
239 | assert num_sub_sequences == 1
240 | # Remove subsequence dim
241 | context_length_tensor = torch.squeeze(context_length_tensor, dim=1)
242 | context_tokens_tensor = torch.squeeze(context_tokens_tensor, dim=1)
243 |
244 | if max_tokens_to_generate == 0:
245 | assert inference_params is not None
246 | return score_and_return_on_first_stage(
247 | model,
248 | context_tokens_tensor,
249 | context_length_tensor,
250 | inference_params,
251 | max_position_embeddings=max_position_embeddings,
252 | )
253 |
254 | # Main inference function.
255 | # Note that the outputs are available on the first stage.
256 | assert inference_params is not None
257 | # Added termination_id to support the case that we want to terminate the
258 | # generation once that id is generated.
259 |
260 | outputs = generate_tokens_probs_and_return_on_first_stage(
261 | model,
262 | context_tokens_tensor,
263 | context_length_tensor,
264 | inference_params=inference_params,
265 | max_position_embeddings=max_position_embeddings,
266 | termination_id=termination_id,
267 | vocab_size=tokenizer.vocab_size,
268 | return_output_log_probs=return_output_log_probs,
269 | log_prob_tokens=log_prob_tokens,
270 | return_all_log_probs=return_all_log_probs,
271 | top_k=top_k_sampling,
272 | top_p=top_p_sampling,
273 | temperature=temperature,
274 | )
275 | # Now we can figure out what actually got generated and return that too
276 | generated_tokens = []
277 | context_lengths = context_length_tensor.detach().cpu().numpy().tolist()
278 | contexts = context_tokens_tensor.detach().cpu().numpy().tolist()
279 | lengths_cpu = outputs["lengths"].detach().cpu().numpy().tolist()
280 |
281 | for b in range(batch_size):
282 | assert lengths_cpu[b] > context_lengths[b]
283 | gen = contexts[b][context_lengths[b] : lengths_cpu[b]]
284 | generated_tokens.append(gen)
285 |
286 | outputs["generated_tokens"] = generated_tokens
287 |
288 | return outputs
289 |
--------------------------------------------------------------------------------
/megatron/text_generation/forward_step.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Forward step utilities."""
20 |
21 | from typing import Optional, Dict, Any, List
22 | from collections.abc import Iterable
23 |
24 | from torch import dtype
25 | import torch
26 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
27 |
28 | from megatron import get_args, mpu
29 | from megatron.text_generation.inference_params import InferenceParams
30 | from megatron.model import DistributedDataParallel as LocalDDP
31 | from megatron.model import Float16Module
32 | from megatron.model.module import MegatronModule
33 | from megatron.model.gpt_model import GPTModel
34 | from megatron.utils import unwrap_model
35 |
36 | from megatron.mpu.communication import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank
37 |
38 |
39 | class ForwardStep:
40 | """Forward step function with all the communications.
41 | We use a class here to hide the inference parameters
42 | from the outside caller."""
43 |
44 | def __init__(
45 | self,
46 | model: MegatronModule,
47 | max_batch_size: int,
48 | max_sequence_len: int,
49 | inference_params: InferenceParams,
50 | ):
51 | """Set values so we don't need to do it multiple times."""
52 | # Make sure model is in eval mode.
53 | assert not isinstance(model, Iterable), "interleaving schedule is not supported for inference"
54 | model.eval()
55 | self.model = model
56 | # Initialize inference parameters.
57 | self.inference_params = inference_params
58 | # Pipelining arguments.
59 | args = get_args()
60 | self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size > 1
61 | # Threshold of pipelining.
62 | self.pipelining_batch_x_seqlen = args.inference_batch_times_seqlen_threshold
63 |
64 | def __call__(
65 | self,
66 | tokens: torch.Tensor,
67 | position_ids: torch.Tensor,
68 | lm_logits_mask: Optional[torch.Tensor] = None,
69 | ) -> Optional[Dict[str, Any]]:
70 | """Invocation of the forward methods. Note that self.inference_params
71 | is being modified by the forward step."""
72 | # Pipelining case.
73 | if self.pipeline_size_larger_than_one:
74 | current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
75 | if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
76 | raise ValueError("We deleted _with_pipelining_forward_step")
77 |
78 | return _no_pipelining_forward_step(
79 | self.model,
80 | tokens,
81 | position_ids,
82 | self.inference_params,
83 | lm_logits_mask=lm_logits_mask,
84 | )
85 |
86 |
87 | def _get_recv_buffer_dtype(args: Any) -> dtype:
88 | """Receive happens between the layers."""
89 | if args.fp32_residual_connection:
90 | return torch.float
91 | return args.params_dtype
92 |
93 |
94 | def _allocate_recv_buffer(batch_size: int, sequence_length: int) -> torch.Tensor:
95 | """Receive happens between the layers with size [s, b, h]."""
96 | if mpu.is_pipeline_first_stage():
97 | return None
98 | args = get_args()
99 | recv_size = (sequence_length, batch_size, args.hidden_size)
100 | return torch.empty(recv_size, dtype=_get_recv_buffer_dtype(args), device=torch.cuda.current_device())
101 |
102 |
103 | def _model_forward_step(
104 | model: MegatronModule,
105 | tokens: torch.Tensor,
106 | position_ids: torch.Tensor,
107 | inference_params: InferenceParams,
108 | lm_logits_mask: Optional[torch.Tensor] = None,
109 | ) -> Dict[str, Any]:
110 | # Run a simple forward pass.
111 | unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module))
112 | outputs: Dict[str, Any] = {}
113 | if isinstance(unwrapped_model, GPTModel):
114 | outputs = model(tokens, position_ids, lm_logits_mask=lm_logits_mask, inference_params=inference_params)
115 | else:
116 | assert False, "Unknown model type!" + str(type(unwrapped_model))
117 |
118 | if not isinstance(outputs, dict):
119 | outputs = {"logits": outputs}
120 |
121 | return outputs
122 |
123 |
124 | def _forward_step_helper(
125 | model: MegatronModule,
126 | tokens: torch.Tensor,
127 | position_ids: torch.Tensor,
128 | inference_params: InferenceParams,
129 | recv_buffer: Optional[torch.Tensor] = None,
130 | lm_logits_mask: Optional[torch.Tensor] = None,
131 | ) -> Dict[str, Any]:
132 | """Single forward step. Update the allocate memory flag so
133 | only the first time the memory is allocated."""
134 | batch_size = tokens.size(0)
135 | sequence_length = tokens.size(1)
136 | if recv_buffer is None:
137 | recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
138 |
139 | # Receive from previous stage.
140 | recv_from_prev_pipeline_rank_(recv_buffer)
141 |
142 | # Forward pass through the model.
143 | model.set_input_tensor(recv_buffer)
144 | outputs = _model_forward_step(
145 | model,
146 | tokens,
147 | position_ids,
148 | inference_params=inference_params,
149 | lm_logits_mask=lm_logits_mask,
150 | )
151 |
152 | # Send output to the next stage.
153 | send_to_next_pipeline_rank(outputs)
154 |
155 | return outputs
156 |
157 |
158 | def _no_pipelining_forward_step(
159 | model: MegatronModule,
160 | tokens: torch.Tensor,
161 | position_ids: torch.Tensor,
162 | inference_params: InferenceParams,
163 | recv_buffer: Optional[torch.Tensor] = None,
164 | lm_logits_mask: Optional[torch.Tensor] = None,
165 | ) -> Optional[Dict[str, Any]]:
166 | """If recv_buffer is none, we will allocate one on the fly."""
167 | # Run a simple forward pass.
168 | outputs = _forward_step_helper(
169 | model,
170 | tokens,
171 | position_ids,
172 | inference_params,
173 | recv_buffer=recv_buffer,
174 | lm_logits_mask=lm_logits_mask,
175 | )
176 | # Update the sequence length offset.
177 | if inference_params is not None:
178 | inference_params.sequence_len_offset += tokens.size(1)
179 |
180 | lm_output = None
181 | if mpu.is_pipeline_last_stage():
182 | lm_output = outputs
183 |
184 | return lm_output
185 |
--------------------------------------------------------------------------------
/megatron/text_generation/inference_params.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, Tuple, Any
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | class InferenceParams:
8 | """Inference parameters that are passed to the main model in order
9 | to efficienly calculate and store the context during inference."""
10 |
11 | def __init__(
12 | self,
13 | max_batch_size: int,
14 | max_sequence_len: int,
15 | lengths_per_sample: Optional[Tensor] = None,
16 | fused_ft_kernel: bool = False,
17 | ) -> None:
18 | # fused_ft_kernel: whether to use FasterTransformer fused single-query attention kernel.
19 | """Note that offsets are set to zero and we always set the
20 | flag to allocate memory. After the first call, make sure to
21 | set this flag to False."""
22 | self.max_sequence_len = max_sequence_len
23 | self.max_batch_size = max_batch_size
24 | self.sequence_len_offset = 0
25 | self.batch_size_offset = 0
26 | self.key_value_memory_dict: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
27 | # This is used incase of an encoder-decoder model. The encoder output can be cached
28 | # and reused for multiple decoder forward passes.
29 | self.encoder_hidden_state: Dict[Any, Any] = {}
30 | self.return_encoder_hidden_state = False
31 | self.fused_ft_kernel = fused_ft_kernel
32 | self.lengths_per_sample: Tensor = lengths_per_sample
33 | # Raise import error at initialization time instead of the 1st generation time.
34 | if fused_ft_kernel:
35 | try:
36 | # pylint: disable=import-outside-toplevel,unused-import
37 | import ft_attention # type: ignore
38 | except ImportError as exc:
39 | raise ImportError("Please install ft_attention from the FlashAttention repo.") from exc
40 |
41 | def reset(self) -> None:
42 | self.sequence_len_offset = 0
43 | self.batch_size_offset = 0
44 |
45 | def swap_key_value_dict(self, batch_idx: Any) -> None:
46 | "swap between batches"
47 | if len(self.key_value_memory_dict) == 0:
48 | raise ValueError("should not swap when dict in empty")
49 |
50 | for layer_number in self.key_value_memory_dict.keys(): # pylint: disable=consider-using-dict-items
51 | inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
52 | assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same
53 | new_inference_key_memory = inference_key_memory[:, batch_idx]
54 | new_inference_value_memory = inference_value_memory[:, batch_idx]
55 | self.key_value_memory_dict[layer_number] = (new_inference_key_memory, new_inference_value_memory)
56 |
--------------------------------------------------------------------------------
/megatron/text_generation/sampling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Sampling utilities.
20 | Part of this code is inspired by:
21 | - https://github.com/ari-holtzman/degen/blob/master/gen.py
22 | - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
23 | """
24 |
25 |
26 | from typing import Optional
27 | import torch
28 |
29 |
30 | def modify_logits_for_top_k_filtering(logits: torch.Tensor, top_k: int) -> None:
31 | """Set the logits for none top-k values to -inf."""
32 |
33 | filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
34 | logits.masked_fill_(filter_, float("-Inf"))
35 |
36 |
37 | def modify_logits_for_top_p_filtering(logits: torch.Tensor, top_p: float) -> None:
38 | """Set the logits for none top-p values to -inf."""
39 |
40 | # First sort and calculate cumulative sum of probabilities.
41 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
42 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
43 |
44 | # Filteration based on the cumulative sum.
45 | filter_ = cumulative_probs > top_p
46 | # This shift by 1 is weird and I cannot justify it. This existed
47 | # in the original implementation:
48 | # https://github.com/ari-holtzman/degen/blob/master/gen.py
49 | # and I guess it is needed so keeping it for now.
50 | filter_[:, 1:] = filter_[:, :-1].clone()
51 | # Make sure we at least have one token to select from.
52 | filter_[..., 0] = 0
53 |
54 | # Fill in the filtered part
55 | filter_ = filter_.scatter(1, sorted_indices, filter_)
56 | logits.masked_fill_(filter_, float("-Inf"))
57 |
58 |
59 | def sample(
60 | logits: torch.Tensor, top_k: int = 0, top_p: float = 0.0, temperature: float = 1.0, vocab_size: Optional[int] = None
61 | ) -> torch.Tensor:
62 | """Sample and generate a token.
63 | Note: logits has the dimension [b, v] where b is the batch size
64 | and v is the vocabulary size.
65 | If vocab_size is provided, we will make sure the sample that is
66 | generated is in [0, vocab-size). This will avoid out of vocabulary
67 | generations due to padding.
68 | """
69 |
70 | # Check logits for consistency.
71 | assert logits.ndim == 2, "expected the logits to be of [b, v] shape."
72 | assert logits.type() == "torch.cuda.FloatTensor", f"input logits should be floats. Got: {logits.type()}"
73 |
74 | # Greedy is just simple argmax.
75 | if top_k == 1:
76 | assert top_p == 0.0, "cannot set both greedy and top-p samplings."
77 | samples = torch.argmax(logits, dim=-1)
78 |
79 | # Top-k or top-p sampling.
80 | else:
81 | # Clone so we do not modify the inputs,
82 | logits = logits.clone()
83 | # Apply temperature in place.
84 | if temperature != 1.0:
85 | logits.div_(temperature)
86 |
87 | if top_k > 1:
88 | assert top_p == 0.0, "cannot set both top-k and top-p samplings."
89 | assert top_k <= logits.size(1), "top-k is larger than logit size."
90 | if vocab_size:
91 | assert top_k < vocab_size, "top-k is larger than vocab size."
92 | modify_logits_for_top_k_filtering(logits, top_k)
93 |
94 | elif top_p > 0.0:
95 | assert top_p <= 1.0, "top-p should be in (0, 1]."
96 | modify_logits_for_top_p_filtering(logits, top_p)
97 |
98 | # After filtering, we need to recalculate the distribution.
99 | probs = logits.softmax(dim=-1)
100 | samples = torch.multinomial(probs, num_samples=1).view(-1)
101 |
102 | # If vocab size is provided, make sure the samples are in
103 | # in the range [0, vocab-size).
104 | if vocab_size:
105 | samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
106 |
107 | return samples
108 |
--------------------------------------------------------------------------------
/megatron/text_generation/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Tokenization utilities."""
20 |
21 | import re
22 | from typing import List, Tuple, Optional, Any, NamedTuple, Dict
23 | import numpy as np
24 |
25 | import torch
26 |
27 | from megatron import get_args, get_tokenizer
28 | from megatron.tokenizer import AbstractTokenizer
29 | from megatron.mpu.communication import broadcast_int_list, broadcast_tensor
30 |
31 | TEXT_REPR_BBOX_OPEN = ""
32 | TEXT_REPR_BBOX_CLOSE = ""
33 | TEXT_REPR_POINT_OPEN = ""
34 | TEXT_REPR_POINT_CLOSE = ""
35 |
36 |
37 | def convert_generations_to_human_readable_tokens(
38 | generations: List[List[int]], bos_token: str
39 | ) -> List[List[str]]:
40 | """Convert the list of integers that a model outputs into a human-readable list of tokens.
41 | Args:
42 | generations: One list per batch, each of which contains a list of integers to detokenize.
43 | bos_token: The BOS token that we are using.
44 | Return:
45 | A list of lists.
46 | """
47 | new_generations = []
48 | tokenizer = get_tokenizer()
49 |
50 | for generation in generations:
51 | tokens: List[str] = []
52 | for i, int_token in enumerate(generation):
53 | token = tokenizer.inv_vocab[int_token]
54 | # convert underscore into an empty string when it is first.
55 | if token[0] == "▁" and (i == 0 or tokens[i - 1] == bos_token):
56 | token = token[1:]
57 | # continue processing normally.
58 | token = re.sub("▁", " ", token)
59 | tokens.append(token)
60 | new_generations.append(tokens)
61 | return new_generations
62 |
63 |
64 | # ====================================================================== TOKENIZATION #
65 |
66 |
67 | def tokenize_prompts(
68 | prompts: Optional[List[List[str]]],
69 | max_tokens_to_generate: int,
70 | max_position_embeddings: int,
71 | add_BOS: bool,
72 | rank: int = 0,
73 | ) -> Tuple[torch.Tensor, torch.Tensor]:
74 | """Tokenize prompts and make them avaiable on all ranks."""
75 | assert add_BOS is not None
76 |
77 | # On all ranks set to None so we can pass them to functions
78 | sizes_list = None
79 | prompts_tokens_cuda_long_tensor = None
80 | prompts_length_cuda_long_tensor = None
81 |
82 | # On the specified rank, build the above.
83 | if torch.distributed.get_rank() == rank:
84 | assert prompts is not None
85 | assert max_tokens_to_generate is not None
86 | # Tensor of tokens padded and their unpadded length.
87 | (
88 | prompts_tokens_cuda_long_tensor,
89 | prompts_length_cuda_long_tensor,
90 | ) = _tokenize_prompts_and_batch(
91 | prompts,
92 | max_tokens_to_generate,
93 | max_position_embeddings,
94 | add_BOS,
95 | )
96 | # We need the sizes of these tensors for the broadcast
97 | sizes_list = [
98 | prompts_tokens_cuda_long_tensor.size(0), # Batch size
99 | prompts_tokens_cuda_long_tensor.size(1), # Num subsequences
100 | prompts_tokens_cuda_long_tensor.size(2), # Sequence length
101 | ]
102 | # First, broadcast the sizes.
103 | sizes_tensor = broadcast_int_list(3, int_list=sizes_list, rank=rank)
104 |
105 | # Now that we have the sizes, we can broadcast the tokens
106 | # and length tensors.
107 | sizes = sizes_tensor.tolist()
108 | prompts_tokens_cuda_long_tensor = broadcast_tensor(
109 | sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank
110 | )
111 | prompts_length_cuda_long_tensor = broadcast_tensor(
112 | sizes[:2], torch.int64, tensor=prompts_length_cuda_long_tensor, rank=rank
113 | )
114 |
115 | return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
116 |
117 |
118 | def _tokenize_prompts_and_batch(
119 | prompts: List[List[str]],
120 | max_tokens_to_generate: int,
121 | max_position_embeddings: int,
122 | add_BOS: bool, # Same issue with types as above
123 | ) -> Tuple[torch.Tensor, torch.Tensor]:
124 | """
125 | Given a set of prompts and number of tokens to generate:
126 | - tokenize prompts
127 | - set the sequence length to be the max of length of prompts
128 | plus the number of tokens we would like to generate
129 | - pad all the sequences to this length so we can convert them
130 | into a 3D tensor.
131 | """
132 | args = get_args()
133 | # Tokenize all the prompts.
134 | tokenizer = get_tokenizer()
135 |
136 | transformed_prompt_tokens = [
137 | [tokenizer.tokenize(prompt) for prompt in prompt_seq] for prompt_seq in prompts
138 | ]
139 |
140 | if add_BOS:
141 | if args.add_bos_prompt_token is not None:
142 | bos_token = tokenizer.vocab[args.add_bos_prompt_token]
143 | else:
144 | bos_token = tokenizer.eod
145 | prompts_tokens = [
146 | [[bos_token] + x for x in prompt_seq]
147 | for prompt_seq in transformed_prompt_tokens
148 | ]
149 | else:
150 | prompts_tokens = transformed_prompt_tokens
151 |
152 | # Now we have a list of list of tokens which each list has a different
153 | # size. We want to extend this list to:
154 | # - incorporate the tokens that need to be generated
155 | # - make all the sequences equal length.
156 | # Get the prompts length.
157 |
158 | prompts_length = [
159 | [len(x) for x in prompts_tokens_seq] for prompts_tokens_seq in prompts_tokens
160 | ]
161 | # Get the max prompts length.
162 | max_prompt_len: int = np.max(prompts_length)
163 | # Number of tokens in the each sample of the batch.
164 | samples_length = min(
165 | max_prompt_len + max_tokens_to_generate, max_position_embeddings
166 | )
167 | if (
168 | max_prompt_len + max_tokens_to_generate > max_position_embeddings
169 | and torch.distributed.get_rank() == 0
170 | ):
171 | print(
172 | f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}",
173 | f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.",
174 | )
175 | # Now update the list of list to be of the same size: samples_length.
176 | for prompt_tokens_seq, prompts_length_seq in zip(prompts_tokens, prompts_length):
177 | for prompt_tokens, prompt_length in zip(prompt_tokens_seq, prompts_length_seq):
178 | if len(prompt_tokens) > samples_length:
179 | raise ValueError(
180 | "Length of subsequence prompt exceeds sequence length."
181 | )
182 | padding_size = samples_length - prompt_length
183 | prompt_tokens.extend([tokenizer.eod] * padding_size)
184 |
185 | # Now we are in a structured format, we can convert to tensors.
186 | prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens)
187 | prompts_length_tensor = torch.cuda.LongTensor(prompts_length)
188 |
189 | return prompts_tokens_tensor, prompts_length_tensor
190 |
--------------------------------------------------------------------------------
/megatron/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from .tokenizer import build_tokenizer, AbstractTokenizer
18 |
--------------------------------------------------------------------------------
/megatron/tokenizer/tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Megatron tokenizers."""
17 |
18 | from abc import ABCMeta
19 | from abc import abstractmethod
20 |
21 | from typing import Dict, Any, List, Optional
22 | import sentencepiece as spm # type: ignore
23 |
24 |
25 | class AbstractTokenizer(metaclass=ABCMeta):
26 | """Abstract class for tokenizer."""
27 |
28 | def __init__(self, name: str) -> None:
29 | self.name = name
30 | super().__init__()
31 |
32 | @property
33 | @abstractmethod
34 | def vocab_size(self) -> int:
35 | """Number of distinct tokens in the vocabulary."""
36 | pass
37 |
38 | @property
39 | @abstractmethod
40 | def vocab(self) -> Dict[str, int]:
41 | """Dictionary from vocab text token to id token."""
42 | pass
43 |
44 | @property
45 | @abstractmethod
46 | def inv_vocab(self) -> Dict[int, str]:
47 | """Dictionary from vocab id token to text token."""
48 | pass
49 |
50 | @abstractmethod
51 | def tokenize(self, text: str) -> List[int]:
52 | """Tokenize the text."""
53 | pass
54 |
55 | def encode(self, text: str) -> List[int]:
56 | """Encode the text."""
57 | return self.tokenize(text)
58 |
59 | def __call__(self, text: str) -> List[int]:
60 | """Syntactic sugar for tokenize."""
61 | return self.tokenize(text)
62 |
63 | @abstractmethod
64 | def detokenize(self, token_ids: List[int]) -> str:
65 | """Transform tokens back to a string."""
66 | pass
67 |
68 | def decode(self, ids: List[int]) -> str:
69 | """Decode the ids."""
70 | return self.detokenize(ids)
71 |
72 | @property
73 | def cls(self) -> int:
74 | raise NotImplementedError(f"CLS is not provided for {self.name} tokenizer")
75 |
76 | @property
77 | def sep(self) -> int:
78 | raise NotImplementedError(f"SEP is not provided for {self.name} tokenizer")
79 |
80 | @property
81 | def pad(self) -> int:
82 | raise NotImplementedError(f"PAD is not provided for {self.name} tokenizer")
83 |
84 | @property
85 | def eod(self) -> int:
86 | raise NotImplementedError(f"EOD is not provided for {self.name} tokenizer")
87 |
88 | @property
89 | def eod_token_text(self) -> str:
90 | """The EOD token string."""
91 | return self.decode([self.eod])
92 |
93 | @property
94 | def mask(self) -> int:
95 | """Get the mask token."""
96 | raise NotImplementedError(f"MASK is not provided for {self.name} tokenizer")
97 |
98 | @property
99 | def pad_token(self) -> str:
100 | """Get the pad token."""
101 | raise NotImplementedError
102 |
103 | @property
104 | def max_len_single_sentence(self) -> int:
105 | """Get the max length of a single sentence."""
106 | raise NotImplementedError
107 |
108 | def __len__(self) -> int:
109 | """Get the length of the tokenizer."""
110 | raise NotImplementedError
111 |
112 | @property
113 | def eos_token_id(self) -> int:
114 | """Get the id of the EOS token."""
115 | raise NotImplementedError
116 |
117 | @property
118 | def eos_token(self) -> str:
119 | """Get the EOS token."""
120 | raise NotImplementedError
121 |
122 |
123 | class _SentencePieceTokenizer(AbstractTokenizer):
124 | """Sentece piece tokenizer."""
125 |
126 | def __init__(self, model_file: str) -> None:
127 | name = "Sentence Piece"
128 | super().__init__(name)
129 |
130 | # no-member is firing here but the code works fine!
131 | # pylint: disable=no-member
132 | self._tokenizer = spm.SentencePieceProcessor()
133 | self._tokenizer.load(model_file)
134 | self._vocab_size = self._tokenizer.vocab_size()
135 | self._tokens = [self._tokenizer.id_to_piece(t) for t in range(self.vocab_size)]
136 | self._vocab = {t: i for i, t in enumerate(self._tokens)}
137 | self._eod_id = None
138 | # look for end of document id
139 | for idx, token in enumerate(self._tokens):
140 | if token == "|ENDOFTEXT|":
141 | self._eod_id = idx
142 | break
143 | if self._eod_id is None:
144 | self._eod_id = self._tokenizer.eos_id()
145 | assert self._eod_id is not None
146 |
147 | @property
148 | def tokenizer(self) -> spm.SentencePieceProcessor:
149 | return self._tokenizer
150 |
151 | @property
152 | def vocab_size(self) -> int:
153 | return int(self._vocab_size)
154 |
155 | @property
156 | def vocab(self) -> Dict[str, int]:
157 | return self._vocab
158 |
159 | @property
160 | def inv_vocab(self): # type: ignore
161 | return self._tokens
162 |
163 | def tokenize(self, text: str): # type: ignore
164 | # pylint: disable=bare-except, no-member
165 | try:
166 | tokenized = self._tokenizer.encode_as_ids(text)
167 | except:
168 | tokenized = None
169 | return tokenized
170 |
171 | def pieces(self, text: str) -> List[str]:
172 | # pylint: disable=no-member
173 | pieces: List[str] = self._tokenizer.encode_as_pieces(text)
174 | return pieces
175 |
176 | def detokenize(self, token_ids: List[int]) -> str:
177 | # pylint: disable=no-member
178 | return self._tokenizer.decode_ids(token_ids) # type: ignore
179 |
180 | @property
181 | def eod(self) -> int:
182 | return self._eod_id # type: ignore
183 |
184 | @property
185 | def eos_token_id(self) -> int:
186 | """Id of the end of sentence token in the vocabulary."""
187 | eos_id: int = self._eod_id # type: ignore
188 | return eos_id
189 |
190 |
191 | def megatron_initialize_tokenizer(
192 | tokenizer_type: str,
193 | sp_model_file: Optional[str] = None,
194 | ) -> AbstractTokenizer:
195 | """Initialize the tokenizer."""
196 | if tokenizer_type == "SentencePiece":
197 | assert sp_model_file is not None
198 | tokenizer = _SentencePieceTokenizer(sp_model_file)
199 | else:
200 | raise NotImplementedError(f"{tokenizer_type} tokenizer is not implemented.")
201 | return tokenizer
202 |
203 |
204 | def _vocab_size_with_padding(orig_vocab_size: int, args: Any) -> int:
205 | """Pad vocab size so it is divisible by model parallel size and
206 | still having GPU friendly size."""
207 |
208 | after = orig_vocab_size
209 | multiple = args.make_vocab_size_divisible_by * args.tensor_model_parallel_size
210 | while (after % multiple) != 0:
211 | after += 1
212 | if args.rank == 0:
213 | print(
214 | f" > padded vocab (sz: {orig_vocab_size}) w/ {after - orig_vocab_size} dummy toks (new sz: {after})",
215 | flush=True,
216 | )
217 | return after
218 |
219 |
220 | def build_tokenizer(args: Any) -> AbstractTokenizer:
221 | """Initialize tokenizer."""
222 | if args.rank == 0:
223 | print(
224 | f"> building {args.tokenizer_type} tokenizer ...",
225 | flush=True,
226 | )
227 |
228 | # Select and instantiate the tokenizer.
229 | if args.tokenizer_type == "SentencePiece":
230 | assert args.sp_model_file is not None
231 | tokenizer = _SentencePieceTokenizer(f"{args.sp_model_file}")
232 | else:
233 | raise NotImplementedError(
234 | f"{args.tokenizer_type} tokenizer is not implemented."
235 | )
236 |
237 | # Add vocab size.
238 | args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
239 |
240 | return tokenizer
241 |
--------------------------------------------------------------------------------
/megatron/training.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Pretrain utilities."""
20 |
21 | import math
22 | import sys
23 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
24 |
25 | import torch
26 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
27 |
28 | from megatron import (
29 | get_args,
30 | mpu,
31 | )
32 | from megatron.checkpointing import (
33 | load_state_dicts_and_update_args,
34 | update_model_and_optim_from_loaded_data,
35 | )
36 | from megatron.initialize import initialize_megatron
37 | from megatron.model import DistributedDataParallel as LocalDDP
38 | from megatron.model import Float16Module, ModelType
39 |
40 | from megatron.utils import (
41 | unwrap_model,
42 | )
43 |
44 |
45 | def get_model(
46 | model_provider_func,
47 | model_type=ModelType.encoder_or_decoder,
48 | wrap_with_ddp=True,
49 | ):
50 | """Build the model."""
51 | args = get_args()
52 | args.model_type = model_type
53 |
54 | # Build model.
55 | if (
56 | mpu.get_pipeline_model_parallel_world_size() > 1
57 | and args.virtual_pipeline_model_parallel_size is not None
58 | ):
59 | assert (
60 | model_type != ModelType.encoder_and_decoder
61 | ), "Interleaved schedule not supported for model with both encoder and decoder"
62 | model = []
63 | for i in range(args.virtual_pipeline_model_parallel_size):
64 | mpu.set_virtual_pipeline_model_parallel_rank(i)
65 | # Set pre_process and post_process only after virtual rank is set.
66 | pre_process = mpu.is_pipeline_first_stage()
67 | post_process = mpu.is_pipeline_last_stage()
68 | this_model = model_provider_func(
69 | pre_process=pre_process, post_process=post_process
70 | )
71 | this_model.model_type = model_type
72 | model.append(this_model)
73 | else:
74 | pre_process = mpu.is_pipeline_first_stage()
75 | post_process = mpu.is_pipeline_last_stage()
76 | add_encoder = True
77 | add_decoder = True
78 | if model_type == ModelType.encoder_and_decoder:
79 | if mpu.get_pipeline_model_parallel_world_size() > 1:
80 | assert (
81 | args.pipeline_model_parallel_split_rank is not None
82 | ), "Split rank needs to be specified for model with both encoder and decoder"
83 | rank = mpu.get_pipeline_model_parallel_rank()
84 | split_rank = args.pipeline_model_parallel_split_rank
85 | world_size = mpu.get_pipeline_model_parallel_world_size()
86 | pre_process = rank == 0 or rank == split_rank
87 | post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))
88 | add_encoder = mpu.is_pipeline_stage_before_split()
89 | add_decoder = mpu.is_pipeline_stage_after_split()
90 | model = model_provider_func(
91 | pre_process=pre_process,
92 | post_process=post_process,
93 | add_encoder=add_encoder,
94 | add_decoder=add_decoder,
95 | )
96 | else:
97 | model = model_provider_func(
98 | pre_process=pre_process, post_process=post_process
99 | )
100 | model.model_type = model_type
101 |
102 | if not isinstance(model, list):
103 | model = [model]
104 |
105 | # Set tensor model parallel attributes if not set.
106 | # Only parameters that are already tensor model parallel have these
107 | # attributes set for them. We should make sure the default attributes
108 | # are set for all params so the optimizer can use them.
109 | for model_module in model:
110 | for param in model_module.parameters():
111 | mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
112 |
113 | # Print number of parameters.
114 | if mpu.get_data_parallel_rank() == 0:
115 | print(
116 | " > number of parameters on (tensor, pipeline) "
117 | "model parallel rank ({}, {}): {}".format(
118 | mpu.get_tensor_model_parallel_rank(),
119 | mpu.get_pipeline_model_parallel_rank(),
120 | sum(
121 | [
122 | sum([p.nelement() for p in model_module.parameters()])
123 | for model_module in model
124 | ]
125 | ),
126 | ),
127 | flush=True,
128 | )
129 |
130 | # GPU allocation.
131 | for model_module in model:
132 | model_module.cuda(torch.cuda.current_device())
133 |
134 | # Fp16 conversion.
135 | if args.fp16 or args.bf16:
136 | model = [Float16Module(model_module, args) for model_module in model]
137 |
138 | if wrap_with_ddp:
139 | if args.DDP_impl == "torch":
140 | i = torch.cuda.current_device()
141 | model = [
142 | torchDDP(
143 | model_module,
144 | device_ids=[i],
145 | output_device=i,
146 | process_group=mpu.get_data_parallel_group(),
147 | )
148 | for model_module in model
149 | ]
150 |
151 | elif args.DDP_impl == "local":
152 | model = [
153 | LocalDDP(
154 | model_module,
155 | args.accumulate_allreduce_grads_in_fp32,
156 | args.use_contiguous_buffers_in_local_ddp,
157 | )
158 | for model_module in model
159 | ]
160 | # broad cast params from data parallel src rank to other data parallel ranks
161 | if args.data_parallel_random_init:
162 | for model_module in model:
163 | model_module.broadcast_params()
164 | else:
165 | raise NotImplementedError(
166 | "Unknown DDP implementation specified: "
167 | "{}. Exiting.".format(args.DDP_impl)
168 | )
169 |
170 | return model
171 |
--------------------------------------------------------------------------------
/prompts.md:
--------------------------------------------------------------------------------
1 | # Prompt Examples
2 |
3 | ## MMLU
4 | You are a bot designed to answer questions by choosing between A, B, C, and D. Only one of these answers from the provided choices is correct. Reply with the letter of the correct answer.
5 | Question: Alfred and Ben don't know each other but are each considering asking the lovely Charlene to the school prom. The probability that at least one of them will ask her is 0.72. The probability that they both ask her is 0.18. The probability that Alfred asks her is 0.6. What is the probability that Ben asks Charlene to the prom?
6 | A: 0.78
7 | B: 0.3
8 | C: 0.24
9 | D: 0.48
10 | Answer: B
11 |
12 | Question: A telephone survey of 400 registered voters showed that 256 had not yet made up their minds 1 month before the election. How sure can we be that between 60% and 68% of the electorate were still undecided at that time?
13 | A: 2.4%
14 | B: 8.0%
15 | C: 64.0%
16 | D: 90.4%
17 | Answer:
18 |
19 |
20 | You are a bot designed to answer questions by choosing between A, B, C, and D. Only one of these answers from the provided choices is correct. Reply with the letter of the correct answer.
21 | Question: The procedure involving repeated presentation of a stimulus to the client until the attractiveness of that stimulus is reduced is best described as
22 | A: stimulus satiation
23 | B: response-prevention
24 | C: flooding
25 | D: implosion
26 | Answer: A
27 |
28 | Question: If, during a postexamination discussion with parents, a psychologist establishes that a child’s new pediatrician is apparently unaware of the child's history of brain damage. which is very important in understanding the problem situation, the psychologist should
29 | A: tell the parents that he/she will inform the pediatrician
30 | B: urge the parents to grant him/her permission to inform the pediatrician
31 | C: cell the parents char be/she is legally obligated to inform the pediatrician
32 | D: cell the parents that it is their responsibility to inform the pediatrician
33 | Answer:
34 |
35 | ## Winogrande
36 | You are a bot that is responsible for answering fill-in-the-blank questions. You are provided with a sentence and two possible fill-in-the-blank options. Your task is to return 1 or 2 as the answer.
37 | Q: The store had 80 platters but only 2 bowls left in stock because the _ were in high demand.. 1 = platters and 2 = bowls. Answer:
38 | 2
39 |
40 | Q: The smell in the kitchen of the home is unbearable, while the laundry room smells fine. The _ must have been cleaned longer ago.. 1 = kitchen and 2 = laundry room. Answer:
41 |
42 |
43 |
44 | You are a bot that is responsible for answering fill-in-the-blank questions. You are provided with a sentence and two possible fill-in-the-blank options. Your task is to return 1 or 2 as the answer.
45 | Q: John painted the pole red close to the color of the wall and painted the frame white and now the _ is similar.. 1 = frame and 2 = pole. Answer:
46 | 2
47 |
48 | Q: Benjamin has a spouse and Kyle is single after being divorced, so _ is celebrating their independence this year.. 1 = Benjamin and 2 = Kyle. Answer:
49 |
50 | ## Arc Easy/Challenge
51 | You are a bot designed to answer multiple choice questions. Given a series of options, you answer by choosing A, B, C, or D. Some examples are below.
52 | Question: Which process uses carbon from the air to make food for plants?
53 | A: growth
54 | B: respiration
55 | C: decomposition
56 | D: photosynthesis
57 | Answer: D
58 |
59 | Question: An object composed mainly of ice is orbiting the Sun in an elliptical path. This object is most likely
60 | A: a planet.
61 | B: an asteroid.
62 | C: a meteor.
63 | D: a comet.
64 | Answer:
65 |
66 |
67 |
68 | You are a bot designed to answer multiple choice questions. Given a series of options, you answer by choosing A, B, C, or D. Some examples are below.
69 | Question: Decomposers are important in the food chain because they
70 | A: produce their own food using light from the Sun.
71 | B: stop the flow of energy from one organism to another.
72 | C: break down dead organisms and recycle nutrients into the soil.
73 | D: are microscopic and other organisms cannot consume them.
74 | Answer: C
75 |
76 | Question: A wire is wrapped around a metal nail and connected to a battery. If the battery is active, the nail will
77 | A: vibrate.
78 | B: create sound.
79 | C: produce heat.
80 | D: become magnetic.
81 | Answer:
82 |
83 | ## Humaneval
84 | ```
85 | You are a code-writing bot. Given a function signature, and a docstring, complete the program body. Some examples are given below.
86 | def sort_array(arr):
87 | """
88 | In this Kata, you have to sort an array of non-negative integers according to
89 | number of ones in their binary representation in ascending order.
90 | For similar number of ones, sort based on decimal value.
91 |
92 | It must be implemented like this:
93 | >>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
94 | >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
95 | >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
96 | """
97 | return sorted(sorted(arr), key=lambda x: bin(x)[2:].count('1'))
98 |
99 |
100 | def prime_fib(n: int):
101 | """
102 | prime_fib returns n-th number that is a Fibonacci number and it's also prime.
103 | >>> prime_fib(1)
104 | 2
105 | >>> prime_fib(2)
106 | 3
107 | >>> prime_fib(3)
108 | 5
109 | >>> prime_fib(4)
110 | 13
111 | >>> prime_fib(5)
112 | 89
113 | """
114 | ```
115 |
116 | ```
117 | # You are a code-writing bot. Given a function signature, and a docstring, complete the program body. Some examples are given below.
118 | def unique(l: list):
119 | """Return sorted unique elements in a list
120 | >>> unique([5, 3, 5, 2, 3, 3, 9, 0, 123])
121 | [0, 2, 3, 5, 9, 123]
122 | """
123 | return sorted(list(set(l)))
124 |
125 | def total_match(lst1, lst2):
126 | '''
127 | Write a function that accepts two lists of strings and returns the list that has
128 | total number of chars in the all strings of the list less than the other list.
129 |
130 | if the two lists have the same number of chars, return the first list.
131 |
132 | Examples
133 | total_match([], []) ➞ []
134 | total_match(['hi', 'admin'], ['hI', 'Hi']) ➞ ['hI', 'Hi']
135 | total_match(['hi', 'admin'], ['hi', 'hi', 'admin', 'project']) ➞ ['hi', 'admin']
136 | total_match(['hi', 'admin'], ['hI', 'hi', 'hi']) ➞ ['hI', 'hi', 'hi']
137 | total_match(['4'], ['1', '2', '3', '4', '5']) ➞ ['4']
138 | '''
139 | ```
140 |
141 | ## TriviaQA
142 |
143 | You are a trivia answering bot designed to answer questions. You are given a question and are supposed to output an answer in 1-3 words. Some examples are below.
144 | Q: To what RAF base, near Wooton Bassett village, were the bodies of servicemen killed in Afghanistan formerly transported
145 | A: LYNEHAM
146 |
147 | Q: What star sign is Jamie Lee Curtis?
148 | A:
149 |
150 |
151 | You are a trivia answering bot designed to answer questions. You are given a question and are supposed to output an answer in 1-3 words. Some examples are below.
152 | Q: Pre restraining order(s), who did People magazine name as their first "Sexiest Man Alive", in 1985?
153 | A: Mel Gibson
154 |
155 | Q: What id the name given to the study of birds?
156 | A:
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # If you change this file, you need to regenerate the docker image!
2 | # To do so, run: adept infra create-dev-docker
3 | black==22.8.0
4 | celery-types==0.14.0
5 | click==8.0.4
6 | codenamize==1.2.3
7 | dacite==1.8.1
8 | dateparser==1.1.8
9 | einops
10 | fabric==3.0.0
11 | fire==0.4.0
12 | Flask==2.2.2
13 | Flask-AutoIndex==0.6.6
14 | Flask-Cors==3.0.10
15 | Flask-Login==0.6.2
16 | Flask-RESTful==0.3.9
17 | flatten-json==0.1.13
18 | gitpython
19 | hypothesis==4.50.8
20 | ipython==8.8.0
21 | jieba==0.42.1
22 | json-parser==1.2.0
23 | ldap3==2.9.1
24 | lm-dataformat==0.0.19
25 | loguru==0.6.0
26 | mock==4.0.3
27 | mypy==0.991
28 | nagisa==0.2.7
29 | nltk==3.7
30 | numexpr==2.7.2
31 | numpy>=1.18,<1.24
32 | parallel-ssh==2.12.0
33 | parameterized==0.8.1
34 | pip
35 | pipdeptree==2.5.2
36 | pre-commit
37 | prompt-toolkit==3.0.29
38 | pydantic>=1.10.2
39 | pykka
40 | pylint==2.13.5
41 | pytablewriter==0.58.0
42 | pytest==7.2.1
43 | pytest-celery==0.0.0
44 | pytest-cov==4.0.0
45 | pytest-forked==1.6.0
46 | pytest-randomly==3.12.0
47 | pytest-subtests==0.8.0
48 | pytest-timeout
49 | pytest-xdist==2.5.0
50 | python-dotenv==1.0.0
51 | python-hostlist==1.22
52 | python-snappy==0.6.1
53 | pytz==2023.3
54 | pyyaml==6.0.1
55 | pyzstd==0.15.0
56 | requirements-parser==0.5.0
57 | rtree
58 | ruamel.yaml==0.17.21
59 | sacrebleu==1.5.0
60 | scikit-learn
61 | sentencepiece==0.1.96
62 | sqlitedict==1.6.0
63 | sqlparse==0.4.2
64 | statsd==4.0.1
65 | strsimpy
66 | tabulate==0.9.0
67 | terminaltables==3.1.0
68 | tldextract==3.4.1
69 | types-bleach==5.0.3.1
70 | types-click
71 | types-Flask-Cors==3.0.10.5
72 | types-invoke
73 | types-paramiko==2.12.0.3
74 | types-Pillow
75 | types-protobuf==4.22.0.0
76 | types-PyYAML==6.0.11
77 | types-redis==4.3.21
78 | types-requests==2.28.4
79 | typing-extensions==4.6.3
80 | tzdata==2023.3
81 | unzip==1.0.0
82 | urllib3==1.26.15
83 | warcio==1.7.4
84 | wonderwords==2.2.0
85 | zstd==1.5.2.6
86 |
--------------------------------------------------------------------------------
/run_text_generation_server.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2023 ADEPT AI LABS INC.
3 | # This file is based on code by the authors denoted below and has been modified from its original version.
4 | #
5 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | """Sample Generate GPT"""
20 | import os
21 | import sys
22 | from typing import Any, Optional, Tuple
23 |
24 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
25 | import torch
26 |
27 | from megatron import get_args, get_tokenizer, mpu
28 | from megatron.checkpointing import load_checkpoint
29 | from megatron.initialize import initialize_megatron
30 | from megatron.model import GPTModel
31 | from megatron.model.module import MegatronModule
32 | from megatron.model.utils import print_named_parameters
33 | from megatron.text_generation.api import generate_and_post_process
34 | from megatron.text_generation.inference_params import InferenceParams
35 | from megatron.text_generation_server import (MegatronServer,
36 | add_text_generate_args,
37 | setup_model)
38 | from megatron.training import get_model
39 |
40 | MAX_BATCH_SIZE = 1 # You can increase this depending on your desired max sequence length and GPU memory
41 | MAX_SEQLEN = 16 * 1024
42 |
43 |
44 | def model_provider(
45 | pre_process: bool=True,
46 | post_process: bool=True
47 | ) -> MegatronModule:
48 | """Build the model."""
49 |
50 | args = get_args()
51 | if args.model_architecture == "GPTModel":
52 | model : MegatronModule = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
53 | else:
54 | raise ValueError(f"Unsupported model type: {args.model_architecture}")
55 | print_named_parameters(model)
56 | return model
57 |
58 |
59 | def initialize_model_from_args() -> Tuple[Any, Optional[InferenceParams]]:
60 | # Needed for tensor parallel inference with CUDA graph
61 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
62 | os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
63 |
64 | initialize_megatron(
65 | extra_args_provider=add_text_generate_args,
66 | args_defaults={
67 | "tokenizer_type": "GPT2BPETokenizer",
68 | "no_load_rng": True,
69 | "no_load_optim": True,
70 | "inference_max_batch_size": MAX_BATCH_SIZE,
71 | "inference_max_seqlen": MAX_SEQLEN,
72 | },
73 | )
74 |
75 | args = get_args()
76 | if not args.fused_ft_kernel:
77 | args.use_cuda_graph = False # CUDA graph requires fused FT kernel
78 | if hasattr(args, "iteration"):
79 | args.curr_iteration = args.iteration
80 | print("curr_iteration", args.curr_iteration)
81 | if args.num_layers_per_virtual_pipeline_stage is not None:
82 | print("Interleaved pipeline schedule is not yet supported for text generation.")
83 | exit()
84 | # Set up model and load checkpoint
85 | model = get_model(model_provider, wrap_with_ddp=False)
86 |
87 | if args.load is not None:
88 | _ = load_checkpoint(model, None, None)
89 |
90 | assert len(model) == 1, "Above condition should have caught this"
91 | model = model[0]
92 |
93 | args.model_architecture = "GPTModel"
94 | inference_params = setup_model(
95 | model,
96 | args.model_architecture,
97 | args.use_inference_kv_cache,
98 | args.fused_ft_kernel,
99 | args.use_cuda_graph,
100 | args.inference_max_batch_size,
101 | args.inference_max_seqlen,
102 | )
103 |
104 | return model, inference_params
105 |
106 |
107 | if __name__ == "__main__":
108 | model, inference_params = initialize_model_from_args()
109 |
110 | args = get_args()
111 | tokenizer = get_tokenizer()
112 | if hasattr(args, "eos_id"):
113 | termination_id = args.eos_id
114 | else:
115 | termination_id = tokenizer.eod
116 | if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
117 | assert inference_params is not None
118 | server = MegatronServer(
119 | model=model,
120 | inference_params=inference_params,
121 | params_dtype=args.params_dtype,
122 | max_position_embeddings=args.max_position_embeddings,
123 | termination_id=termination_id,
124 | tokenizer=tokenizer,
125 | port=args.port,
126 | )
127 | server.run("0.0.0.0")
128 |
129 | while True:
130 | if inference_params is not None:
131 | inference_params.reset()
132 | choice = torch.cuda.LongTensor(1)
133 | torch.distributed.broadcast(choice, 0)
134 | if choice[0].item() == 0:
135 | try:
136 | assert inference_params is not None
137 | generate_and_post_process(
138 | model=model,
139 | params_dtype=args.params_dtype,
140 | max_position_embeddings=args.max_position_embeddings,
141 | termination_id=termination_id,
142 | tokenizer=tokenizer,
143 | inference_params=inference_params)
144 | except ValueError as ve:
145 | pass
146 |
--------------------------------------------------------------------------------
/run_text_generation_server.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, ADEPT AI LABS INC. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | MODEL_DIR="8b_chat_model_release"
15 | torchrun --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr $(hostname) --master_port 5003 run_text_generation_server.py --no-load-rng --no-load-optim --no-initialization --top_p 0.9 --port 6001 --micro-batch-size 1 --load ${MODEL_DIR} --use-flash-attn --sp-model-file ${MODEL_DIR}/adept_vocab.model
16 |
--------------------------------------------------------------------------------
/tool_use/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/tool_use/__init__.py
--------------------------------------------------------------------------------
/tool_use/experiment_pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/persimmon-ai-labs/adept-inference/61743d07cfb151dadb0cd2ae9de8f7325c4e828a/tool_use/experiment_pipeline/__init__.py
--------------------------------------------------------------------------------
/tool_use/megatron/megatron_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for megatron tool use."""
2 |
3 | from enum import Enum
4 |
5 |
6 | class ModelBackend(Enum):
7 | """Model backend."""
8 |
9 | MEGATRON = "megatron"
10 |
--------------------------------------------------------------------------------