├── .gitignore ├── .pre-commit-config.yaml ├── .style.yapf ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docker ├── Dockerfile └── launch.sh ├── energonai ├── __init__.py ├── batch_mgr.py ├── communication │ ├── __init__.py │ ├── collective.py │ ├── p2p.py │ ├── ring.py │ └── utils.py ├── engine.py ├── kernel │ ├── __init__.py │ └── cuda_native │ │ ├── __init__.py │ │ ├── csrc │ │ ├── common.h │ │ ├── compat.h │ │ ├── get_ncclid.cpp │ │ ├── layer_norm_cuda.cpp │ │ ├── layer_norm_cuda_kernel.cu │ │ ├── linear_wrapper.cpp │ │ ├── scale_mask_softmax_kernel.cu │ │ ├── scale_mask_softmax_wrapper.cpp │ │ ├── transpose_pad_fusion_kernel.cu │ │ ├── transpose_pad_fusion_wrapper.cpp │ │ └── type_shim.h │ │ ├── layer_norm.py │ │ ├── linear_func.py │ │ ├── scale_mask_softmax.py │ │ └── transpose_pad.py ├── legacy_batch_mgr │ ├── __init__.py │ ├── dynamic_batch_manager.py │ └── naive_batch_manager.py ├── model │ ├── __init__.py │ ├── attention.py │ ├── downstream.py │ ├── embedding.py │ ├── endecoder.py │ ├── mlp.py │ └── model_factory.py ├── nemesis │ └── nemesis_manager.py ├── pipe.py ├── pipelinable │ ├── __init__.py │ ├── energon_tracer.py │ ├── split_method.py │ └── split_policy.py ├── task.py ├── testing │ ├── __init__.py │ └── models.py ├── utils │ ├── __init__.py │ ├── checkpointing.py │ ├── checkpointing_hf_gpt2.py │ ├── checkpointing_opt.py │ ├── common.py │ ├── files.py │ └── timer.py └── worker.py ├── examples ├── auto_pipeline │ ├── bert.py │ ├── bert_config.py │ ├── bert_server.py │ └── requirements.txt ├── bert │ ├── bert.py │ ├── bert_config.py │ ├── bert_server.py │ └── requirements.txt ├── bloom │ ├── README.md │ ├── batch.py │ ├── benchmark │ │ └── locustfile.py │ ├── cache.py │ ├── requirements.txt │ ├── run.sh │ ├── server.py │ └── utils.py ├── gpt │ ├── gpt.py │ ├── gpt_batch_server.py │ ├── gpt_config.py │ └── requirements.txt ├── hf_gpt2 │ ├── hf_gpt2.py │ ├── hf_gpt2_config.py │ ├── hf_gpt2_server.py │ └── requirements.txt ├── linear │ ├── linear.py │ └── requirements.txt ├── opt │ ├── README.md │ ├── batch.py │ ├── benchmark │ │ └── locustfile.py │ ├── cache.py │ ├── opt_fastapi.py │ ├── opt_server.py │ ├── requirements.txt │ └── script │ │ ├── process-opt-175b │ │ ├── README.md │ │ ├── convert_ckpt.py │ │ ├── flat-meta.json │ │ └── unflat.sh │ │ └── processing_ckpt_66b.py ├── trt_demo │ ├── net.py │ ├── requirements.txt │ ├── trt_net_config.py │ └── trt_net_server.py └── vit │ ├── dataset │ └── n01667114_9985.JPEG │ ├── proc_img.py │ ├── requirements.txt │ ├── vit.py │ ├── vit_config.py │ └── vit_server.py ├── requirements.txt ├── setup.py ├── tests ├── run_standalone_tests.sh ├── test_checkpoint │ ├── test_checkpoint_basic1d.py │ ├── test_checkpoint_bert1d.py │ ├── test_checkpoint_gpt1d.py │ └── test_moduledict.py ├── test_engine │ ├── boring_model_utils.py │ ├── test_hybrid.py │ ├── test_pp.py │ ├── test_single_device.py │ └── test_tp.py └── test_kernel │ ├── test_ft_transpose_pad.py │ ├── test_linear_func.py │ └── test_transpose_pad_fusion_kernel.py └── version.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE 132 | .idea/ 133 | .vscode/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-yapf 3 | rev: v0.32.0 4 | hooks: 5 | - id: yapf 6 | args: ['--style=.style.yapf', '--parallel', '--in-place'] 7 | - repo: https://github.com/pre-commit/mirrors-clang-format 8 | rev: v13.0.1 9 | hooks: 10 | - id: clang-format -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | spaces_before_comment = 4 4 | split_before_logical_operator = true 5 | column_limit = 120 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | The EnergonAI project is always open for constructive suggestion and contributions from the community. We sincerely invite you to take a part in making this project more friendly and easier to use. 4 | 5 | ## Environment Setup 6 | The first step of becoming a contributor would be setting up the environment for EnergonAI. 7 | Run the following codes to build your own EnergonAI. 8 | 9 | --- 10 | ``` bash 11 | $ git clone https://github.com/hpcaitech/EnergonAI.git 12 | $ python setup.py install or python setup.py develop 13 | ``` 14 | 15 | ## Coding Standards 16 | 17 | ### Unit Tests 18 | We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests. 19 | 20 | If you only want to run CPU tests, you can run 21 | 22 | ```bash 23 | pytest -m cpu tests/ 24 | ``` 25 | 26 | If you have 8 GPUs on your machine, you can run the full test 27 | 28 | ```bash 29 | pytest tests/ 30 | ``` 31 | 32 | ### Code Style 33 | 34 | We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. 35 | 36 | ```shell 37 | # these commands are executed under the Colossal-AI directory 38 | pip install pre-commit 39 | pre-commit install 40 | ``` 41 | 42 | Code format checking will be automatically executed when you commit your changes. 43 | 44 | ## Contribution Guide 45 | 46 | You need to follow these steps below to make contribution to the main repository via pull request. You can learn about the details of pull request [here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests). 47 | 48 | ### 1. Fork the Official Repository 49 | 50 | Firstly, you need to visit the [ColossalAI-Inference repository](https://github.com/hpcaitech/ColossalAI-Inference) and fork into your own account. The `fork` button is at the right top corner of the web page alongside with buttons such as `watch` and `star`. 51 | 52 | Now, you can clone your own forked repository into your local environment. 53 | 54 | ```shell 55 | git clone https://github.com//ColossalAI-Inference.git 56 | ``` 57 | ### 2. Configure Git 58 | 59 | You need to set the official repository as your upstream so that you can synchronize with the latest update in the official repository. You can learn about upstream [here](https://www.atlassian.com/git/tutorials/git-forks-and-upstreams). 60 | 61 | Then add the original repository as upstream 62 | 63 | ```shell 64 | cd ColossalAI-Inference 65 | git remote add upstream https://github.com/hpcaitech/ColossalAI-Inference.git 66 | ``` 67 | 68 | you can use the following command to verify that the remote is set. You should see both `origin` and `upstream` in the output. 69 | 70 | ```shell 71 | git remote -v 72 | ``` 73 | 74 | ### 3. Synchronize with Official Repository 75 | 76 | Before you make changes to the codebase, it is always good to fetch the latest updates in the official repository. In order to do so, you can use the commands below. 77 | 78 | ```shell 79 | git fetch upstream 80 | git checkout main 81 | git merge upstream/main 82 | git push origin main 83 | ``` 84 | 85 | Otherwise, you can click the `fetch upstream` button on the github webpage of the main branch of your forked repository. Then, use these commands to sync. 86 | 87 | ``` 88 | git checkout main 89 | git fetch main 90 | ``` 91 | 92 | ### 4. Choose/Create an Issue for Your Pull Request 93 | 94 | Generally, your code change should be only targeted at one problem. Stacking multiple commits for different problems into one pull request will only make the code review such dire suffering and make the system prone to new bugs as the reviewer may not understand the code logic correctly. Thus, you should choose an existing issue or [create your own issue](https://github.com/hpcaitech/ColossalAI-Inference/issues) as your pull request target. If you wish to create a new issue, do use appropriate title and description and add related labels. 95 | 96 | ### 5. Create a New Branch 97 | 98 | You should not make changes to the `main` branch of your forked repository as this might make upstream synchronization difficult. You can create a new branch with the appropriate name. General branch name format should start with `hotfix/` and `feature/`. `hotfix` is for bug fix and `feature` is for addition of a new feature. 99 | 100 | 101 | ```shell 102 | git checkout -b 103 | ``` 104 | 105 | ### 6. Implementation and Code Commit 106 | 107 | Now you can implement your code change in the source code. Remember that you installed the system in development, thus you do not need to uninstall and install to make the code take effect. The code change will be reflected in every new PyThon execution. 108 | You can commit and push the changes to your local repository. The changes should be kept logical, modular and atomic. 109 | 110 | ```shell 111 | git add -A 112 | git commit -m "" 113 | git push -u origin 114 | ``` 115 | 116 | ### 7. Open a Pull Request 117 | 118 | You can now create a pull request on the GitHub webpage of your repository. The source branch is `` of your repository and the target branch should be `main` of `hpcaitech/ColossalAI`. After creating this pull request, you should be able to see it [here](https://github.com/hpcaitech/ColossalAI-Inference/pulls). 119 | 120 | Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved. 121 | 122 | In case of code conflict, you should rebase your branch and resolve the conflicts manually. 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 4 | # Energon-AI 5 | 6 | ![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat) 7 | [![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold)](https://github.com/hpcaitech/ColossalAI-Inference/blob/main/LICENSE) 8 | 9 | A service framework for large-scale model inference, Energon-AI has the following characteristics: 10 | 11 | - **Parallelism for Large-scale Models:** With tensor parallel operations, pipeline parallel wrapper, distributed checkpoint loading, and customized CUDA kernel, EnergonAI can enable efficient parallel inference for larges-scale models. 12 | - **Pre-built large models:** There are pre-built implementation for popular models, such as OPT. It supports the cache technique for the generation task and distributed parameter loading. 13 | - **Engine encapsulation:** There has an abstraction layer called engine. It encapsulates the single instance multiple devices (SIMD) execution with the remote procedure call, making it acts as the single instance single device (SISD) execution. 14 | - **An online service system:** Based on FastAPI, users can launch a web service of the distributed infernce quickly. The online service makes special optimizations for the generation task. It adopts both left padding and bucket batching techniques for improving the efficiency. 15 | 16 | For models trained by [Colossal-AI](https://github.com/hpcaitech/ColossalAI), they can be easily transferred to Energon-AI. 17 | For single-device models, they require manual coding works to introduce tensor parallelism and pipeline parallelism. 18 | 19 | 20 | ### Installation 21 | **Install from source** 22 | ``` bash 23 | $ git clone git@github.com:hpcaitech/EnergonAI.git 24 | $ pip install -r requirements.txt 25 | $ pip install . 26 | ``` 27 | **Use docker** 28 | ``` bash 29 | $ docker pull hpcaitech/energon-ai:latest 30 | ``` 31 | 32 | 33 | ### Build an online OPT service in 5 minutes 34 | 35 | 1. **Download OPT model:** 36 | To launch the distributed inference service quickly, you can download the checkpoint of OPT-125M [here](https://huggingface.co/patrickvonplaten/opt_metaseq_125m/blob/main/model/restored.pt). You can get details for loading other sizes of models [here](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt/script). 37 | 38 | 2. **Launch an HTTP service:** 39 | To launch a service, we need to provide python scripts to describe the model type and related configurations, and start an http service. 40 | An OPT example is [EnergonAI/examples/opt](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt). 41 | The entrance of the service is a bash script ***server.sh***. 42 | The config of the service is at ***opt_config.py***, which defines the model type, the checkpoint file path, the parallel strategy, and http settings. You can adapt it for your own case. 43 | For example, set the model class as opt_125M and set the correct checkpoint path as follows. Set the tensor parallelism degree the same as your gpu number. 44 | ```bash 45 | model_class = opt_125M 46 | checkpoint = 'your_file_path' 47 | tp_init_size = #gpu 48 | ``` 49 | Now, we can launch a service: 50 | 51 | ```bash 52 | bash server.sh 53 | ``` 54 | 55 | Then open ***https://[ip]:[port]/docs*** in your browser and try out! 56 | 57 | 58 | ### Publication 59 | You can find technical details in our blog and manuscript: 60 | 61 | [Build an online OPT service using Colossal-AI in 5 minutes](https://www.colossalai.org/docs/advanced_tutorials/opt_service/) 62 | 63 | [EnergonAI: An Inference System for 10-100 Billion Parameter Transformer Models](https://arxiv.org/pdf/2209.02341.pdf) 64 | 65 | ``` 66 | @misc{du2022energonai, 67 | title={EnergonAI: An Inference System for 10-100 Billion Parameter Transformer Models}, 68 | author={Jiangsu Du and Ziming Liu and Jiarui Fang and Shenggui Li and Yongbin Li and Yutong Lu and Yang You}, 69 | year={2022}, 70 | eprint={2209.02341}, 71 | archivePrefix={arXiv}, 72 | primaryClass={cs.LG} 73 | } 74 | ``` 75 | 76 | ### Contributing 77 | 78 | If interested in making your own contribution to the project, please refer to [Contributing](./CONTRIBUTING.md) for guidance. 79 | 80 | Thanks so much! -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM hpcaitech/colossalai:0.1.8 2 | 3 | WORKDIR /workspace 4 | 5 | RUN yum install -y vim 6 | RUN mkdir -p /workspace && cd /workspace && git clone https://github.com/hpcaitech/EnergonAI.git --recursive && cd EnergonAI && pip --no-cache-dir install -r requirements.txt && pip install . && rm -rf /workspace/EnergonAI 7 | 8 | CMD ["bash", "/config/server.sh"] -------------------------------------------------------------------------------- /docker/launch.sh: -------------------------------------------------------------------------------- 1 | # the directory contains the checkpoint 2 | export CHECKPOINT_DIR="/data/user/lclhx/opt-30B" 3 | # the ${CONFIG_DIR} must contain a server.sh file as the entry of service 4 | export CONFIG_DIR="/home/lcfjr/codes/EnergonAI/examples/opt" 5 | 6 | docker run --gpus all --rm -it -p 8090:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest 7 | -------------------------------------------------------------------------------- /energonai/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_mgr import BatchManager 2 | from .engine import launch_engine, SubmitEntry, QueueFullError 3 | from .task import TaskEntry 4 | 5 | 6 | __all__ = ['BatchManager', 'launch_engine', 'SubmitEntry', 'TaskEntry', 'QueueFullError'] 7 | -------------------------------------------------------------------------------- /energonai/batch_mgr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Hashable, Tuple, Deque, Iterable 2 | from dataclasses import dataclass 3 | from .task import TaskEntry 4 | 5 | 6 | @dataclass 7 | class SubmitEntry: 8 | uid: Hashable 9 | data: Any 10 | 11 | 12 | class BatchManager: 13 | def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: 14 | entry = q.popleft() 15 | return TaskEntry((entry.uid, ), entry.data), {} 16 | 17 | def split_batch(self, task_entry: TaskEntry, **kwargs: Any) -> Iterable[Tuple[Hashable, Any]]: 18 | return [(task_entry.uids[0], task_entry.batch)] 19 | -------------------------------------------------------------------------------- /energonai/communication/__init__.py: -------------------------------------------------------------------------------- 1 | from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce 2 | from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, 3 | send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, 4 | recv_forward, recv_backward) 5 | from .ring import ring_forward 6 | from .utils import send_tensor_meta, recv_tensor_meta 7 | 8 | __all__ = [ 9 | 'all_gather', 10 | 'reduce_scatter', 11 | 'all_reduce', 12 | 'broadcast', 13 | 'reduce', 14 | 'send_forward', 15 | 'send_forward_recv_forward', 16 | 'send_forward_backward_recv_forward_backward', 17 | 'send_backward', 18 | 'send_backward_recv_backward', 19 | 'send_backward_recv_forward', 20 | 'send_forward_recv_backward', 21 | 'recv_backward', 22 | 'recv_forward', 23 | 'ring_forward', 24 | 'send_tensor_meta', 25 | 'recv_tensor_meta', 26 | ] 27 | -------------------------------------------------------------------------------- /energonai/communication/collective.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.distributed import ReduceOp 7 | from torch import Tensor 8 | 9 | from colossalai.core import global_context as gpc 10 | from colossalai.context import ParallelMode 11 | from colossalai.utils import get_current_device 12 | 13 | 14 | def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: 15 | """Gathers all tensors from the parallel group and concatenates them in a 16 | specific dimension. 17 | 18 | :param tensor: Tensor to be gathered 19 | :param dim: The dimension concatenating in 20 | :param parallel_mode: Parallel group mode used in this communication 21 | :param async_op: Whether operations are asynchronous 22 | 23 | :type tensor: :class:`torch.Tensor` 24 | :type dim: int 25 | :type parallel_mode: :class:`colossalai.context.ParallelMode` 26 | :type async_op: bool, optional 27 | 28 | :return: The tensor generated by all-gather 29 | :rtype: :class:`torch.Tensor` 30 | """ 31 | depth = gpc.get_world_size(parallel_mode) 32 | if depth == 1: 33 | out = [tensor] 34 | work = None 35 | else: 36 | shape = list(tensor.shape) 37 | shape[0], shape[dim] = shape[dim], shape[0] 38 | shape[0] *= depth 39 | out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device()) 40 | temp = list(torch.chunk(out, depth, dim=0)) 41 | work = dist.all_gather(tensor_list=temp, 42 | tensor=tensor.transpose(0, dim).contiguous(), 43 | group=gpc.get_group(parallel_mode), 44 | async_op=async_op) 45 | out = torch.transpose(out, 0, dim) 46 | if async_op: 47 | return out, work 48 | else: 49 | return out 50 | 51 | 52 | def reduce_scatter(tensor: Tensor, 53 | dim: int, 54 | parallel_mode: ParallelMode, 55 | op: ReduceOp = ReduceOp.SUM, 56 | async_op: bool = False) -> Tensor: 57 | """Reduces all tensors then scatters it in a specific dimension to all 58 | members in the parallel group. 59 | 60 | :param tensor: Tensor to be reduced and scattered 61 | :param dim: The dimension scattering in 62 | :param parallel_mode: Parallel group mode used in this communication 63 | :param op: The type of reduce operation 64 | :param async_op: Whether operations are asynchronous 65 | 66 | :type tensor: :class:`torch.Tensor` 67 | :type dim: int 68 | :type parallel_mode: :class:`colossalai.context.ParallelMode` 69 | :type op: ReduceOp, optional 70 | :type async_op: bool, optional 71 | 72 | :return: The tensor generated by reduce-scatter 73 | :rtype: :class:`Tensor` 74 | """ 75 | depth = gpc.get_world_size(parallel_mode) 76 | if depth == 1: 77 | out = tensor 78 | work = None 79 | else: 80 | temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) 81 | out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device()) 82 | work = dist.reduce_scatter(output=out, 83 | input_list=temp, 84 | op=op, 85 | group=gpc.get_group(parallel_mode), 86 | async_op=async_op) 87 | if async_op: 88 | return out, work 89 | else: 90 | return out 91 | 92 | 93 | def all_reduce(tensor: Tensor, 94 | parallel_mode: ParallelMode, 95 | op: ReduceOp = ReduceOp.SUM, 96 | async_op: bool = False) -> Tensor: 97 | depth = gpc.get_world_size(parallel_mode) 98 | if depth == 1: 99 | work = None 100 | else: 101 | work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op) 102 | if async_op: 103 | return tensor, work 104 | else: 105 | return tensor 106 | 107 | 108 | def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): 109 | depth = gpc.get_world_size(parallel_mode) 110 | if depth == 1: 111 | work = None 112 | else: 113 | work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op) 114 | if async_op: 115 | return tensor, work 116 | else: 117 | return tensor 118 | 119 | 120 | def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): 121 | depth = gpc.get_world_size(parallel_mode) 122 | if depth == 1: 123 | work = None 124 | else: 125 | work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) 126 | if async_op: 127 | return tensor, work 128 | else: 129 | return tensor 130 | 131 | 132 | def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None): 133 | r"""Modified from `torch.distributed.scatter_object_list ` to fix issues 134 | """ 135 | if dist._rank_not_in_group(group): 136 | return 137 | 138 | if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: 139 | raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") 140 | 141 | # set tensor device to cuda if backend is nccl 142 | device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") 143 | 144 | my_rank = dist.get_rank() # use global rank 145 | if my_rank == src: 146 | tensor_list, tensor_sizes = zip( 147 | *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) 148 | tensor_list = list(map(lambda x: x.to(device), tensor_list)) 149 | tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) 150 | 151 | # Src rank broadcasts the maximum tensor size. This is because all ranks are 152 | # expected to call into scatter() with equal-sized tensors. 153 | if my_rank == src: 154 | max_tensor_size = max(tensor_sizes) 155 | for tensor in tensor_list: 156 | tensor.resize_(max_tensor_size) 157 | else: 158 | max_tensor_size = torch.tensor([0], dtype=torch.long).to(device) 159 | 160 | dist.broadcast(max_tensor_size, src=src, group=group) 161 | 162 | # Scatter actual serialized objects 163 | output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device) 164 | dist.scatter( 165 | output_tensor, 166 | scatter_list=None if my_rank != src else tensor_list, 167 | src=src, 168 | group=group, 169 | ) 170 | 171 | # Scatter per-object sizes to trim tensors when deserializing back to object 172 | obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device) 173 | dist.scatter( 174 | obj_tensor_size, 175 | scatter_list=None if my_rank != src else tensor_sizes, 176 | src=src, 177 | group=group, 178 | ) 179 | 180 | output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu() 181 | # Deserialize back to object 182 | scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size) 183 | -------------------------------------------------------------------------------- /energonai/communication/ring.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | from colossalai.core import global_context as gpc 7 | from colossalai.context import ParallelMode 8 | from colossalai.utils import get_current_device, synchronize 9 | 10 | 11 | def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode): 12 | """Sends a tensor to the next member and recieves a tensor from the previous member. 13 | This function returns the recieved tensor from the previous member. 14 | 15 | :param tensor_send_next: Tensor sent to next member 16 | :param parallel_mode: Parallel group mode used in this communication 17 | :type tensor_send_next: :class:`torch.Tensor` 18 | :type parallel_mode: :class:`colossalai.context.ParallelMode` 19 | :return: The tensor recieved from the previous 20 | :rtype: :class:`torch.Tensor` 21 | """ 22 | buffer_shape = tensor_send_next.size() 23 | 24 | ops = [] 25 | current_rank = gpc.get_global_rank() 26 | 27 | tensor_recv_prev = torch.empty(buffer_shape, 28 | requires_grad=True, 29 | device=get_current_device(), 30 | dtype=tensor_send_next.dtype) 31 | 32 | # send to next rank 33 | send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, 34 | gpc.get_next_global_rank(parallel_mode)) 35 | ops.append(send_next_op) 36 | 37 | # receive from prev rank 38 | recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, 39 | gpc.get_prev_global_rank(parallel_mode)) 40 | ops.append(recv_prev_op) 41 | 42 | if current_rank % 2 == 0: 43 | ops = ops[::-1] 44 | 45 | reqs = torch.distributed.batch_isend_irecv(ops) 46 | for req in reqs: 47 | req.wait() 48 | 49 | # To protect against race condition when using batch_isend_irecv(). 50 | synchronize() 51 | 52 | return tensor_recv_prev 53 | -------------------------------------------------------------------------------- /energonai/communication/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from colossalai.core import global_context as gpc 5 | from colossalai.context import ParallelMode 6 | from colossalai.utils import get_current_device 7 | 8 | 9 | def send_tensor_meta(tensor, need_meta=True, next_rank=None): 10 | """Sends tensor meta information before sending a specific tensor. 11 | Since the recipient must know the shape of the tensor in p2p communications, 12 | meta information of the tensor should be sent before communications. This function 13 | synchronizes with :func:`recv_tensor_meta`. 14 | 15 | :param tensor: Tensor to be sent 16 | :param need_meta: If False, meta information won't be sent 17 | :param next_rank: The rank of the next member in pipeline parallel group 18 | :type tensor: Tensor 19 | :type need_meta: bool, optional 20 | :type next_rank: int 21 | :return: False 22 | :rtype: bool 23 | """ 24 | if need_meta: 25 | if next_rank is None: 26 | next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) 27 | 28 | tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} 29 | 30 | send_shape = torch.tensor(tensor.size(), **tensor_kwargs) 31 | send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) 32 | dist.send(send_ndims, next_rank) 33 | dist.send(send_shape, next_rank) 34 | 35 | return False 36 | 37 | 38 | def recv_tensor_meta(tensor_shape, prev_rank=None): 39 | """Recieves tensor meta information before recieving a specific tensor. 40 | Since the recipient must know the shape of the tensor in p2p communications, 41 | meta information of the tensor should be recieved before communications. This function 42 | synchronizes with :func:`send_tensor_meta`. 43 | 44 | :param tensor_shape: The shape of the tensor to be recieved 45 | :param prev_rank: The rank of the source of the tensor 46 | :type tensor_shape: torch.Size 47 | :type prev_rank: int, optional 48 | :return: The shape of the tensor to be recieved 49 | :rtype: torch.Size 50 | """ 51 | if tensor_shape is None: 52 | if prev_rank is None: 53 | prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) 54 | 55 | tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} 56 | 57 | recv_ndims = torch.empty((), **tensor_kwargs) 58 | dist.recv(recv_ndims, prev_rank) 59 | recv_shape = torch.empty(recv_ndims, **tensor_kwargs) 60 | dist.recv(recv_shape, prev_rank) 61 | 62 | tensor_shape = torch.Size(recv_shape) 63 | 64 | return tensor_shape 65 | 66 | 67 | def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): 68 | """Break a tensor into equal 1D chunks. 69 | 70 | :param tensor: Tensor to be splitted before communication 71 | :param new_buffer: Whether uses a new buffer to store sliced tensor 72 | 73 | :type tensor: torch.Tensor 74 | :type new_buffer: bool, optional 75 | 76 | :return splitted_tensor: The splitted tensor 77 | :rtype splitted_tensor: torch.Tensor 78 | """ 79 | partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) 80 | start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) 81 | end_index = start_index + partition_size 82 | if new_buffer: 83 | data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) 84 | data.copy_(tensor.view(-1)[start_index:end_index]) 85 | else: 86 | data = tensor.view(-1)[start_index:end_index] 87 | return data 88 | 89 | 90 | def gather_split_1d_tensor(tensor): 91 | """Opposite of above function, gather values from model parallel ranks. 92 | 93 | :param tensor: Tensor to be gathered after communication 94 | :type tensor: torch.Tensor 95 | 96 | :return gathered: The gathered tensor 97 | :rtype gathered: torch.Tensor 98 | """ 99 | world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) 100 | numel = torch.numel(tensor) 101 | numel_gathered = world_size * numel 102 | gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) 103 | chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] 104 | dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) 105 | return gathered 106 | -------------------------------------------------------------------------------- /energonai/engine.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import signal 3 | import time 4 | from collections import deque 5 | from threading import Lock, Thread 6 | from typing import Any, Callable, Deque, Dict, Hashable, List, Optional, Tuple 7 | 8 | import torch.distributed.rpc as trpc 9 | import torch.nn as nn 10 | from colossalai.logging import get_dist_logger 11 | 12 | from .batch_mgr import BatchManager, SubmitEntry 13 | from .pipe import Pipe 14 | from .task import TaskEntry 15 | from .utils import Terminator, build_device_maps, use_lock 16 | from .worker import launch_workers 17 | 18 | 19 | class QueueFullError(Exception): 20 | pass 21 | 22 | 23 | class AsyncEngine: 24 | def __init__(self, tp_world_size: int, pp_world_size: int, master_host: str, rpc_port: int, n_proc_per_node: int, 25 | batch_manager: Optional[BatchManager] = None, pipe_size: int = 1, queue_size: int = 0, rpc_disable_shm: bool = True) -> None: 26 | self.lock = Lock() 27 | self.logger = get_dist_logger('energonai') 28 | if batch_manager is None: 29 | self.batch_manager = BatchManager() 30 | else: 31 | assert isinstance(batch_manager, BatchManager) 32 | self.batch_manager = batch_manager 33 | self.world_size = tp_world_size * pp_world_size 34 | 35 | rpc_options = {} 36 | if rpc_disable_shm: 37 | # SHM may lead to timeout error. Disabling SHM and only enabling uv transport can solve this problem. 38 | # See https://discuss.pytorch.org/t/rpc-behavior-difference-between-pytorch-1-7-0-vs-1-9-0/124772/5 39 | # This is a workaround and may be solved in the future. 40 | rpc_options['_transports'] = ['uv'] 41 | trpc.init_rpc('master', rank=0, world_size=self.world_size + 1, 42 | rpc_backend_options=trpc.TensorPipeRpcBackendOptions( 43 | init_method=f'tcp://{master_host}:{rpc_port}', 44 | device_maps=build_device_maps(self.world_size, n_proc_per_node), 45 | **rpc_options 46 | )) 47 | self.from_worker_pipes: List[Pipe] = [] 48 | for i in range(self.world_size): 49 | pipe = Pipe(f'{i}_to_m', f'worker{i}', 'master') 50 | self.from_worker_pipes.append(pipe) 51 | self.submit_pipes: List[Pipe] = [] 52 | self.completion_pipes: List[Pipe] = [] 53 | for i, pipe in enumerate(self.from_worker_pipes): 54 | worker_pp_rank = pipe.recv() 55 | if worker_pp_rank == 0: 56 | self.submit_pipes.append(Pipe(f'm_to_{i}', 'master', f'worker{i}', max_size=pipe_size)) 57 | if worker_pp_rank == pp_world_size - 1: 58 | self.completion_pipes.append(pipe) 59 | 60 | self.running: bool = False 61 | self.submit_thread = None 62 | self.completion_thread = None 63 | self.queue_size = queue_size 64 | self.submit_queue: Deque[SubmitEntry] = deque() 65 | self.batch_info: Dict[Hashable, Any] = {} 66 | self.timer_info: Dict[Hashable, Tuple[int, float]] = {} 67 | self.completion_map: Dict[Hashable, Any] = {} 68 | 69 | self.logger.info('Engine start') 70 | self._start() 71 | self.register_sigint() 72 | 73 | def _submit_loop(self) -> None: 74 | while self.running: 75 | if len(self.submit_queue) > 0: 76 | task_entry, batch_info = self.batch_manager.make_batch(self.submit_queue) 77 | self.batch_info[task_entry.uids] = batch_info 78 | self.timer_info[task_entry.uids] = (len(task_entry.uids), time.time()) 79 | for pipe in self.submit_pipes: 80 | pipe.send(task_entry) 81 | else: 82 | time.sleep(0.01) 83 | 84 | def _completion_loop(self) -> None: 85 | received_data: Dict[int, Any] = {} 86 | while self.running: 87 | for i, pipe in enumerate(self.completion_pipes): 88 | if i not in received_data: 89 | try: 90 | received_data[i] = pipe.recv_nowait() 91 | except RuntimeError: 92 | pass 93 | if len(received_data) == len(self.completion_pipes): 94 | # TODO: validate they are all the same 95 | task_entries: List[TaskEntry] = list(map(lambda k: received_data[k], sorted(received_data.keys()))) 96 | received_data.clear() 97 | batch_info = self.batch_info.pop(task_entries[0].uids) 98 | for uid, output in self.batch_manager.split_batch(task_entries[0], **batch_info): 99 | self.completion_map[uid] = output 100 | batch_size, start_time = self.timer_info.pop(task_entries[0].uids) 101 | self.logger.info(f'batch size: {batch_size}, time: {time.time() -start_time:.3f}') 102 | else: 103 | time.sleep(0.01) 104 | 105 | def _start(self) -> None: 106 | self.running = True 107 | self.submit_thread = Thread(target=self._submit_loop) 108 | self.submit_thread.start() 109 | self.completion_thread = Thread(target=self._completion_loop) 110 | self.completion_thread.start() 111 | 112 | def shutdown(self) -> None: 113 | with use_lock(self.lock): 114 | if not self.running: 115 | return 116 | self.running = False 117 | Terminator.shield() 118 | for i in range(self.world_size): 119 | trpc.rpc_sync(f'worker{i}', Terminator.terminate) 120 | trpc.shutdown() 121 | self.submit_thread.join() 122 | self.completion_thread.join() 123 | 124 | def submit(self, uid: Hashable, data: Any) -> None: 125 | assert self.submit_thread.is_alive() 126 | assert uid not in self.completion_map 127 | if self.queue_size > 0 and len(self.submit_queue) >= self.queue_size: 128 | raise QueueFullError(f'Submit queue full, size: {self.queue_size}') 129 | self.submit_queue.append(SubmitEntry(uid, data)) 130 | 131 | async def wait(self, uid: Hashable) -> Any: 132 | assert self.completion_thread.is_alive() 133 | while True: 134 | if uid in self.completion_map: 135 | output = self.completion_map[uid] 136 | del self.completion_map[uid] 137 | return output 138 | await asyncio.sleep(0.1) 139 | 140 | def get(self, uid: Hashable, interval: float = 0.05) -> Any: 141 | assert self.completion_thread.is_alive() 142 | while True: 143 | if uid in self.completion_map: 144 | output = self.completion_map[uid] 145 | del self.completion_map[uid] 146 | return output 147 | time.sleep(interval) 148 | 149 | def _sigint_handler(self, *_): 150 | self.shutdown() 151 | raise KeyboardInterrupt 152 | 153 | def register_sigint(self): 154 | signal.signal(signal.SIGINT, self._sigint_handler) 155 | 156 | 157 | def launch_engine(tp_world_size: int, pp_world_size: int, master_host: str, master_port: int, rpc_port: int, 158 | model_fn: Callable[[Any], nn.Module], n_nodes: int = 1, node_rank: int = 0, batch_manager: Optional[BatchManager] = None, 159 | pipe_size: int = 1, queue_size: int = 0, rpc_disable_shm: bool = True, **model_kwargs: Any) -> Optional[AsyncEngine]: 160 | world_size = tp_world_size * pp_world_size 161 | assert world_size % n_nodes == 0 162 | n_proc_per_node = world_size // n_nodes 163 | launch_workers(tp_world_size, pp_world_size, master_host, master_port, rpc_port, 164 | model_fn, n_proc_per_node=n_proc_per_node, node_rank=node_rank, pipe_size=pipe_size, **model_kwargs) 165 | if node_rank == 0: 166 | engine = AsyncEngine(tp_world_size, pp_world_size, master_host, rpc_port, 167 | n_proc_per_node, batch_manager=batch_manager, pipe_size=pipe_size, queue_size=queue_size, rpc_disable_shm=rpc_disable_shm) 168 | return engine 169 | -------------------------------------------------------------------------------- /energonai/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cuda_native import transpose_pad, transpose_depad, depad, scale_mask_softmax 2 | from .cuda_native import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding 3 | from .cuda_native import linear, find_algo 4 | # from .cuda_native import OneLayerNorm 5 | 6 | __all__ = [ 7 | "transpose_pad", "transpose_depad", "depad", "scale_mask_softmax", "ft_build_padding_offsets", "ft_remove_padding", 8 | "ft_rebuild_padding", "ft_transpose_remove_padding", "ft_transpose_rebuild_padding", "linear", "find_algo", 9 | # "OneLayerNorm" 10 | ] 11 | -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/__init__.py: -------------------------------------------------------------------------------- 1 | from .transpose_pad import transpose_pad, transpose_depad, depad 2 | from .transpose_pad import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding 3 | from .scale_mask_softmax import scale_mask_softmax 4 | from .layer_norm import MixedFusedLayerNorm as LayerNorm 5 | from .linear_func import linear, find_algo -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/csrc/compat.h: -------------------------------------------------------------------------------- 1 | // modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h 2 | #ifndef TORCH_CHECK 3 | #define TORCH_CHECK AT_CHECK 4 | #endif 5 | 6 | #ifdef VERSION_GE_1_3 7 | #define DATA_PTR data_ptr 8 | #else 9 | #define DATA_PTR data 10 | #endif -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/csrc/get_ncclid.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "nccl.h" 8 | #include 9 | #include 10 | #include 11 | 12 | // c10::intrusive_ptr 13 | void sendNcclUniqueId(at::Tensor &ncclid, int dstRank, 14 | const c10::intrusive_ptr &pg) { 15 | // pack in 16 | // auto tensor = torch::from_blob(ncclId->internal, {int(32)}, 17 | // torch::TensorOptions(torch::kCUDA).dtype(torch::kFloat32).requires_grad(false)); 18 | // at::Tensor tensor = torch::zeros({int(32)}, 19 | // torch::TensorOptions(torch::kCUDA).dtype(torch::kFloat32)); 20 | std::vector tensors = {ncclid}; 21 | printf("[INFO] rank start send \n"); 22 | 23 | if (pg == c10::detail::UniqueVoidPtr()) { 24 | auto ret = pg->send(tensors, dstRank, 0); 25 | ret->wait(); 26 | } 27 | 28 | printf("[INFO] rank finish send \n"); 29 | // return ret; 30 | } 31 | 32 | void recvNcclUniqueId(at::Tensor &ncclid, int srcRank, 33 | const c10::intrusive_ptr &pg) { 34 | // pack in 35 | at::Tensor tensor = torch::zeros( 36 | {int(32)}, torch::TensorOptions(torch::kCUDA).dtype(torch::kFloat32)); 37 | // auto tensor = torch::from_blob(ncclId->internal, {int(32)}, 38 | // torch::TensorOptions(torch::kCUDA).dtype(torch::kFloat32).requires_grad(false)); 39 | std::vector tensors = {ncclid}; 40 | printf("[INFO] rank start recv \n"); 41 | 42 | if (pg == c10::detail::UniqueVoidPtr()) { 43 | auto ret = pg->recv(tensors, srcRank, 0); 44 | ret->wait(); 45 | } 46 | 47 | printf("[INFO] rank finish recv \n"); 48 | // at::Tensor tensor = tensors[0]; 49 | // float* temp = tensor.data_ptr(); 50 | // ncclId->internal 51 | // char * x = reinterpret_cast(temp); 52 | // get_ptr(tensor); 53 | } 54 | // if(local_rank == 0) 55 | // { 56 | // for(int i = 1; i &pg) { 71 | 72 | std::vector tensors = {ncclid}; 73 | 74 | printf("[INFO] rank start ncclid broadcast \n"); 75 | 76 | if (pg != c10::detail::UniqueVoidPtr()) { 77 | auto ret = pg->broadcast(tensors, c10d::BroadcastOptions()); 78 | ret->wait(); 79 | } 80 | 81 | printf("[INFO] rank finish ncclid broadcast in func \n"); 82 | 83 | // char* temp = reinterpret_cast(cpuNCCLID.data_ptr()); 84 | // for(int i = 0; i &pg) { 106 | 107 | ncclUniqueId tensor_para_nccl_uid; 108 | ncclGetUniqueId(&tensor_para_nccl_uid); 109 | auto tensor = torch::from_blob(tensor_para_nccl_uid.internal, {int(32)}, 110 | torch::TensorOptions(torch::kCPU) 111 | .dtype(torch::kFloat32) 112 | .requires_grad(false)); 113 | torch::Tensor gpuNCCLID = tensor.to(torch::kCUDA); 114 | broadcastUniqueId(gpuNCCLID, local_rank, pg); 115 | torch::Tensor cpuNCCLID = gpuNCCLID.to(torch::kCPU); 116 | 117 | // char* temp = reinterpret_cast(cpuNCCLID.data_ptr()); 118 | // for(int i = 0; i 7 | #include 8 | #include 9 | 10 | namespace { 11 | 12 | void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, 13 | int &n2) { 14 | int idiff = input.ndimension() - normalized_shape.size(); 15 | n2 = 1; 16 | for (int i = 0; i < (int)normalized_shape.size(); ++i) { 17 | assert(input.sizes()[i + idiff] == normalized_shape[i]); 18 | n2 *= normalized_shape[i]; 19 | } 20 | n1 = 1; 21 | for (int i = 0; i < idiff; ++i) { 22 | n1 *= input.sizes()[i]; 23 | } 24 | } 25 | 26 | void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, 27 | at::Tensor beta) { 28 | TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); 29 | TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); 30 | } 31 | 32 | void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, 33 | int &n2) { 34 | int64_t normalized_ndim = normalized_shape.size(); 35 | 36 | if (normalized_ndim < 1) { 37 | std::stringstream ss; 38 | ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " 39 | << "containing at least one element, but got normalized_shape=" 40 | << normalized_shape; 41 | throw std::runtime_error(ss.str()); 42 | } 43 | 44 | auto input_shape = input.sizes(); 45 | auto input_ndim = input.dim(); 46 | 47 | if (input_ndim < normalized_ndim || 48 | !input_shape.slice(input_ndim - normalized_ndim) 49 | .equals(normalized_shape)) { 50 | std::stringstream ss; 51 | ss << "Given normalized_shape=" << normalized_shape 52 | << ", expected input with shape [*"; 53 | for (auto size : normalized_shape) { 54 | ss << ", " << size; 55 | } 56 | ss << "], but got input of size" << input_shape; 57 | throw std::runtime_error(ss.str()); 58 | } 59 | 60 | compute_n1_n2(input, normalized_shape, n1, n2); 61 | } 62 | 63 | void check_args(at::Tensor input, at::IntArrayRef normalized_shape, 64 | at::Tensor gamma, at::Tensor beta, int &n1, int &n2) { 65 | check_args(input, normalized_shape, n1, n2); 66 | check_args(normalized_shape, gamma, beta); 67 | } 68 | } // namespace 69 | 70 | void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, 71 | at::Tensor *input, int n1, int n2, 72 | at::IntArrayRef normalized_shape, at::Tensor *gamma, 73 | at::Tensor *beta, double epsilon); 74 | 75 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 76 | #define CHECK_CONTIGUOUS(x) \ 77 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 78 | #define CHECK_INPUT(x) \ 79 | CHECK_CUDA(x); \ 80 | CHECK_CONTIGUOUS(x) 81 | 82 | std::vector layer_norm_affine(at::Tensor input, 83 | at::IntArrayRef normalized_shape, 84 | at::Tensor gamma, at::Tensor beta, 85 | double epsilon) { 86 | 87 | CHECK_INPUT(input); 88 | CHECK_INPUT(gamma); 89 | CHECK_INPUT(beta); 90 | int n1, n2; 91 | check_args(input, normalized_shape, gamma, beta, n1, n2); 92 | 93 | at::Tensor output = 94 | at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); 95 | at::Tensor mean = 96 | at::empty({n1}, input.options().dtype(at::ScalarType::Float)); 97 | at::Tensor invvar = at::empty_like(mean); 98 | 99 | cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, 100 | &gamma, &beta, epsilon); 101 | 102 | return {output, mean, invvar}; 103 | } 104 | 105 | void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, 106 | at::Tensor *invvar, at::Tensor *input, int n1, 107 | int n2, at::IntArrayRef normalized_shape, 108 | at::Tensor *gamma, at::Tensor *beta, 109 | double epsilon, at::Tensor *grad_input, 110 | at::Tensor *grad_gamma, at::Tensor *grad_beta); 111 | 112 | std::vector 113 | layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar, 114 | at::Tensor input, at::IntArrayRef normalized_shape, 115 | at::Tensor gamma, at::Tensor beta, double epsilon) { 116 | 117 | CHECK_INPUT(dout); 118 | CHECK_INPUT(mean); 119 | CHECK_INPUT(invvar); 120 | CHECK_INPUT(input); 121 | CHECK_INPUT(gamma); 122 | CHECK_INPUT(beta); 123 | int n1, n2; 124 | check_args(input, normalized_shape, gamma, beta, n1, n2); 125 | 126 | at::Tensor grad_input = at::empty_like(input); 127 | at::Tensor grad_gamma = at::empty_like(gamma); 128 | at::Tensor grad_beta = at::empty_like(beta); 129 | 130 | cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, 131 | normalized_shape, &gamma, &beta, epsilon, 132 | &grad_input, &grad_gamma, &grad_beta); 133 | 134 | return {grad_input, grad_gamma, grad_beta}; 135 | } 136 | 137 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 138 | m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); 139 | m.def("backward_affine", &layer_norm_gradient_affine, 140 | "LayerNorm backward (CUDA)"); 141 | } -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/csrc/linear_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "compat.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | 14 | static const char *_cudaGetErrorEnum(cublasStatus_t error) 15 | { 16 | switch (error) 17 | { 18 | case CUBLAS_STATUS_SUCCESS: 19 | return "CUBLAS_STATUS_SUCCESS"; 20 | 21 | case CUBLAS_STATUS_NOT_INITIALIZED: 22 | return "CUBLAS_STATUS_NOT_INITIALIZED"; 23 | 24 | case CUBLAS_STATUS_ALLOC_FAILED: 25 | return "CUBLAS_STATUS_ALLOC_FAILED"; 26 | 27 | case CUBLAS_STATUS_INVALID_VALUE: 28 | return "CUBLAS_STATUS_INVALID_VALUE"; 29 | 30 | case CUBLAS_STATUS_ARCH_MISMATCH: 31 | return "CUBLAS_STATUS_ARCH_MISMATCH"; 32 | 33 | case CUBLAS_STATUS_MAPPING_ERROR: 34 | return "CUBLAS_STATUS_MAPPING_ERROR"; 35 | 36 | case CUBLAS_STATUS_EXECUTION_FAILED: 37 | return "CUBLAS_STATUS_EXECUTION_FAILED"; 38 | 39 | case CUBLAS_STATUS_INTERNAL_ERROR: 40 | return "CUBLAS_STATUS_INTERNAL_ERROR"; 41 | 42 | case CUBLAS_STATUS_NOT_SUPPORTED: 43 | return "CUBLAS_STATUS_NOT_SUPPORTED"; 44 | 45 | case CUBLAS_STATUS_LICENSE_ERROR: 46 | return "CUBLAS_STATUS_LICENSE_ERROR"; 47 | } 48 | return ""; 49 | } 50 | 51 | #define CHECK_CUDA(x) \ 52 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 53 | #define CHECK_CONTIGUOUS(x) \ 54 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 55 | #define CHECK_FP32(x) \ 56 | AT_ASSERTM(x.dtype() == torch::kFloat32, "Datatype not implemented") 57 | #define CHECK_FP16(x) \ 58 | AT_ASSERTM(x.dtype() == torch::kFloat16, "Datatype not implemented") 59 | #define CHECK_FP16_INPUT(x) \ 60 | CHECK_CUDA(x); \ 61 | CHECK_CONTIGUOUS(x); \ 62 | CHECK_FP16(x) 63 | #define CHECK_FP32_INPUT(x) \ 64 | CHECK_CUDA(x); \ 65 | CHECK_CONTIGUOUS(x); \ 66 | CHECK_FP32(x) 67 | #define CHECK_INPUT(x) \ 68 | CHECK_CUDA(x); \ 69 | CHECK_CONTIGUOUS(x) 70 | 71 | template 72 | void check(T result, char const *const func, const char *const file, int const line) 73 | { 74 | if (result) 75 | { 76 | throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " + file + ":" + std::to_string(line) + " \n"); 77 | } 78 | } 79 | #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) 80 | 81 | 82 | torch::Tensor mlp_gemm(torch::Tensor input_tensor, torch::Tensor weights, int algo = CUBLAS_GEMM_DEFAULT) 83 | { 84 | CHECK_FP16_INPUT(input_tensor); 85 | CHECK_FP16_INPUT(weights); 86 | static half h_alpha = (half)1.0f; 87 | static half h_beta = (half)0.0f; 88 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); 89 | 90 | int batch_size = input_tensor.sizes()[0]; 91 | int seq_len = input_tensor.sizes()[1]; 92 | int din = input_tensor.sizes()[2]; 93 | int dout = weights.sizes()[0]; 94 | 95 | auto options = torch::TensorOptions().dtype(input_tensor.dtype()).device(torch::kCUDA).requires_grad(false); 96 | auto output = torch::empty({batch_size, seq_len, dout}, options); 97 | 98 | check_cuda_error( 99 | cublasGemmEx( 100 | handle, CUBLAS_OP_T, CUBLAS_OP_N, 101 | dout, seq_len * batch_size, din, 102 | &h_alpha, 103 | weights.data_ptr(), CUDA_R_16F, din, 104 | input_tensor.data_ptr(), CUDA_R_16F, din, 105 | &h_beta, 106 | output.data_ptr(), CUDA_R_16F, dout, 107 | CUBLAS_COMPUTE_16F, 108 | static_cast(algo))); 109 | 110 | return output; 111 | } 112 | 113 | int get_start_algo() 114 | { 115 | return CUBLAS_GEMM_DEFAULT; 116 | } 117 | 118 | int get_end_algo() 119 | { 120 | return CUBLAS_GEMM_ALGO23; 121 | } 122 | 123 | int get_start_algo_t_op() 124 | { 125 | return CUBLAS_GEMM_DEFAULT_TENSOR_OP; 126 | } 127 | 128 | int get_end_algo_t_op() 129 | { 130 | return CUBLAS_GEMM_ALGO15_TENSOR_OP; 131 | } 132 | 133 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 134 | { 135 | m.def("mlp_gemm", &mlp_gemm, py::arg("input"), py::arg("param"), py::arg("algo") = (int)CUBLAS_GEMM_DEFAULT); 136 | m.def("get_start_algo", &get_start_algo); 137 | m.def("get_end_algo", &get_end_algo); 138 | m.def("get_start_algo_t_op", &get_start_algo_t_op); 139 | m.def("get_end_algo_t_op", &get_end_algo_t_op); 140 | } -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/csrc/scale_mask_softmax_kernel.cu: -------------------------------------------------------------------------------- 1 | // LightSeq 2 | // Copyright 2019 Bytedance Inc. 3 | 4 | #include "common.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | template 12 | __global__ void ker_scale_mask_softmax(T *correlation, const int *real_seq_len, 13 | const int batch_seq_len) { 14 | int query_token_pos = blockIdx.y % batch_seq_len; 15 | if (query_token_pos >= real_seq_len[blockIdx.x]) { 16 | return; 17 | } 18 | 19 | int mask = 0; // can see the token when mask=0 20 | if (threadIdx.x > query_token_pos || threadIdx.x >= batch_seq_len) { 21 | mask = 1; // Can only see the token on the left side of it 22 | } 23 | 24 | int idx = (blockIdx.x * gridDim.y + blockIdx.y) * batch_seq_len + threadIdx.x; 25 | float val = threadIdx.x < batch_seq_len ? (float)correlation[idx] 26 | : CUDA_FLOAT_INF_NEG; 27 | float max_val = blockReduceMax(mask ? CUDA_FLOAT_INF_NEG : val); 28 | __shared__ float smax; 29 | if (threadIdx.x == 0) 30 | smax = max_val; 31 | __syncthreads(); 32 | 33 | val = mask ? 0.f : expf(val - smax); 34 | float rsum = blockReduceSum(val); 35 | __shared__ float ssum; 36 | if (threadIdx.x == 0) 37 | ssum = rsum; 38 | __syncthreads(); 39 | 40 | if (threadIdx.x < batch_seq_len) 41 | correlation[idx] = (T)(val / ssum); 42 | } 43 | 44 | template 45 | void ker_scale_mask_softmax_launcher(int batch_size, int batch_seq_len, 46 | int head_num, T *correlation, 47 | const int *real_seq_len) { 48 | int block_dim = batch_seq_len; 49 | if (batch_seq_len < 1024) { 50 | block_dim = (batch_seq_len + 31) >> 5; 51 | block_dim *= 32; 52 | } 53 | 54 | ker_scale_mask_softmax 55 | <<>>( 56 | correlation, real_seq_len, batch_seq_len); 57 | } 58 | 59 | template void ker_scale_mask_softmax_launcher(int batch_size, 60 | int batch_seq_len, 61 | int head_num, 62 | float *correlation, 63 | const int *real_seq_len); 64 | 65 | template void ker_scale_mask_softmax_launcher<__half>(int batch_size, 66 | int batch_seq_len, 67 | int head_num, 68 | __half *correlation, 69 | const int *real_seq_len); -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/csrc/scale_mask_softmax_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_CUDA(x) \ 5 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) \ 7 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_FP16_32(x) \ 9 | AT_ASSERTM(x.dtype() == torch::kFloat32 || x.dtype() == torch::kFloat16, \ 10 | "Datatype not implemented") 11 | 12 | #define CHECK_FP16_32_INPUT(x) \ 13 | CHECK_CUDA(x); \ 14 | CHECK_CONTIGUOUS(x); \ 15 | CHECK_FP16_32(x) 16 | #define CHECK_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | 20 | template 21 | void ker_scale_mask_softmax_launcher(int batch_size, int batch_seq_len, 22 | int head_num, T *correlation, 23 | const int *real_seq_len); 24 | 25 | torch::Tensor scale_mask_softmax_wrapper(int batch_size, int batch_seq_len, 26 | int head_num, 27 | torch::Tensor correlation, 28 | torch::Tensor real_seq_len) { 29 | CHECK_FP16_32_INPUT(correlation); 30 | CHECK_INPUT(real_seq_len); 31 | 32 | if (correlation.dtype() == torch::kFloat32) { 33 | ker_scale_mask_softmax_launcher(batch_size, batch_seq_len, head_num, 34 | correlation.data_ptr(), 35 | real_seq_len.data_ptr()); 36 | } else if (correlation.dtype() == torch::kFloat16) { 37 | ker_scale_mask_softmax_launcher<__half>( 38 | batch_size, batch_seq_len, head_num, 39 | (__half *)correlation.data_ptr(), 40 | real_seq_len.data_ptr()); 41 | } 42 | 43 | return correlation; 44 | } 45 | 46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 47 | m.def("scale_mask_softmax_wrapper", &scale_mask_softmax_wrapper, 48 | "scale mask softmax fusion"); 49 | } 50 | -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/layer_norm.py: -------------------------------------------------------------------------------- 1 | """This code is from NVIDIA apex: 2 | https://github.com/NVIDIA/apex 3 | with some changes. """ 4 | 5 | import numbers 6 | import torch 7 | from torch.nn.parameter import Parameter 8 | from torch.nn import init 9 | from torch.cuda.amp import custom_fwd, custom_bwd 10 | import importlib 11 | 12 | global colossal_layer_norm_cuda 13 | colossal_layer_norm_cuda = None 14 | 15 | 16 | class FusedLayerNormAffineFunction(torch.autograd.Function): 17 | 18 | @staticmethod 19 | @custom_fwd(cast_inputs=torch.float32) 20 | def forward(ctx, input, weight, bias, normalized_shape, eps): 21 | 22 | ctx.normalized_shape = normalized_shape 23 | ctx.eps = eps 24 | input_ = input.contiguous() 25 | weight_ = weight.contiguous() 26 | bias_ = bias.contiguous() 27 | output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, 28 | ctx.eps) 29 | ctx.save_for_backward(input_, weight_, bias_, mean, invvar) 30 | 31 | return output 32 | 33 | @staticmethod 34 | @custom_bwd 35 | def backward(ctx, grad_output): 36 | 37 | input_, weight_, bias_, mean, invvar = ctx.saved_tensors 38 | grad_input = grad_weight = grad_bias = None 39 | grad_input, grad_weight, grad_bias \ 40 | = colossal_layer_norm_cuda.backward_affine( 41 | grad_output.contiguous(), mean, invvar, 42 | input_, ctx.normalized_shape, 43 | weight_, bias_, ctx.eps) 44 | 45 | return grad_input, grad_weight, grad_bias, None, None 46 | 47 | 48 | class MixedFusedLayerNorm(torch.nn.Module): 49 | 50 | def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): 51 | super(MixedFusedLayerNorm, self).__init__() 52 | 53 | global colossal_layer_norm_cuda 54 | if colossal_layer_norm_cuda is None: 55 | try: 56 | colossal_layer_norm_cuda = importlib.import_module("energonai_layer_norm") 57 | except ImportError: 58 | raise RuntimeError('MixedFusedLayerNorm requires cuda extensions') 59 | 60 | if isinstance(normalized_shape, numbers.Integral): 61 | normalized_shape = (normalized_shape,) 62 | self.normalized_shape = torch.Size(normalized_shape) 63 | self.eps = eps 64 | self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) 65 | self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | 70 | init.ones_(self.weight) 71 | init.zeros_(self.bias) 72 | 73 | def forward(self, input): 74 | 75 | return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) 76 | 77 | def __repr__(self): 78 | return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' 79 | -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/linear_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import importlib 4 | 5 | try: 6 | energonai_linear = importlib.import_module("energonai_linear_func") 7 | except ImportError: 8 | raise RuntimeError('energonai_linear_func requires cuda extensions') 9 | 10 | 11 | def linear(inputs, param, algo=-1): 12 | """ 13 | Linear function using Cublas 14 | 15 | Args: 16 | inputs (tensor): (batch, seq_len, din) 17 | param (tensor): (dout, din) 18 | algo (int): Cublas GEMM algorithms, defaults to -1. No effect for Ampere architecture gpu or above. 19 | -1: Apply Heuristics to select the GEMM algorithm 20 | 0~23: Explicitly choose a GEMM algorithm 21 | 99: Apply Heuristics to select the GEMM algorithm while allowing the use of Tensor Core operations if possible 22 | 100~115: Explicitly choose a GEMM algorithm allowing it to use Tensor Core operations if possible 23 | Returns: 24 | tensor: (batch, seq_len, dout) 25 | """ 26 | assert inputs.is_contiguous() 27 | assert param.is_contiguous() 28 | assert len(inputs.shape) == 3 29 | assert len(param.shape) == 2 30 | assert inputs.shape[2] == param.shape[1] 31 | assert isinstance(algo, int) and (-1 <= algo <= 23 or 99 <= algo <= 115) 32 | return energonai_linear.mlp_gemm(inputs, param, algo) 33 | 34 | 35 | @torch.no_grad() 36 | def find_algo(): 37 | """ 38 | Auto find best algo, may take tens of seconds 39 | 40 | Returns: 41 | int: best algo 42 | """ 43 | batch_size = 16 44 | seq_len = 64 45 | din = 12288 46 | dout = 49152 47 | 48 | inner_loop = 3 49 | 50 | input_list = [] 51 | param_list = [] 52 | for i in range(inner_loop): 53 | input_list.append(torch.randn(batch_size, seq_len, din).half().cuda()) 54 | param_list.append(torch.randn(dout, din).half().cuda()) 55 | 56 | start_algo = -1 57 | end_algo = 23 58 | start_algo_t_op = 99 59 | end_algo_t_op = 115 60 | 61 | algo_map = {} 62 | for algo in range(start_algo, end_algo + 1): 63 | algo_map[algo] = 0 64 | for algo in range(start_algo_t_op, end_algo_t_op + 1): 65 | algo_map[algo] = 0 66 | 67 | for i in range(inner_loop): 68 | _ = linear(input_list[i], param_list[i], start_algo) 69 | _ = linear(input_list[i], param_list[i], start_algo) 70 | 71 | for algo in range(start_algo, end_algo + 1): 72 | torch.cuda.synchronize() 73 | start_time = time.time() 74 | _ = linear(input_list[i], param_list[i], algo) 75 | torch.cuda.synchronize() 76 | algo_map[algo] += time.time() - start_time 77 | 78 | for algo in range(start_algo_t_op, end_algo_t_op + 1): 79 | torch.cuda.synchronize() 80 | start_time = time.time() 81 | _ = linear(input_list[i], param_list[i], algo) 82 | torch.cuda.synchronize() 83 | algo_map[algo] += time.time() - start_time 84 | 85 | best_idx = None 86 | best_value = 999 87 | for key, value in algo_map.items(): 88 | if value < best_value: 89 | best_value = value 90 | best_idx = key 91 | 92 | return best_idx 93 | -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/scale_mask_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import importlib 3 | 4 | try: 5 | energonai_scale_mask = importlib.import_module("energonai_scale_mask") 6 | except ImportError: 7 | raise RuntimeError('energonai_scale_mask requires cuda extensions') 8 | 9 | 10 | def scale_mask_softmax(batch_size, batch_seq_len, head_num, src, seq_len_list): 11 | src = src.contiguous() 12 | dst = energonai_scale_mask.scale_mask_softmax_wrapper(batch_size, batch_seq_len, head_num, src, seq_len_list) 13 | return dst -------------------------------------------------------------------------------- /energonai/kernel/cuda_native/transpose_pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import importlib 3 | 4 | try: 5 | energonai_transpose_pad = importlib.import_module("energonai_transpose_pad") 6 | except ImportError: 7 | raise RuntimeError('transpose_pad requires cuda extensions') 8 | 9 | # from transpose import transpose_pad_wrapper, transpose_depad_wrapper 10 | 11 | 12 | def transpose_pad(src, batch_size, max_seq_len, seq_len_list, head_num, size_per_head): 13 | src = src.contiguous() 14 | 15 | dst = energonai_transpose_pad.transpose_pad_wrapper(src, batch_size, max_seq_len, seq_len_list, head_num, 16 | size_per_head) 17 | 18 | return dst 19 | 20 | 21 | def transpose_depad(src, batch_size, sum_seq, max_seq_len, seq_len_list, head_num, size_per_head): 22 | src = src.contiguous() 23 | 24 | dst = energonai_transpose_pad.transpose_depad_wrapper(src, batch_size, sum_seq, max_seq_len, seq_len_list, head_num, 25 | size_per_head) 26 | 27 | return dst 28 | 29 | 30 | def depad(src, batch_size, seq_lens): 31 | dst = src[0:1, 0:seq_lens[0], :] 32 | 33 | for i in range(1, batch_size): 34 | tlen = seq_lens[i] 35 | dst = torch.cat([dst, src[i:i + 1, 0:tlen, :]], dim=1) 36 | 37 | return dst 38 | 39 | 40 | # From FasterTransformer 41 | 42 | 43 | def ft_build_padding_offsets(seq_lens, batch_size, max_seq_len, valid_word_num, tmp_mask_offset): 44 | seq_lens = seq_lens.contiguous() 45 | # tmp_mask_offset = tmp_mask_offset.contiguous() 46 | 47 | energonai_transpose_pad.ft_build_padding_offsets_wrapper(seq_lens, batch_size, max_seq_len, valid_word_num, 48 | tmp_mask_offset) 49 | 50 | 51 | def ft_remove_padding(src, tmp_mask_offset, mask_offset, valid_word_num, hidden_dim): 52 | src = src.contiguous() 53 | # tmp_mask_offset = tmp_mask_offset.contiguous() 54 | # mask_offset = mask_offset.contiguous() 55 | 56 | dst = energonai_transpose_pad.ft_remove_padding_wrapper(src, tmp_mask_offset, mask_offset, valid_word_num, hidden_dim) 57 | return dst 58 | 59 | 60 | def ft_rebuild_padding(src, mask_offset, valid_word_num, hidden_dim, batch_size, max_seq_len): 61 | src = src.contiguous() 62 | # mask_offset = mask_offset.contiguous() 63 | 64 | dst = energonai_transpose_pad.ft_rebuild_padding_wrapper(src, mask_offset, valid_word_num, hidden_dim, batch_size, 65 | max_seq_len) 66 | return dst 67 | 68 | 69 | def ft_transpose_rebuild_padding(Q, K, V, q_buf, k_buf, v_buf, batch_size, seq_len, head_num, size_per_head, 70 | valid_word_num, mask_offset): 71 | Q = Q.contiguous() 72 | K = K.contiguous() 73 | V = V.contiguous() 74 | q_buf = q_buf.contiguous() 75 | k_buf = k_buf.contiguous() 76 | v_buf = v_buf.contiguous() 77 | 78 | energonai_transpose_pad.ft_transpose_rebuild_padding_wrapper(Q, K, V, q_buf, k_buf, v_buf, batch_size, seq_len, 79 | head_num, size_per_head, valid_word_num, mask_offset) 80 | 81 | 82 | def ft_transpose_remove_padding(src, valid_word_num, batch_size, seq_len, head_num, size_per_head, mask_offset): 83 | src = src.contiguous() 84 | 85 | dst = energonai_transpose_pad.ft_transpose_remove_padding_wrapper(src, valid_word_num, batch_size, seq_len, head_num, 86 | size_per_head, mask_offset) 87 | return dst 88 | -------------------------------------------------------------------------------- /energonai/legacy_batch_mgr/__init__.py: -------------------------------------------------------------------------------- 1 | from .worker_server import launch_worker 2 | -------------------------------------------------------------------------------- /energonai/legacy_batch_mgr/naive_batch_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | ------------------------------------------ 3 | Class Batch Manager. 4 | a naive version that is used for cases in which padding is not needed. 5 | ------------------------------------------ 6 | """ 7 | import time 8 | import redis 9 | from energonai.context import MEATCONFIG 10 | import threading 11 | from readerwriterlock import rwlock 12 | import logging 13 | from concurrent.futures import ThreadPoolExecutor 14 | 15 | 16 | class single_request: 17 | 18 | def __init__(self, input_, time_stamp: float, input_str: str): 19 | """ 20 | class to store related information for a single request. 21 | :param input_: The output of GPT2Tokenizer.tokenizer, a dict including input_ids and attention_mask 22 | :param time_stamp: The time stamp when we receive the request. We use the time stamp as a index to 23 | identify the request. 24 | :param input_str: The input string of the request. 25 | """ 26 | self.input_ = input_ 27 | self.text = input_str 28 | self.time_ = time_stamp 29 | self.seq_len = input_['input_ids'].shape[1] 30 | 31 | 32 | class Manager: 33 | """ 34 | Base class of batch manager. 35 | """ 36 | 37 | def __init__(self): 38 | pass 39 | 40 | def insert_req(self, time_stamp: float, input_ids, input_str: str): 41 | pass 42 | 43 | 44 | class Naive_Batch_Manager(Manager): 45 | """ 46 | This batch manager is mainly used for maintaining a queue of request to be processed. The requests in the 47 | queue is wrapped into batches and then sent into the inference engine. 48 | """ 49 | 50 | def __init__(self, forward_func, 51 | result_process): 52 | """ 53 | :param forward_func a function of calling a forward propagation, returning a RPC ref. 54 | :param result_process a function to process the output of the model before returning the result. 55 | """ 56 | super().__init__() 57 | self.req_list = [] 58 | self.max_batch_size = MEATCONFIG['max_batch_size'] 59 | self.max_sequence_length = MEATCONFIG['max_sequence_length'] 60 | self.req_list_lock = rwlock.RWLockFair() 61 | self.write_lock = self.req_list_lock.gen_wlock() 62 | self.running_flag = True 63 | self.publisher = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) 64 | self.max_workers = MEATCONFIG['pp_init_size'] + 2 65 | self.pool = ThreadPoolExecutor(max_workers=self.max_workers) 66 | self.working_workers = 0 67 | self.forward_func = forward_func 68 | self.result_process = result_process 69 | self.main_thread = threading.Thread(target=self.processing_batch) 70 | self.main_thread.start() 71 | 72 | def insert_req(self, time_stamp: float, input_ids, input_str: str): 73 | """ 74 | Build a single_request class with the input string and then insert it into the queue. 75 | """ 76 | tmp_req = single_request(input_ids, time_stamp, input_str) 77 | self.write_lock.acquire() 78 | self.req_list.append(tmp_req) 79 | self.write_lock.release() 80 | 81 | def subscribe_result(self, time_stamp): 82 | """ 83 | waiting for the result and send back. 84 | """ 85 | sub = self.publisher.pubsub() 86 | sub.subscribe(str(time_stamp)) 87 | predictions = '' 88 | for message in sub.listen(): 89 | if message is not None and isinstance(message, dict): 90 | predictions = message.get('data') 91 | if not isinstance(predictions, int): 92 | break 93 | return predictions 94 | 95 | def wrap_batch(self): 96 | """ 97 | Simply wrap batches by the order of insertion. 98 | """ 99 | self.write_lock.acquire() 100 | result_batch = self.req_list[0:min(self.max_batch_size, len(self.req_list))] 101 | del self.req_list[0:min(self.max_batch_size, len(self.req_list))] 102 | self.write_lock.release() 103 | return result_batch 104 | 105 | def processing_batch(self): 106 | """ 107 | The background process that continuously calls wrap_batch, puts the batch into the inference engine, 108 | and starts new processes that wait for and publish the inference result. 109 | """ 110 | while self.running_flag: 111 | if (self.working_workers < self.max_workers) and (len(self.req_list) > 0): 112 | target_batch = self.wrap_batch() 113 | pad_len = max([p.seq_len for p in target_batch]) 114 | logging.info("A batch with {} requests and length of {} packed, in-batch length: {}".format( 115 | len(target_batch), pad_len, [p.seq_len for p in target_batch])) 116 | input_text = [i.text for i in target_batch] 117 | self.working_workers = self.working_workers + 1 118 | output_ = self.forward_func(input_list=input_text) 119 | self.pool.submit(self.publish_result, output_, target_batch) 120 | time.sleep(0.001) 121 | 122 | def publish_result(self, output, target_batch): 123 | """ 124 | Background process that waits for the inference result and uses the publisher of Redis to publish it to 125 | the waiting requests. 126 | :param output: the rpc reference of the inference result. 127 | :param target_batch: the input batch 128 | """ 129 | predictions = output.to_here() 130 | for i in range(len(target_batch)): 131 | temp_st = target_batch[i].time_ 132 | chosen_pred = predictions[i] 133 | result = self.result_process(chosen_pred) 134 | self.publisher.publish(str(temp_st), result) 135 | 136 | self.working_workers = self.working_workers - 1 137 | 138 | -------------------------------------------------------------------------------- /energonai/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_factory import gpt2_small, gpt2_large, gpt2_8B, gpt3 2 | from .model_factory import hf_gpt2 3 | from .model_factory import bert_small, bert_large, bert_8B, bert_175B 4 | from .model_factory import opt_125M, opt_6B, opt_30B, opt_66B, opt_175B 5 | -------------------------------------------------------------------------------- /energonai/model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, dtype 3 | 4 | from colossalai.nn.layer.utils import divide 5 | from colossalai.nn import Linear1D_Col, Linear1D_Row 6 | from colossalai.utils import get_current_device 7 | 8 | 9 | class MultiHeadAttention1D(nn.Module): 10 | def __init__(self, 11 | hidden_size: int, 12 | num_heads: int, 13 | bias: bool = True, 14 | dtype: dtype = torch.float16, 15 | max_seq_len: int = 512, 16 | fused_qkv: bool = True, 17 | is_decoder: bool = True, 18 | disable_past_cache=False 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.hidden_size = hidden_size 23 | self.attention_head_size = divide(hidden_size, num_heads) 24 | self.fused_qkv = fused_qkv 25 | self.is_decoder = is_decoder 26 | self.disable_past_cache = disable_past_cache 27 | self.scaling = self.attention_head_size**-0.5 28 | if fused_qkv: 29 | self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype) 30 | else: 31 | self.query_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype) 32 | self.key_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype) 33 | self.value_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype) 34 | 35 | self.softmax = nn.Softmax(dim=-1) 36 | 37 | self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True) 38 | 39 | if is_decoder: 40 | self.causal_mask = torch.tril(torch.ones((max_seq_len, max_seq_len), dtype=torch.uint8, 41 | device=get_current_device())).view(1, 1, max_seq_len, max_seq_len).bool() 42 | self.causal_mask_bias = torch.tensor(-1e4, dtype=dtype, device=get_current_device()) 43 | 44 | self.past_cache = {} 45 | 46 | def _split_heads(self, tensor, num_heads, attn_head_size): 47 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 48 | tensor = tensor.view(new_shape) 49 | return tensor.permute(0, 2, 1, 3) 50 | 51 | def last_word(self, hidden_states): 52 | batch_size = hidden_states.shape[0] 53 | hidden_size = hidden_states.shape[2] 54 | return hidden_states[:, -1, :].view(batch_size, 1, hidden_size) 55 | 56 | def forward(self, 57 | hidden_states, 58 | attention_mask=None, 59 | first_cache=False, 60 | seq_lens=None): 61 | if self.fused_qkv: 62 | if self.disable_past_cache: 63 | kvq = self.query_key_value(hidden_states) 64 | else: 65 | if first_cache: 66 | kvq = self.query_key_value(hidden_states) 67 | self.past_cache['query_key_value'] = kvq 68 | else: 69 | kvq = self.query_key_value(self.last_word(hidden_states)) 70 | self.past_cache['query_key_value'] = torch.cat((self.past_cache['query_key_value'], kvq), 1) 71 | kvq = self.past_cache['query_key_value'] 72 | all_head_size = kvq.shape[-1] // 3 73 | num_attention_heads = divide(all_head_size, self.attention_head_size) 74 | # kvq = self._split_heads(kvq, num_attention_heads, 3 * self.attention_head_size) 75 | k, v, q = [t.contiguous() for t in torch.chunk(kvq, 3, dim=-1)] 76 | else: 77 | if self.disable_past_cache: 78 | q = self.query_(hidden_states) 79 | k = self.key_(hidden_states) 80 | v = self.value_(hidden_states) 81 | else: 82 | if first_cache: 83 | q = self.query_(hidden_states) 84 | k = self.key_(hidden_states) 85 | v = self.value_(hidden_states) 86 | self.past_cache['q'] = q 87 | self.past_cache['k'] = k 88 | self.past_cache['v'] = v 89 | else: 90 | q = self.query_(self.last_word(hidden_states)) 91 | k = self.key_(self.last_word(hidden_states)) 92 | v = self.value_(self.last_word(hidden_states)) 93 | self.past_cache['q'] = torch.cat((self.past_cache['q'], q), 1) 94 | self.past_cache['k'] = torch.cat((self.past_cache['k'], k), 1) 95 | self.past_cache['v'] = torch.cat((self.past_cache['v'], v), 1) 96 | q = self.past_cache['q'] 97 | k = self.past_cache['k'] 98 | v = self.past_cache['v'] 99 | all_head_size = q.shape[-1] 100 | num_attention_heads = divide(all_head_size, self.attention_head_size) 101 | q = self._split_heads(q, num_attention_heads, self.attention_head_size) 102 | k = self._split_heads(k, num_attention_heads, self.attention_head_size) 103 | v = self._split_heads(v, num_attention_heads, self.attention_head_size) 104 | 105 | q *= self.scaling 106 | hidden_states = torch.matmul(q, k.transpose(-1, -2)) 107 | 108 | q_len, k_len = q.size(-2), k.size(-2) 109 | 110 | if self.is_decoder: 111 | hidden_states = torch.where(self.causal_mask[:, :, 0:q_len, 0:k_len], hidden_states, self.causal_mask_bias) 112 | 113 | if attention_mask is not None: 114 | hidden_states = hidden_states + attention_mask 115 | dtype = hidden_states.dtype 116 | hidden_states = torch.softmax(hidden_states, -1, dtype=torch.float).to(dtype) 117 | 118 | hidden_states = torch.matmul(hidden_states, v) 119 | 120 | hidden_states = hidden_states.transpose(1, 2) 121 | 122 | new_context_layer_shape = hidden_states.size()[:-2] + (all_head_size,) 123 | 124 | hidden_states = hidden_states.reshape(new_context_layer_shape) 125 | 126 | if self.disable_past_cache: 127 | hidden_states = self.dense(hidden_states) 128 | else: 129 | if first_cache: 130 | hidden_states = self.dense(hidden_states) 131 | self.past_cache['dense'] = hidden_states 132 | else: 133 | hidden_states = self.dense(self.last_word(hidden_states)) 134 | self.past_cache['dense'] = torch.cat((self.past_cache['dense'], hidden_states), 1) 135 | hidden_states = self.past_cache['dense'] 136 | return hidden_states 137 | -------------------------------------------------------------------------------- /energonai/model/downstream.py: -------------------------------------------------------------------------------- 1 | from torch import dtype, nn 2 | from colossalai.nn import Classifier1D, VocabParallelClassifier1D 3 | 4 | 5 | class LMHead1D(nn.Module): 6 | def __init__(self, 7 | hidden_size: int, 8 | vocab_size: int, 9 | word_embedding_weight: nn.Parameter = None, 10 | bias: bool = False, 11 | dtype: dtype = None, 12 | vocab_parallel: bool = False) -> None: 13 | super().__init__() 14 | self.vocab_parallel = vocab_parallel 15 | if vocab_parallel: 16 | self.dense = VocabParallelClassifier1D(hidden_size, vocab_size, bias=bias, dtype=dtype, gather_output=True) 17 | else: 18 | self.dense = Classifier1D(hidden_size, vocab_size, word_embedding_weight, bias=bias, dtype=dtype) 19 | 20 | @property 21 | def weight(self): 22 | return self.dense.weight 23 | 24 | def forward(self, x): 25 | x = self.dense(x) 26 | return x 27 | -------------------------------------------------------------------------------- /energonai/model/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch import dtype 4 | from colossalai.nn import VocabParallelEmbedding1D 5 | from torch.nn import Embedding 6 | from colossalai.utils import get_current_device 7 | 8 | 9 | class Embedding1D(nn.Module): 10 | def __init__(self, 11 | hidden_size: int, 12 | vocab_size: int, 13 | max_seq_len: int, 14 | num_tokentypes: int = 0, 15 | padding_idx: int = 0, 16 | dtype: dtype = None, 17 | vocab_parallel: bool = False, 18 | ) -> None: 19 | super().__init__() 20 | if vocab_parallel: 21 | self.word_embeddings = VocabParallelEmbedding1D( 22 | vocab_size, hidden_size, padding_idx=padding_idx, dtype=dtype) 23 | else: 24 | self.word_embeddings = Embedding(vocab_size, hidden_size, padding_idx=padding_idx).to( 25 | dtype=dtype, device=get_current_device()) 26 | 27 | self.position_embeddings = Embedding(max_seq_len, hidden_size).to(dtype=dtype, device=get_current_device()) 28 | 29 | if num_tokentypes > 0: 30 | self.tokentype_embeddings = Embedding(num_tokentypes, hidden_size).to( 31 | dtype=dtype, device=get_current_device()) 32 | else: 33 | self.tokentype_embeddings = None 34 | 35 | # self.position_ids = torch.arange(max_seq_len, dtype=torch.long, device=get_current_device()).expand((1, -1)) 36 | 37 | @property 38 | def word_embedding_weight(self): 39 | return self.word_embeddings.weight 40 | 41 | def forward(self, 42 | input_ids, 43 | position_ids=None, 44 | tokentype_ids=None, 45 | past_key_values_length: int = 0): 46 | 47 | seq_length = input_ids.size(1) 48 | if position_ids is None: 49 | position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) 50 | # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 51 | 52 | x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) 53 | 54 | if self.tokentype_embeddings is not None and tokentype_ids is not None: 55 | x = x + self.tokentype_embeddings(tokentype_ids) 56 | 57 | return x 58 | -------------------------------------------------------------------------------- /energonai/model/endecoder.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | from torch import dtype 4 | from torch import nn 5 | from colossalai.nn import LayerNorm1D 6 | 7 | from .mlp import MLP1D 8 | from .attention import MultiHeadAttention1D 9 | 10 | 11 | class Block1D(nn.Module): 12 | def __init__(self, 13 | hidden_size: int, 14 | num_heads: int, 15 | mlp_ratio: float, 16 | activation: Callable = nn.functional.gelu, 17 | layernorm_epsilon: float = 1e-5, 18 | dtype: dtype = torch.float16, 19 | bias: bool = True, 20 | apply_post_layernorm: bool = False, 21 | max_seq_len: int = 512, 22 | fused_qkv: bool = True, 23 | is_decoder: bool = True, 24 | disable_past_cache=False) -> None: 25 | super().__init__() 26 | 27 | self.apply_post_layernorm = apply_post_layernorm 28 | self.norm1 = LayerNorm1D(hidden_size, eps=layernorm_epsilon, dtype=dtype) 29 | 30 | self.attn = MultiHeadAttention1D(hidden_size=hidden_size, 31 | num_heads=num_heads, 32 | bias=bias, 33 | dtype=dtype, 34 | max_seq_len=max_seq_len, 35 | fused_qkv=fused_qkv, 36 | is_decoder=is_decoder, 37 | disable_past_cache=disable_past_cache) 38 | 39 | self.norm2 = LayerNorm1D(hidden_size, eps=layernorm_epsilon, dtype=dtype) 40 | 41 | self.mlp = MLP1D(hidden_size=hidden_size, 42 | mlp_ratio=mlp_ratio, 43 | activation=activation, 44 | dtype=dtype, 45 | bias=bias, 46 | disable_past_cache=disable_past_cache) 47 | 48 | def forward(self, hidden_states, attention_mask=None, first_cache=False, seq_lens=None): 49 | 50 | if not self.apply_post_layernorm: 51 | residual = hidden_states 52 | hidden_states = self.norm1(hidden_states) 53 | 54 | if self.apply_post_layernorm: 55 | residual = hidden_states 56 | hidden_states = residual + self.attn(hidden_states=hidden_states, 57 | attention_mask=attention_mask, 58 | first_cache=first_cache) 59 | 60 | if not self.apply_post_layernorm: 61 | residual = hidden_states 62 | 63 | hidden_states = self.norm2(hidden_states) 64 | 65 | if self.apply_post_layernorm: 66 | residual = hidden_states 67 | hidden_states = residual + self.mlp(hidden_states=hidden_states, 68 | first_cache=first_cache) 69 | 70 | return hidden_states 71 | -------------------------------------------------------------------------------- /energonai/model/mlp.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Callable, Optional 3 | import torch 4 | from torch import dtype, nn 5 | from colossalai.nn import Linear1D_Col, Linear1D_Row 6 | 7 | 8 | class MLP1D(nn.Module): 9 | 10 | def __init__(self, 11 | hidden_size: int, 12 | mlp_ratio: float, 13 | activation: Callable, 14 | dtype: dtype = torch.float16, 15 | bias: bool = True, 16 | disable_past_cache=False): 17 | super().__init__() 18 | self.disable_past_cache = disable_past_cache 19 | intermediate_dim = int(hidden_size * mlp_ratio) 20 | self.dense_1 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) 21 | self.activation = activation 22 | self.dense_2 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias, dtype=dtype, parallel_input=True) 23 | self.past_cache = {} 24 | 25 | def last_word(self, hidden_states): 26 | batch_size = hidden_states.shape[0] 27 | hidden_size = hidden_states.shape[2] 28 | return hidden_states[:, -1, :].view(batch_size, 1, hidden_size) 29 | 30 | def forward(self, hidden_states, first_cache: Optional[bool] = True): 31 | 32 | if self.disable_past_cache: 33 | hidden_states = self.dense_1(hidden_states) 34 | hidden_states = self.activation(hidden_states) 35 | hidden_states = self.dense_2(hidden_states) 36 | else: 37 | if first_cache: 38 | hidden_states = self.dense_1(hidden_states) 39 | self.past_cache['dense_1'] = hidden_states 40 | hidden_states = self.activation(hidden_states) 41 | hidden_states = self.dense_2(hidden_states) 42 | self.past_cache['dense_2'] = hidden_states 43 | else: 44 | hidden_states = self.dense_1(self.last_word(hidden_states)) 45 | self.past_cache['dense_1'] = torch.cat((self.past_cache['dense_1'], hidden_states), 1) 46 | hidden_states = self.activation(self.past_cache['dense_1']) 47 | hidden_states = self.dense_2(self.last_word(hidden_states)) 48 | self.past_cache['dense_2'] = torch.cat((self.past_cache['dense_2'], hidden_states), 1) 49 | hidden_states = self.past_cache['dense_2'] 50 | 51 | return hidden_states 52 | -------------------------------------------------------------------------------- /energonai/pipe.py: -------------------------------------------------------------------------------- 1 | import torch.distributed.rpc as trpc 2 | import time 3 | from queue import Queue, Empty 4 | from typing import Dict 5 | from threading import Lock 6 | from typing import Any 7 | from .utils import use_lock 8 | 9 | 10 | def rpc_queue_can_put(q: trpc.RRef) -> bool: 11 | q = q.local_value() 12 | return not q.full() 13 | 14 | 15 | def rpc_queue_put(q: trpc.RRef, data: Any) -> None: 16 | q = q.local_value() 17 | q.put(data) 18 | 19 | 20 | class Pipe: 21 | _queues: Dict[str, Queue] = {} 22 | _lock = Lock() 23 | 24 | def __init__(self, name: str, src: str, dest: str, max_size: int = 0) -> None: 25 | self.rpc_info = trpc.get_worker_info() 26 | self.name = name 27 | self.src = src 28 | self.dest = dest 29 | self.remote_queue: trpc.RRef = None 30 | self.local_queue: Queue[Any] = None 31 | with use_lock(self._lock): 32 | if src == self.rpc_info.name: 33 | assert name not in self._queues, f'pipe {name} already exists on {self.rpc_info.name}' 34 | self.remote_queue = self.get_remote_queue(max_size) 35 | self._queues[name] = self.remote_queue 36 | 37 | @classmethod 38 | def rpc_create_local_queue(cls, name: str, max_size: int) -> Queue: 39 | with use_lock(cls._lock): 40 | assert name not in cls._queues, f'pipe {name} already exists' 41 | cls._queues[name] = Queue(max_size) 42 | return cls._queues[name] 43 | 44 | def get_remote_queue(self, max_size: int) -> trpc.RRef: 45 | return trpc.remote(self.dest, self.rpc_create_local_queue, args=(self.name, max_size)) 46 | 47 | def prepare_local_queue(self) -> None: 48 | if self.local_queue is None: 49 | with use_lock(self._lock): 50 | if self.name in self._queues: 51 | self.local_queue = self._queues[self.name] 52 | 53 | def recv(self) -> Any: 54 | assert self.dest == self.rpc_info.name 55 | while True: 56 | self.prepare_local_queue() 57 | if self.local_queue is not None: 58 | return self.local_queue.get() 59 | time.sleep(0.01) 60 | 61 | def recv_nowait(self) -> Any: 62 | assert self.dest == self.rpc_info.name 63 | self.prepare_local_queue() 64 | if self.local_queue is not None: 65 | try: 66 | return self.local_queue.get_nowait() 67 | except Empty: 68 | raise RuntimeError('pipe is empty') 69 | raise RuntimeError('local queue is not created') 70 | 71 | def send(self, data: Any) -> None: 72 | assert self.src == self.rpc_info.name 73 | while not trpc.rpc_sync(self.dest, rpc_queue_can_put, args=(self.remote_queue, )): 74 | time.sleep(0.1) 75 | trpc.rpc_sync(self.dest, rpc_queue_put, args=(self.remote_queue, data)) 76 | -------------------------------------------------------------------------------- /energonai/pipelinable/__init__.py: -------------------------------------------------------------------------------- 1 | from .split_method import split_transformer_into_partitions 2 | -------------------------------------------------------------------------------- /energonai/pipelinable/energon_tracer.py: -------------------------------------------------------------------------------- 1 | import torch.fx 2 | from energonai.context import MEATCONFIG 3 | 4 | class EnergonTracer(torch.fx.Tracer): 5 | def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: 6 | leaves = MEATCONFIG["LeafSet"] # set([BertTransformerLayer]) 7 | return type(m) in leaves -------------------------------------------------------------------------------- /energonai/pipelinable/split_method.py: -------------------------------------------------------------------------------- 1 | from torch.fx.passes.split_module import split_module 2 | from .split_policy import module_equal_partition, naive_equal_partition, transformer_partition 3 | from .energon_tracer import EnergonTracer 4 | import torch.fx 5 | 6 | def filter_graph(traced: torch.fx.GraphModule, filter_type: str): 7 | len = 0 8 | for node in traced.graph.nodes: 9 | if node.op == filter_type: 10 | len = len + 1 11 | return len 12 | 13 | 14 | def split_transformer_into_partitions(model_class): 15 | model = model_class() 16 | graph = EnergonTracer().trace(model) 17 | traced = torch.fx.GraphModule(model, graph) 18 | depth = filter_graph(traced, "call_module") - 1 19 | submodules = split_module(traced, model, transformer_partition(depth)) 20 | del model 21 | 22 | return submodules -------------------------------------------------------------------------------- /energonai/pipelinable/split_policy.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.fx.node import Node 3 | from energonai.context import MEATCONFIG 4 | 5 | 6 | partition_counter_0 = 0 7 | 8 | # partition_nums: nums of each submodule 9 | def _naive_equal_partition(node: Node, partition_nums): 10 | global partition_counter_0 11 | partition = partition_counter_0 // partition_nums 12 | partition_counter_0 = partition_counter_0 + 1 13 | return partition 14 | 15 | def naive_equal_partition(partition_nums): 16 | mod_partition = functools.partial(_naive_equal_partition, partition_nums = partition_nums) 17 | return mod_partition 18 | 19 | partition_counter_1 = 0 20 | 21 | # partition_nums: nums of each submodule 22 | def _module_equal_partition(node: Node, partition_nums): 23 | global partition_counter_1 24 | partition = partition_counter_1 // partition_nums 25 | if node.op == 'call_module': 26 | partition_counter_1 = partition_counter_1 + 1 27 | return partition 28 | 29 | def module_equal_partition(partition_nums): 30 | mod_partition = functools.partial(_module_equal_partition, partition_nums = partition_nums) 31 | return mod_partition 32 | 33 | 34 | 35 | from colossalai.core import global_context as gpc 36 | from colossalai.context import ParallelMode 37 | partition_counter_2 = -1 # for embedding layer 38 | # partition_nums: nums of each submodule 39 | def _transformer_partition(node: Node, depth): 40 | global partition_counter_2 41 | assert gpc.is_initialized(ParallelMode.PIPELINE), "Pipeline communication group should be initialized!" 42 | 43 | pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) 44 | partition_nums = depth // pipeline_size 45 | partition = abs(partition_counter_2) // partition_nums 46 | if node.op == 'call_module': 47 | partition_counter_2 = partition_counter_2 + 1 48 | return partition 49 | 50 | def transformer_partition(depth): 51 | mod_partition = functools.partial(_transformer_partition, depth = depth) 52 | return mod_partition -------------------------------------------------------------------------------- /energonai/task.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Hashable, Tuple, Any 3 | 4 | 5 | @dataclass 6 | class TaskEntry: 7 | uids: Tuple[Hashable, ...] 8 | batch: Any 9 | -------------------------------------------------------------------------------- /energonai/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import BoringModel, get_correct_output 2 | -------------------------------------------------------------------------------- /energonai/testing/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from colossalai.nn import Linear1D_Col, Linear1D_Row 4 | from colossalai.core import global_context as gpc 5 | from colossalai.context import ParallelMode 6 | from colossalai.utils import is_using_pp 7 | 8 | 9 | class BoringModel(nn.Module): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | if is_using_pp(): 13 | pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) 14 | gather_output = pp_rank == gpc.get_world_size(ParallelMode.PIPELINE) - 1 15 | else: 16 | pp_rank = 0 17 | gather_output = True 18 | if pp_rank % 2 == 0: 19 | self.dense = Linear1D_Col(4, 4, gather_output=gather_output) 20 | else: 21 | self.dense = Linear1D_Row(4, 4) 22 | self._init_weights() 23 | 24 | def _init_weights(self): 25 | with torch.no_grad(): 26 | self.dense.weight.fill_(1.0) 27 | self.dense.bias.fill_(1.0) 28 | 29 | def forward(self, x): 30 | return self.dense(x) 31 | 32 | 33 | def get_correct_output(x: torch.Tensor, pp_world_size: int) -> torch.Tensor: 34 | def step(t): 35 | return t * 4 + 1 36 | for _ in range(pp_world_size): 37 | x = step(x) 38 | return x 39 | -------------------------------------------------------------------------------- /energonai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .files import ensure_directory_exists 2 | from .timer import get_timers 3 | from .common import build_device_maps, use_lock, run_once, Terminator 4 | -------------------------------------------------------------------------------- /energonai/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from colossalai.utils import is_using_pp 7 | from colossalai.context import ParallelMode 8 | from colossalai.core import global_context as gpc 9 | from typing import Optional, Callable 10 | from colossalai.utils.checkpointing import partition_pipeline_parallel_state_dict, broadcast_model 11 | 12 | 13 | __all__ = [ 14 | "load_checkpoint", "load_state_dict" 15 | ] 16 | 17 | import os 18 | from multiprocessing import Pool 19 | from time import time 20 | 21 | 22 | def load_state_dict(path: str): 23 | if os.path.isfile(path): 24 | return torch.load(path) 25 | assert os.path.isdir(path) 26 | state_dict = {} 27 | files = [] 28 | for filename in os.listdir(path): 29 | filepath = os.path.join(path, filename) 30 | if os.path.isfile(filepath): 31 | files.append(filepath) 32 | procs = int(os.environ.get('LOAD_N_PROC', '1')) 33 | procs = min(procs, len(files)) 34 | print(f'load {len(files)} files using {procs} procs') 35 | if procs > 1: 36 | with Pool(procs) as pool: 37 | state_dicts = pool.map(torch.load, files) 38 | for sd in state_dicts: 39 | state_dict.update(sd) 40 | else: 41 | for filepath in files: 42 | sd = torch.load(filepath) 43 | state_dict.update(sd) 44 | return state_dict 45 | 46 | 47 | def remove_prefix(state_dict, prefix): 48 | if prefix[-1] != '.': 49 | prefix += '.' 50 | res_dict = OrderedDict() 51 | for k_ in state_dict.keys(): 52 | res_dict[k_.replace(prefix, '')] = state_dict[k_] 53 | return res_dict 54 | 55 | 56 | def load_checkpoint(file, 57 | model: torch.nn.Module, 58 | strict: bool = True, 59 | preprocess_fn: Optional[Callable[[dict], dict]] = None, 60 | **kwargs): 61 | """Loads training states from a checkpoint file. 62 | 63 | Args: 64 | file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike 65 | object containing a file name. 66 | model (:class:`torch.nn.Module`): Model to load saved weights and buffers. 67 | optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. 68 | lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional): 69 | lr_scheduler to recuperate, defaults to None. 70 | strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` 71 | of the checkpoint match the names of parameters and buffers in model, defaults to True. 72 | 73 | Returns: 74 | int: The saved epoch number. 75 | 76 | Raises: 77 | RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated 78 | """ 79 | start = time() 80 | if gpc.get_local_rank(ParallelMode.MODEL) == 0: 81 | model_state = load_state_dict(file) 82 | if preprocess_fn: 83 | model_state = preprocess_fn(model_state) 84 | else: 85 | model_state = dict() 86 | dist.barrier() 87 | print(f'Load file time: {time()-start:.3f} s') 88 | # pipeline 89 | if is_using_pp(): 90 | model_state = partition_pipeline_parallel_state_dict(model, model_state, **kwargs) 91 | if "prefix" in kwargs.keys(): 92 | if kwargs['prefix'] != '': 93 | model_state = remove_prefix(model_state, kwargs["prefix"]) 94 | 95 | model.load_state_dict(model_state, strict=strict) 96 | broadcast_model(model) 97 | 98 | return -1 99 | -------------------------------------------------------------------------------- /energonai/utils/checkpointing_hf_gpt2.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | import torch 4 | 5 | 6 | __all__ = [ 7 | 'processing_HF_GPT' 8 | ] 9 | 10 | name_map = { 11 | 'ln_2': 'norm2', 12 | 'c_attn': 'query_key_value', 13 | 'attn.c_proj': 'attn.dense', 14 | 'ln_1': 'norm1', 15 | 'c_fc': 'dense_1', 16 | 'mlp.c_proj': 'mlp.dense_2' 17 | } 18 | 19 | 20 | def judge_t(key_): 21 | key_words = ['attn.query_key_value.weight', 'mlp.dense_1.weight', 'mlp.dense_2.weight', 'attn.dense.weight'] 22 | for word_ in key_words: 23 | if word_ in key_: 24 | return True 25 | return False 26 | 27 | 28 | def processing_HF_GPT(state_dict: OrderedDict): 29 | if 'model' in state_dict: 30 | state_dict = state_dict.pop('model') 31 | new_dict = OrderedDict() 32 | for k_ in state_dict.keys(): 33 | new_k = module_name_mapping(k_) 34 | if new_k == "": 35 | continue 36 | 37 | new_v = state_dict[k_] 38 | if judge_t(new_k): 39 | new_v = torch.transpose(new_v, 0, 1) 40 | if "attn.query_key_value.weight" in new_k: 41 | num_ = re.search(r"blocks\.\d+?\.", new_k) 42 | if num_: 43 | prefix = num_.group() 44 | else: 45 | prefix = '' 46 | # print("prefix: {}".format(prefix)) 47 | q_, k_, v_ = torch.chunk(new_v, 3, 0) 48 | # new_dict[prefix + "attn.query_.weight"] = torch.transpose(q_, 0, 1) 49 | # new_dict[prefix + "attn.key_.weight"] = torch.transpose(k_, 0, 1) 50 | # new_dict[prefix + "attn.value_.weight"] = torch.transpose(v_, 0, 1) 51 | new_dict[prefix + "attn.query_.weight"] = q_ 52 | new_dict[prefix + "attn.key_.weight"] = k_ 53 | new_dict[prefix + "attn.value_.weight"] = v_ 54 | elif "attn.query_key_value.bias" in new_k: 55 | num_ = re.search(r"blocks\.\d+?\.", new_k) 56 | if num_: 57 | prefix = num_.group() 58 | else: 59 | prefix = '' 60 | # print("prefix: {}".format(prefix)) 61 | q_, k_, v_ = torch.chunk(new_v, 3, 0) 62 | new_dict[prefix + "attn.query_.bias"] = q_ 63 | new_dict[prefix + "attn.key_.bias"] = k_ 64 | new_dict[prefix + "attn.value_.bias"] = v_ 65 | else: 66 | new_dict[new_k] = new_v 67 | new_dict['head.dense.weight'] = new_dict['embed.word_embeddings.weight'].clone() 68 | # print("="*100) 69 | # print(new_dict.keys()) 70 | return {"model": new_dict, "epoch": 0} 71 | 72 | 73 | def id_map(matched): 74 | value = matched.group('value') 75 | return "blocks.{}.".format(value) 76 | 77 | 78 | def module_name_mapping(ori_name: str): 79 | if ori_name == 'wte.weight': 80 | return "embed.word_embeddings.weight" 81 | elif ori_name == 'wpe.weight': 82 | return "embed.position_embeddings.weight" 83 | elif "ln_f" in ori_name: 84 | return ori_name.replace('ln_f', 'norm') 85 | elif ".attn.bias" in ori_name: 86 | return "" 87 | else: 88 | res = re.sub(r"h\.(?P\d+)?\.", id_map, ori_name) 89 | for k_ in name_map.keys(): 90 | res = res.replace(k_, name_map[k_]) 91 | return res 92 | -------------------------------------------------------------------------------- /energonai/utils/checkpointing_opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from collections import OrderedDict 4 | from typing import Dict 5 | 6 | import torch 7 | from colossalai.context import ParallelMode 8 | from colossalai.core import global_context as gpc 9 | 10 | __all__ = [ 11 | 'processing_OPT' 12 | ] 13 | 14 | name_map = { 15 | 'embed_tokens': 'embed.word_embeddings', 16 | 'embed_positions': 'embed.position_embeddings', 17 | # 'layers': 'blocks', 18 | 'self_attn.q_proj': 'attn.query_', 19 | 'self_attn.k_proj': 'attn.key_', 20 | 'self_attn.v_proj': 'attn.value_', 21 | 'self_attn.out_proj': 'attn.dense', 22 | 'self_attn_layer_norm': 'norm1.module', 23 | 'final_layer_norm': 'norm2.module', 24 | 'fc1': 'mlp.dense_1', 25 | 'fc2': 'mlp.dense_2' 26 | } 27 | 28 | 29 | def judge_t(key_): 30 | key_words = ['attn.query_key_value.weight', 'mlp.dense_1.weight', 'mlp.dense_2.weight', 'attn.dense.weight'] 31 | for word_ in key_words: 32 | if word_ in key_: 33 | return True 34 | return False 35 | 36 | 37 | def module_name_mapping(ori_name: str): 38 | # print(ori_name) 39 | if ori_name == 'decoder.embed_tokens.weight': 40 | return "embed.word_embeddings.weight" 41 | elif ori_name == 'decoder.embed_positions.weight': 42 | return "embed.position_embeddings.weight" 43 | elif "decoder.layer_norm" in ori_name: 44 | return ori_name.replace('decoder.layer_norm', 'norm.module') 45 | elif "decoder.final_layer_norm" in ori_name: # hugging face style 46 | return ori_name.replace('decoder.final_layer_norm', 'norm.module') 47 | # elif ".attn.bias" in ori_name: 48 | # return "" 49 | else: 50 | res = re.sub(r"decoder.layers\.(?P\d+)?\.", id_map, ori_name) 51 | for k_ in name_map.keys(): 52 | res = res.replace(k_, name_map[k_]) 53 | return res 54 | 55 | 56 | def processing_OPT(state_dict: OrderedDict): 57 | if 'model' in state_dict: 58 | state_dict = state_dict.pop('model') 59 | new_dict = OrderedDict() 60 | for k_ in state_dict.keys(): 61 | new_k = module_name_mapping(k_) 62 | if new_k == "": 63 | continue 64 | new_v = state_dict[k_] 65 | new_dict[new_k] = new_v 66 | # if judge_t(new_k): 67 | # new_v = torch.transpose(new_v, 0, 1) 68 | # if "attn.query_key_value.weight" in new_k: 69 | # num_ = re.search(r"blocks\.\d+?\.", new_k) 70 | # if num_: 71 | # prefix = num_.group() 72 | # else: 73 | # prefix = '' 74 | # # print("prefix: {}".format(prefix)) 75 | # q_, k_, v_ = torch.chunk(new_v, 3, 0) 76 | # # new_dict[prefix + "attn.query_.weight"] = torch.transpose(q_, 0, 1) 77 | # # new_dict[prefix + "attn.key_.weight"] = torch.transpose(k_, 0, 1) 78 | # # new_dict[prefix + "attn.value_.weight"] = torch.transpose(v_, 0, 1) 79 | # new_dict[prefix + "attn.query_.weight"] = q_ 80 | # new_dict[prefix + "attn.key_.weight"] = k_ 81 | # new_dict[prefix + "attn.value_.weight"] = v_ 82 | # elif "attn.query_key_value.bias" in new_k: 83 | # num_ = re.search(r"blocks\.\d+?\.", new_k) 84 | # if num_: 85 | # prefix = num_.group() 86 | # else: 87 | # prefix = '' 88 | # # print("prefix: {}".format(prefix)) 89 | # q_, k_, v_ = torch.chunk(new_v, 3, 0) 90 | # new_dict[prefix + "attn.query_.bias"] = q_ 91 | # new_dict[prefix + "attn.key_.bias"] = k_ 92 | # new_dict[prefix + "attn.value_.bias"] = v_ 93 | # else: 94 | # new_dict[new_k] = new_v 95 | # print(new_dict.keys()) 96 | if 'head.dense.weight' not in new_dict: 97 | new_dict['head.dense.weight'] = new_dict['embed.word_embeddings.weight'].clone() 98 | 99 | if 'decoder.version' in new_dict: 100 | del new_dict['decoder.version'] 101 | # print("="*100) 102 | # print(new_dict.keys()) 103 | # print("---------------------------") 104 | return new_dict # {"model": new_dict, "epoch": 0} 105 | 106 | 107 | def id_map(matched): 108 | value = matched.group('value') 109 | return "blocks.{}.".format(value) 110 | 111 | 112 | def preprocess_175b(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 113 | key_map = { 114 | 'decoder.embed_tokens.weight': 'embed.word_embeddings.weight', 115 | 'decoder.embed_positions.weight': 'embed.position_embeddings.weight', 116 | 'decoder.layer_norm': 'norm', 117 | 'decoder.layers': 'blocks', 118 | 'self_attn.qkv_proj': 'attn.query_key_value', 119 | 'self_attn.out_proj': 'attn.dense', 120 | 'self_attn_layer_norm': 'norm1', 121 | 'final_layer_norm': 'norm2', 122 | 'fc1': 'mlp.dense_1', 123 | 'fc2': 'mlp.dense_2' 124 | } 125 | output_sd = {} 126 | for k, v in state_dict.items(): 127 | new_key = k 128 | for old, new in key_map.items(): 129 | new_key = new_key.replace(old, new) 130 | output_sd[new_key] = v 131 | output_sd['head.dense.weight'] = output_sd['embed.word_embeddings.weight'].clone() 132 | return output_sd 133 | 134 | 135 | def load_175b(checkpoint_dir: str, model: torch.nn.Module) -> None: 136 | tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) 137 | checkpoint_path = os.path.join(checkpoint_dir, f'reshard-model_part-{tp_rank}.pt') 138 | print(f'Rank{gpc.get_global_rank()} load {checkpoint_path}') 139 | state_dict = torch.load(checkpoint_path) 140 | state_dict = preprocess_175b(state_dict) 141 | for n, p in model.named_parameters(): 142 | with torch.no_grad(): 143 | p.copy_(state_dict[n]) 144 | -------------------------------------------------------------------------------- /energonai/utils/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | import signal 4 | from typing import Optional, Dict, Union, Callable, Any 5 | from threading import Lock 6 | from contextlib import contextmanager 7 | 8 | DeviceType = Union[int, str, torch.device] 9 | 10 | 11 | def build_device_maps(world_size: int, n_proc_per_node: int, rank: Optional[int] = None) -> Dict[str, Dict[DeviceType, DeviceType]]: 12 | is_master = rank is None 13 | device_maps: Dict[str, Dict[DeviceType, DeviceType]] = {} 14 | if is_master: 15 | for i in range(world_size): 16 | worker_local_rank = i % n_proc_per_node 17 | device_maps[f'worker{i}'] = {'cpu': worker_local_rank} 18 | else: 19 | local_rank = rank % n_proc_per_node 20 | for i in range(world_size): 21 | if i != rank: 22 | worker_local_rank = i % n_proc_per_node 23 | device_maps[f'worker{i}'] = {local_rank: worker_local_rank} 24 | device_maps['master'] = {local_rank: 'cpu'} 25 | return device_maps 26 | 27 | 28 | @contextmanager 29 | def use_lock(lock: Lock): 30 | try: 31 | lock.acquire() 32 | yield 33 | finally: 34 | lock.release() 35 | 36 | 37 | def run_once(func: Callable[[Any], Any]) -> Callable[[Any], Any]: 38 | called: bool = False 39 | 40 | @functools.wraps(func) 41 | def wrapper(*args, **kwargs): 42 | nonlocal called 43 | if not called: 44 | func(*args, **kwargs) 45 | called = True 46 | return wrapper 47 | 48 | 49 | class Terminator: 50 | lock = Lock() 51 | called: bool = False 52 | 53 | @classmethod 54 | def shield(cls): 55 | with use_lock(cls.lock): 56 | cls.called = True 57 | 58 | @classmethod 59 | def terminate(cls): 60 | with use_lock(cls.lock): 61 | if not cls.called: 62 | cls.called = True 63 | signal.raise_signal(signal.SIGINT) 64 | -------------------------------------------------------------------------------- /energonai/utils/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ensure_directory_exists(path: str): 5 | # ensure the directory exists 6 | # dir = os.path.dirname(path) 7 | if not os.path.exists(path): 8 | os.makedirs(path) -------------------------------------------------------------------------------- /energonai/utils/timer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch 6 | 7 | _GLOBAL_TIMERS = None 8 | 9 | 10 | class _Timer: 11 | """Timer.""" 12 | 13 | def __init__(self, name, ignore_first): 14 | self.name_ = name 15 | self.elapsed_ = 0.0 16 | self.started_ = False 17 | self.start_time = time.time() 18 | if ignore_first: 19 | self.times = 2 20 | else: 21 | self.times = 0 22 | 23 | def start(self): 24 | """Start the timer.""" 25 | if(self.times != 0): 26 | self.times = self.times - 1 27 | else: 28 | assert not self.started_, 'timer has already been started' 29 | torch.cuda.synchronize() 30 | self.start_time = time.time() 31 | self.started_ = True 32 | 33 | def stop(self): 34 | """Stop the timer.""" 35 | if(self.times != 0): 36 | self.times = self.times - 1 37 | else: 38 | assert self.started_, 'timer is not started' 39 | torch.cuda.synchronize() 40 | self.elapsed_ += (time.time() - self.start_time) 41 | self.started_ = False 42 | 43 | def reset(self): 44 | """Reset timer.""" 45 | self.elapsed_ = 0.0 46 | self.started_ = False 47 | 48 | def elapsed(self, reset=True): 49 | """Calculate the elapsed time.""" 50 | started_ = self.started_ 51 | # If the timing in progress, end it first. 52 | if self.started_: 53 | self.stop() 54 | # Get the elapsed time. 55 | elapsed_ = self.elapsed_ 56 | # Reset the elapsed time 57 | if reset: 58 | self.reset() 59 | # If timing was in progress, set it back. 60 | if started_: 61 | self.start() 62 | return elapsed_ 63 | 64 | 65 | class Timers: 66 | """Group of timers.""" 67 | 68 | def __init__(self, ignore_first): 69 | self.timers = {} 70 | self.ignore_first = ignore_first 71 | 72 | def __call__(self, name): 73 | if name not in self.timers: 74 | self.timers[name] = _Timer(name, self.ignore_first) 75 | return self.timers[name] 76 | 77 | def write(self, names, writer, iteration, normalizer=1.0, reset=False): 78 | """Write timers to a tensorboard writer""" 79 | # currently when using add_scalars, 80 | # torch.utils.add_scalars makes each timer its own run, which 81 | # polutes the runs list, so we just add each as a scalar 82 | assert normalizer > 0.0 83 | for name in names: 84 | value = self.timers[name].elapsed(reset=reset) / normalizer 85 | writer.add_scalar(name + '-time', value, iteration) 86 | 87 | def log(self, names, normalizer=1.0, reset=True): 88 | """Log a group of timers.""" 89 | assert normalizer > 0.0 90 | string0 = '' 91 | string1 = '' 92 | for name in names: 93 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer 94 | string0 += ' : {}'.format(name) 95 | string1 += ' : {:.2f}'.format(elapsed_time) 96 | 97 | if torch.distributed.is_initialized(): 98 | if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): 99 | print(f'{string0} \n {string1}', flush=True) 100 | else: 101 | print(f'{string0} \n {string1}', flush=True) 102 | 103 | 104 | def _ensure_var_is_not_initialized(var, name): 105 | """Make sure the input variable is not None.""" 106 | assert var is None, '{} is already initialized.'.format(name) 107 | 108 | 109 | def _set_timers(ignore_first): 110 | """Initialize timers.""" 111 | global _GLOBAL_TIMERS 112 | _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') 113 | _GLOBAL_TIMERS = Timers(ignore_first) 114 | 115 | 116 | def _ensure_var_is_initialized(var, name): 117 | """Make sure the input variable is not None.""" 118 | assert var is not None, '{} is not initialized.'.format(name) 119 | 120 | 121 | def get_timers(ignore_first = False): 122 | """Return timers.""" 123 | if _GLOBAL_TIMERS is None: 124 | _set_timers(ignore_first) 125 | _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') 126 | return _GLOBAL_TIMERS 127 | -------------------------------------------------------------------------------- /energonai/worker.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import contextmanager 3 | from typing import Any, Callable 4 | 5 | import colossalai 6 | import torch 7 | import torch.distributed.rpc as trpc 8 | import torch.multiprocessing as mp 9 | import torch.nn as nn 10 | from colossalai.context import ParallelMode 11 | from colossalai.core import global_context as gpc 12 | from colossalai.logging import disable_existing_loggers, get_dist_logger 13 | 14 | from .pipe import Pipe 15 | from .task import TaskEntry 16 | from .utils import Terminator, build_device_maps 17 | 18 | 19 | class Worker: 20 | def __init__(self, rank: int, tp_world_size: int, pp_world_size: int, master_host: str, master_port: int, rpc_port: int, n_proc_per_node: int, 21 | model_fn: Callable[[Any], nn.Module], pipe_size: int = 1, rpc_disable_shm: bool = True, **model_kwargs: Any) -> None: 22 | self.global_rank = rank 23 | self.world_size = tp_world_size * pp_world_size 24 | self.tp_world_size = tp_world_size 25 | self.pp_world_size = pp_world_size 26 | disable_existing_loggers(exclude=['energonai', 'colossalai']) 27 | colossalai.launch({'parallel': {'tensor': {'mode': '1d', 'size': tp_world_size}, 28 | 'pipeline': pp_world_size}}, rank, self.world_size, master_host, master_port) 29 | self.tp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) 30 | self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if gpc.is_initialized(ParallelMode.PIPELINE) else 0 31 | 32 | self.model: nn.Module = model_fn(**model_kwargs).cuda() 33 | 34 | self.rpc_name = f'worker{self.global_rank}' 35 | rpc_options = {} 36 | if rpc_disable_shm: 37 | # SHM may lead to timeout error. Disabling SHM and only enabling uv transport can solve this problem. 38 | # See https://discuss.pytorch.org/t/rpc-behavior-difference-between-pytorch-1-7-0-vs-1-9-0/124772/5 39 | # This is a workaround and may be solved in the future. 40 | rpc_options['_transports'] = ['uv'] 41 | trpc.init_rpc(self.rpc_name, rank=self.global_rank + 1, world_size=self.world_size + 1, 42 | rpc_backend_options=trpc.TensorPipeRpcBackendOptions( 43 | init_method=f'tcp://{master_host}:{rpc_port}', 44 | device_maps=build_device_maps(self.world_size, n_proc_per_node, rank=self.global_rank), 45 | **rpc_options 46 | )) 47 | self.to_master_pipe = Pipe(f'{self.global_rank}_to_m', self.rpc_name, 'master') 48 | self.to_master_pipe.send(self.pp_rank) 49 | 50 | if self.pp_rank == 0: 51 | self.input_pipe = Pipe(f'm_to_{self.global_rank}', 'master', self.rpc_name, max_size=pipe_size) 52 | else: 53 | pp_prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) 54 | self.input_pipe = Pipe(f'{pp_prev_rank}_to_{self.global_rank}', 55 | f'worker{pp_prev_rank}', self.rpc_name, max_size=pipe_size) 56 | if self.pp_rank == self.pp_world_size - 1: 57 | self.output_pipe = self.to_master_pipe 58 | else: 59 | pp_next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) 60 | self.output_pipe = Pipe(f'{self.global_rank}_to_{pp_next_rank}', self.rpc_name, 61 | f'worker{pp_next_rank}', max_size=pipe_size) 62 | 63 | self.logger = get_dist_logger('energonai') 64 | self.logger.info(f'{self.rpc_name} start') 65 | self._start() 66 | 67 | @contextmanager 68 | def _lifespan(self): 69 | try: 70 | yield 71 | finally: 72 | self._shutdown() 73 | 74 | def _start(self) -> None: 75 | with self._lifespan(): 76 | while True: 77 | try: 78 | task_entry: TaskEntry = self.input_pipe.recv_nowait() 79 | with torch.inference_mode(): 80 | outputs = self._forward(task_entry.batch) 81 | self.output_pipe.send(TaskEntry(task_entry.uids, outputs)) 82 | except RuntimeError: 83 | time.sleep(0.01) 84 | 85 | def _shutdown(self) -> None: 86 | Terminator.shield() 87 | trpc.rpc_sync('master', Terminator.terminate) 88 | trpc.shutdown() 89 | 90 | def _forward(self, inputs: Any) -> Any: 91 | if isinstance(inputs, (tuple, list)): 92 | outputs = self.model(*inputs) 93 | elif isinstance(inputs, dict): 94 | outputs = self.model(**inputs) 95 | else: 96 | outputs = self.model(inputs) 97 | return outputs 98 | 99 | 100 | def launch_workers(tp_world_size: int, pp_world_size: int, master_host: str, master_port: int, rpc_port: int, 101 | model_fn: Callable[[Any], nn.Module], n_proc_per_node: int = 1, node_rank: int = 0, pipe_size: int = 1, rpc_disable_shm: bool = True, 102 | **model_kwargs: Any) -> None: 103 | ctx = mp.get_context('spawn') 104 | procs = [] 105 | for i in range(n_proc_per_node): 106 | rank = n_proc_per_node * node_rank + i 107 | p = ctx.Process(target=Worker, args=(rank, tp_world_size, pp_world_size, 108 | master_host, master_port, rpc_port, n_proc_per_node, model_fn, pipe_size, rpc_disable_shm), kwargs=model_kwargs) 109 | procs.append(p) 110 | p.start() 111 | -------------------------------------------------------------------------------- /examples/auto_pipeline/bert_config.py: -------------------------------------------------------------------------------- 1 | from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B 2 | from bert_server import launch_engine 3 | from bert import BertEmbedding1D, BertTransformerLayer1D 4 | 5 | model_class = bert_8B 6 | model_type = "bert" 7 | engine_server = launch_engine 8 | 9 | 10 | # parallel 11 | tp_init_size = 2 12 | pp_init_size = 2 13 | auto_pp = True 14 | LeafSet = set([BertTransformerLayer1D, BertEmbedding1D]) 15 | 16 | 17 | 18 | host = "127.0.0.1" 19 | port = 29400 20 | half = False 21 | server_host = "127.0.0.1" 22 | server_port = 8010 23 | log_level = "info" 24 | backend = "nccl" -------------------------------------------------------------------------------- /examples/auto_pipeline/bert_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import uvicorn 4 | from fastapi import FastAPI 5 | from energonai.engine import InferenceEngine 6 | from energonai.context import MEATCONFIG 7 | 8 | app = FastAPI() # 创建 api 对象 9 | 10 | @app.get("/") # 根路由 11 | def root(): 12 | return {"200"} 13 | 14 | @app.get("/model_with_padding") 15 | def run(): 16 | # for the performance only 17 | seq_len = 512 18 | batch_size = 32 19 | 20 | input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) 21 | attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) 22 | # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly 23 | hidden_states = None 24 | sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) 25 | 26 | output = engine.run(sample) 27 | output = output.to_here() 28 | print(output) 29 | return {"To return the string result."} 30 | 31 | # @app.get("/model_rm_padding") 32 | # def run(): 33 | # # for the performance only 34 | # seq_len = 512 35 | # batch_size = 32 36 | 37 | # input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) 38 | # attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) 39 | # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly 40 | # hidden_states = None 41 | # sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens) 42 | 43 | # output = engine.run(sample) 44 | # output = output.to_here() 45 | # print(output) 46 | # return {"To return the string result."} 47 | 48 | 49 | @app.get("/shutdown") 50 | async def shutdown(): 51 | engine.clear() 52 | server.should_exit = True 53 | server.force_exit = True 54 | await server.shutdown() 55 | 56 | 57 | def launch_engine(model_class, 58 | model_type, 59 | max_batch_size: int = 1, 60 | tp_init_size: int = -1, 61 | pp_init_size: int = -1, 62 | host: str = "localhost", 63 | port: int = 29500, 64 | dtype = torch.float, 65 | checkpoint: str = None, 66 | tokenizer_path: str = None, 67 | server_host = "localhost", 68 | server_port = 8005, 69 | log_level = "info" 70 | ): 71 | 72 | if checkpoint: 73 | model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} 74 | else: 75 | model_config = {'dtype': dtype} 76 | 77 | global engine 78 | engine = InferenceEngine(model_class, 79 | model_config, 80 | model_type, 81 | max_batch_size = max_batch_size, 82 | tp_init_size = tp_init_size, 83 | pp_init_size = pp_init_size, 84 | auto_pp = MEATCONFIG['auto_pp'], 85 | host = host, 86 | port = port, 87 | dtype = dtype) 88 | 89 | global server 90 | config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) 91 | server = uvicorn.Server(config=config) 92 | server.run() -------------------------------------------------------------------------------- /examples/auto_pipeline/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/bert/bert_config.py: -------------------------------------------------------------------------------- 1 | from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B 2 | from bert_server import launch_engine 3 | 4 | model_class = bert_8B 5 | model_type = "bert" 6 | engine_server = launch_engine 7 | tp_init_size = 2 8 | pp_init_size = 2 9 | host = "127.0.0.1" 10 | port = 29400 11 | half = False 12 | server_host = "127.0.0.1" 13 | server_port = 8010 14 | log_level = "info" 15 | backend = "nccl" -------------------------------------------------------------------------------- /examples/bert/bert_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import uvicorn 4 | from fastapi import FastAPI 5 | from fastapi import Response 6 | import torch.distributed.rpc as rpc 7 | from energonai.engine import InferenceEngine 8 | 9 | app = FastAPI() # 创建 api 对象 10 | 11 | @app.get("/") # 根路由 12 | def root(): 13 | return {"200"} 14 | 15 | @app.get("/model_with_padding") 16 | def run(): 17 | # for the performance only 18 | seq_len = 512 19 | batch_size = 32 20 | 21 | input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) 22 | attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) 23 | # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly 24 | hidden_states = None 25 | sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) 26 | 27 | output = engine.run(sample) 28 | output = output.to_here() 29 | print(output) 30 | return {"To return the string result."} 31 | 32 | @app.get("/model_rm_padding") 33 | def run(): 34 | # for the performance only 35 | seq_len = 512 36 | batch_size = 32 37 | 38 | input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) 39 | attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) 40 | seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly 41 | hidden_states = None 42 | sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens) 43 | 44 | output = engine.run(sample) 45 | output = output.to_here() 46 | print(output) 47 | return {"To return the string result."} 48 | 49 | 50 | @app.get("/shutdown") 51 | async def shutdown(): 52 | engine.clear() 53 | server.should_exit = True 54 | server.force_exit = True 55 | await server.shutdown() 56 | 57 | 58 | def launch_engine(model_class, 59 | model_type, 60 | max_batch_size: int = 1, 61 | tp_init_size: int = -1, 62 | pp_init_size: int = -1, 63 | host: str = "localhost", 64 | port: int = 29500, 65 | dtype = torch.float, 66 | checkpoint: str = None, 67 | tokenizer_path: str = None, 68 | server_host = "localhost", 69 | server_port = 8005, 70 | log_level = "info" 71 | ): 72 | 73 | if checkpoint: 74 | model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} 75 | else: 76 | model_config = {'dtype': dtype} 77 | 78 | global engine 79 | engine = InferenceEngine(model_class, 80 | model_config, 81 | model_type, 82 | max_batch_size = max_batch_size, 83 | tp_init_size = tp_init_size, 84 | pp_init_size = pp_init_size, 85 | host = host, 86 | port = port, 87 | dtype = dtype) 88 | 89 | global server 90 | config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) 91 | server = uvicorn.Server(config=config) 92 | server.run() -------------------------------------------------------------------------------- /examples/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/bloom/README.md: -------------------------------------------------------------------------------- 1 | # Energon-AI for Bloom inference 2 | # How to run 3 | To start service 4 | ``` 5 | bash run.sh 6 | ``` 7 | 8 | You can change the args in `run.sh` as follow: 9 | ``` 10 | param list 11 | --name :Name Path (required) 12 | --tp: (int) GPU_NUM, default=1 13 | --http_host: (x.x.x.x) your IP address, default=0.0.0.0 14 | --http_port: (xxxx) your port, default=7070 15 | --dtype:(str) use int8-quant or not ["fp16", "int8"], default="fp16" 16 | --max_batchsize:(int) limitation of batchsize, default=1 17 | --random_init:(bool) random init or not(if you don't have whole model data), default=False 18 | --random_model_size:(str) size of random init model,["560m", "7b1", "175b"],default="560m" 19 | 20 | Once use [--random_init True], the [--random_model_size] option will be used. 21 | Name is also required while using[--random_init True], for getting tokenizer. 22 | ``` 23 | 24 | While the service is running, send `POST` requests to https://[ip]:[port]/generation 25 | 26 | send POST body as json file: 27 | ``` 28 | curl -X POST http://ip:port/generation -H "Content-Type: application/json" -d @test.json 29 | ``` 30 | 31 | test.json looks like follows: 32 | ``` 33 | { 34 | "prompt": "However, there are still immense benefits to learning quantum computing.", 35 | "top_p": 0.90, 36 | "top_k": 40, 37 | "max_new_tokens": 60 38 | } 39 | ``` 40 | 41 | received message: 42 | ``` 43 | { 44 | "text": "However, there are still immense benefits to learning quantum computing. For example, quantum computing can be used to solve problems that are difficult to solve by classical methods. Quantum computing can also be used to solve problems that are difficult to solve by classical methods. Quantum computing can also be used to solve problems that are" 45 | } 46 | ``` 47 | 48 | # Configure 49 | ## Configure batching 50 | add `--max_batch_size ` to the python command in `run.sh` 51 | 52 | The `` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value. 53 | 54 | Bigger MaxBatchSize may speed up concurrent requests in a single batched forwarding process. 55 | 56 | # Testing Result 57 | ## Memory 58 | Int8_model_size = 1/2 FP16_model_size 59 | To inference the 176B Bloom model with 8 GPUs, we reduce the MAX_GPU_MEM_ALLOCATED from `85GB(FP32)`or `42.88GB(FP16)` to `21.68GB` per GPU ! 60 | Also we use no more CPU_mem than fp16 model. 61 | image 62 | ## Time 63 | ### Inference 64 | image 65 | 66 | ### Generate 67 | image 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /examples/bloom/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Deque, Tuple, Hashable, Any 3 | from energonai import BatchManager, SubmitEntry, TaskEntry 4 | 5 | 6 | class BatchManagerForGeneration(BatchManager): 7 | def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: 8 | super().__init__() 9 | self.max_batch_size = max_batch_size 10 | self.pad_token_id = pad_token_id 11 | 12 | def _left_padding(self, batch_inputs): 13 | max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) 14 | outputs = {'input_ids': [], 'attention_mask': []} 15 | for inputs in batch_inputs: 16 | input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] 17 | padding_len = max_len - len(input_ids) 18 | padding = torch.tensor([self.pad_token_id] * padding_len, device=input_ids.device, dtype=torch.int) 19 | input_ids = torch.cat((padding, input_ids), 0) 20 | 21 | padding = torch.tensor([0] * padding_len, device=attention_mask.device, dtype=torch.int) 22 | attention_mask = torch.cat((padding, attention_mask), 0) 23 | outputs['input_ids'].append(input_ids) 24 | outputs['attention_mask'].append(attention_mask) 25 | return outputs, max_len 26 | 27 | @staticmethod 28 | def _make_batch_key(entry: SubmitEntry) -> tuple: 29 | data = entry.data 30 | return () 31 | 32 | def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: 33 | entry = q.popleft() 34 | uids = [entry.uid] 35 | batch = [entry.data] 36 | while len(batch) < self.max_batch_size: 37 | if len(q) == 0: 38 | break 39 | if self._make_batch_key(entry) != self._make_batch_key(q[0]): 40 | break 41 | if q[0].data['max_new_tokens'] > entry.data['max_new_tokens']: 42 | break 43 | e = q.popleft() 44 | batch.append(e.data) 45 | uids.append(e.uid) 46 | inputs, max_len = self._left_padding(batch) 47 | trunc_lens = [] 48 | for data in batch: 49 | trunc_lens.append(max_len + data['max_new_tokens']) 50 | inputs['max_new_tokens'] = max_len + entry.data['max_new_tokens'] 51 | return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} 52 | 53 | def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: 54 | retval = [] 55 | for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens): 56 | retval.append((uid, (output[:trunc_len]).reshape(1, -1))) 57 | return retval 58 | -------------------------------------------------------------------------------- /examples/bloom/benchmark/locustfile.py: -------------------------------------------------------------------------------- 1 | from locust import HttpUser, task 2 | from json import JSONDecodeError 3 | 4 | 5 | class GenerationUser(HttpUser): 6 | @task 7 | def generate(self): 8 | prompt = 'Question: What is the longest river on the earth? Answer:' 9 | for i in range(4, 9): 10 | data = {'max_new_tokens': 2**i, 'prompt': prompt} 11 | with self.client.post('/generation', json=data, catch_response=True) as response: 12 | if response.status_code in (200, 406): 13 | response.success() 14 | else: 15 | response.failure('Response wrong') 16 | -------------------------------------------------------------------------------- /examples/bloom/cache.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from threading import Lock 3 | from contextlib import contextmanager 4 | from typing import List, Any, Hashable, Dict 5 | 6 | 7 | class MissCacheError(Exception): 8 | pass 9 | 10 | 11 | class ListCache: 12 | def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: 13 | """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. 14 | When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. 15 | 16 | Args: 17 | cache_size (int): Max size for LRU cache. 18 | list_size (int): Value list size. 19 | fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to []. 20 | """ 21 | self.cache_size = cache_size 22 | self.list_size = list_size 23 | self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict() 24 | self.fixed_cache: Dict[Hashable, List[Any]] = {} 25 | for key in fixed_keys: 26 | self.fixed_cache[key] = [] 27 | self._lock = Lock() 28 | 29 | def get(self, key: Hashable) -> List[Any]: 30 | with self.lock(): 31 | if key in self.fixed_cache: 32 | l = self.fixed_cache[key] 33 | if len(l) >= self.list_size: 34 | return l 35 | elif key in self.cache: 36 | self.cache.move_to_end(key) 37 | l = self.cache[key] 38 | if len(l) >= self.list_size: 39 | return l 40 | raise MissCacheError() 41 | 42 | def add(self, key: Hashable, value: Any) -> None: 43 | with self.lock(): 44 | if key in self.fixed_cache: 45 | l = self.fixed_cache[key] 46 | if len(l) < self.list_size and value not in l: 47 | l.append(value) 48 | elif key in self.cache: 49 | self.cache.move_to_end(key) 50 | l = self.cache[key] 51 | if len(l) < self.list_size and value not in l: 52 | l.append(value) 53 | else: 54 | if len(self.cache) >= self.cache_size: 55 | self.cache.popitem(last=False) 56 | self.cache[key] = [value] 57 | 58 | @contextmanager 59 | def lock(self): 60 | try: 61 | self._lock.acquire() 62 | yield 63 | finally: 64 | self._lock.release() 65 | -------------------------------------------------------------------------------- /examples/bloom/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/bloom/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { 2 | local n=${1:-"9999"} 3 | echo "GPU Memory Usage:" 4 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ 5 | | tail -n +2 \ 6 | | nl -v 0 \ 7 | | tee /dev/tty \ 8 | | sort -g -k 2 \ 9 | | awk '{print $1}' \ 10 | | head -n $n) 11 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') 12 | echo "Now CUDA_VISIBLE_DEVICES is set to:" 13 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 14 | } 15 | 16 | 17 | export GPU_NUM=2 18 | 19 | export DATASET=/data2/users/lczht/bloom-560m 20 | CUDA_VISIBLE_DEVICES_set_n_least_memory_usage ${GPU_NUM} 21 | 22 | # param list 23 | # --name :Name Path 24 | # --tp: (int) GPU_NUM, default=1 25 | # --http_host: (x.x.x.x) your IP address, default=0.0.0.0 26 | # --http_port: (xxxx) your port, default=7070 27 | # --dtype:(str) use int8-quant or not ["fp16", "int8"], default="fp16" 28 | # --max_batchsize:(int) limitation of batchsize, default=1 29 | # --random_init:(bool) random init or not(if you don't have whole model data), default=False 30 | # --random_model_size:(str) size of random init model,["560m", "7b1", "175b"],default="560m" 31 | 32 | 33 | python server.py --tp ${GPU_NUM} --name ${DATASET} --dtype "int8" --max_batch_size 4 --random_model_size "7b1" --random_init True 34 | 35 | 36 | -------------------------------------------------------------------------------- /examples/bloom/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import json 4 | import random 5 | from typing import Optional 6 | import torch 7 | import uvicorn 8 | import colossalai 9 | from colossalai.utils.model.colo_init_context import ColoInitContext 10 | from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup, ReplicaSpec 11 | 12 | from energonai import QueueFullError, launch_engine 13 | from fastapi import FastAPI, HTTPException, Request 14 | from pydantic import BaseModel, Field 15 | 16 | from batch import BatchManagerForGeneration 17 | from cache import ListCache, MissCacheError 18 | from transformers import AutoTokenizer, BloomForCausalLM 19 | from transformers import BloomConfig 20 | 21 | TP_TARGET = ['mlp', 'self_attention.dense', 'self_attention.query_key_value', 'word_embeddings.weight'] # 'self_attention.attention_dropout', 22 | 23 | class GenerationTaskReq(BaseModel): 24 | max_new_tokens: int = Field(gt=0, le=256, example=64) 25 | prompt: str = Field( 26 | min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') 27 | # top_k: Optional[int] = Field(default=None, gt=0, example=50) 28 | # top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) 29 | greedy: Optional[bool] = False 30 | 31 | 32 | app = FastAPI() 33 | 34 | 35 | @app.post('/generation') 36 | async def generate(data: GenerationTaskReq, request: Request): 37 | logger.info( 38 | f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') 39 | key = (data.prompt, data.max_new_tokens) 40 | try: 41 | if cache is None: 42 | raise MissCacheError() 43 | outputs = cache.get(key) 44 | output_str = random.choice(outputs) 45 | logger.info('Cache hit') 46 | except MissCacheError: 47 | input_tokens = tokenizer.encode_plus(data.prompt, return_tensors="pt", padding=True) 48 | input_tokens['max_new_tokens'] = data.max_new_tokens 49 | try: 50 | uid = id(data) 51 | engine.submit(uid, input_tokens) 52 | outputs = await engine.wait(uid) 53 | outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 54 | if cache is not None: 55 | cache.add(key, outputs) 56 | output_str = outputs 57 | except QueueFullError as e: 58 | raise HTTPException(status_code=406, detail=e.args[0]) 59 | return {'text': output_str} 60 | 61 | 62 | @app.on_event("shutdown") 63 | async def shutdown(*_): 64 | engine.shutdown() 65 | server.should_exit = True 66 | server.force_exit = True 67 | await server.shutdown() 68 | 69 | 70 | def print_args(args: argparse.Namespace): 71 | print('\n==> Args:') 72 | for k, v in args.__dict__.items(): 73 | print(f'{k} = {v}') 74 | 75 | class WrapCallModule(torch.nn.Module): 76 | def __init__(self, model: torch.nn.Module): 77 | super(WrapCallModule, self).__init__() 78 | self.model = model 79 | 80 | def forward(self, **generate_kwargs): 81 | input_ids_batch = generate_kwargs["input_ids"] 82 | attention_mask_batch = generate_kwargs["attention_mask"] 83 | generate_kwargs["input_ids"] = torch.cat(input_ids_batch, 0) 84 | generate_kwargs["attention_mask"] = torch.cat(attention_mask_batch, 0) 85 | return self.model.generate(**generate_kwargs) 86 | 87 | def model_fn(**model_kwargs): 88 | from utils import run 89 | if model_kwargs['tp']!=1: 90 | tp = True 91 | else: 92 | tp = False 93 | if model_kwargs['dtype']=="int8": 94 | use_int8 = True 95 | else: 96 | use_int8 = False 97 | if model_kwargs['random_init']==False: 98 | from_pretrain = True 99 | else: 100 | from_pretrain = False 101 | data_path = model_kwargs['name'] 102 | size = model_kwargs['size'] 103 | model = run(tp=tp, from_pretrain=from_pretrain, data_path=data_path, use_int8=use_int8, size=size) 104 | return WrapCallModule(model) 105 | 106 | 107 | FIXED_CACHE_KEYS = [ 108 | ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), 109 | ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), 110 | ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) 111 | ] 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--name', type=str, help="Name path", required=True) 116 | parser.add_argument('--tp', type=int, default=1) 117 | parser.add_argument('--master_host', default='localhost') 118 | parser.add_argument('--master_port', type=int, default=19991) 119 | parser.add_argument('--rpc_port', type=int, default=19981) 120 | parser.add_argument('--max_batch_size', type=int, default=1) 121 | parser.add_argument('--pipe_size', type=int, default=1) 122 | parser.add_argument('--queue_size', type=int, default=0) 123 | parser.add_argument('--http_host', default='0.0.0.0') 124 | parser.add_argument('--http_port', type=int, default=7070) 125 | parser.add_argument('--cache_size', type=int, default=0) 126 | parser.add_argument('--cache_list_size', type=int, default=1) 127 | parser.add_argument('--dtype', type=str, help="module dtype", default="fp16", choices=["fp16", "int8"]) 128 | parser.add_argument('--random_init', type=bool, help="If have no model params", default=False) 129 | parser.add_argument('--random_model_size', type=str, help="size of random init model", default="560m", choices=["560m", "7b1", "175b"]) 130 | args = parser.parse_args() 131 | print_args(args) 132 | 133 | num_tokens = 100 134 | model_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) 135 | model_name = args.name 136 | model_kwargs['name'] = model_name 137 | model_kwargs['dtype'] = args.dtype 138 | model_kwargs['random_init'] = args.random_init 139 | model_kwargs['tp'] = args.tp 140 | model_kwargs['size'] = args.random_model_size 141 | 142 | logger = logging.getLogger(__name__) 143 | 144 | tokenizer = AutoTokenizer.from_pretrained(model_name) 145 | 146 | if args.cache_size > 0: 147 | cache = ListCache(args.cache_size, args.cache_list_size, 148 | fixed_keys=FIXED_CACHE_KEYS) 149 | else: 150 | cache = None 151 | engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, model_fn, 152 | batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, 153 | pad_token_id=tokenizer.pad_token_id), 154 | pipe_size=args.pipe_size, 155 | queue_size=args.queue_size, 156 | **model_kwargs) 157 | print("engine start") 158 | config = uvicorn.Config(app, host=args.http_host, port=args.http_port) 159 | server = uvicorn.Server(config=config) 160 | server.run() 161 | -------------------------------------------------------------------------------- /examples/gpt/gpt_batch_server.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import uvicorn 4 | from transformers import GPT2Tokenizer 5 | from fastapi import FastAPI 6 | from fastapi import Response, Body 7 | from energonai.engine import InferenceEngine 8 | from energonai.legacy_batch_mgr.dynamic_batch_manager import Dynamic_Batch_Manager 9 | 10 | app = FastAPI() 11 | 12 | 13 | def forward_func(input_list: list = [], seq_len: int = 0, batch_size: int = 0): 14 | """ 15 | Forward run function needed for batch manager 16 | """ 17 | if len(input_list) == 0: 18 | input_list = [("test " * seq_len)[:-1] for _ in range(batch_size)] 19 | input_ = tokenizer(input_list, return_tensors="pt", padding="longest") 20 | output_ = engine.run(input_) 21 | return output_ 22 | 23 | 24 | def result_process(output_): 25 | """ 26 | Decode the output of the model 27 | """ 28 | result = tokenizer.decode(int(output_)) 29 | return result 30 | 31 | 32 | @app.get("/") # 根路由 33 | def root(): 34 | return {"200"} 35 | 36 | 37 | @app.post("/gpt") 38 | def run_new_batch(input_str: str = Body(..., title="input_str", embed=True)): 39 | global batch_manager 40 | input_token = tokenizer(input_str, return_tensors="pt") 41 | time_stamp = time.time() 42 | batch_manager.insert_req(time_stamp, input_token, input_str) 43 | predictions = batch_manager.subscribe_result(time_stamp) 44 | return {predictions} 45 | 46 | 47 | @app.get("/shutdown") 48 | async def shutdown(): 49 | engine.clear() 50 | server.should_exit = True 51 | server.force_exit = True 52 | await server.shutdown() 53 | 54 | 55 | def launch_engine(model_class, 56 | model_type, 57 | max_batch_size: int = 1, 58 | tp_init_size: int = -1, 59 | pp_init_size: int = -1, 60 | host: str = "localhost", 61 | port: int = 29500, 62 | dtype=torch.float, 63 | checkpoint: str = None, 64 | tokenizer_path: str = None, 65 | server_host="localhost", 66 | server_port=8005, 67 | log_level="info", 68 | rm_padding=False 69 | ): 70 | """Initialize the tokenizer, inference engine, cached cost for current device, 71 | and batch manager. Then start the server.""" 72 | if checkpoint: 73 | model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} 74 | else: 75 | model_config = {'dtype': dtype} 76 | 77 | global tokenizer 78 | tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) 79 | tokenizer.pad_token = GPT2Tokenizer.eos_token 80 | 81 | global engine 82 | engine = InferenceEngine(model_class, 83 | model_config, 84 | model_type, 85 | max_batch_size=max_batch_size, 86 | tp_init_size=tp_init_size, 87 | pp_init_size=pp_init_size, 88 | host=host, 89 | port=port, 90 | dtype=dtype) 91 | 92 | global batch_manager 93 | batch_manager = Dynamic_Batch_Manager(forward_func=forward_func, result_process=result_process) 94 | 95 | global server 96 | config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) 97 | server = uvicorn.Server(config=config) 98 | print("running server") 99 | server.run() 100 | print("application started") 101 | -------------------------------------------------------------------------------- /examples/gpt/gpt_config.py: -------------------------------------------------------------------------------- 1 | from gpt import gpt2_small, gpt2_medium, gpt2_large, gpt2_xl, gpt2_8B, gpt3 2 | from gpt_batch_server import launch_engine 3 | 4 | # for engine 5 | model_class = gpt2_8B 6 | model_type = "gpt" 7 | host = "127.0.0.1" 8 | port = 29401 9 | half = True 10 | backend = "nccl" 11 | 12 | # for parallel 13 | tp_init_size = 4 14 | pp_init_size = 2 15 | 16 | # for server 17 | engine_server = launch_engine 18 | server_host = "127.0.0.1" 19 | server_port = 8016 20 | log_level = "info" 21 | tokenizer_path = "/workspace/hf_gpt2" 22 | rm_padding = False 23 | 24 | #for batch manager 25 | max_batch_size = 15 26 | max_sequence_length = 1024 27 | repeat_round = 2 28 | step = 8 29 | max_wait_time = 2 -------------------------------------------------------------------------------- /examples/gpt/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/hf_gpt2/hf_gpt2_config.py: -------------------------------------------------------------------------------- 1 | from hf_gpt2 import hf_gpt2 2 | from hf_gpt2_server import launch_engine 3 | 4 | # for engine 5 | model_class = hf_gpt2 6 | model_type = "gpt" 7 | host = "127.0.0.1" 8 | port = 29401 9 | half = True 10 | checkpoint = "/workspace/hf_gpt2/GPT2.bin" 11 | backend = "nccl" 12 | 13 | # for parallel 14 | tp_init_size = 2 15 | pp_init_size = 2 16 | 17 | # for server 18 | engine_server = launch_engine 19 | tokenizer_path = "/workspace/hf_gpt2" 20 | server_host = "127.0.0.1" 21 | server_port = 8020 22 | log_level = "info" 23 | -------------------------------------------------------------------------------- /examples/hf_gpt2/hf_gpt2_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import uvicorn 4 | from fastapi import FastAPI 5 | from fastapi import Response 6 | import torch.distributed.rpc as rpc 7 | from energonai.engine import InferenceEngine 8 | 9 | from transformers import GPT2Tokenizer 10 | 11 | app = FastAPI() # 创建 api 对象 12 | 13 | @app.get("/") # 根路由 14 | def root(): 15 | return {"200"} 16 | 17 | @app.get("/run/{request}") 18 | def run(request: str, max_seq_length: int): 19 | 20 | input_token = tokenizer(request, return_tensors="pt") 21 | total_predicted_text = request 22 | 23 | for i in range(1, max_seq_length): 24 | output = engine.run(input_token) 25 | predictions = output.to_here() 26 | total_predicted_text += tokenizer.decode(predictions) 27 | # print(total_predicted_text) 28 | if '<|endoftext|>' in total_predicted_text: 29 | break 30 | input_token = tokenizer(total_predicted_text, return_tensors="pt") 31 | 32 | return {total_predicted_text} 33 | 34 | 35 | @app.get("/shutdown") 36 | async def shutdown(): 37 | engine.clear() 38 | server.should_exit = True 39 | server.force_exit = True 40 | await server.shutdown() 41 | 42 | 43 | def launch_engine(model_class, 44 | model_type, 45 | max_batch_size: int = 1, 46 | tp_init_size: int = -1, 47 | pp_init_size: int = -1, 48 | host: str = "localhost", 49 | port: int = 29500, 50 | dtype = torch.float, 51 | checkpoint: str = None, 52 | tokenizer_path: str = None, 53 | server_host = "localhost", 54 | server_port = 8005, 55 | log_level = "info" 56 | ): 57 | 58 | # only for the generation task 59 | global tokenizer 60 | if(tokenizer_path): 61 | tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) 62 | 63 | if checkpoint: 64 | model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} 65 | else: 66 | model_config = {'dtype': dtype} 67 | 68 | global engine 69 | engine = InferenceEngine(model_class, 70 | model_config, 71 | model_type, 72 | max_batch_size = max_batch_size, 73 | tp_init_size = tp_init_size, 74 | pp_init_size = pp_init_size, 75 | host = host, 76 | port = port, 77 | dtype = dtype) 78 | 79 | global server 80 | config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) 81 | server = uvicorn.Server(config=config) 82 | server.run() -------------------------------------------------------------------------------- /examples/hf_gpt2/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/linear/linear.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from energonai.nemesis.nemesis_manager import Ne_manager 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | compute_device = 'cuda:0' # manually set which device to compute on 10 | offload_flag = True # whether or not to activate offloading 11 | 12 | def setup_seed(seed): 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | random.seed(seed) 16 | torch.backends.cudnn.deterministic = True 17 | 18 | 19 | class single_linear(nn.Module): 20 | def __init__(self, input_dim: int, output_dim: int, bias=False): 21 | super().__init__() 22 | self.weight = torch.empty(output_dim, input_dim) 23 | nn.init.normal_(self.weight) 24 | self.weight = nn.Parameter(self.weight.to(compute_device)) 25 | if bias: 26 | self.bias = torch.empty(output_dim) 27 | nn.init.normal_(self.bias) 28 | self.bias = nn.Parameter(self.bias.to(compute_device)) 29 | else: 30 | self.bias = None 31 | 32 | def forward(self, input_): 33 | output = F.linear(input_, self.weight, self.bias) 34 | return output 35 | 36 | 37 | class nv_layers(nn.Module): 38 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, layer_num: int): 39 | super().__init__() 40 | self.module_list = list() 41 | for i in range(layer_num): 42 | if i == 0: 43 | temp_layer = single_linear(input_dim, hidden_dim, True) 44 | elif i == layer_num - 1: 45 | temp_layer = single_linear(hidden_dim, output_dim, True) 46 | else: 47 | temp_layer = single_linear(hidden_dim, hidden_dim, True) 48 | Ne_manager.register_module(temp_layer, compute_device) 49 | if Ne_manager.offload_flags[i] and offload_flag: 50 | Ne_manager.offload_module(temp_layer) 51 | self.module_list.append(temp_layer) 52 | 53 | def print_device(self): 54 | cnt__ = 0 55 | print("=" * 50) 56 | for mod in self.module_list: 57 | print("layer {} device: ".format(cnt__)) 58 | cnt__ += 1 59 | print(next(mod.parameters()).data.device) 60 | print("=" * 50) 61 | 62 | def forward(self, input_): 63 | output = input_ 64 | for layer_ in self.module_list: 65 | if Ne_manager.event_dict[id(layer_)] is not None: 66 | Ne_manager.compute_stream.wait_event(Ne_manager.event_dict[id(layer_)]) 67 | Ne_manager.event_dict[id(layer_)] = None 68 | output = layer_(output) 69 | return output 70 | 71 | 72 | if __name__ == "__main__": 73 | setup_seed(42) 74 | Ne_manager.set_model_info(12, 6) # register model info 75 | Ne_manager.set_free_device("cuda:1") 76 | # Ne_manager.set_free_device("cpu") # modify here if you want to use cpu as offloading target 77 | model_ = nv_layers(200, 150000, 10, 12) 78 | if offload_flag: 79 | Ne_manager.apply_hook() # call this to activate offloading hooks 80 | input_ = torch.randn((20, 200)).to("cuda:0") 81 | print("init done") 82 | with torch.inference_mode(): 83 | for i in range(5): 84 | out_ = model_(input_) 85 | start_ = time.time() 86 | with torch.inference_mode(): 87 | for i in range(20): 88 | out_ = model_(input_) 89 | print(time.time() - start_) 90 | 91 | -------------------------------------------------------------------------------- /examples/linear/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/opt/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI. 4 | 5 | It supports tensor parallelism, batching and caching. 6 | 7 | # How to run 8 | 9 | Run OPT-125M: 10 | ```shell 11 | python opt_fastapi.py opt-125m 12 | ``` 13 | 14 | It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs. 15 | 16 | ## Configure 17 | 18 | ### Configure model 19 | ```shell 20 | python opt_fastapi.py 21 | ``` 22 | Available models: opt-125m, opt-6.7b, opt-30b, opt-175b. 23 | 24 | ### Configure tensor parallelism 25 | ```shell 26 | python opt_fastapi.py --tp 27 | ``` 28 | The `` can be an integer in `[1, #GPUs]`. Default `1`. 29 | 30 | ### Configure checkpoint 31 | ```shell 32 | python opt_fastapi.py --checkpoint 33 | ``` 34 | The `` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded. 35 | 36 | ### Configure queue 37 | ```shell 38 | python opt_fastapi.py --queue_size 39 | ``` 40 | The `` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406). 41 | 42 | ### Configure bathcing 43 | ```shell 44 | python opt_fastapi.py --max_batch_size 45 | ``` 46 | The `` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value. 47 | 48 | Note that the batch size is not always equal to ``, as some consecutive requests may not be batched. 49 | 50 | ### Configure caching 51 | ```shell 52 | python opt_fastapi.py --cache_size --cache_list_size 53 | ``` 54 | This will cache `` unique requests. And for each unique request, it cache `` different results. A random result will be returned if the cache is hit. 55 | 56 | The `` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `` can be an integer in `[1, MAXINT]`. 57 | 58 | ### Other configurations 59 | ```shell 60 | python opt_fastapi.py -h 61 | ``` 62 | 63 | # How to benchmark 64 | ```shell 65 | cd benchmark 66 | locust 67 | ``` 68 | 69 | Then open the web interface link which is on your console. 70 | 71 | # Pre-process pre-trained weights 72 | 73 | ## OPT-66B 74 | See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py). 75 | 76 | ## OPT-175B 77 | See [script/process-opt-175b](./script/process-opt-175b/). -------------------------------------------------------------------------------- /examples/opt/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Deque, Tuple, Hashable, Any 3 | from energonai import BatchManager, SubmitEntry, TaskEntry 4 | 5 | 6 | class BatchManagerForGeneration(BatchManager): 7 | def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: 8 | super().__init__() 9 | self.max_batch_size = max_batch_size 10 | self.pad_token_id = pad_token_id 11 | 12 | def _left_padding(self, batch_inputs): 13 | max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) 14 | outputs = {'input_ids': [], 'attention_mask': []} 15 | for inputs in batch_inputs: 16 | input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] 17 | padding_len = max_len - len(input_ids) 18 | input_ids = [self.pad_token_id] * padding_len + input_ids 19 | attention_mask = [0] * padding_len + attention_mask 20 | outputs['input_ids'].append(input_ids) 21 | outputs['attention_mask'].append(attention_mask) 22 | for k in outputs: 23 | outputs[k] = torch.tensor(outputs[k]) 24 | return outputs, max_len 25 | 26 | @staticmethod 27 | def _make_batch_key(entry: SubmitEntry) -> tuple: 28 | data = entry.data 29 | return (data['top_k'], data['top_p'], data['temperature']) 30 | 31 | def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: 32 | entry = q.popleft() 33 | uids = [entry.uid] 34 | batch = [entry.data] 35 | while len(batch) < self.max_batch_size: 36 | if len(q) == 0: 37 | break 38 | if self._make_batch_key(entry) != self._make_batch_key(q[0]): 39 | break 40 | if q[0].data['max_tokens'] > entry.data['max_tokens']: 41 | break 42 | e = q.popleft() 43 | batch.append(e.data) 44 | uids.append(e.uid) 45 | inputs, max_len = self._left_padding(batch) 46 | trunc_lens = [] 47 | for data in batch: 48 | trunc_lens.append(max_len + data['max_tokens']) 49 | inputs['top_k'] = entry.data['top_k'] 50 | inputs['top_p'] = entry.data['top_p'] 51 | inputs['temperature'] = entry.data['temperature'] 52 | inputs['max_tokens'] = max_len + entry.data['max_tokens'] 53 | return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} 54 | 55 | def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: 56 | retval = [] 57 | for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens): 58 | retval.append((uid, output[:trunc_len])) 59 | return retval 60 | -------------------------------------------------------------------------------- /examples/opt/benchmark/locustfile.py: -------------------------------------------------------------------------------- 1 | from locust import HttpUser, task 2 | from json import JSONDecodeError 3 | 4 | 5 | class GenerationUser(HttpUser): 6 | @task 7 | def generate(self): 8 | prompt = 'Question: What is the longest river on the earth? Answer:' 9 | for i in range(4, 9): 10 | data = {'max_tokens': 2**i, 'prompt': prompt} 11 | with self.client.post('/generation', json=data, catch_response=True) as response: 12 | if response.status_code in (200, 406): 13 | response.success() 14 | else: 15 | response.failure('Response wrong') 16 | -------------------------------------------------------------------------------- /examples/opt/cache.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from threading import Lock 3 | from contextlib import contextmanager 4 | from typing import List, Any, Hashable, Dict 5 | 6 | 7 | class MissCacheError(Exception): 8 | pass 9 | 10 | 11 | class ListCache: 12 | def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: 13 | """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. 14 | When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. 15 | 16 | Args: 17 | cache_size (int): Max size for LRU cache. 18 | list_size (int): Value list size. 19 | fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to []. 20 | """ 21 | self.cache_size = cache_size 22 | self.list_size = list_size 23 | self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict() 24 | self.fixed_cache: Dict[Hashable, List[Any]] = {} 25 | for key in fixed_keys: 26 | self.fixed_cache[key] = [] 27 | self._lock = Lock() 28 | 29 | def get(self, key: Hashable) -> List[Any]: 30 | with self.lock(): 31 | if key in self.fixed_cache: 32 | l = self.fixed_cache[key] 33 | if len(l) >= self.list_size: 34 | return l 35 | elif key in self.cache: 36 | self.cache.move_to_end(key) 37 | l = self.cache[key] 38 | if len(l) >= self.list_size: 39 | return l 40 | raise MissCacheError() 41 | 42 | def add(self, key: Hashable, value: Any) -> None: 43 | with self.lock(): 44 | if key in self.fixed_cache: 45 | l = self.fixed_cache[key] 46 | if len(l) < self.list_size and value not in l: 47 | l.append(value) 48 | elif key in self.cache: 49 | self.cache.move_to_end(key) 50 | l = self.cache[key] 51 | if len(l) < self.list_size and value not in l: 52 | l.append(value) 53 | else: 54 | if len(self.cache) >= self.cache_size: 55 | self.cache.popitem(last=False) 56 | self.cache[key] = [value] 57 | 58 | @contextmanager 59 | def lock(self): 60 | try: 61 | self._lock.acquire() 62 | yield 63 | finally: 64 | self._lock.release() 65 | -------------------------------------------------------------------------------- /examples/opt/opt_fastapi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | from typing import Optional 5 | 6 | import uvicorn 7 | from energonai import QueueFullError, launch_engine 8 | from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B 9 | from fastapi import FastAPI, HTTPException, Request 10 | from pydantic import BaseModel, Field 11 | from transformers import GPT2Tokenizer 12 | 13 | from batch import BatchManagerForGeneration 14 | from cache import ListCache, MissCacheError 15 | 16 | 17 | class GenerationTaskReq(BaseModel): 18 | max_tokens: int = Field(gt=0, le=256, example=64) 19 | prompt: str = Field( 20 | min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') 21 | top_k: Optional[int] = Field(default=None, gt=0, example=50) 22 | top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) 23 | temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) 24 | 25 | 26 | app = FastAPI() 27 | 28 | 29 | @app.post('/generation') 30 | async def generate(data: GenerationTaskReq, request: Request): 31 | logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') 32 | key = (data.prompt, data.max_tokens) 33 | try: 34 | if cache is None: 35 | raise MissCacheError() 36 | outputs = cache.get(key) 37 | output = random.choice(outputs) 38 | logger.info('Cache hit') 39 | except MissCacheError: 40 | inputs = tokenizer(data.prompt, truncation=True, max_length=512) 41 | inputs['max_tokens'] = data.max_tokens 42 | inputs['top_k'] = data.top_k 43 | inputs['top_p'] = data.top_p 44 | inputs['temperature'] = data.temperature 45 | try: 46 | uid = id(data) 47 | engine.submit(uid, inputs) 48 | output = await engine.wait(uid) 49 | output = tokenizer.decode(output, skip_special_tokens=True) 50 | if cache is not None: 51 | cache.add(key, output) 52 | except QueueFullError as e: 53 | raise HTTPException(status_code=406, detail=e.args[0]) 54 | 55 | return {'text': output} 56 | 57 | 58 | @app.on_event("shutdown") 59 | async def shutdown(*_): 60 | engine.shutdown() 61 | server.should_exit = True 62 | server.force_exit = True 63 | await server.shutdown() 64 | 65 | 66 | def get_model_fn(model_name: str): 67 | model_map = { 68 | 'opt-125m': opt_125M, 69 | 'opt-6.7b': opt_6B, 70 | 'opt-30b': opt_30B, 71 | 'opt-175b': opt_175B 72 | } 73 | return model_map[model_name] 74 | 75 | 76 | def print_args(args: argparse.Namespace): 77 | print('\n==> Args:') 78 | for k, v in args.__dict__.items(): 79 | print(f'{k} = {v}') 80 | 81 | 82 | FIXED_CACHE_KEYS = [ 83 | ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), 84 | ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), 85 | ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) 86 | ] 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) 91 | parser.add_argument('--tp', type=int, default=1) 92 | parser.add_argument('--master_host', default='localhost') 93 | parser.add_argument('--master_port', type=int, default=19990) 94 | parser.add_argument('--rpc_port', type=int, default=19980) 95 | parser.add_argument('--max_batch_size', type=int, default=8) 96 | parser.add_argument('--pipe_size', type=int, default=1) 97 | parser.add_argument('--queue_size', type=int, default=0) 98 | parser.add_argument('--http_host', default='0.0.0.0') 99 | parser.add_argument('--http_port', type=int, default=7070) 100 | parser.add_argument('--checkpoint', default=None) 101 | parser.add_argument('--cache_size', type=int, default=0) 102 | parser.add_argument('--cache_list_size', type=int, default=1) 103 | args = parser.parse_args() 104 | print_args(args) 105 | model_kwargs = {} 106 | if args.checkpoint is not None: 107 | model_kwargs['checkpoint'] = args.checkpoint 108 | 109 | logger = logging.getLogger(__name__) 110 | tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') 111 | if args.cache_size > 0: 112 | cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) 113 | else: 114 | cache = None 115 | engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), 116 | batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, 117 | pad_token_id=tokenizer.pad_token_id), 118 | pipe_size=args.pipe_size, 119 | queue_size=args.queue_size, 120 | **model_kwargs) 121 | config = uvicorn.Config(app, host=args.http_host, port=args.http_port) 122 | server = uvicorn.Server(config=config) 123 | server.run() 124 | -------------------------------------------------------------------------------- /examples/opt/opt_server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import random 4 | from torch import Tensor 5 | from pydantic import BaseModel, Field 6 | from typing import Optional 7 | from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B 8 | from transformers import GPT2Tokenizer 9 | from energonai import launch_engine, QueueFullError 10 | from sanic import Sanic 11 | from sanic.request import Request 12 | from sanic.response import json 13 | from sanic_ext import validate, openapi 14 | from batch import BatchManagerForGeneration 15 | from cache import ListCache, MissCacheError 16 | 17 | 18 | class GenerationTaskReq(BaseModel): 19 | max_tokens: int = Field(gt=0, le=256, example=64) 20 | prompt: str = Field( 21 | min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') 22 | top_k: Optional[int] = Field(default=None, gt=0, example=50) 23 | top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) 24 | temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) 25 | 26 | 27 | app = Sanic('opt') 28 | 29 | 30 | @app.post('/generation') 31 | @openapi.body(GenerationTaskReq) 32 | @validate(json=GenerationTaskReq) 33 | async def generate(request: Request, body: GenerationTaskReq): 34 | logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}') 35 | key = (body.prompt, body.max_tokens) 36 | try: 37 | if cache is None: 38 | raise MissCacheError() 39 | outputs = cache.get(key) 40 | output = random.choice(outputs) 41 | logger.info('Cache hit') 42 | except MissCacheError: 43 | inputs = tokenizer(body.prompt, truncation=True, max_length=512) 44 | inputs['max_tokens'] = body.max_tokens 45 | inputs['top_k'] = body.top_k 46 | inputs['top_p'] = body.top_p 47 | inputs['temperature'] = body.temperature 48 | try: 49 | uid = id(body) 50 | engine.submit(uid, inputs) 51 | output = await engine.wait(uid) 52 | assert isinstance(output, Tensor) 53 | output = tokenizer.decode(output, skip_special_tokens=True) 54 | if cache is not None: 55 | cache.add(key, output) 56 | except QueueFullError as e: 57 | return json({'detail': e.args[0]}, status=406) 58 | 59 | return json({'text': output}) 60 | 61 | 62 | @app.after_server_stop 63 | def shutdown(*_): 64 | engine.shutdown() 65 | 66 | 67 | def get_model_fn(model_name: str): 68 | model_map = { 69 | 'opt-125m': opt_125M, 70 | 'opt-6.7b': opt_6B, 71 | 'opt-30b': opt_30B, 72 | 'opt-175b': opt_175B 73 | } 74 | return model_map[model_name] 75 | 76 | 77 | def print_args(args: argparse.Namespace): 78 | print('\n==> Args:') 79 | for k, v in args.__dict__.items(): 80 | print(f'{k} = {v}') 81 | 82 | 83 | FIXED_CACHE_KEYS = [ 84 | ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), 85 | ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), 86 | ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) 87 | ] 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) 92 | parser.add_argument('--tp', type=int, default=1) 93 | parser.add_argument('--master_host', default='localhost') 94 | parser.add_argument('--master_port', type=int, default=19990) 95 | parser.add_argument('--rpc_port', type=int, default=19980) 96 | parser.add_argument('--max_batch_size', type=int, default=8) 97 | parser.add_argument('--pipe_size', type=int, default=1) 98 | parser.add_argument('--queue_size', type=int, default=0) 99 | parser.add_argument('--http_host', default='0.0.0.0') 100 | parser.add_argument('--http_port', type=int, default=7070) 101 | parser.add_argument('--checkpoint', default=None) 102 | parser.add_argument('--cache_size', type=int, default=0) 103 | parser.add_argument('--cache_list_size', type=int, default=1) 104 | args = parser.parse_args() 105 | print_args(args) 106 | model_kwargs = {} 107 | if args.checkpoint is not None: 108 | model_kwargs['checkpoint'] = args.checkpoint 109 | 110 | logger = logging.getLogger(__name__) 111 | tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') 112 | if args.cache_size > 0: 113 | cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) 114 | else: 115 | cache = None 116 | engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), 117 | batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, 118 | pad_token_id=tokenizer.pad_token_id), 119 | pipe_size=args.pipe_size, 120 | queue_size=args.queue_size, 121 | **model_kwargs) 122 | app.run(args.http_host, args.http_port) 123 | -------------------------------------------------------------------------------- /examples/opt/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.85.1 2 | locust==2.11.0 3 | pydantic==1.10.2 4 | sanic==22.9.0 5 | sanic_ext==22.9.0 6 | torch>=1.10.0 7 | transformers==4.23.1 8 | uvicorn==0.19.0 9 | colossalai 10 | -------------------------------------------------------------------------------- /examples/opt/script/process-opt-175b/README.md: -------------------------------------------------------------------------------- 1 | # Process OPT-175B weights 2 | 3 | You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this. 4 | 5 | First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`. 6 | 7 | Then, `cd metaseq`. 8 | 9 | To consolidate checkpoints to eliminate FSDP: 10 | 11 | ```shell 12 | bash metaseq/scripts/reshard_mp_launch_no_slurm.sh /checkpoint_last / 8 1 13 | ``` 14 | 15 | You will get 8 files in ``, and you should have the following checksums: 16 | ``` 17 | 7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt 18 | c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt 19 | 45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt 20 | abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt 21 | 05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt 22 | d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt 23 | fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt 24 | 2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt 25 | ``` 26 | 27 | Copy `flat-meta.json` to ``. 28 | 29 | Then cd to this dir, and we unflatten parameters. 30 | 31 | ```shell 32 | bash unflat.sh / / 33 | ``` 34 | 35 | Finally, you will get 8 files in `` with following checksums: 36 | ``` 37 | 6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt 38 | 58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt 39 | 69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt 40 | 002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt 41 | 6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt 42 | 93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt 43 | 5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt 44 | f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt 45 | ``` 46 | 47 | -------------------------------------------------------------------------------- /examples/opt/script/process-opt-175b/convert_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def load_json(path: str): 12 | with open(path) as f: 13 | return json.load(f) 14 | 15 | 16 | def parse_shape_info(flat_dir: str): 17 | data = load_json(os.path.join(flat_dir, 'shape.json')) 18 | flat_info = defaultdict(lambda: defaultdict(list)) 19 | for k, shape in data.items(): 20 | matched = re.match(r'decoder.layers.\d+', k) 21 | if matched is None: 22 | flat_key = 'flat_param_0' 23 | else: 24 | flat_key = f'{matched[0]}.flat_param_0' 25 | flat_info[flat_key]['names'].append(k) 26 | flat_info[flat_key]['shapes'].append(shape) 27 | flat_info[flat_key]['numels'].append(int(np.prod(shape))) 28 | return flat_info 29 | 30 | 31 | def convert(flat_dir: str, output_dir: str, part: int): 32 | flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') 33 | output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') 34 | flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) 35 | flat_sd = torch.load(flat_path) 36 | print(f'Loaded flat state dict from {flat_path}') 37 | output_sd = {} 38 | for flat_key, param_meta in flat_meta.items(): 39 | flat_param = flat_sd['model'][flat_key] 40 | assert sum(param_meta['numels']) == flat_param.numel( 41 | ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' 42 | for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): 43 | output_sd[name] = param.view(shape) 44 | 45 | torch.save(output_sd, output_path) 46 | print(f'Saved unflat state dict to {output_path}') 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('flat_dir') 52 | parser.add_argument('output_dir') 53 | parser.add_argument('part', type=int) 54 | args = parser.parse_args() 55 | convert(args.flat_dir, args.output_dir, args.part) 56 | -------------------------------------------------------------------------------- /examples/opt/script/process-opt-175b/unflat.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | for i in $(seq 0 7); do 4 | python convert_ckpt.py $1 $2 ${i} & 5 | done 6 | 7 | wait $(jobs -p) 8 | -------------------------------------------------------------------------------- /examples/opt/script/processing_ckpt_66b.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from multiprocessing import Pool 4 | 5 | # download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main 6 | # you can use whether wget or git lfs 7 | 8 | path = "/path/to/your/ckpt" 9 | new_path = "/path/to/the/processed/ckpt/" 10 | 11 | assert os.path.isdir(path) 12 | files = [] 13 | for filename in os.listdir(path): 14 | filepath = os.path.join(path, filename) 15 | if os.path.isfile(filepath): 16 | files.append(filepath) 17 | 18 | with Pool(14) as pool: 19 | ckpts = pool.map(torch.load, files) 20 | 21 | restored = {} 22 | for ckpt in ckpts: 23 | for k,v in ckpt.items(): 24 | if(k[0] == 'm'): 25 | k = k[6:] 26 | if(k == "lm_head.weight"): 27 | k = "head.dense.weight" 28 | if(k == "decoder.final_layer_norm.weight"): 29 | k = "decoder.layer_norm.weight" 30 | if(k == "decoder.final_layer_norm.bias"): 31 | k = "decoder.layer_norm.bias" 32 | restored[k] = v 33 | restored["decoder.version"] = "0.0" 34 | 35 | 36 | split_num = len(restored.keys()) // 60 37 | count = 0 38 | file_count = 1 39 | tmp = {} 40 | for k,v in restored.items(): 41 | print(k) 42 | tmp[k] = v 43 | count = count + 1 44 | if(count == split_num): 45 | filename = str(file_count) + "-restored.pt" 46 | torch.save(tmp, os.path.join(new_path, filename)) 47 | file_count = file_count + 1 48 | count = 0 49 | tmp = {} 50 | 51 | filename = str(file_count) + "-restored.pt" 52 | torch.save(tmp, os.path.join(new_path, filename)) 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /examples/trt_demo/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/trt_demo/trt_net_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from net import bert_large 3 | from trt_net_server import launch_engine 4 | 5 | 6 | # for engine 7 | model_class = bert_large 8 | model_type = "bert" 9 | host = "127.0.0.1" 10 | port = 29401 11 | half = True 12 | backend = "nccl" 13 | 14 | # for parallel 15 | tp_init_size = 1 16 | pp_init_size = 1 17 | 18 | # for server 19 | engine_server = launch_engine 20 | server_host = "127.0.0.1" 21 | server_port = 8020 22 | log_level = "info" 23 | 24 | # for tensorrt 25 | trt_sample = [torch.ones((1,128,1024)).half().cuda(), torch.ones((1, 1, 128)).half().cuda()] 26 | -------------------------------------------------------------------------------- /examples/trt_demo/trt_net_server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import uvicorn 3 | from fastapi import FastAPI 4 | from energonai.engine import InferenceEngine 5 | 6 | from transformers import GPT2Tokenizer 7 | 8 | app = FastAPI() # 创建 api 对象 9 | 10 | @app.get("/") # 根路由 11 | def root(): 12 | return {"200"} 13 | 14 | # @app.get("/run/{request}") 15 | # def run(request: str, max_seq_length: int): 16 | 17 | # input_token = tokenizer(request, return_tensors="pt") 18 | # total_predicted_text = request 19 | 20 | # for i in range(1, max_seq_length): 21 | # output = engine.run(input_token) 22 | # predictions = output.to_here() 23 | # total_predicted_text += tokenizer.decode(predictions) 24 | # # print(total_predicted_text) 25 | # if '<|endoftext|>' in total_predicted_text: 26 | # break 27 | # input_token = tokenizer(total_predicted_text, return_tensors="pt") 28 | 29 | # return {total_predicted_text} 30 | 31 | 32 | # @app.get("/shutdown") 33 | # async def shutdown(): 34 | # engine.clear() 35 | # server.should_exit = True 36 | # server.force_exit = True 37 | # await server.shutdown() 38 | 39 | 40 | def launch_engine(model_class, 41 | model_type, 42 | max_batch_size: int = 1, 43 | tp_init_size: int = -1, 44 | pp_init_size: int = -1, 45 | host: str = "localhost", 46 | port: int = 29500, 47 | dtype = torch.float, 48 | checkpoint: str = None, 49 | tokenizer_path: str = None, 50 | server_host = "localhost", 51 | server_port = 8005, 52 | log_level = "info" 53 | ): 54 | 55 | # only for the generation task 56 | global tokenizer 57 | if(tokenizer_path): 58 | tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) 59 | 60 | model_config = dict() 61 | 62 | global engine 63 | engine = InferenceEngine(model_class, 64 | model_config, 65 | model_type, 66 | max_batch_size = max_batch_size, 67 | tp_init_size = tp_init_size, 68 | pp_init_size = pp_init_size, 69 | host = host, 70 | port = port, 71 | dtype = dtype) 72 | 73 | global server 74 | config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) 75 | server = uvicorn.Server(config=config) 76 | server.run() -------------------------------------------------------------------------------- /examples/vit/dataset/n01667114_9985.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/EnergonAI/56b35f3c06eaac11b1bee633d1e836563f74bcea/examples/vit/dataset/n01667114_9985.JPEG -------------------------------------------------------------------------------- /examples/vit/proc_img.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | 7 | 8 | default_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 9 | std=[0.229, 0.224, 0.225]) 10 | 11 | def pil_loader(path: str) -> Image.Image: 12 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 13 | with open(path, "rb") as f: 14 | img = Image.open(f) 15 | return img.convert("RGB") 16 | 17 | def accimage_loader(path: str) -> Any: 18 | import accimage 19 | 20 | try: 21 | return accimage.Image(path) 22 | except OSError: 23 | # Potentially a decoding problem, fall back to PIL.Image 24 | return pil_loader(path) 25 | 26 | def proc_img(path: str, size: int=224, normalize=default_normalize, loader=pil_loader) -> torch.Tensor: 27 | img = loader(path) 28 | transform = transforms.Compose([ 29 | transforms.RandomResizedCrop(224), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | normalize, 33 | ]) 34 | img = transform(img) 35 | return img -------------------------------------------------------------------------------- /examples/vit/requirements.txt: -------------------------------------------------------------------------------- 1 | colossalai 2 | torch >= 1.8.1 3 | -------------------------------------------------------------------------------- /examples/vit/vit_config.py: -------------------------------------------------------------------------------- 1 | from vit import vit_lite_depth7_patch4_32, vit_tiny_patch4_32, vit_base_patch16_224 2 | from vit import vit_base_patch16_384, vit_base_patch32_224, vit_base_patch32_384, vit_large_patch16_224 3 | from vit import vit_large_patch16_384, vit_large_patch32_224, vit_large_patch32_384 4 | from vit_server import launch_engine 5 | 6 | 7 | # for engine 8 | model_class = vit_base_patch16_224 9 | model_type = "vit" 10 | host = "127.0.0.1" 11 | port = 29402 12 | half = True 13 | backend = "nccl" 14 | 15 | # for parallel 16 | tp_init_size = 2 17 | pp_init_size = 2 18 | 19 | # for server 20 | engine_server = launch_engine 21 | server_host = "127.0.0.1" 22 | server_port = 8020 23 | log_level = "info" 24 | -------------------------------------------------------------------------------- /examples/vit/vit_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import uvicorn 4 | from fastapi import FastAPI 5 | from fastapi import Response 6 | import torch.distributed.rpc as rpc 7 | from energonai.engine import InferenceEngine 8 | from proc_img import proc_img 9 | 10 | app = FastAPI() # 创建 api 对象 11 | 12 | @app.get("/") # 根路由 13 | def root(): 14 | return {"200"} 15 | 16 | @app.get("/vit") 17 | def run(): 18 | # for the performance only 19 | img = proc_img('/home/lcdjs/ColossalAI-Inference/examples/vit/dataset/n01667114_9985.JPEG') 20 | img = img.half() 21 | img = torch.unsqueeze(img, 0) 22 | sample = dict(img=img) 23 | output = engine.run(sample) 24 | output = output.to_here() 25 | print(output.size()) 26 | return {"To return the class."} 27 | 28 | @app.get("/shutdown") 29 | async def shutdown(): 30 | engine.clear() 31 | server.should_exit = True 32 | server.force_exit = True 33 | await server.shutdown() 34 | 35 | 36 | def launch_engine(model_class, 37 | model_type, 38 | max_batch_size: int = 1, 39 | tp_init_size: int = -1, 40 | pp_init_size: int = -1, 41 | host: str = "localhost", 42 | port: int = 29500, 43 | dtype = torch.float, 44 | checkpoint: str = None, 45 | tokenizer_path: str = None, 46 | server_host = "localhost", 47 | server_port = 8005, 48 | log_level = "info" 49 | ): 50 | 51 | if checkpoint: 52 | model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} 53 | else: 54 | model_config = {'dtype': dtype} 55 | 56 | global engine 57 | engine = InferenceEngine(model_class, 58 | model_config, 59 | model_type, 60 | max_batch_size = max_batch_size, 61 | tp_init_size = tp_init_size, 62 | pp_init_size = pp_init_size, 63 | host = host, 64 | port = port, 65 | dtype = dtype) 66 | 67 | global server 68 | config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) 69 | server = uvicorn.Server(config=config) 70 | server.run() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | psutil 4 | packaging 5 | fastapi~=0.75.1 6 | uvicorn==0.14 7 | typer 8 | redis 9 | scipy 10 | pytest 11 | requests 12 | click 13 | transformers 14 | readerwriterlock 15 | --extra-index-url https://download.pytorch.org/whl/cu113 16 | torch 17 | torchvision 18 | torchaudio 19 | colossalai 20 | omegaconf 21 | prometheus-fastapi-instrumentator 22 | 23 | -------------------------------------------------------------------------------- /tests/run_standalone_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | export PYTHONPATH=$(realpath $(dirname $0)) 5 | 6 | find . -name "test*.py" -print0 | xargs -0L1 pytest -m "standalone" 7 | 8 | -------------------------------------------------------------------------------- /tests/test_checkpoint/test_checkpoint_basic1d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import pprint 5 | from functools import partial 6 | from colossalai.logging import get_dist_logger 7 | 8 | import colossalai.nn as col_nn 9 | import pytest 10 | import torch 11 | import torch.multiprocessing as mp 12 | import torch.nn as nn 13 | from energonai.context.parallel_mode import ParallelMode 14 | from energonai.core import global_context as gpc 15 | from energonai.initialize import launch 16 | from colossalai.logging import disable_existing_loggers 17 | from colossalai.utils import free_port, is_using_pp 18 | from energonai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint 19 | 20 | 21 | def partition_uniform(num_items, pipeline_parallel_size, num_chunks): 22 | assert num_items % num_chunks == 0, \ 23 | "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" 24 | 25 | logger = get_dist_logger('energonai') 26 | parts = [[] for _ in range(pipeline_parallel_size)] 27 | partition_items = num_items // num_chunks 28 | for idx in range(num_chunks): 29 | base_idx = idx * partition_items 30 | chunk_size = partition_items // pipeline_parallel_size 31 | left = pipeline_parallel_size - partition_items % pipeline_parallel_size 32 | if chunk_size == 0: 33 | logger.warning("Some nodes in Pipeline have no requests") 34 | 35 | for p in range(pipeline_parallel_size): 36 | st = base_idx 37 | base_idx += chunk_size + (p >= left) 38 | parts[p].append((st, base_idx)) 39 | 40 | return parts 41 | 42 | 43 | def build_pipeline(model): 44 | 45 | pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) 46 | pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) 47 | depth = len(model) 48 | start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] 49 | layers = [] 50 | for i in range(depth): 51 | if start <= i < end: 52 | layers.append(model[i]) 53 | else: 54 | layers.append(nn.Identity()) 55 | return nn.Sequential(*tuple(layers)) 56 | 57 | 58 | def check_equal(A, B): 59 | assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) 60 | 61 | 62 | def check_basic_1d(rank, world_size, port): 63 | # config = dict( 64 | # parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")), 65 | # ) 66 | disable_existing_loggers() 67 | launch(pp_size=2, tp_size=2, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") 68 | # m1 = nn.Sequential(col_nn.Embedding1D(20, 12), col_nn.Linear1D(12, 20), col_nn.Classifier1D(20, 3), 69 | # col_nn.Embedding1D(20, 12), col_nn.Linear1D(12, 20), col_nn.Classifier1D(20, 3)) 70 | # m1 = nn.Sequential(col_nn.Embedding1D(20, 12), col_nn.Embedding1D(20, 12)) 71 | # m1 = nn.Sequential(col_nn.Linear1D(4, 2), col_nn.Linear1D(2, 4)) 72 | if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: 73 | m1 = nn.Sequential(col_nn.Embedding1D(16, 4), col_nn.Classifier1D(8, 4), col_nn.Linear1D(4, 2), 74 | col_nn.Dropout1D(), 75 | nn.Identity(), nn.Identity(), nn.Identity(), nn.Identity()) 76 | else: 77 | m1 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity(), nn.Identity(), 78 | col_nn.Embedding1D(16, 4), col_nn.Classifier1D(8, 4), col_nn.Linear1D(4, 2), 79 | col_nn.Dropout1D()) 80 | for name, param in m1.named_parameters(): 81 | print("RANK {}: {}, {}".format(gpc.get_global_rank(), name, param.size())) 82 | sd1 = m1.state_dict() 83 | # print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") 84 | save_checkpoint("test.pt", 0, m1) 85 | # m2 = nn.Sequential(col_nn.Embedding1D(20, 12), col_nn.Linear1D(12, 20), col_nn.Classifier1D(20, 3), 86 | # col_nn.Embedding1D(20, 12), col_nn.Linear1D(12, 20), col_nn.Classifier1D(20, 3)) 87 | m2 = nn.Sequential(col_nn.Embedding1D(16, 4), col_nn.Classifier1D(8, 4), col_nn.Linear1D(4, 2), col_nn.Dropout1D(), 88 | col_nn.Embedding1D(16, 4), col_nn.Classifier1D(8, 4), col_nn.Linear1D(4, 2), col_nn.Dropout1D()) 89 | 90 | if is_using_pp(): 91 | m2 = build_pipeline(m2) 92 | load_checkpoint("test.pt", m2) 93 | sd2 = m2.state_dict() 94 | if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: 95 | sd2 = gather_pipeline_parallel_state_dict(sd2) 96 | print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") 97 | 98 | if gpc.get_global_rank() == 0: 99 | for k, v in sd1.items(): 100 | assert k in sd2 101 | check_equal(v.to(torch.device("cpu")), sd2[k].to(torch.device("cpu"))) 102 | 103 | 104 | @pytest.mark.dist 105 | def test_checkpoint_1d(): 106 | world_size = 4 107 | run_func = partial(check_basic_1d, world_size=world_size, port=free_port()) 108 | mp.spawn(run_func, nprocs=world_size) 109 | 110 | 111 | if __name__ == "__main__": 112 | test_checkpoint_1d() 113 | -------------------------------------------------------------------------------- /tests/test_checkpoint/test_checkpoint_bert1d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import pprint 5 | from functools import partial 6 | 7 | import colossalai.nn as col_nn 8 | import torch 9 | import torch.multiprocessing as mp 10 | import torch.nn as nn 11 | 12 | from example.gpt.gpt import gpt2_small 13 | from energonai.context.parallel_mode import ParallelMode 14 | from energonai.engine import InferenceEngine 15 | from example.bert.bert import bert_small 16 | from energonai.core import global_context as gpc 17 | from energonai.initialize import launch 18 | from colossalai.logging import disable_existing_loggers 19 | from colossalai.utils import free_port, is_using_pp 20 | from energonai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint 21 | 22 | 23 | def check_equal(A, B): 24 | assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) 25 | 26 | 27 | def check_bert_1d(rank, world_size, port): 28 | disable_existing_loggers() 29 | launch(pp_size=2, tp_size=2, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") 30 | state_prefix = "" 31 | parameter_prefix = "" 32 | m1 = bert_small() 33 | sd1 = m1.state_dict(prefix=state_prefix) 34 | for name, param in m1.named_parameters(prefix=parameter_prefix): 35 | print("RANK {}: {}, {}".format(gpc.get_global_rank(), name, param.size())) 36 | save_checkpoint("bert_test.pt", 0, m1, prefix=state_prefix) 37 | print("Rank {} building second GPT".format(gpc.get_global_rank())) 38 | m2 = bert_small(checkpoint=True, checkpoint_path="bert_test.pt", prefix=parameter_prefix) 39 | sd2 = m2.state_dict(prefix=state_prefix) 40 | if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: 41 | sd2 = gather_pipeline_parallel_state_dict(sd2) 42 | # print("Rank {} : {}".format(gpc.get_global_rank(), sd2)) 43 | print("Rank {} gather done".format(gpc.get_global_rank())) 44 | # print(f'Rank {gpc.get_global_rank()}:{pprint.pformat(sd2)}') 45 | if gpc.get_global_rank() == 0: 46 | for k, v in sd1.items(): 47 | assert k in sd2 48 | check_equal(v.to(torch.device("cpu")), sd2[k].to(torch.device("cpu"))) 49 | 50 | 51 | def test_bert(): 52 | world_size = 4 53 | run_func = partial(check_bert_1d, world_size=world_size, port=free_port()) 54 | mp.spawn(run_func, nprocs=world_size) 55 | 56 | 57 | if __name__ == "__main__": 58 | test_bert() 59 | -------------------------------------------------------------------------------- /tests/test_checkpoint/test_checkpoint_gpt1d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import pprint 5 | from functools import partial 6 | 7 | import colossalai.nn as col_nn 8 | import torch 9 | import torch.multiprocessing as mp 10 | import torch.nn as nn 11 | 12 | from example.gpt.gpt import gpt2_small 13 | from energonai.context.parallel_mode import ParallelMode 14 | from energonai.engine import InferenceEngine 15 | from example.gpt import * 16 | from energonai.core import global_context as gpc 17 | from energonai.initialize import launch 18 | from colossalai.logging import disable_existing_loggers 19 | from colossalai.utils import free_port, is_using_pp 20 | from energonai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint 21 | 22 | 23 | def check_equal(A, B): 24 | assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) 25 | 26 | 27 | def check_gpt_1d(rank, world_size, port): 28 | disable_existing_loggers() 29 | launch(pp_size=2, tp_size=2, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") 30 | # state_prefix = "rk_{}.".format(gpc.get_global_rank()) 31 | # parameter_prefix = "rk_{}".format(gpc.get_global_rank()) 32 | state_prefix = '' 33 | parameter_prefix = '' 34 | m1 = gpt2_small(vocab_size=50257) 35 | sd1 = m1.state_dict(prefix=state_prefix) 36 | for name, param in m1.named_parameters(prefix=parameter_prefix): 37 | print("RANK {}: {}, {}".format(gpc.get_global_rank(), name, param.size())) 38 | save_checkpoint("gpt_test.pt", 0, m1, prefix=state_prefix) 39 | print("Rank {} building second GPT".format(gpc.get_global_rank())) 40 | m2 = gpt2_small(checkpoint=True, checkpoint_path="gpt_test.pt", prefix=parameter_prefix, vocab_size=50257) 41 | sd2 = m2.state_dict(prefix=state_prefix) 42 | if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: 43 | sd2 = gather_pipeline_parallel_state_dict(sd2) 44 | # print("Rank {} : {}".format(gpc.get_global_rank(), sd2)) 45 | print("Rank {} gather done".format(gpc.get_global_rank())) 46 | # print(f'Rank {gpc.get_global_rank()}:{pprint.pformat(sd2)}') 47 | if gpc.get_global_rank() == 0: 48 | for k, v in sd1.items(): 49 | assert k in sd2 50 | check_equal(v.to(torch.device("cpu")), sd2[k].to(torch.device("cpu"))) 51 | 52 | 53 | def test_gpt(): 54 | world_size = 4 55 | run_func = partial(check_gpt_1d, world_size=world_size, port=free_port()) 56 | mp.spawn(run_func, nprocs=world_size) 57 | 58 | 59 | if __name__ == "__main__": 60 | test_gpt() 61 | -------------------------------------------------------------------------------- /tests/test_checkpoint/test_moduledict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | -------------------------------------------------------------------------------- /tests/test_engine/boring_model_utils.py: -------------------------------------------------------------------------------- 1 | from energonai.testing import BoringModel, get_correct_output 2 | from energonai import launch_engine 3 | from colossalai.utils import free_port 4 | import torch 5 | import asyncio 6 | 7 | 8 | def run_boring_model(tp_world_size: int, pp_world_size: int): 9 | engine = launch_engine(tp_world_size, pp_world_size, 'localhost', free_port(), free_port(), BoringModel) 10 | x = torch.ones(4) 11 | correct_output = get_correct_output(x, pp_world_size) 12 | engine.submit(0, x) 13 | output = asyncio.run(engine.wait(0)) 14 | try: 15 | assert torch.equal(output, correct_output), f'output: {output} vs target: {correct_output}' 16 | finally: 17 | engine.shutdown() 18 | -------------------------------------------------------------------------------- /tests/test_engine/test_hybrid.py: -------------------------------------------------------------------------------- 1 | from colossalai.testing import rerun_if_address_is_in_use 2 | from test_engine.boring_model_utils import run_boring_model 3 | import pytest 4 | 5 | 6 | @pytest.mark.dist 7 | @pytest.mark.standalone 8 | @rerun_if_address_is_in_use() 9 | def test_hybrid(): 10 | run_boring_model(2, 2) 11 | 12 | 13 | if __name__ == '__main__': 14 | test_hybrid() 15 | -------------------------------------------------------------------------------- /tests/test_engine/test_pp.py: -------------------------------------------------------------------------------- 1 | from colossalai.testing import rerun_if_address_is_in_use 2 | from test_engine.boring_model_utils import run_boring_model 3 | import pytest 4 | 5 | 6 | @pytest.mark.dist 7 | @pytest.mark.standalone 8 | @rerun_if_address_is_in_use() 9 | def test_pp(): 10 | run_boring_model(1, 2) 11 | 12 | 13 | if __name__ == '__main__': 14 | test_pp() 15 | -------------------------------------------------------------------------------- /tests/test_engine/test_single_device.py: -------------------------------------------------------------------------------- 1 | from colossalai.testing import rerun_if_address_is_in_use 2 | from test_engine.boring_model_utils import run_boring_model 3 | import pytest 4 | 5 | 6 | @pytest.mark.dist 7 | @pytest.mark.standalone 8 | @rerun_if_address_is_in_use() 9 | def test_single_device(): 10 | run_boring_model(1, 1) 11 | 12 | 13 | if __name__ == '__main__': 14 | test_single_device() 15 | -------------------------------------------------------------------------------- /tests/test_engine/test_tp.py: -------------------------------------------------------------------------------- 1 | from colossalai.testing import rerun_if_address_is_in_use 2 | from test_engine.boring_model_utils import run_boring_model 3 | import pytest 4 | 5 | 6 | @pytest.mark.dist 7 | @pytest.mark.standalone 8 | @rerun_if_address_is_in_use() 9 | def test_tp(): 10 | run_boring_model(2, 1) 11 | 12 | 13 | if __name__ == '__main__': 14 | test_tp() 15 | -------------------------------------------------------------------------------- /tests/test_kernel/test_ft_transpose_pad.py: -------------------------------------------------------------------------------- 1 | from energonai.kernel import ft_build_padding_offsets, ft_remove_padding, ft_rebuild_padding, ft_transpose_remove_padding, ft_transpose_rebuild_padding 2 | import torch 3 | import pytest 4 | 5 | 6 | seq_lens = torch.tensor([24,127,31,65,24,127,31,65], dtype=torch.int).cuda() 7 | batch_size = 8 8 | max_padding_size = 128 9 | head_size = 64 10 | head_num = 12 11 | hidden_size = head_num * head_size 12 | 13 | 14 | def test_kernel(): 15 | hidden_states_q = torch.rand(batch_size, max_padding_size, hidden_size).cuda() 16 | hidden_states_k = torch.rand(batch_size, max_padding_size, hidden_size).cuda() 17 | hidden_states_v = torch.rand(batch_size, max_padding_size, hidden_size).cuda() 18 | 19 | 20 | tmp_mask_offset = torch.zeros(batch_size, max_padding_size, dtype=torch.int).cuda() 21 | mask_offset = torch.zeros(batch_size, max_padding_size, dtype=torch.int).cuda() 22 | valid_word_num = torch.zeros(1, dtype=torch.int).cuda() 23 | 24 | ft_build_padding_offsets(seq_lens, batch_size, max_padding_size, valid_word_num, tmp_mask_offset) 25 | q = ft_remove_padding(hidden_states_q, tmp_mask_offset, mask_offset, valid_word_num[0].item(), hidden_size) 26 | k = ft_remove_padding(hidden_states_k, tmp_mask_offset, mask_offset, valid_word_num[0].item(), hidden_size) 27 | v = ft_remove_padding(hidden_states_v, tmp_mask_offset, mask_offset, valid_word_num[0].item(), hidden_size) 28 | 29 | new_qkv_shape = q.shape[:-1] + (head_num, head_size) 30 | 31 | q = q.view(new_qkv_shape) 32 | k = k.view(new_qkv_shape) 33 | v = v.view(new_qkv_shape) 34 | print(q.size()) 35 | 36 | q_buf = torch.zeros(batch_size, head_num, max_padding_size, head_size).cuda() 37 | k_buf = torch.zeros(batch_size, head_num, max_padding_size, head_size).cuda() 38 | v_buf = torch.zeros(batch_size, head_num, max_padding_size, head_size).cuda() 39 | 40 | ft_transpose_rebuild_padding(q, k, v, q_buf, k_buf, v_buf, batch_size, max_padding_size, head_num, head_size, valid_word_num[0].item(), mask_offset) 41 | 42 | print(q_buf.size()) 43 | 44 | q_buf = ft_transpose_remove_padding(v_buf, valid_word_num[0].item(), batch_size, max_padding_size, head_num, head_size, mask_offset) 45 | 46 | print(q_buf.size()) 47 | 48 | q_buf = ft_rebuild_padding(q_buf, mask_offset, valid_word_num[0].item(), hidden_size, batch_size, max_padding_size) 49 | 50 | print(q_buf.size()) 51 | 52 | 53 | 54 | 55 | 56 | # ft_transpose_remove_padding() 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | # void ft_transpose_remove_padding_wrapper(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor q_buf, torch::Tensor k_buf, torch::Tensor v_buf, 65 | # int batch_size, int seq_len, int head_num, int size_per_head, int valid_word_num, torch::Tensor mask_offset){ 66 | 67 | 68 | # print(new_hidden_states.size()) 69 | 70 | # def ft_remove_padding(src, tmp_mask_offset, mask_offset, valid_word_num, hidden_dim): 71 | # def ft_rebuild_padding(src, mask_offset, valid_word_num, hidden_dim): 72 | # def ft_transpose_remove_padding(Q, K, V, q_buf, k_buf, v_buf, batch_size, seq_len, head_num, size_per_head, valid_word_num, mask_offset): 73 | # def ft_transpose_rebuild_padding(src, valid_word_num, batch_size, seq_len, head_num, size_per_head, mask_offset): 74 | 75 | 76 | if __name__ == '__main__': 77 | test_kernel() -------------------------------------------------------------------------------- /tests/test_kernel/test_linear_func.py: -------------------------------------------------------------------------------- 1 | from energonai.kernel import linear, find_algo 2 | import torch 3 | import time 4 | 5 | 6 | @torch.no_grad() 7 | def test_linear_func(): 8 | batch_size = 16 9 | seq_len = 64 10 | din = 12288 11 | dout = 49152 12 | 13 | inputs = torch.randn(batch_size, seq_len, din).half().cuda() 14 | params = torch.randn(dout, din).half().cuda() 15 | tensor_target = torch.nn.functional.linear(inputs, params) 16 | tensor_output = linear(inputs, params) 17 | diff = torch.abs(tensor_output - tensor_target) 18 | max_diff = torch.max(diff) 19 | mean_diff = torch.mean(diff) 20 | max_array = torch.max(tensor_target) 21 | 22 | if mean_diff > 0.5 or max_diff > 15 or max_diff / max_array > 0.05: 23 | print("mean_diff:%.2f, max_diff:%.2f, max_diff/max_array:%.4f" % 24 | (mean_diff, max_diff, max_diff / max_array)) 25 | print('target:', tensor_target, '\n') 26 | print('output:', tensor_output, '\n') 27 | raise AssertionError("Wrong value!") 28 | 29 | print('tests pass') 30 | 31 | 32 | @torch.no_grad() 33 | def benchmark_linear_func(): 34 | algo = find_algo() 35 | batch_size = 16 36 | seq_len = 64 37 | din = 12288 38 | dout = 49152 39 | 40 | inner_loop = 8 41 | outer_loop = 20 42 | 43 | input_list_1 = [] 44 | param_list_1 = [] 45 | input_list_2 = [] 46 | param_list_2 = [] 47 | for i in range(inner_loop): 48 | input_list_1.append(torch.randn(batch_size, seq_len, din).half().cuda()) 49 | param_list_1.append(torch.randn(dout, din).half().cuda()) 50 | input_list_2.append(input_list_1[-1].clone().detach()) 51 | param_list_2.append(param_list_1[-1].clone().detach()) 52 | 53 | torch_count = 0 54 | cublas_count = 0 55 | 56 | for _ in range(outer_loop): 57 | for i in range(inner_loop): 58 | _ = torch.nn.functional.linear(input_list_2[i], param_list_2[i]) 59 | torch.cuda.synchronize() 60 | _ = linear(input_list_1[i], param_list_1[i], algo) 61 | torch.cuda.synchronize() 62 | _ = torch.nn.functional.linear(input_list_2[i], param_list_2[i]) 63 | torch.cuda.synchronize() 64 | _ = linear(input_list_1[i], param_list_1[i], algo) 65 | torch.cuda.synchronize() 66 | 67 | torch.cuda.synchronize() 68 | start_time = time.time() 69 | _ = torch.nn.functional.linear(input_list_2[i], param_list_2[i]) 70 | torch.cuda.synchronize() 71 | torch_count += time.time() - start_time 72 | 73 | torch.cuda.synchronize() 74 | start_time = time.time() 75 | _ = linear(input_list_1[i], param_list_1[i], algo) 76 | torch.cuda.synchronize() 77 | cublas_count += time.time() - start_time 78 | 79 | torch_time = torch_count / inner_loop / outer_loop 80 | cublas_time = (cublas_count / inner_loop / outer_loop) 81 | print("==> torch time: %.6f" % torch_time) 82 | print("==> cublas time: %.6f, speedup: %.4f%%" % (cublas_time, (torch_time - cublas_time) / torch_time * 100)) 83 | 84 | 85 | if __name__ == '__main__': 86 | test_linear_func() 87 | benchmark_linear_func() 88 | -------------------------------------------------------------------------------- /tests/test_kernel/test_transpose_pad_fusion_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import pytest 4 | from energonai.kernel import transpose_pad, transpose_depad 5 | 6 | seq_lens = torch.tensor([24,127,31,65,24,127,31,65], dtype=torch.int64).cuda() 7 | batch_size = 8 8 | max_padding_size = 128 9 | head_size = 64 10 | head_num = 12 11 | hidden_size = head_num * head_size 12 | 13 | 14 | def seq_init(x): 15 | # for i in range(batch_size): 16 | # len = seq_lens[i] 17 | # for j in range (max_padding_size): 18 | # for k in range(hidden_size): 19 | # if(j 0.001): 57 | print(i) 58 | print(tta[i]) 59 | print(ttb[i]) 60 | return False 61 | return True 62 | 63 | def test_kernel(): 64 | # original 65 | hidden_states = torch.zeros(batch_size, max_padding_size, head_num*head_size).cuda().float() 66 | hidden_states = seq_init(hidden_states) 67 | input_pad = reshape(hidden_states) 68 | res_original_pad = input_pad.permute(0,2,1,3) 69 | 70 | # transpose_pad 71 | hidden_states_depad = manual_depad(hidden_states) 72 | input_depad = reshape(hidden_states_depad) 73 | res_transpose_pad = transpose_pad(input_depad, batch_size, max_padding_size, seq_lens, head_num, head_size) 74 | assert compare(res_transpose_pad, res_original_pad) == True, "transpose_pad fault." 75 | 76 | # transpose_depad 77 | sum_seq = torch.sum(seq_lens) 78 | res_transpose_depad = transpose_depad(res_original_pad, batch_size, sum_seq, max_padding_size, seq_lens, head_num, head_size) 79 | assert compare(input_depad, res_transpose_depad) == True, "transpose_depad fault." 80 | 81 | 82 | 83 | 84 | 85 | if __name__ == '__main__': 86 | test_kernel() -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.0.1 --------------------------------------------------------------------------------