├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── releasenotes.md ├── setup.py └── src └── transformers_neuronx ├── LICENSE ├── __init__.py ├── activations.py ├── base.py ├── bloom ├── config.py ├── hlo.py ├── model.py └── modules.py ├── bucket.py ├── compiler.py ├── config.py ├── constants.py ├── decoder.py ├── decoder_topk.py ├── dtypes.py ├── eagle_speculation.py ├── fused_speculation.py ├── generation_demo.py ├── generation_utils.py ├── global_debugger.py ├── gpt2 ├── config.py ├── demo.py ├── gen_random_pretrained.py ├── hlo.py └── model.py ├── gpt_demo.py ├── gptj ├── config.py ├── demo.py ├── hlo.py └── model.py ├── gptneox ├── config.py ├── demo.py ├── hlo.py └── model.py ├── hlo.py ├── kv_cache_manager.py ├── layers ├── alibi.py ├── attention.py ├── attention_utils.py ├── flash_decoding.py ├── generation.py ├── masking.py ├── rotary.py └── transformer.py ├── llama ├── aem.py ├── config.py ├── hlo.py ├── model.py └── modules.py ├── mistral ├── config.py ├── hlo.py ├── model.py └── modules.py ├── mixtral ├── config.py ├── hlo.py ├── model.py └── modules.py ├── modeling_auto.py ├── module.py ├── nki └── compile.py ├── ops.py ├── opt ├── config.py ├── demo.py ├── gen_random_pretrained.py ├── hlo.py └── model.py ├── pad ├── checkpoint.py └── layernorm_padded_cpu.py ├── parallel.py ├── program.py ├── quantize.py ├── sampling.py ├── sparse_attn_utils.py ├── speculation.py ├── stopping_criteria.py ├── tensor_pool.py ├── testing ├── data.py └── validation.py ├── tools ├── ckpt_converter.py └── gen_hlo_snapshot.py ├── util └── token_tree.py ├── utils.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | *.egg-info/ 3 | dist/ 4 | apt/ 5 | pip/ 6 | .attach_pid* 7 | __pycache__ 8 | copy.bara.tmp 9 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This file creates codeowners for the documentation. It will allow setting code reviewers for all Pull requests to merge to the master branch 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # Refernce guide - https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/creating-a-repository-on-github/about-code-owners#example-[…]ners-file 5 | # Example - These owners will be the default owners for everything in 6 | # the repo. Unless a later match takes precedence, 7 | # @global-owner1 and @global-owner2 will be requested for 8 | # review when someone opens a pull request. 9 | # * @global-owner1 @global-owner2 10 | 11 | * @aws-maens @aws-mesharma @musunita 12 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Transformers Neuron for Trn1 and Inf2 is a software package that enables 2 | PyTorch users to perform large language model (LLM) inference on 3 | second-generation Neuron hardware (See: [NeuronCore-v2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuron-core-v2.html)). 4 | 5 | # Transformers Neuron (``transformers-neuronx``) Documentation 6 | Please refer to the [Transformers Neuron documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/) for setup and developer guides. 7 | 8 | # Installation 9 | 10 | ## Stable Release 11 | 12 | To install the most rigorously tested stable release, use the PyPI pip wheel: 13 | 14 | ``` 15 | pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com 16 | ``` 17 | 18 | ## Development Version 19 | 20 | The AWS Neuron team is currently restructuring the contribution model of this github repository. This github repository content 21 | does not reflect latest features and improvements of transformers-neuronx library. Please install the stable release version 22 | from https://pip.repos.neuron.amazonaws.com to get latest features and improvements. 23 | 24 | # Release Notes and Supported Models 25 | 26 | Please refer to the [transformers-neuronx release notes](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/transformers-neuronx/index.html) to see the latest supported features and models. 27 | 28 | 29 | # Troubleshooting 30 | 31 | Please refer to our [Contact 32 | Us](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/contact.html) 33 | page for additional information and support resources. If you intend to 34 | file a ticket and you can share your model artifacts, please re-run your 35 | failing script with ``NEURONX_DUMP_TO=./some_dir``. This will dump 36 | compiler artifacts and logs to ``./some_dir``. You can then include this 37 | directory in your correspondance with us. The artifacts and logs are 38 | useful for debugging the specific failure. 39 | 40 | # Security 41 | 42 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 43 | 44 | # License 45 | 46 | This library is licensed under the Apache License 2.0 License. 47 | -------------------------------------------------------------------------------- /releasenotes.md: -------------------------------------------------------------------------------- 1 | # Transformers Neuron 0.5.0 Release Notes 2 | 3 | Date: 2023-07-03 4 | 5 | ## What's New? 6 | 7 | - [Experimental] Added support for GPT-NeoX models. 8 | - [Experimental] Added support for BLOOM models. 9 | - [Prototype] Added support for LLaMA models. 10 | - Added support for more flexible tensor-parallel configurations to GPT2, OPT, and BLOOM. The attention heads doesn't need to be evenly divisible by `tp_degree` anymore. (Note: The `tp_degree` still needs to satisfy the runtime topologies constraint for collective communication (i.e Allreduce). For more details on supported topologies, see: [Tensor-parallelism support](README.md#tensor-parallelism-support) and https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-features/collective-communication.html.) 11 | - Added multi-query / multi-group attention support for GPT2. 12 | 13 | ## Bug Fixes 14 | 15 | - Fixed NaN issues for GPT2 model. 16 | - Fixed OPT/GPT-NeoX gibberish output 17 | - Resolved an issue where NaN values could be produced when the context_length argument was used in GPT2/OPT. 18 | 19 | ## Known Issues and Limitations 20 | 21 | - Missing cache reorder support for beam search. 22 | 23 | # Transformers Neuron 0.4.0 Release Notes 24 | 25 | Date: 2023-06-12 26 | 27 | ## What's New? 28 | 29 | - Added ``int8`` weight storage for `GPT2` models. 30 | - Improved prompt context encoding performance for `GPT2` models. 31 | - Improved collective communications performance for tp-degrees 4, 8, and 24 on Inf2. 32 | - Improved collective communications performance for tp-degrees 8 and 32 on Trn1. 33 | - Support for the ``--model-type=transformer-inference`` compiler flag for optimized decoder-only LLM inference. 34 | 35 | ## Bug Fixes 36 | 37 | - Added padding to the `GPT-J` ``linear`` layer to correctly handle odd vocabulary sizes. 38 | - Issues where the HuggingFace `generate` method produces incorrect results when 39 | `beam_search` is used have been resolved. 40 | 41 | 42 | # Transformers Neuron 0.3.0 Release Notes 43 | 44 | Date: 2023-04-28 45 | 46 | ## What's New? 47 | 48 | - Added ``transformers-neuronx`` artifacts to PyPI repository. 49 | - Added support for the the Hugging Face [generate()](https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) 50 | - Added support for model serialization, including model saving, loading, and 51 | weight swapping. 52 | - Added support for caching compiled artifacts. 53 | - Improved performance by removing unnecessary KV-cache tensor resetting. 54 | - Improved prompt context encoding performance (`OPT`, `GPT2`). 55 | 56 | ## Bug Fixes 57 | 58 | - Incorrect `GPT-J` ``amp_callback`` import: Fixed the `GPT-J` demo now imports the correct ``amp_callback`` function. 59 | 60 | ## Known Issues and Limitations 61 | 62 | Incorrect output with HuggingFace `beam_search`: When the HuggingFace `generate` method is configured to use `beam_search`, this 63 | can produce incorrect results for certain configurations. It is recommended to 64 | use other generation methods such as `sample` or `greedy_search`. 65 | 66 | 67 | # Transformers Neuron 0.2.0 Release Notes 68 | 69 | Date: 2023-02-24 70 | 71 | ## What's New? 72 | 73 | - Added error handling to check if the desired generated sequence length is valid based on the model configuration 74 | - Improved logging: 75 | - Reduced overly verbose compiler messages 76 | - Disabled lazy module warnings 77 | 78 | ## Bug Fixes 79 | 80 | - Updated `src/transformers_neuronx/gptj/demo.py` to correctly use the `amp_callback` function from `transformers_neuronx.gpt2.demo` 81 | - Extend the `gpt_demo.py` `save` function to support GPT-2 and GPT-J configs 82 | 83 | # Transformers Neuron 0.1.0 Release Notes 84 | 85 | Date: 2023-02-08 86 | 87 | First release of `transformers-neuronx`, a new library that enables LLM model inference on Inf2 & Trn1 using the Neuron SDK. `transformers-neuronx` contains optimized model implementations that are checkpoint-compatible with HuggingFace Transformers, and currently supports Transformer Decoder models like GPT2, GPT-J and OPT. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import os 16 | from datetime import datetime 17 | from setuptools import setup, PEP420PackageFinder 18 | 19 | 20 | def version_py_path(): 21 | return os.path.join(os.path.dirname(__file__), 'src', 'transformers_neuronx', 'version.py') 22 | 23 | 24 | exec(open(version_py_path()).read()) 25 | 26 | 27 | def get_version(): 28 | # please make sure the major.minor version matches the interface version in Config 29 | version = os.environ.get('TRANSFORMERS_NEURONX_VERSION', __version__) 30 | today = datetime.today().strftime('%Y%m%d') 31 | return version.replace('.x', f'.{today}') 32 | 33 | 34 | setup( 35 | name='transformers-neuronx', 36 | version=get_version(), 37 | classifiers=[ 38 | 'Development Status :: 3 - Alpha', 39 | 'Intended Audience :: Developers', 40 | 'Intended Audience :: Education', 41 | 'Intended Audience :: Science/Research', 42 | 'License :: OSI Approved :: Apache Software License', 43 | 'Programming Language :: Python :: 3', 44 | 'Topic :: Scientific/Engineering', 45 | 'Topic :: Scientific/Engineering :: Mathematics', 46 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 47 | 'Topic :: Software Development', 48 | 'Topic :: Software Development :: Libraries', 49 | 'Topic :: Software Development :: Libraries :: Python Modules', 50 | ], 51 | keywords='aws neuron neuronx transformers', 52 | packages=PEP420PackageFinder.find(where='src'), 53 | url='https://github.com/aws-neuron/transformers-neuronx', 54 | maintainer='Amazon Web Services, Inc.', 55 | description='Transformers Neuron for Trn1 and Inf2 is a software package that enables PyTorch users to perform large language model (LLM) inference on second-generation Neuron hardware', 56 | license='Apache License', 57 | license_files = ('LICENSE',), 58 | entry_points = { 59 | 'console_scripts': [ 60 | 'generation_demo=transformers_neuronx.generation_demo:main', 61 | 'gpt2_demo=transformers_neuronx.gpt2.demo:main', 62 | 'gptj_demo=transformers_neuronx.gptj.demo:main', 63 | 'gptneox_demo=transformers_neuronx.gptneox.demo:main', 64 | 'opt_demo=transformers_neuronx.opt.demo:main', 65 | 'opt_gen_random_pretrained=transformers_neuronx.opt.gen_random_pretrained:main', 66 | 'gen_randn_hlo_snapshot=transformers_neuronx.tools.gen_hlo_snapshot:main_randn', 67 | 'ckpt_converter=transformers_neuronx.tools.ckpt_converter:ckpt_converter', 68 | ], 69 | }, 70 | install_requires=[ 71 | 'accelerate', 72 | 'safetensors', 73 | 'torch-neuronx', 74 | 'transformers>=4.36', 75 | ], 76 | python_requires='>=3.7', 77 | package_dir={'': 'src'}, 78 | ) 79 | -------------------------------------------------------------------------------- /src/transformers_neuronx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx.version import __version__ 16 | 17 | 18 | from transformers_neuronx.constants import GQA, Layout 19 | from transformers_neuronx.sparse_attn_utils import SparseAttnConfig 20 | from transformers_neuronx.config import NeuronConfig, QuantizationConfig, ContinuousBatchingConfig, GenerationConfig 21 | from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter 22 | 23 | from transformers_neuronx.bloom.model import BloomForSampling 24 | from transformers_neuronx.llama.model import LlamaForSampling 25 | from transformers_neuronx.gpt2.model import GPT2ForSamplingWithContextBroadcasting 26 | from transformers_neuronx.gptneox.model import GPTNeoXForSampling 27 | from transformers_neuronx.gptj.model import GPTJForSampling 28 | from transformers_neuronx.mistral.model import MistralForSampling 29 | from transformers_neuronx.mixtral.model import MixtralForSampling 30 | from transformers_neuronx.opt.model import OPTForSampling 31 | 32 | from transformers_neuronx.modeling_auto import NeuronAutoModelForCausalLM 33 | 34 | from . import testing 35 | -------------------------------------------------------------------------------- /src/transformers_neuronx/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import math 16 | 17 | def gelu_new(hidden): 18 | return hidden.dtype[hidden.sizes].CustomCall(hidden, custom_call_target="AwsNeuronGeluApprxTanh") 19 | 20 | def gelu_new_legacy(hidden): 21 | dtype = hidden.dtype 22 | sizes = hidden.sizes 23 | input_input = dtype[sizes].Multiply(hidden, hidden) 24 | input_pow_3 = dtype[sizes].Multiply(input_input, hidden) 25 | scale = dtype.Constant(constant_value=0.044715) 26 | scale_br = dtype[sizes].Broadcast(scale, dimensions=[]) 27 | mul = dtype[sizes].Multiply(input_pow_3, scale_br) 28 | add = dtype[sizes].Add(mul, hidden) 29 | sqrt_2_over_pi = dtype.Constant(constant_value=math.sqrt(2.0 / math.pi)) 30 | sqrt_2_over_pi_br = dtype[sizes].Broadcast(sqrt_2_over_pi, dimensions=[]) 31 | mul2 = dtype[sizes].Multiply(add, sqrt_2_over_pi_br) 32 | tanh = dtype[sizes].Tanh(mul2) 33 | one = dtype.Constant(constant_value=1.0) 34 | one_br = dtype[sizes].Broadcast(one, dimensions=[]) 35 | add1 = dtype[sizes].Add(tanh, one_br) 36 | mul3 = dtype[sizes].Multiply(add1, hidden) 37 | half = dtype.Constant(constant_value=0.5) 38 | half_br = dtype[sizes].Broadcast(half, dimensions=[]) 39 | output = dtype[sizes].Multiply(mul3, half_br) 40 | return output 41 | 42 | 43 | def relu(hidden): 44 | dtype = hidden.dtype 45 | sizes = hidden.sizes 46 | zero = dtype.Constant(constant_value=0.0) 47 | zero_br = dtype[sizes].Broadcast(zero, dimensions=[]) 48 | return dtype[sizes].Maximum(hidden, zero_br) 49 | 50 | def softmax(logits, dim=None): 51 | rank = len(logits.sizes) 52 | if dim is None: 53 | dim = rank - 1 54 | shape = logits.sizes 55 | dtype = logits.dtype 56 | backend_config = str(dim).encode() 57 | return dtype[shape].CustomCall(logits, custom_call_target="AwsNeuronSoftmax", backend_config=backend_config,) 58 | 59 | def solu(hidden, dim=None): 60 | dtype = hidden.dtype 61 | sizes = hidden.sizes 62 | softmax_hidden = softmax(hidden, dim) 63 | output = dtype[sizes].Multiply(hidden,softmax_hidden) 64 | return output 65 | 66 | def sigmoid(tensor): 67 | return tensor.dtype[tensor.sizes].Logistic(tensor) 68 | 69 | 70 | def silu(tensor): 71 | logistic = sigmoid(tensor) 72 | return tensor.dtype[tensor.sizes].Multiply(tensor, logistic) 73 | -------------------------------------------------------------------------------- /src/transformers_neuronx/bloom/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import utils 16 | 17 | class BloomConfig: 18 | 19 | def __init__( 20 | self, 21 | config, 22 | n_positions, 23 | batch_size, 24 | amp, 25 | tp_degree, 26 | **kwargs 27 | ): 28 | 29 | # Extract configs used for building HLO 30 | self.embed_dim = config.hidden_size 31 | self.hidden_size = config.hidden_size 32 | self.attention_head_size = config.hidden_size // config.n_head 33 | self.n_head = config.n_head 34 | self.n_layer = config.num_hidden_layers 35 | self.vocab_size = config.vocab_size 36 | self.layer_norm_epsilon = config.layer_norm_epsilon 37 | self.bos_token_id = config.bos_token_id 38 | self.eos_token_id = config.eos_token_id 39 | 40 | utils.maybe_override_attributes(self, kwargs) 41 | 42 | # Add required Neuron configs 43 | self.n_positions = n_positions 44 | self.batch_size = batch_size 45 | self.amp = amp 46 | self.tp_degree = tp_degree 47 | self.model_type = 'bloom' 48 | -------------------------------------------------------------------------------- /src/transformers_neuronx/bloom/hlo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import hlo 16 | from transformers_neuronx.constants import LAYOUT_BSH 17 | from transformers_neuronx.layers import transformer, alibi, attention 18 | from transformers_neuronx.bloom.config import BloomConfig 19 | 20 | class BloomForSamplingNoEmbeddingHlo: 21 | 22 | def __init__(self, config: BloomConfig, neuron_config=None): 23 | self.config = config 24 | self.neuron_config = neuron_config 25 | self.n_positions = None 26 | 27 | def inputs(self, scribe, dtype, n_active_tokens, batch_size): 28 | tensors, dims = transformer.inputs( 29 | scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config 30 | ) 31 | return tensors, dims 32 | 33 | def embedding(self, input_ids, cache_ids, start_ids, last_token_id, block_tables, context_lens, slopes, word_embeddings, ln_weight, ln_bias): 34 | dtype = getattr(input_ids.scribe, self.config.amp) 35 | hidden = hlo.embedding(word_embeddings, input_ids, tp_degree=self.config.tp_degree, dtype=dtype) 36 | if self.config.hidden_size % self.config.tp_degree != 0: 37 | hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0) 38 | is_bsh = self.neuron_config.attention_layout == LAYOUT_BSH 39 | if not is_bsh: 40 | hidden = hlo.transpose210(hidden) 41 | return hlo.layer_norm_bsh(hidden, ln_weight, ln_bias) if is_bsh \ 42 | else hlo.layer_norm(hidden, ln_weight, ln_bias) 43 | 44 | def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, context_lens, *pre_layer_weights): 45 | slopes, *rest = pre_layer_weights 46 | mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions, 47 | last_token_id=last_token_id, neuron_config=self.neuron_config) 48 | prior_alibi, active_alibi = alibi.alibi(slopes, mask, active_mask) 49 | return hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi 50 | 51 | def layer(self, hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi, attn_k_cache, attn_v_cache, 52 | pre_attn_ln_weight, pre_attn_ln_bias, 53 | attn_q_weight, attn_q_scales, attn_q_bias, 54 | attn_k_weight, attn_k_scales, attn_k_bias, 55 | attn_v_weight, attn_v_scales, attn_v_bias, 56 | attn_out_weight, attn_out_scales, attn_out_bias, 57 | post_attn_ln_weight, post_attn_ln_bias, 58 | pre_mlp_ln_weight, pre_mlp_ln_bias, 59 | mlp_in_weight, mlp_in_scales, mlp_in_bias, 60 | mlp_out_weight, mlp_out_scales, mlp_out_bias, 61 | post_mlp_ln_weight, post_mlp_ln_bias): 62 | 63 | dtype = hidden.dtype 64 | is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH 65 | layer_norm = hlo.layer_norm_bsh if is_bsh else hlo.layer_norm 66 | ln_hidden = layer_norm(hidden, pre_attn_ln_weight, pre_attn_ln_bias) 67 | attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( 68 | ln_hidden, cache_ids, mask, active_mask, prior_alibi, active_alibi, 69 | attn_k_cache, attn_v_cache, 70 | attn_q_weight, attn_q_scales, attn_q_bias, 71 | attn_k_weight, attn_k_scales, attn_k_bias, 72 | attn_v_weight, attn_v_scales, attn_v_bias, 73 | attn_out_weight, attn_out_scales, attn_out_bias, 74 | neuron_config=self.neuron_config 75 | ) 76 | hidden = dtype[hidden.sizes].Add(attn_output, hidden) 77 | ln_hidden = layer_norm(hidden, pre_mlp_ln_weight, pre_mlp_ln_bias) 78 | mlp = hlo.mlp_bsh if is_bsh else hlo.mlp 79 | mlp_hidden = mlp( 80 | ln_hidden, 81 | mlp_in_weight, mlp_in_bias, mlp_out_weight, mlp_out_bias, 82 | activation_function='gelu_new', 83 | tp_degree=self.config.tp_degree, 84 | in_scales=mlp_in_scales, out_scales=mlp_out_scales, 85 | neuron_config=self.neuron_config, 86 | transposed=True, 87 | ) 88 | 89 | hidden = dtype[hidden.sizes].Add(mlp_hidden, hidden) 90 | return hidden, out_attn_k_cache, out_attn_v_cache 91 | 92 | def ln_lm_head(self, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, lm_head_bias, return_all_outputs=True): 93 | logits = transformer.ln_lm_head(self.config.tp_degree, hidden, last_token_id, ln_f_weight, ln_f_bias, lm_head_weight, 94 | lm_head_bias, return_all_outputs, neuron_config=self.neuron_config) 95 | return logits 96 | 97 | def attention(self, 98 | hidden, cache_ids, mask, active_mask, prior_alibi, active_alibi, 99 | cached_keys, cached_values, 100 | q_weight, q_scales, q_bias, 101 | k_weight, k_scales, k_bias, 102 | v_weight, v_scales, v_bias, 103 | out_weight, out_scales, out_bias, 104 | neuron_config=None 105 | ): 106 | scribe = hidden.scribe 107 | f32 = scribe.f32 108 | dtype = hidden.dtype 109 | d_head = self.config.hidden_size // self.config.n_head 110 | 111 | # Q = (hidden @ wQ) + bQ 112 | # K = (hidden @ wK) + bK 113 | # V = (hidden @ wV) + bV 114 | query, key, value = attention.query_key_value( 115 | hidden, 116 | q_weight, q_scales, q_bias, 117 | k_weight, k_scales, k_bias, 118 | v_weight, v_scales, v_bias, 119 | d_head, 120 | neuron_config=neuron_config, 121 | ) 122 | 123 | # Q = Q / sqrt(d_head) 124 | query = attention.scale(query, d_head) 125 | 126 | # Single Token Generation ("Prefetch"-style) 127 | if active_mask is not None: 128 | 129 | # This is an optimized `context` computation for a single token. 130 | # It uses a split prior/active key & value tensors. This allows 131 | # the KV cache updates to occur asynchronously with the attention 132 | # computation which improves performance on Neuron hardware. 133 | 134 | # Sp = Q @ Kp + Ap 135 | prior_scores = attention.score(query, cached_keys) 136 | prior_scores = f32[prior_scores.sizes].Convert(prior_scores) 137 | prior_scores = f32[prior_scores.sizes].Add(prior_scores, prior_alibi) 138 | prior_scores = attention.mask(prior_scores, mask) 139 | prior_scores = hlo.cast(prior_scores, dtype) 140 | 141 | # Sa = Q @ Ka + Aa 142 | active_score = attention.score(query, key) 143 | active_score = f32[active_score.sizes].Convert(active_score) 144 | active_score = f32[active_score.sizes].Add(active_score, active_alibi) 145 | active_mask_sh = hlo.unsqueeze(active_mask, 1) 146 | active_score = attention.mask(active_score, active_mask_sh) 147 | active_score = hlo.cast(active_score, dtype) 148 | 149 | # C = softmax(Sa, Sp) @ (Va, Vp) 150 | context = attention.context(prior_scores, active_score, cached_values, value) 151 | 152 | # KCache[I] = K 153 | # VCache[I] = V 154 | updated_keys = attention.update_cache(cached_keys, cache_ids, key) 155 | updated_values = attention.update_cache(cached_values, cache_ids, value) 156 | 157 | # Multi-Token Context Encoding 158 | else: 159 | 160 | # This `context` computation block is intended for populating the 161 | # KV cache with multiple `n_active_tokens` tokens. This assumes 162 | # that there is no prior history so it skips any computation 163 | # performed on the cache. 164 | 165 | # S = Q @ K + A 166 | score = attention.score(query, key) 167 | score = f32[score.sizes].Convert(score) 168 | score = f32[score.sizes].Add(score, prior_alibi) 169 | score = attention.mask(score, mask) 170 | score = hlo.cast(score, dtype) 171 | 172 | # C = softmax(S) @ V 173 | context = attention.context_combined(score, value) 174 | 175 | # KCache = K 176 | # VCache = V 177 | updated_keys = key 178 | updated_values = value 179 | 180 | # O = (C @ wO) + bO 181 | output = attention.output(context, out_weight, out_scales, out_bias, self.config.tp_degree, neuron_config) 182 | return output, updated_keys, updated_values 183 | -------------------------------------------------------------------------------- /src/transformers_neuronx/bloom/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import dtypes 16 | from transformers_neuronx import module 17 | from transformers_neuronx import utils 18 | 19 | 20 | class BloomForCausalLM(module.PretrainedModel): 21 | 22 | def __init__(self, config): 23 | super().__init__() 24 | dtype, _, _ = utils.parse_amp(config.amp) 25 | dtype = dtypes.to_torch_dtype(dtype) 26 | self.transformer = BloomModel(config) 27 | self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) 28 | 29 | def get_tied_parameters(self): 30 | return [(self.transformer.word_embeddings.weight, self.lm_head.weight)] 31 | 32 | def get_base_model(self): 33 | return self.transformer 34 | 35 | class BloomModel(module.LowMemoryModule): 36 | 37 | def __init__(self, config): 38 | super().__init__() 39 | self.word_embeddings = module.LowMemoryEmbedding(config.vocab_size, config.embed_dim) 40 | self.word_embeddings_layernorm = module.LowMemoryLayerNorm(config.embed_dim, eps=config.layer_norm_epsilon) 41 | self.h = module.LowMemoryModuleList([BloomBlock(config) for _ in range(config.n_layer)]) 42 | self.ln_f = module.LowMemoryLayerNorm(config.embed_dim, eps=config.layer_norm_epsilon) 43 | 44 | 45 | class BloomBlock(module.LowMemoryModule): 46 | 47 | def __init__(self, config): 48 | super().__init__() 49 | self.input_layernorm = module.LowMemoryLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) 50 | self.self_attention = BloomAttention(config) 51 | self.post_attention_layernorm = module.LowMemoryLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) 52 | self.mlp = BloomMLP(config) 53 | 54 | 55 | class BloomAttention(module.LowMemoryModule): 56 | 57 | def __init__(self, config): 58 | super().__init__() 59 | dtype, _, _ = utils.parse_amp(config.amp) 60 | dtype = dtypes.to_torch_dtype(dtype) 61 | self.query_key_value = module.LowMemoryLazyLinear(3 * config.hidden_size, dtype=dtype, bias=True) 62 | self.dense = module.LowMemoryLazyLinear(config.hidden_size, dtype=dtype) 63 | 64 | 65 | class BloomMLP(module.LowMemoryModule): 66 | 67 | def __init__(self, config): 68 | super().__init__() 69 | dtype, _, _ = utils.parse_amp(config.amp) 70 | dtype = dtypes.to_torch_dtype(dtype) 71 | self.dense_h_to_4h = module.LowMemoryLazyLinear(4 * config.hidden_size, dtype=dtype) 72 | self.dense_4h_to_h = module.LowMemoryLazyLinear(config.embed_dim, dtype=dtype) 73 | -------------------------------------------------------------------------------- /src/transformers_neuronx/bucket.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import bisect 16 | from typing import List, Union, Optional 17 | 18 | from transformers_neuronx import utils 19 | 20 | 21 | def token_sizes(buckets_or_size: Union[List[int], int]) -> List[int]: 22 | """ 23 | Compute the bucket sizes for token generation. 24 | 25 | The `buckets_or_size` argument may be specified as a list of buckets and 26 | any other logic in this function will be ignored. 27 | 28 | When `buckets_or_size` is an integer value, buckets will be chosen by 29 | doubling the size of each bucket starting at 128 up to the provided value. 30 | 31 | Arguments: 32 | buckets_or_size: A list of buckets or maximum size for token generation. 33 | 34 | Returns 35 | buckets: The list of bucket sizes for token generation. 36 | """ 37 | if isinstance(buckets_or_size, list): 38 | return sorted(buckets_or_size) 39 | return utils.power_of_two_bucket_sizes(128, buckets_or_size) 40 | 41 | 42 | def context_sizes( 43 | buckets_or_size: Optional[Union[List[int], int]], 44 | token_buckets: Optional[List[int]] = None, 45 | ) -> List[int]: 46 | """ 47 | Compute the bucket sizes for context encoding. 48 | 49 | The `buckets_or_size` argument may be specified as a list of buckets or 50 | a single bucket and any other logic in this function will be ignored. 51 | 52 | When `bucket_or_size` is None, the default context length bucket sizes 53 | are set to be half of token model bucket sizes. 54 | 55 | When `bucket_or_size` is set to 0 this completely disables context 56 | encoding. 57 | 58 | Arguments: 59 | buckets_or_size: A list of buckets or maximum size for context encoding. 60 | token_buckets: The token buckets to generate context buckets for. 61 | 62 | Returns 63 | buckets: The list of bucket sizes for context encoding. 64 | """ 65 | if isinstance(buckets_or_size, list): 66 | return sorted(buckets_or_size) 67 | if isinstance(buckets_or_size, int): 68 | if buckets_or_size <= 0: 69 | return [] 70 | return [buckets_or_size] 71 | if token_buckets is not None and buckets_or_size is None: 72 | return [bucket // 2 for bucket in token_buckets] 73 | raise NotImplementedError(f'Prompt bucket config {buckets_or_size} not supported') 74 | 75 | 76 | def batch_sizes(batch_size: Union[List[int], int]) -> List[int]: 77 | """ 78 | Format the user-specified batch size buckets for all model variants. 79 | 80 | Arguments: 81 | batch_size: The batch size(s) to construct models for. 82 | 83 | Returns 84 | batch_sizes: The formatted list of batch sizes. 85 | """ 86 | if isinstance(batch_size, int): 87 | return [batch_size] 88 | elif isinstance(batch_size, list): 89 | return sorted(batch_size) 90 | else: 91 | raise TypeError("batch_size must be list of ints or int type") 92 | 93 | 94 | def find(buckets: Optional[List[int]], size: int) -> Optional[int]: 95 | """ 96 | Find the smallest bucket with that fits the given `size` input. 97 | 98 | When `size` exceeds the largest bucket size, the largest bucket will be 99 | returned. 100 | 101 | Arguments: 102 | buckets: Either the prompt/token bucket sizes to search. 103 | size: The size to fit into the buckets. 104 | 105 | Return: 106 | bucket: The bucket value (not the bucket index) 107 | """ 108 | if not buckets: # When we have no buckets, return None 109 | return None 110 | index = bisect.bisect_left(buckets, size, hi=len(buckets) - 1) 111 | return buckets[index] 112 | -------------------------------------------------------------------------------- /src/transformers_neuronx/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import enum 16 | 17 | # Tile size used for the weight transformation 18 | TILE_SIZE = 128 19 | 20 | # Size used to determine fused QKV operation. 21 | FUSED_QKV_TP_FACTOR = 3 22 | 23 | # KV sharding pad for flash decoding 24 | KV_SHARD_PAD = 128 25 | 26 | # Number of chips on Trn 27 | TRN1_WORLD_SIZE = 32 28 | 29 | # Layout for attention 30 | LAYOUT_BSH = 'BSH' 31 | LAYOUT_HSB = 'HSB' 32 | LAYOUT_SBH = 'SBH' 33 | 34 | 35 | # fp8 bounds 36 | class FpBounds: 37 | min = -240.0 38 | max = 240.0 39 | 40 | 41 | class Layout(enum.Enum): 42 | HSB = 'HSB' 43 | BSH = 'BSH' 44 | SBH = 'SBH' 45 | 46 | def __eq__(self, value): 47 | return super().__eq__(Layout(value)) 48 | 49 | 50 | # Group query attention sharding configurations 51 | class GQA(enum.Enum): 52 | # [Default] Sharding over the heads splits entire (complete) K/V heads 53 | # onto the NeuronCores where the corresponding Q heads reside. This is 54 | # similar to traditional MHA except that the Q and K/V heads do not need 55 | # to be equal. 56 | # 57 | # This cannot be enabled when number of K/V heads cannot be evenly split 58 | # across the NeuronCores according to the tensor parallelism degree. 59 | SHARD_OVER_HEADS = 'shard-over-heads' 60 | 61 | # Sharding over the bach dimension linearly shards the K/V heads across 62 | # all NeuronCores (incomplete heads per NeuronCore) and shards the K/V 63 | # cache across the batch dimension according to the tensor parallellism. 64 | # These partial K/V heads are concatenated and then split across NeuronCores 65 | # along the batch dimension (AllToAll) before computing the attention. 66 | # 67 | # This cannot be enabled when the batch size cannot to be evenly split 68 | # across the NeuronCores according to the tensor parallelism degree. 69 | SHARD_OVER_BATCH = 'shard-over-batch' 70 | 71 | # This transforms a GQA attention mechanism into a traditional MHA mechanism 72 | # by replicating the K/V heads to evenly match the corresponding Q heads. 73 | # This consumes more memory than would otherwise be used with other sharding 74 | # mechanisms but avoids collective communications overheads. 75 | REPLICATED_HEADS = 'replicated-heads' 76 | 77 | # This mechanism evenly splits the K/V heads across all NeuronCores 78 | # (incomplete heads per NeuronCore). These partial k/V heads are 79 | # concatenated to the the NeuronCores where the corresponding Q heads 80 | # reside (AllGather). This can be more memory efficient than replication 81 | # but introduces an additional collective communication operation per 82 | # decoder layer. 83 | # 84 | # This cannot be enabled when the number of Q heads cannot to be evenly 85 | # split across the NeuronCores according to the tensor parallelism degree. 86 | ALL_GATHER_HEADS = 'all-gather-heads' 87 | -------------------------------------------------------------------------------- /src/transformers_neuronx/dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import torch 16 | 17 | 18 | def to_torch_dtype(dtype): 19 | mapping = { 20 | 'f32': torch.float32, 21 | 'f16': torch.float16, 22 | 'bf16': torch.bfloat16, 23 | 's8': torch.int8, 24 | 'f8e4m3fn': torch.float8_e4m3fn if hasattr(torch, 'float8_e4m3fn') else torch.int8, 25 | } 26 | return mapping[dtype] 27 | 28 | 29 | def to_amp(dtype): 30 | mapping = { 31 | torch.float32: 'f32', 32 | torch.float16: 'f16', 33 | torch.bfloat16: 'bf16', 34 | } 35 | return mapping[dtype] 36 | 37 | 38 | def to_pyhlo_type(scribe, dtype): 39 | """ 40 | Map a torch dtype to the corresponding scribe dtype object. 41 | """ 42 | mapping = { 43 | "float32": scribe.f32, 44 | "float16": scribe.f16, 45 | "bfloat16": scribe.bf16, 46 | } 47 | return mapping[dtype] 48 | -------------------------------------------------------------------------------- /src/transformers_neuronx/generation_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | 4 | from transformers import PreTrainedModel 5 | from transformers.utils import ModelOutput 6 | 7 | """ 8 | An adapter class for HuggingFace Generation API compatibility. 9 | 10 | It requires model to have a forward interface as: 11 | 12 | forward(input_ids: Tensor(batch_size, seq_len), cache_ids: Tensor(seq_len), 13 | start_ids: Optional[Tensor(batch_size)])) -> Tensor(batch_size, vocab_size) 14 | """ 15 | class HuggingFaceGenerationModelAdapter(PreTrainedModel): 16 | 17 | def __init__(self, config, model): 18 | # sdpa attension is currently unsupported. Torch >=2.1.1 sets _attn_implementation to sdpa causing failures. 19 | if hasattr(config, "_attn_implementation") and config._attn_implementation == "sdpa": 20 | warnings.warn("Warning: sdpa is unsupported as part of attention implementation. Falling back to eager attention implementation.") 21 | config._attn_implementation = "eager" 22 | super().__init__(config) 23 | self.model = model 24 | self.config = config 25 | self.cur_len = torch.zeros(1, dtype=torch.long) 26 | 27 | def reset_generation(self): 28 | self.cur_len = torch.zeros(1, dtype=torch.long) 29 | 30 | def forward(self, input_ids, cache_ids, start_ids=None, output_hidden_states=False, output_attentions=False, 31 | attention_mask=None, return_dict=False): 32 | 33 | if output_hidden_states or output_attentions or attention_mask is not None: 34 | warnings.warn("Warning: These arguments are not used by forward(): \ 35 | (output_hidden_states, output_attentions, attention_mask)") 36 | 37 | if self.model.neuron_config.is_pp(): 38 | out_logits = self.model.pp_forward(input_ids, cache_ids, start_ids) 39 | else: 40 | out_logits = self.model(input_ids, cache_ids, start_ids) 41 | 42 | out_logits = out_logits[:, None, :] 43 | if return_dict: 44 | return ModelOutput( 45 | [("logits", out_logits), ("past_key_values", tuple())], 46 | ) 47 | return (out_logits,) 48 | 49 | # keep the generation stateless 50 | def generate(self, *args, **kwargs): 51 | self.reset_generation() 52 | return super().generate(*args, **kwargs) 53 | 54 | # implemented for beam search 55 | # we ignore past as we don't expose k/v_cache 56 | def _reorder_cache(self, past, beam_idx): 57 | assert hasattr(self.model, 'reorder_cache') and callable(self.model.reorder_cache), f"{self.model.__class__.__name__} doesn't have reorder_cache implemented for beam search" 58 | self.model.reorder_cache(beam_idx) 59 | 60 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 61 | # convert attention_mask to start_ids 62 | attention_mask = None 63 | cache_ids = None 64 | start_ids = None 65 | 66 | if "attention_mask" in kwargs: 67 | attention_mask = kwargs["attention_mask"] 68 | 69 | if attention_mask is not None: 70 | _, start_ids = attention_mask.max(axis=1) 71 | start_ids = start_ids.int() 72 | 73 | if (self.cur_len > 0).any().item(): 74 | input_ids = input_ids[:, -1:] 75 | 76 | if self.model.neuron_config.use_2d_cache_ids: 77 | # 2D cache_ids 78 | batch_size, context_length = attention_mask.shape 79 | start_ids = torch.arange(input_ids.shape[0]) 80 | if (self.cur_len > 0).any().item(): 81 | # token generation (aka decoding) with 2D cache_ids 82 | cache_ids = self.cur_len.unsqueeze(-1) + 1 83 | self.cur_len = cache_ids.squeeze(-1) 84 | else: 85 | # context encoding (aka prefill) with 2D cache_ids 86 | cache_ids = torch.arange(context_length) * attention_mask 87 | self.cur_len = cache_ids.max(dim=1).values 88 | else: 89 | if (self.cur_len > 0).any().item(): 90 | # token generation (aka decoding) with 1D cache_ids 91 | cache_ids = self.cur_len 92 | self.cur_len = cache_ids + 1 93 | else: 94 | # context encoding (aka prefill) with 1D cache_ids 95 | batch_size, context_length = input_ids.shape 96 | cache_ids = torch.arange(context_length) 97 | self.cur_len = torch.tensor([context_length], dtype=torch.long) 98 | 99 | model_inputs = { 100 | "input_ids": input_ids, 101 | "cache_ids": cache_ids, 102 | "start_ids": start_ids, 103 | } 104 | return model_inputs 105 | -------------------------------------------------------------------------------- /src/transformers_neuronx/global_debugger.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from collections import namedtuple 17 | import contextlib 18 | from typing import Optional 19 | from torch_neuronx.pyhlo.scribe import HloShape 20 | import functools 21 | 22 | debug_tensors = None 23 | 24 | DebugTensor = namedtuple("DebugTensor", ["tensor", "unshard_dim", "metadata"]) 25 | 26 | 27 | def tap(name: str, tensor: HloShape, unshard_dim: Optional[int] = None): 28 | """ 29 | Debug function "taps" a tensor in an HLO program for retrieval during 30 | program execution. This is useful for viewing intermediate tensor values. 31 | 32 | Args: 33 | name: the name of the tagged tensor that will be displayed. 34 | tensor: the intermidate tensor variable. 35 | 36 | Example Usage: 37 | In a model's hlo.py file, tag a tensor such as the attention output: 38 | ``` 39 | output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config) 40 | from transformers_neuronx import global_debugger 41 | global_debugger.tap("output", output) 42 | ``` 43 | """ 44 | if debug_tensors is None: 45 | return tensor 46 | debug_tensors[name] = DebugTensor(tensor, unshard_dim, {}) 47 | return tensor 48 | 49 | 50 | @contextlib.contextmanager 51 | def debug_context(): 52 | """ 53 | Context manager that retrieves debug_tensors during execution. 54 | 55 | Example Usage: 56 | During model execution, retrieve the debug tensors by wrapping the 57 | inference call: 58 | ``` 59 | with global_debugger.debug_context(): 60 | outputs = model_neuron.sample(inputs, sequence_length=128, top_k=1) 61 | debug_tensors = global_debugger.debug_tensors 62 | print(f'debug_tensors: {debug_tensors}') 63 | ``` 64 | """ 65 | global debug_tensors 66 | debug_tensors = {} 67 | try: 68 | yield 69 | finally: 70 | debug_tensors = None 71 | 72 | # Decorates a scribe function so that instead of returning a tuple, 73 | # it instead returns a list of outputs to which the debug tensors are 74 | # added. 75 | def populate_debug_tensors(debug_outputs={}): 76 | from transformers_neuronx import global_debugger as gdbg 77 | def inner(func): 78 | @functools.wraps(func) 79 | def wrapper(scribe): 80 | with gdbg.debug_context(): 81 | # This func code should run, and inside of it users can tag their tensors 82 | # any tensors that are tagged will be added to the outputs 83 | outputs = func(scribe) 84 | for (tag_name, debug_tensor) in gdbg.debug_tensors.items(): 85 | debug_tensor.metadata['output_index'] = len(outputs) 86 | debug_outputs[tag_name] = debug_tensor 87 | outputs.append(debug_tensor.tensor) 88 | # Now we just have to actually return a tuple 89 | root_shapes = [shape.dtype[shape.sizes] for shape in outputs] 90 | return scribe.tuple(*root_shapes).Tuple(*outputs) 91 | return wrapper 92 | return inner -------------------------------------------------------------------------------- /src/transformers_neuronx/gpt2/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import transformers 16 | from transformers_neuronx import utils 17 | 18 | 19 | # inherit from transformer.GPT2Config as we need to inherit from 20 | # transformers.PreTrainedModel to call transformers generation API 21 | class GPT2HuggingFaceConfig(transformers.GPT2Config): 22 | 23 | def __init__(self, config, batch_size, amp, tp_degree, **kwargs): 24 | kwargs.update(config.to_dict()) 25 | super().__init__(**kwargs) 26 | self.activation_function = config.activation_function 27 | self.n_embd = config.n_embd 28 | self.n_head = config.n_head 29 | self.n_kv_head = config.to_dict().get("n_kv_head", config.n_head) 30 | self.n_layer = config.n_layer 31 | self.n_positions = config.n_positions 32 | self.max_position_embeddings = config.max_position_embeddings 33 | self.vocab_size = config.vocab_size 34 | self.eos_token_id = config.eos_token_id 35 | utils.maybe_override_attributes(self, kwargs) 36 | self.intermediate_dim = self.n_embd * 4 37 | self.batch_size = batch_size 38 | self.amp = amp 39 | self.tp_degree = tp_degree 40 | 41 | 42 | 43 | class GPT2Config: 44 | 45 | def __init__(self, config, batch_size, amp, tp_degree, **kwargs): 46 | self.activation_function = config.activation_function 47 | self.n_embd = config.n_embd 48 | self.n_head = config.n_head 49 | self.n_kv_head = config.n_kv_head if hasattr(config, "n_kv_head") else config.n_head 50 | self.n_layer = config.n_layer 51 | self.n_positions = config.n_positions 52 | self.max_position_embeddings = config.max_position_embeddings 53 | self.vocab_size = config.vocab_size 54 | self.eos_token_id = config.eos_token_id 55 | utils.maybe_override_attributes(self, kwargs) 56 | self.intermediate_dim = self.n_embd * 4 57 | self.batch_size = batch_size 58 | self.amp = amp 59 | self.tp_degree = tp_degree 60 | self.model_type = 'gpt2' 61 | -------------------------------------------------------------------------------- /src/transformers_neuronx/gpt2/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx.gpt_demo import demo 16 | from transformers_neuronx.gpt2.model import GPT2ForSampling 17 | 18 | 19 | def main(): 20 | demo('gpt2', GPT2ForSampling, amp_callback) 21 | 22 | 23 | def amp_callback(model, dtype): 24 | # cast attention and mlp to low precisions only; layernorms stay as f32 25 | for block in model.transformer.h: 26 | block.attn.to(dtype) 27 | block.mlp.to(dtype) 28 | model.lm_head.to(dtype) 29 | -------------------------------------------------------------------------------- /src/transformers_neuronx/gpt2/gen_random_pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import argparse 17 | import json 18 | import os 19 | import torch 20 | from transformers.models.gpt2 import GPT2Config 21 | from transformers_neuronx.module import sanitize_file_name, _KEY_TO_FILENAME_JSON 22 | from transformers.configuration_utils import PretrainedConfig 23 | 24 | def gen_random_pretrained(model_name, save, empty=False, print_shapes=False): 25 | if '.json' in model_name: 26 | config = json.load(open(model_name)) 27 | else: 28 | config = GPT2Config.from_pretrained(model_name).to_dict() 29 | os.makedirs(save, exist_ok=True) 30 | with open(os.path.join(save, 'config.json'), 'w') as fp: 31 | json.dump(config, fp, indent=2) 32 | vocab_size = config['vocab_size'] 33 | hidden_size = config['n_embd'] 34 | kv_hidden_size = config.get('n_kv_head', config['n_head']) * (hidden_size // config['n_head']) 35 | max_position_embeddings = config['n_ctx'] 36 | ffn_dim = hidden_size * 4 37 | num_hidden_layers = config['n_layer'] 38 | torch_dtype = config.get('torch_dtype', None) 39 | if torch_dtype is None: 40 | torch_dtype = 'float16' 41 | init_std = 0.001 42 | name2shape = { 43 | 'transformer.wte.weight': [vocab_size, hidden_size], 44 | 'transformer.wpe.weight': [max_position_embeddings, hidden_size], 45 | } 46 | layer_name2shape = { 47 | 'ln_1.weight': [hidden_size], 48 | 'ln_1.bias': [hidden_size], 49 | 'attn.bias': [hidden_size], 50 | 'attn.masked_bias': [hidden_size], 51 | 'attn.c_attn.weight': [hidden_size, 2 * kv_hidden_size + hidden_size], 52 | 'attn.c_attn.bias': [2 * hidden_size + hidden_size], 53 | 'attn.c_proj.weight': [hidden_size, hidden_size], 54 | 'attn.c_proj.bias': [hidden_size], 55 | 'ln_2.weight': [hidden_size], 56 | 'ln_2.bias': [hidden_size], 57 | 'mlp.c_fc.weight': [hidden_size, ffn_dim], 58 | 'mlp.c_fc.bias': [ffn_dim], 59 | 'mlp.c_proj.weight': [ffn_dim, hidden_size], 60 | 'mlp.c_proj.bias': [hidden_size], 61 | } 62 | for idx in range(num_hidden_layers): 63 | for name, shape in layer_name2shape.items(): 64 | name2shape[f'transformer.h.{idx}.{name}'] = shape 65 | name2shape['transformer.ln_f.weight'] = [hidden_size] 66 | name2shape['transformer.ln_f.bias'] = [hidden_size] 67 | name2shape['lm_head.weight'] = [vocab_size, hidden_size] 68 | if print_shapes: 69 | print("Components' shapes") 70 | [print(x) for x in name2shape.items()] 71 | key_to_filename = {} 72 | for idx, key in enumerate(name2shape.keys()): 73 | key_to_filename[key] = f'p{idx}.{sanitize_file_name(key)}' 74 | if empty: 75 | key_to_filename[key] = f'{key_to_filename[key]}.empty_json' 76 | split_param_dir = os.path.join(save, 'pytorch_model.bin') 77 | os.makedirs(split_param_dir, exist_ok=True) 78 | with open(os.path.join(split_param_dir, _KEY_TO_FILENAME_JSON), 'w') as fp: 79 | json.dump(key_to_filename, fp, indent=2) 80 | dtype = getattr(torch, torch_dtype) 81 | for name, shape in name2shape.items(): 82 | save_path = os.path.join(split_param_dir, key_to_filename[name]) 83 | factor = 0.0 if 'layer_norm' in name or 'bias' in name else init_std 84 | if empty: 85 | empty_json = { 86 | 'torch_dtype': torch_dtype, 87 | 'shape': shape, 88 | } 89 | with open(save_path, 'w') as fp: 90 | json.dump(empty_json, fp, indent=2) 91 | continue 92 | init_param = factor * torch.randn(shape) 93 | init_param = init_param.to(dtype) 94 | torch.save(init_param, save_path) 95 | print(f'done saving {save_path}') 96 | -------------------------------------------------------------------------------- /src/transformers_neuronx/gpt_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import argparse 16 | import itertools 17 | import math 18 | import time 19 | import json 20 | import os 21 | import torch 22 | from transformers import AutoModelForCausalLM 23 | from transformers import AutoTokenizer 24 | from transformers import GPT2Config as GPT2ConfigTransformer 25 | from transformers import GPTJConfig as GPTJConfigTransformer 26 | from transformers_neuronx import dtypes 27 | from transformers_neuronx.module import save_pretrained_split 28 | from transformers_neuronx.config import NeuronConfig, QuantizationConfig 29 | from transformers_neuronx.sparse_attn_utils import SparseAttnConfig, BlkSparseAttnConfig, SlidingWindowAttnConfig 30 | 31 | def demo(model_name, model_cls, amp_callback): 32 | parser = argparse.ArgumentParser() 33 | amp_choices = ['f32', 'f16', 'bf16'] 34 | floatx_floaty_combinations = list(itertools.product(amp_choices, amp_choices)) 35 | for floatx, floaty in floatx_floaty_combinations: 36 | amp_choices.append(f'{floatx}-u8-{floaty}') 37 | parser.add_argument('--amp', default='f32', choices=amp_choices) 38 | parser.add_argument('--model_name', default=None, help="Model name for loading a pretrained model") 39 | subparsers = parser.add_subparsers() 40 | save_name = 'save' 41 | save_parser = subparsers.add_parser(save_name) 42 | save_parser.set_defaults(which=save_name) 43 | save_parser.add_argument('save', help="Directory to save the model") 44 | save_parser.add_argument('--random', action='store_true', help="Random weights flag. If true, config.json would be used to generate a model with random weight") 45 | save_parser.add_argument('--config', type=str, default='', help="Path to config.json file (example: path/to/config.json)") 46 | run_name = 'run' 47 | run_parser = subparsers.add_parser(run_name) 48 | run_parser.set_defaults(which=run_name) 49 | run_parser.add_argument('load') 50 | run_parser.add_argument('--batch_size', type=int, default=4, help="Input batch size") 51 | run_parser.add_argument('--n_positions', type=int, default=128, help="Input sequence length") 52 | run_parser.add_argument('--tp_degree', type=int, default=2, help="Number of neuron cores used for tensor parallel") 53 | run_parser.add_argument('--unroll', type=int, default=None) 54 | run_parser.add_argument('--print_latency', action='store_true', help="Print latency for generation of each output token") 55 | run_parser.add_argument('--quantize', action='store_true', help="Quantize model") 56 | # Sparse attention configs 57 | run_parser.add_argument('--sparse_attn', type=str, choices=[None, 'blk_sparse', 'custom', 'window'], 58 | default=None, help="Use sparse attention or not. ") 59 | # Block-sparse configs 60 | run_parser.add_argument('--blk_size', type=int, default=128, help="Block size in blk-sparse attention") 61 | run_parser.add_argument('--num_global_blks', type=int, default=0, help="Number of global blocks in blk-sparse attention") 62 | run_parser.add_argument('--num_local_blks', type=int, default=1, help="Number of local blocks in blk-sparse attention") 63 | run_parser.add_argument('--num_random_blks', type=int, default=0, help="Number of random blocks in blk-sparse attention") 64 | # Window attention configs 65 | run_parser.add_argument('--window_size', type=int, default=128, help="Window size for sliding-window attention. ") 66 | run_parser.add_argument('--context_length_estimate', type=int, default=None, help="Context length estimate.") 67 | # TODO: args for custom sparse attention not added 68 | 69 | args = parser.parse_args() 70 | if args.model_name is not None: 71 | model_name = args.model_name 72 | if args.which == save_name: 73 | save(args, model_name, amp_callback, model_cls) 74 | elif args.which == run_name: 75 | run(args, model_name, model_cls) 76 | 77 | def load_config(args): 78 | config_filename = args.config 79 | assert config_filename, "Please provide the config.json like: --config=./config.json" 80 | assert os.path.exists(config_filename), f"File {config_filename} does not exist." 81 | config = json.load(open(config_filename)) 82 | return config 83 | 84 | def save(args, model_name, amp_callback, model_cls): 85 | if args.random: 86 | config = load_config(args) 87 | if "GPTJ" in str(model_cls): 88 | config = GPTJConfigTransformer(**config) 89 | else: 90 | config = GPT2ConfigTransformer(**config) 91 | model = AutoModelForCausalLM.from_config(config=config) 92 | else: 93 | model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True) 94 | if args.amp != 'f32': 95 | dtype = dtypes.to_torch_dtype(args.amp) 96 | amp_callback(model, dtype) 97 | save_pretrained_split(model, args.save) 98 | 99 | 100 | def run(args, model_name, model_cls): 101 | torch.manual_seed(15213) 102 | tokenizer = AutoTokenizer.from_pretrained(model_name) 103 | prompt_text = "Hello, I'm a language model," 104 | print(f'running {model_cls.__name__}.from_pretrained') 105 | neuron_config = None 106 | quant_config = QuantizationConfig(dequant_dtype=args.amp) if args.quantize else None 107 | sparse_config = None 108 | if args.sparse_attn: 109 | if args.sparse_attn == "blk_sparse": 110 | sparse_attn_config = BlkSparseAttnConfig( 111 | blk_size=args.blk_size, 112 | num_global_blks=args.num_global_blks, 113 | num_local_blks=args.num_local_blks, 114 | num_random_blks=args.num_random_blks 115 | ) 116 | elif args.sparse_attn == "window": 117 | sparse_attn_config = SlidingWindowAttnConfig(window_size=args.window_size) 118 | else: 119 | raise NotImplementedError("Interface for other attention patterns not implemented yet!") 120 | sparse_config = SparseAttnConfig( 121 | attn_type=args.sparse_attn, causal=True, sparse_attn_config=sparse_attn_config, 122 | same_mask_per_layer=True) 123 | if args.quantize or args.sparse_attn: 124 | neuron_config = NeuronConfig(quant=quant_config, sparse_attn=sparse_config) 125 | if args.context_length_estimate: 126 | model = model_cls.from_pretrained(args.load, batch_size=args.batch_size, amp=args.amp, 127 | tp_degree=args.tp_degree, n_positions=args.n_positions, 128 | unroll=args.unroll, neuron_config=neuron_config, 129 | context_length_estimate=args.context_length_estimate) 130 | else: 131 | model = model_cls.from_pretrained(args.load, batch_size=args.batch_size, amp=args.amp, 132 | tp_degree=args.tp_degree, n_positions=args.n_positions, 133 | unroll=args.unroll, neuron_config=neuron_config) 134 | if args.print_latency: 135 | latency_printer = LatencyPrinter() 136 | model.register_forward_pre_hook(latency_printer.pre_hook) 137 | model.register_forward_hook(latency_printer.hook) 138 | if hasattr(model, 'register_to_neuron_hook'): 139 | model.register_to_neuron_hook(lambda idx: print(f'done to_neuron layer {idx}')) 140 | print('running model.to_neuron') 141 | model.to_neuron() 142 | with torch.inference_mode(): 143 | encoded_text = tokenizer.encode(prompt_text) 144 | input_ids = torch.as_tensor([encoded_text]) 145 | input_ids = torch.cat([input_ids for _ in range(args.batch_size)], dim=0) 146 | print('running model.sample') 147 | generated_sequence = model.sample(input_ids, sequence_length=args.n_positions) 148 | print('generated_sequence=', generated_sequence) 149 | outputs = [tokenizer.decode(gen_seq) for gen_seq in generated_sequence] 150 | print(outputs) 151 | 152 | 153 | class LatencyPrinter: 154 | 155 | def __init__(self): 156 | self.start = None 157 | 158 | def pre_hook(self, module, input): 159 | if len(input) == 3: 160 | _, cache_offset, _ = input 161 | print(f'cache_offset: {cache_offset}') 162 | self.start = time.time() 163 | 164 | def hook(self, *args): 165 | latency_ms = math.ceil((time.time() - self.start) * 1000) 166 | print(f'Latency: {latency_ms} ms') -------------------------------------------------------------------------------- /src/transformers_neuronx/gptj/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import utils 16 | 17 | 18 | class GPTJConfig: 19 | 20 | def __init__(self, config, batch_size, amp, tp_degree, **kwargs): 21 | self.activation_function = config.activation_function 22 | self.n_embd = config.n_embd 23 | self.n_head = config.n_head 24 | self.n_layer = config.n_layer 25 | self.n_positions = config.n_positions 26 | self.n_ctx = config.n_positions # needed since self.n_positions will be overriden by input n_positions 27 | self.rotary_dim = config.rotary_dim 28 | self.vocab_size = config.vocab_size 29 | self.eos_token_id = config.eos_token_id 30 | utils.maybe_override_attributes(self, kwargs) 31 | self.intermediate_dim = self.n_embd * 4 32 | self.batch_size = batch_size 33 | self.amp = amp 34 | self.tp_degree = tp_degree 35 | -------------------------------------------------------------------------------- /src/transformers_neuronx/gptj/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx.gpt_demo import demo 16 | from transformers_neuronx.gptj.model import GPTJForSampling 17 | from transformers_neuronx.gpt2.demo import amp_callback 18 | 19 | 20 | def main(): 21 | demo('EleutherAI/gpt-j-6B', GPTJForSampling, amp_callback) 22 | -------------------------------------------------------------------------------- /src/transformers_neuronx/gptneox/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import utils 16 | 17 | 18 | class GPTNeoXConfig: 19 | 20 | def __init__(self, config, batch_size, amp, tp_degree, **kwargs): 21 | self.activation_function = config.hidden_act 22 | self.n_embd = config.hidden_size 23 | self.n_head = config.num_attention_heads 24 | self.n_layer = config.num_hidden_layers 25 | self.n_positions = config.max_position_embeddings 26 | self.n_ctx = self.n_positions # TODO - Needed for now because self.n_positions will be overridden by the kwargs n_positions 27 | """ 28 | rotary_dim calculation 29 | Reference: transformers/models/gpt_neox/modeling_gpt_neox.py, class "GPTNeoXAttention" 30 | self.num_attention_heads = config.num_attention_heads 31 | self.hidden_size = config.hidden_size 32 | self.head_size = self.hidden_size // self.num_attention_heads 33 | self.rotary_ndims = int(self.head_size * config.rotary_pct) 34 | """ 35 | head_size = self.n_embd // self.n_head 36 | self.rotary_dim = int(head_size * config.rotary_pct) 37 | self.vocab_size = config.vocab_size 38 | self.eos_token_id = config.eos_token_id 39 | self.rotary_emb_base = config.rotary_emb_base 40 | self.use_parallel_residual = config.use_parallel_residual 41 | utils.maybe_override_attributes(self, kwargs) 42 | self.intermediate_dim = config.intermediate_size # `intermediate_size` in GPT-NeoX is given in the config file 43 | self.batch_size = batch_size 44 | self.amp = amp 45 | self.tp_degree = tp_degree 46 | -------------------------------------------------------------------------------- /src/transformers_neuronx/gptneox/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx.gpt_demo import demo 16 | from transformers_neuronx.gptneox.model import GPTNeoXForSampling 17 | 18 | 19 | def amp_callback(model, dtype): 20 | # cast attention and mlp to low precisions only; layernorms stay as f32 21 | for layer in model.gpt_neox.layers: 22 | layer.attention.to(dtype) 23 | layer.mlp.to(dtype) 24 | model.embed_out.to(dtype) 25 | 26 | 27 | def main(): 28 | demo('EleutherAI/gpt-neox-20b', GPTNeoXForSampling, amp_callback) 29 | -------------------------------------------------------------------------------- /src/transformers_neuronx/kv_cache_manager.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from transformers_neuronx import parallel 5 | from transformers_neuronx import compiler 6 | from transformers_neuronx import decoder 7 | from transformers_neuronx import ops 8 | from transformers_neuronx import hlo 9 | 10 | class KVCacheManager: 11 | 12 | def __init__(self, cache_shape, tp, num_layer, dtype=torch.float32): 13 | self.cache_shape = cache_shape 14 | 15 | self.tp = tp 16 | self.num_layer = num_layer 17 | self.manipulator = parallel.ParallelTensorManipulator(self.tp) 18 | 19 | self.dtype = dtype 20 | 21 | 22 | def generate_cache(self, generator=torch.zeros): 23 | cache = generator(self.cache_shape).to(self.dtype) 24 | 25 | ops.init() 26 | k_caches = [(self.manipulator.shard_along(cache, dim=2)) for _ in range(self.num_layer)] 27 | v_caches = [(self.manipulator.shard_along(cache, dim=2)) for _ in range(self.num_layer)] 28 | 29 | self.set_kv_cache(k_caches, v_caches) 30 | 31 | 32 | def set_kv_cache(self, k_caches, v_caches): 33 | assert len(k_caches) == len(v_caches) 34 | 35 | self.k_caches = k_caches 36 | self.v_caches = v_caches 37 | 38 | def set_cache_shape(self, cache_shape): 39 | self.cache_shape = cache_shape 40 | self.n_positions, self.batch_size, self.n_heads_kv_cache, self.attention_head_size = cache_shape 41 | 42 | 43 | def to_cpu(self): 44 | cpu_k_caches = [(self.manipulator.unshard_along(cache, dim=2)) for cache in self.k_caches] 45 | cpu_v_caches = [(self.manipulator.unshard_along(cache, dim=2)) for cache in self.v_caches] 46 | 47 | return cpu_k_caches, cpu_v_caches 48 | 49 | def send_kv_cache_program(self, source_g_start_id, target_g_start_id, world_size): 50 | 51 | def send_program(scribe): 52 | param_builder = decoder.DecoderParameterBuilder(scribe, 0) 53 | cache_params = [] 54 | for cache in self.k_caches: 55 | cache_param = param_builder.from_tensor(cache) 56 | cache_params.append(cache_param) 57 | 58 | for cache in self.v_caches: 59 | cache_param = param_builder.from_tensor(cache) 60 | cache_params.append(cache_param) 61 | 62 | # concat on batch dim, easier to slice 63 | cache_shape = cache_params[0].sizes 64 | concat_size = [cache_shape[0], len(cache_params) * cache_shape[1], *cache_shape[2:]] 65 | cat_cache = cache_params[0].dtype[concat_size].Concatenate(*cache_params, dimensions=[1]) 66 | 67 | return hlo.all_reduce_sum( 68 | cat_cache, 69 | tp_degree=self.tp, 70 | dtype=cat_cache.dtype, 71 | replica_groups=[[source_g_start_id+i, target_g_start_id+i] for i in range(self.tp)] 72 | ) 73 | 74 | self.send_hlo_kernel = compiler.HLOKernel(send_program, self.tp, start_g_nc_id=source_g_start_id, g_nc_count=world_size*self.tp) 75 | self.send_hlo_kernel.build() 76 | self.send_hlo_kernel.load() 77 | self.send_hlo_kernel.setup([*self.k_caches, *self.v_caches], []) 78 | 79 | def run_send(self): 80 | self.send_hlo_kernel.run() 81 | return self.manipulator.unshard_along(self.send_hlo_kernel.memories.output_tensors[0], dim=2) 82 | 83 | 84 | def receive_kv_cache_program(self, source_g_start_ids, g_start_id, batch_ids, source_batch_size=1): 85 | 86 | world_size = len(source_g_start_ids) + 1 87 | def receive_program(scribe): 88 | param_builder = decoder.DecoderParameterBuilder(scribe, 0) 89 | cache_params = [] 90 | for cache in self.k_caches: 91 | cache_param = param_builder.from_tensor(cache) 92 | cache_params.append(cache_param) 93 | 94 | for cache in self.v_caches: 95 | cache_param = param_builder.from_tensor(cache) 96 | cache_params.append(cache_param) 97 | 98 | 99 | # concat on batch dim, easier to slice 100 | cache_shape = cache_params[0].sizes 101 | cache_dtype = cache_params[0].dtype 102 | concat_size = [cache_shape[0], len(cache_params) * source_batch_size, *cache_shape[2:]] 103 | 104 | for source_g_start_id, batch_id in zip(source_g_start_ids, batch_ids): 105 | 106 | zero = cache_dtype.Constant(constant_value=0) 107 | cat_cache = cache_dtype[concat_size].Broadcast(zero, dimensions=[]) 108 | 109 | cat_cache = hlo.all_reduce_sum( 110 | cat_cache, 111 | tp_degree=self.tp, 112 | dtype=cat_cache.dtype, 113 | replica_groups=[[source_g_start_id+i, g_start_id+i] for i in range(self.tp)] 114 | ) 115 | 116 | for i in range(2): # k,v 117 | for j in range(self.num_layer): 118 | offset_in_cat = i*self.num_layer + j # offset in cat kv cache 119 | cache_slice = hlo.slice_along(cat_cache, 1, offset_in_cat*source_batch_size + source_batch_size, offset_in_cat*source_batch_size) 120 | start_indices = [0, batch_id, 0, 0] 121 | 122 | updated_cache = hlo.dynamic_update_slice(cache_params[offset_in_cat], cache_slice, start_indices) 123 | cache_params[offset_in_cat] = updated_cache 124 | 125 | root_shapes = [cache.dtype[cache.sizes] for cache in cache_params] 126 | return scribe.tuple(*root_shapes).Tuple(*cache_params) 127 | 128 | self.recv_hlo_kernel = compiler.HLOKernel(receive_program, self.tp, start_g_nc_id=g_start_id, g_nc_count=self.tp*world_size) 129 | self.recv_hlo_kernel.build() 130 | self.recv_hlo_kernel.load() 131 | 132 | # update cache inplace 133 | # we cannot simply slice cache[line_id] and update inplace 134 | # as the slice dimension is not the first one, leads to non-contiguous slice 135 | cache_buffers = [*self.k_caches,*self.v_caches] 136 | 137 | self.recv_hlo_kernel.setup(cache_buffers, cache_buffers) 138 | 139 | def run_receive(self): 140 | self.recv_hlo_kernel.run() 141 | -------------------------------------------------------------------------------- /src/transformers_neuronx/layers/alibi.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import math 16 | 17 | import torch 18 | 19 | from transformers_neuronx import hlo 20 | 21 | 22 | def build_slopes(num_heads: int) -> torch.Tensor: 23 | """ 24 | Builds a fixed slopes tensor to compute the ALiBI positional encoding. 25 | 26 | This tensor must be partitioned across the attention layers so that each 27 | attention head uses the correct ALiBI slope for the head allocated 28 | to that NeuronCore. 29 | 30 | Reference: https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/models/bloom/modeling_bloom.py#L86 31 | 32 | Arguments: 33 | num_heads: The number of attention heads for the model. 34 | 35 | Returns: 36 | slopes: The slope for each attention head. Shape: [num_heads, 1] 37 | """ 38 | closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) 39 | base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))) 40 | powers = range(1, 1 + closest_power_of_2) 41 | slopes = list(map(lambda x: math.pow(base, x), powers)) 42 | 43 | if closest_power_of_2 != num_heads: 44 | extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))) 45 | num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) 46 | extra_powers = range(1, 1 + 2 * num_remaining_heads, 2) 47 | extra_slopes = list(map(lambda x: math.pow(extra_base, x), extra_powers)) 48 | slopes.extend(extra_slopes) 49 | 50 | assert len(slopes) == num_heads 51 | return torch.tensor(slopes).view(num_heads, 1) 52 | 53 | 54 | def alibi(slopes, attention_mask, active_mask=None): 55 | """ 56 | Compute the ALiBI positional encoding from the attention mask and slopes. 57 | 58 | This must be used during both context encoding and token generation: 59 | - For prompt encoding only the `attention_mask` should be provided. 60 | - For token generation, the prefetch mechanism can be enabled by providing 61 | both an `attention_mask` for prior tokens and an `active_mask` for 62 | the newly generated tokens. 63 | 64 | See: alibi.build_slopes 65 | hlo.decoder_attention_mask 66 | 67 | Arguments: 68 | slopes: The ALiBI slopes for the current NeuronCore 69 | attention_mask: The mask to build the encoding for. 70 | active_mask: An optional mask only used during prefetch-attention. 71 | 72 | Returns: 73 | alibi: The ALiBI to apply for the given `attention_mask` 74 | active_alibi: The ALiBI to apply to the optional `active_mask` 75 | """ 76 | 77 | num_heads_tp, *_ = slopes.sizes 78 | scribe = attention_mask.scribe 79 | dtype = scribe.f32 80 | 81 | def _alibi(summation, mask): 82 | 83 | size = mask.sizes 84 | batch_size, n_active_tokens, seq_length = mask.sizes 85 | 86 | one = dtype.Constant(constant_value=1) 87 | one_br = dtype[size].Broadcast(one, dimensions=[]) 88 | summation_sub = dtype[size].Subtract(summation, one_br) 89 | sum_mul = dtype[size].Multiply(summation_sub, mask) 90 | 91 | slopes_sh = dtype[batch_size, n_active_tokens, num_heads_tp, 1].Broadcast(slopes, dimensions=[2, 3]) 92 | sum_sh = dtype[batch_size, n_active_tokens, 1, seq_length].Reshape(sum_mul) 93 | dot_dims = dict( 94 | lhs_contracting_dimensions=[3], 95 | lhs_batch_dimensions=[0, 1], 96 | rhs_contracting_dimensions=[2], 97 | rhs_batch_dimensions=[0, 1] 98 | ) 99 | product = dtype[batch_size, n_active_tokens, num_heads_tp, seq_length].Dot(slopes_sh, sum_sh, dot_dimension_numbers=dot_dims) 100 | result = dtype[batch_size, num_heads_tp, n_active_tokens, seq_length].Transpose(product, dimensions=[0, 2, 1, 3]) 101 | return result 102 | 103 | scribe = attention_mask.scribe 104 | fp32 = scribe.f32 105 | 106 | # Create alibi for the `attention_mask` tokens 107 | mask_cast = hlo.cast(attention_mask, fp32) 108 | summation = hlo.cumsum(mask_cast, -1) 109 | alibi = _alibi(summation, mask_cast) 110 | 111 | # Create alibi for the `active_mask` tokens: 112 | # Since the prior token mask is the `attention_mask` and the 113 | # active token mask is the `active_mask`, we need to combine both masks to 114 | # find the true cumulative sum. 115 | if active_mask is not None: 116 | total = hlo.reduce_sum(mask_cast, 2) 117 | active_cast = hlo.cast(active_mask, fp32) 118 | total = fp32[total.sizes].Add(total, active_cast) 119 | total_sh = hlo.unsqueeze(total, 1) 120 | active_cast_sh = hlo.unsqueeze(active_cast, 1) 121 | active_alibi = _alibi(total_sh, active_cast_sh) 122 | return alibi, active_alibi 123 | 124 | # When no active mask, we do not have an "active" alibi 125 | return alibi, None 126 | -------------------------------------------------------------------------------- /src/transformers_neuronx/layers/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import hlo, config 16 | 17 | def generate(logits, logits_indices, config: config.GenerationConfig, tp_degree=1, early_return=False, return_probs=False, seq_ids=None): 18 | logits = mask_logits(logits, logits_indices, config.vocab_size) 19 | if not config.dynamic and not config.do_sample: 20 | tokens = greedy_search(logits, tp_degree=tp_degree) 21 | if early_return or return_probs: 22 | batch_size, n_active_tokens = tokens.sizes 23 | probs = hlo.full(1.0, logits.dtype, [batch_size, n_active_tokens, 1]) 24 | indices = hlo.reshape(tokens, [batch_size, n_active_tokens, 1]) 25 | if early_return: 26 | return probs, indices 27 | else: # return probs 28 | return tokens, probs, indices 29 | else: 30 | return tokens 31 | 32 | if not config.per_batch_line: 33 | return sample( 34 | logits, 35 | top_k=config.top_k, 36 | top_p=config.top_p, 37 | temperature=config.temperature, 38 | top_p_min_tokens=config.top_p_min_tokens, 39 | global_top_k=config.global_top_k, 40 | tp_degree=tp_degree, 41 | dynamic=config.dynamic, 42 | deterministic=config.deterministic, 43 | early_return=early_return, 44 | return_probs=return_probs, 45 | ) 46 | 47 | if config.global_top_k is not None: 48 | assert config.dynamic is True, "Dynamic on device generation must be enabled when global_top_k is set." 49 | 50 | logits = hlo.permute(logits, (2, 1, 0)) 51 | batch_size, n_active_tokens, vocab_size = logits.sizes 52 | 53 | indices = None 54 | # Perform global top-k 55 | if config.global_top_k is not None: 56 | logits, indices = hlo.topk(logits, k=config.global_top_k, dim=2, tp_degree=tp_degree) 57 | 58 | tokens = [] 59 | probs = [] 60 | rindices = [] 61 | for batch_line in range(batch_size): 62 | logits_slice = hlo.slice_along(logits, 0, start=batch_line, limit=batch_line+1) 63 | indices_slice = None if indices is None else hlo.slice_along(indices, 0, start=batch_line, limit=batch_line+1) 64 | 65 | batch_line_top_k, batch_line_top_p, batch_line_temperature, batch_line_top_p_min_tokens = sampling_params_for_batch_line( 66 | seq_ids, batch_line, config 67 | ) 68 | 69 | token = sample( 70 | logits_slice, 71 | indices=indices_slice, 72 | top_k=batch_line_top_k, 73 | top_p=batch_line_top_p, 74 | temperature=batch_line_temperature, 75 | top_p_min_tokens=batch_line_top_p_min_tokens, 76 | global_top_k=None, # we already performed global_top_k 77 | permute=False, # we already permuted 78 | tp_degree=tp_degree, 79 | dynamic=config.dynamic, 80 | deterministic=config.deterministic, 81 | early_return=early_return, 82 | return_probs=return_probs, 83 | ) 84 | if early_return: 85 | probs.append(token[0]) 86 | rindices.append(token[1]) 87 | elif return_probs: 88 | tokens.append(token[0]) 89 | probs.append(token[1]) 90 | rindices.append(token[2]) 91 | else: 92 | tokens.append(token) 93 | if early_return: 94 | returned_probs = hlo.concatenate(probs, dimension=0) 95 | returned_indices = hlo.concatenate(rindices, dimension=0) 96 | return returned_probs, returned_indices 97 | elif return_probs: 98 | returned_probs = hlo.concatenate(probs, dimension=0) 99 | returned_tokens = hlo.concatenate(tokens, dimension=0) 100 | returned_indices = hlo.concatenate(rindices, dimension=0) 101 | return returned_tokens, returned_probs, returned_indices 102 | else: 103 | returned_tokens = hlo.concatenate(tokens, dimension=0) 104 | return returned_tokens 105 | 106 | 107 | def sampling_params_for_batch_line(seq_ids, batch_line: int, config: config.GenerationConfig): 108 | if seq_ids is not None: 109 | seq_id_for_batch = hlo.slice_along(seq_ids, 0, start=batch_line, limit=batch_line+1) 110 | batch_line_top_k = hlo.reshape(hlo.index_select(config.top_k, 0, seq_id_for_batch), []) 111 | batch_line_top_p = hlo.reshape(hlo.index_select(config.top_p, 0, seq_id_for_batch), []) 112 | batch_line_temperature = hlo.reshape(hlo.index_select(config.temperature, 0, seq_id_for_batch), []) 113 | batch_line_top_p_min_tokens = hlo.reshape(hlo.index_select(config.top_p_min_tokens, 0, seq_id_for_batch), []) 114 | else: 115 | batch_line_top_k = config.top_k if hlo._is_hlo_scalar(config.top_k) else hlo.get_hlo_scalar_by_index(config.top_k, batch_line) 116 | batch_line_top_p = config.top_p if hlo._is_hlo_scalar(config.top_p) else hlo.get_hlo_scalar_by_index(config.top_p, batch_line) 117 | batch_line_temperature = config.temperature if hlo._is_hlo_scalar(config.temperature) else hlo.get_hlo_scalar_by_index(config.temperature, batch_line) 118 | batch_line_top_p_min_tokens = config.top_p_min_tokens if hlo._is_hlo_scalar(config.top_p_min_tokens) else hlo.get_hlo_scalar_by_index(config.top_p_min_tokens, batch_line) 119 | return (batch_line_top_k, batch_line_top_p, batch_line_temperature, batch_line_top_p_min_tokens) 120 | 121 | 122 | def mask_logits(logits, indices, model_vocab_size): 123 | vocab_size, n_active_tokens, _ = logits.sizes 124 | indices_br = hlo.broadcast(indices, (logits.sizes), broadcast_dimensions=(0,)) 125 | mask = hlo.less(indices_br, model_vocab_size) 126 | logits = hlo.masked_select(mask, logits, float('-inf')) 127 | return logits 128 | 129 | 130 | def greedy_search(logits, *, tp_degree=1, permute=True): 131 | if permute: 132 | logits = hlo.permute(logits, (2, 1, 0)) 133 | batch_size, n_active_tokens, vocab_size = logits.sizes 134 | return hlo.argmax(logits, 2, tp_degree=tp_degree) # shape: batch_size, n_active_tokens 135 | 136 | 137 | def sample(logits, *, top_k=50, top_p=1.0, top_p_min_tokens=1, temperature=None, global_top_k=None, tp_degree=1, dynamic=False, deterministic=False, indices=None, permute=True, early_return=False, return_probs=False): 138 | 139 | if global_top_k is not None: 140 | assert dynamic is True, "Dynamic on device generation must be enabled when global_top_k is set." 141 | 142 | if permute: 143 | logits = hlo.permute(logits, (2, 1, 0)) 144 | 145 | _, _, orig_vocab_size = logits.sizes 146 | 147 | if global_top_k is not None: 148 | logits, indices = hlo.topk(logits, k=global_top_k, dim=2, tp_degree=tp_degree) 149 | 150 | batch_size, n_active_tokens, vocab_size = logits.sizes 151 | 152 | # NOTE: Compiler failures can occur when batch != 1 153 | if top_k == 1 and batch_size == 1 and indices is None: 154 | tokens = greedy_search(logits, tp_degree=tp_degree, permute=False) 155 | if early_return or return_probs: 156 | batch_size, n_active_tokens = tokens.sizes 157 | probs = hlo.full(1.0, logits.dtype, [batch_size, n_active_tokens, 1]) 158 | indices = hlo.reshape(tokens, [batch_size, n_active_tokens, 1]) 159 | if early_return: 160 | return probs, indices 161 | else: # return probs 162 | return tokens, probs, indices 163 | else: 164 | return tokens 165 | 166 | if temperature is not None and temperature != 1.0: 167 | if hlo._is_hlo_scalar(temperature): 168 | temperature = hlo.cast(temperature, logits.dtype) 169 | logits = hlo.divide(logits, temperature) 170 | 171 | # Perform Top-K 172 | if top_k is not None: 173 | if dynamic: 174 | logits, index, indices = hlo.topk_masked(logits, k=top_k, dim=2, tp_degree=tp_degree, indices=indices) 175 | else: 176 | logits, indices = hlo.topk(logits, k=top_k, dim=2, tp_degree=tp_degree) 177 | 178 | # Perform Top-P 179 | if dynamic or top_p is not None and top_p < 1.0: 180 | logits, indices = hlo.topp(logits, top_p=top_p, top_p_min_tokens=top_p_min_tokens, tp_degree=tp_degree, indices=indices, dim=2) 181 | 182 | if indices is None: 183 | if tp_degree > 1: 184 | logits = hlo.all_gather(logits, dim=2, tp_degree=tp_degree) 185 | 186 | probs = hlo.softmax(logits, dim=2) 187 | 188 | if early_return: 189 | return probs, indices 190 | 191 | # Final sample after filtering TopP/TopK 192 | samples = hlo.multinomial(probs, dim=2, deterministic=deterministic) 193 | if indices is not None: 194 | tokens = hlo.gather(indices, 2, samples) 195 | else: 196 | tokens = samples 197 | if return_probs: 198 | return hlo.squeeze(tokens, 2), probs, indices 199 | return hlo.squeeze(tokens, 2) 200 | -------------------------------------------------------------------------------- /src/transformers_neuronx/layers/rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import torch 16 | import math 17 | from transformers_neuronx import hlo 18 | 19 | 20 | def apply_inv_frequency_scaling(freq, rope_scaling): 21 | scale_factor = rope_scaling.get('factor') 22 | low_freq_factor = rope_scaling.get('low_freq_factor') 23 | high_freq_factor = rope_scaling.get('high_freq_factor') 24 | old_context_len = rope_scaling.get('original_max_position_embeddings') 25 | 26 | low_freq_wavelen = old_context_len / low_freq_factor 27 | high_freq_wavelen = old_context_len / high_freq_factor 28 | assert low_freq_wavelen != high_freq_wavelen 29 | 30 | wavelen = 2 * math.pi / freq 31 | smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) 32 | 33 | new_freq = torch.where(wavelen < high_freq_wavelen, freq, freq/scale_factor) 34 | smooth_cond = torch.logical_and(wavelen >= high_freq_wavelen, wavelen <= low_freq_wavelen) 35 | new_freq = torch.where(smooth_cond, (1 - smooth) * freq / scale_factor + smooth * freq, new_freq) 36 | return new_freq.to(dtype=freq.dtype) 37 | 38 | 39 | def rotary_embedding(head_dim, cache_ids, base=10000, interpolation_factor=None): 40 | seq_len = cache_ids.shape[0] 41 | inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim)) 42 | t = torch.arange(seq_len, dtype=inv_freq.dtype) 43 | if interpolation_factor: 44 | t /= interpolation_factor 45 | sinusoid_inp = torch.einsum("i , j -> i j", t, inv_freq).float() 46 | sin = torch.sin(sinusoid_inp) 47 | cos = torch.cos(sinusoid_inp) 48 | pos_embd = torch.cat((sin, cos), dim=-1) 49 | return pos_embd 50 | 51 | 52 | def hlo_rotary_embedding(dtype, head_dim, cache_ids, base=10000, interpolation_factor=None, rope_scaling=None): 53 | 54 | scribe = cache_ids.scribe 55 | # Using f16 during compute causes relatively high error 56 | mtype = scribe.f32 57 | 58 | cache_ids = hlo.cast(cache_ids, mtype) 59 | 60 | use_2d_cache_ids = len(cache_ids.sizes) > 1 61 | if use_2d_cache_ids: 62 | batch_size, n_active_tokens = cache_ids.sizes # 2d cache_ids 63 | cache_ids = hlo.reshape(cache_ids, [batch_size, n_active_tokens, 1]) 64 | dot_dims = dict(lhs_contracting_dimensions=[2], rhs_contracting_dimensions=[0]) 65 | else: 66 | n_active_tokens, = cache_ids.sizes # 1d cache_ids 67 | cache_ids = hlo.reshape(cache_ids, [n_active_tokens, 1]) 68 | dot_dims = dict(lhs_contracting_dimensions=[1], rhs_contracting_dimensions=[0]) 69 | size = head_dim // 2 70 | 71 | inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim)) 72 | if rope_scaling is not None and rope_scaling.get("rope_type", rope_scaling.get("type", None)) == "llama3": 73 | inv_freq = apply_inv_frequency_scaling(inv_freq, rope_scaling) 74 | inv_freq = hlo.literal(mtype, inv_freq) 75 | 76 | if interpolation_factor: 77 | cache_ids = hlo.divide(cache_ids, interpolation_factor) 78 | 79 | inv_freq = hlo.reshape(inv_freq, (1, size)) 80 | sinusoid_inp = hlo.dot_general(cache_ids, inv_freq, dimension_numbers=dot_dims) 81 | 82 | sin = hlo.sin(sinusoid_inp) 83 | cos = hlo.cos(sinusoid_inp) 84 | sin = hlo.cast(sin, dtype) 85 | cos = hlo.cast(cos, dtype) 86 | return sin, cos 87 | 88 | 89 | def get_up_down(q): 90 | """ 91 | Given a tensor, returns its upper and lower halves (divided in the last dimension) 92 | """ 93 | head_dim = q.sizes[-1] 94 | q_up = hlo.slice_along(q, -1, head_dim//2) 95 | q_down = hlo.slice_along(q, -1, head_dim, head_dim//2) 96 | return q_up, q_down 97 | 98 | def get_up_down_with_percentage(q, percentage): 99 | """ 100 | Given a tensor, returns its upper and lower halves with given percentage (divided in the last dimension) 101 | """ 102 | head_dim = q.sizes[-1] 103 | q_up = hlo.slice_along(q, -1, int(head_dim * percentage)) 104 | q_down = hlo.slice_along(q, -1, head_dim, int(head_dim * percentage)) 105 | return q_up, q_down 106 | 107 | def rotate_vec(q, sin_r, cos_r, rotary_percentage=1): 108 | """ 109 | Given vectors q, sin, and cos tables, apply rotation to vectors 110 | """ 111 | if rotary_percentage == 1: 112 | q_up, q_down = get_up_down(q) 113 | q_rot_up = hlo.ax_minus_by(cos_r, q_up, sin_r, q_down) 114 | q_rot_down = hlo.ax_plus_by(cos_r, q_down, sin_r, q_up) 115 | q_rot = hlo.concatenate([q_rot_up, q_rot_down], dimension=3) 116 | return q_rot 117 | else: 118 | q_rotary, q_pass = get_up_down_with_percentage(q, rotary_percentage) 119 | q_rotary_up, q_rotary_down = get_up_down(q_rotary) 120 | q_rotary_rot_up = hlo.ax_minus_by(cos_r, q_rotary_up, sin_r, q_rotary_down) 121 | q_rotary_rot_down = hlo.ax_plus_by(cos_r, q_rotary_down, sin_r, q_rotary_up) 122 | q_rotary_rot = hlo.concatenate([q_rotary_rot_up, q_rotary_rot_down], dimension=3) 123 | return hlo.concatenate([q_rotary_rot, q_pass], dimension=3) 124 | 125 | 126 | def rotate_half(query, key, sin_cos, rotary_percentage=1, tp_degree=None, shard_over_batch=False): 127 | """ 128 | A secondary projection to apply to input query/key projections (used in 129 | specific models: GPT-J/GPT-NeoX/Llama). 130 | 131 | """ 132 | dtype = key.dtype 133 | if shard_over_batch: 134 | n_active_tokens, n_seqs_per_nc, n_kv_heads, d_head = active_sizes = key.sizes 135 | _, _, n_heads, _ = query.sizes 136 | broadcast_sizes = n_active_tokens, n_seqs_per_nc, n_heads, int((d_head // 2) * rotary_percentage) 137 | kv_broadcast_sizes = n_active_tokens, n_seqs_per_nc, n_kv_heads, int((d_head // 2) * rotary_percentage) 138 | else: 139 | n_active_tokens, n_seqs, n_kv_heads_tp, d_head = active_sizes = key.sizes 140 | _, _, n_heads_tp, _ = query.sizes 141 | 142 | """ 143 | Vector approach: 144 | | q_up cos - q_down sin | 145 | | q_up sin + q_down cos | 146 | """ 147 | # Rotate query and key 148 | broadcast_sizes = n_active_tokens, n_seqs, n_heads_tp, int((d_head // 2) * rotary_percentage) 149 | kv_broadcast_sizes = n_active_tokens, n_seqs, n_kv_heads_tp, int((d_head // 2) * rotary_percentage) 150 | 151 | def _broadcast_sin_cos(sin_cos, broadcast_sizes): 152 | sin, cos = sin_cos 153 | use_2d_cache_ids = len(sin.sizes) > 2 154 | if use_2d_cache_ids: 155 | # transpose from (n_seqs, n_active_tokens, d_head) to (n_active_tokens, n_seqs, d_head) 156 | sin_t = hlo.transpose(sin, 0, 1) 157 | cos_t = hlo.transpose(cos, 0, 1) 158 | # broadcast from (n_active_tokens, n_seqs, d_head) to (n_active_tokens, n_seqs, n_heads_tp, d_head) 159 | sin_r = hlo.broadcast(sin_t, broadcast_sizes, [0, 1, 3]) 160 | cos_r = hlo.broadcast(cos_t, broadcast_sizes, [0, 1, 3]) 161 | else: 162 | # 1D cache_ids 163 | sin_r = hlo.broadcast(sin, broadcast_sizes, [0, 3]) 164 | cos_r = hlo.broadcast(cos, broadcast_sizes, [0, 3]) 165 | return sin_r, cos_r 166 | 167 | # Get sin and cos as upper and lower half of input embedding 168 | sin_r, cos_r = _broadcast_sin_cos(sin_cos, broadcast_sizes) 169 | 170 | # Rotate query 171 | query = rotate_vec(query, sin_r, cos_r, rotary_percentage) 172 | 173 | # Get sin and cos as upper and lower half of input embedding 174 | kv_sin_r, kv_cos_r = _broadcast_sin_cos(sin_cos, kv_broadcast_sizes) 175 | 176 | # Rotate key 177 | key = rotate_vec(key, kv_sin_r, kv_cos_r, rotary_percentage) 178 | return query, key 179 | -------------------------------------------------------------------------------- /src/transformers_neuronx/llama/aem.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import torch 16 | import os 17 | from transformers_neuronx import decoder 18 | from transformers_neuronx import module 19 | from transformers_neuronx import ops 20 | from transformers_neuronx import sampling 21 | from transformers_neuronx import utils 22 | from transformers_neuronx import bucket 23 | from transformers_neuronx import base 24 | from transformers_neuronx.llama.config import LlamaConfig 25 | from transformers_neuronx.llama.modules import LlamaForCausalLM 26 | from transformers_neuronx.llama.hlo import LlamaForSamplingNoEmbeddingHlo 27 | from transformers_neuronx.llama.model import LlamaForSampling 28 | 29 | 30 | class AEMLlamaForSampling(LlamaForSampling): 31 | 32 | def __init__(self, config, *, n_positions=2048, batch_size=1, amp='f32', tp_degree=2, 33 | context_length_estimate=None, context_unroll=None, unroll=None, 34 | neuron_config=None, aem_models=None, **kwargs): 35 | super().__init__(config, n_positions=n_positions, batch_size=batch_size, amp=amp, 36 | tp_degree=tp_degree, context_length_estimate=context_length_estimate, 37 | context_unroll=context_unroll, unroll=unroll, neuron_config=neuron_config, 38 | **kwargs) 39 | self.max_positions = self.token_buckets[-1] 40 | self.aem_models = aem_models 41 | 42 | def run_aems(hidden): 43 | # Run AEM LLMs in parallel <<--------------- 44 | return hidden 45 | 46 | def lmhead(hidden): 47 | # Run LLM HEAD <<--------------- 48 | return hidden 49 | 50 | def forward(self, input_ids, cache_ids=None, start_ids=None): 51 | 52 | batch_size, context_length = input_ids.shape 53 | if start_ids is None: 54 | start_ids = torch.zeros(batch_size, dtype=torch.int32) 55 | if cache_ids is None: 56 | cache_ids = torch.arange(context_length, dtype=torch.int32) 57 | 58 | hidden = self.chkpt_model.model.embed_tokens(input_ids) 59 | hidden = hidden.transpose(0, -1).contiguous() 60 | 61 | if context_length > 1: 62 | hidden = self.context(hidden, cache_ids, start_ids) 63 | else: 64 | hidden = self.decoder_lm_head(hidden, cache_ids, start_ids) 65 | 66 | # Apply AEMs to hidden here <<---------------------- 67 | aems = self.run_aems(hidden) 68 | logits = self.lmhead(hidden) 69 | 70 | logits = self._cast_logits(logits) 71 | # We need to also apply Llama's LLM head as well 72 | logits = logits[:self.config.vocab_size, -1, :] 73 | logits = logits.transpose(0, 1) 74 | return logits 75 | 76 | 77 | def sample(self, input_ids, sequence_length, start_ids=None, 78 | top_k=50, top_p=1.0, eos_token_override=None, temperature=1.0, streamer=None): 79 | 80 | # To enable optimized context encoding network, we must pad 81 | # up to the context length estimate or we will not correctly 82 | # select the final context logits (See: layers/transformer.py). 83 | # This also means we need to shift the start_ids over to correct 84 | # for padding. 85 | offset = 0 86 | batch_size, context_length = input_ids.shape 87 | estimate = bucket.find(self.context_buckets, context_length) 88 | if estimate: 89 | if context_length < estimate: 90 | input_ids = utils.pad(input_ids, 1, estimate, left=True) 91 | offset = estimate - context_length 92 | if start_ids is None: 93 | start_ids = torch.zeros(batch_size, dtype=torch.int32) 94 | start_ids += offset 95 | sequence_length += offset 96 | # Sequence length cannot be greater than n_positions 97 | sequence_length = min(sequence_length, self.max_positions) 98 | 99 | # Change sample_llama in sampling.py to use AEM LLMs <------------------ 100 | result = sampling.sample_llama( 101 | self, input_ids, start_ids, sequence_length, 102 | eos_token_id=self.config.eos_token_id if eos_token_override is None else eos_token_override, 103 | top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer 104 | ) 105 | 106 | if offset != 0: 107 | result = result[:, offset:] 108 | return result 109 | -------------------------------------------------------------------------------- /src/transformers_neuronx/llama/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import utils 16 | 17 | class LlamaConfig: 18 | 19 | def __init__( 20 | self, 21 | config, 22 | n_positions, 23 | batch_size, 24 | amp, 25 | tp_degree, 26 | **kwargs 27 | ): 28 | 29 | # Extract configs used for building HLO 30 | self.intermediate_size = config.intermediate_size 31 | self.hidden_size = config.hidden_size 32 | self.attention_head_size = config.hidden_size // config.num_attention_heads 33 | self.num_attention_heads = config.num_attention_heads 34 | self.num_key_value_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads 35 | self.num_hidden_layers = config.num_hidden_layers 36 | self.vocab_size = config.vocab_size 37 | self.hidden_act = config.hidden_act 38 | self.bos_token_id = config.bos_token_id 39 | self.eos_token_id = config.eos_token_id 40 | self.max_position_embeddings = config.max_position_embeddings 41 | self.rms_norm_eps = config.rms_norm_eps 42 | self.rotary_percentage = getattr(config, "rotary_percentage", 1) 43 | self.rope_theta = getattr(config, "rope_theta", 10000) 44 | self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None) 45 | self.rope_scaling = getattr(config, "rope_scaling", None) 46 | rope_scaling_type = self.rope_scaling.get("rope_type", self.rope_scaling.get("type", None)) if self.rope_scaling is not None else None 47 | if self.rope_scaling is not None and rope_scaling_type not in {'default', 'llama3'}: 48 | raise ValueError(f"Only default and llama3 ropes scaling types are currently supported. Received {rope_scaling_type}") 49 | self.is_eagle = getattr(config, "is_eagle", False) 50 | self.bias = getattr(config, "bias", True) 51 | 52 | utils.maybe_override_attributes(self, kwargs) 53 | 54 | # Add required Neuron configs 55 | self.n_positions = n_positions 56 | self.batch_size = batch_size 57 | self.amp = amp 58 | self.tp_degree = tp_degree 59 | self.model_type = 'llama' 60 | 61 | -------------------------------------------------------------------------------- /src/transformers_neuronx/llama/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import dtypes 16 | from transformers_neuronx import module 17 | from transformers_neuronx import utils 18 | 19 | 20 | class LlamaForCausalLM(module.PretrainedModel): 21 | 22 | def __init__(self, config): 23 | super().__init__() 24 | dtype, _, _ = utils.parse_amp(config.amp) 25 | dtype = dtypes.to_torch_dtype(dtype) 26 | self.model = LlamaModel(config) if not config.is_eagle else EagleLlamaModel(config, config.bias) 27 | self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) 28 | 29 | def get_tied_parameters(self): 30 | return [(self.model.embed_tokens.weight, self.lm_head.weight)] 31 | 32 | def get_base_model(self): 33 | return self.model 34 | 35 | 36 | class LlamaModel(module.LowMemoryModule): 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) 41 | self.layers = module.LowMemoryModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 42 | self.norm = LlamaRMSNorm(config) 43 | 44 | 45 | class LlamaRMSNorm(module.LowMemoryModule): 46 | 47 | def __init__(self, config) -> None: 48 | super().__init__() 49 | self.weight = module.UninitializedParameter() 50 | 51 | 52 | class LlamaDecoderLayer(module.LowMemoryModule): 53 | 54 | def __init__(self, config): 55 | super().__init__() 56 | self.self_attn = LlamaAttention(config) 57 | self.mlp = LlamaMLP(config) 58 | self.input_layernorm = LlamaRMSNorm(config) 59 | self.post_attention_layernorm = LlamaRMSNorm(config) 60 | 61 | 62 | class LlamaAttention(module.LowMemoryModule): 63 | 64 | def __init__(self, config): 65 | super().__init__() 66 | self.hidden_size = config.hidden_size 67 | self.num_heads = config.num_attention_heads 68 | self.head_dim = self.hidden_size // self.num_heads 69 | dtype, _, _ = utils.parse_amp(config.amp) 70 | dtype = dtypes.to_torch_dtype(dtype) 71 | self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 72 | self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 73 | self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 74 | self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) 75 | 76 | 77 | class LlamaMLP(module.LowMemoryModule): 78 | 79 | def __init__(self, config): 80 | super().__init__() 81 | dtype, _, _ = utils.parse_amp(config.amp) 82 | dtype = dtypes.to_torch_dtype(dtype) 83 | self.gate_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) 84 | self.up_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) 85 | self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) 86 | 87 | class EagleLlamaModel(module.LowMemoryModule): 88 | 89 | def __init__(self, config, bias=True): 90 | super().__init__() 91 | self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) 92 | self.layers = module.LowMemoryModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 93 | self.norm = LlamaRMSNorm(config) 94 | # An extra fc layer before attention layers 95 | dtype, _, _ = utils.parse_amp(config.amp) 96 | dtype = dtypes.to_torch_dtype(dtype) 97 | self.fc = module.LowMemoryLazyLinear(config.hidden_size*2, bias=bias, dtype=dtype) 98 | -------------------------------------------------------------------------------- /src/transformers_neuronx/mistral/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import utils 16 | 17 | class MistralConfig: 18 | 19 | def __init__( 20 | self, 21 | config, 22 | n_positions, 23 | batch_size, 24 | amp, 25 | tp_degree, 26 | **kwargs 27 | ): 28 | 29 | # Extract configs used for building HLO 30 | self.intermediate_size = config.intermediate_size 31 | self.hidden_size = config.hidden_size 32 | self.attention_head_size = config.hidden_size // config.num_attention_heads 33 | self.num_attention_heads = config.num_attention_heads 34 | self.num_key_value_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads 35 | self.num_hidden_layers = config.num_hidden_layers 36 | self.vocab_size = config.vocab_size 37 | self.hidden_act = config.hidden_act 38 | self.bos_token_id = config.bos_token_id 39 | self.eos_token_id = config.eos_token_id 40 | self.max_position_embeddings = config.max_position_embeddings 41 | self.rms_norm_eps = config.rms_norm_eps 42 | self.rotary_percentage = getattr(config, "rotary_percentage", 1) 43 | self.rope_theta = getattr(config, "rope_theta", 10000) 44 | self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None) 45 | self.sliding_window = getattr(config, "sliding_window", None) 46 | 47 | utils.maybe_override_attributes(self, kwargs) 48 | 49 | # Add required Neuron configs 50 | self.n_positions = n_positions 51 | self.batch_size = batch_size 52 | self.amp = amp 53 | self.tp_degree = tp_degree 54 | self.model_type = 'mistral' 55 | 56 | -------------------------------------------------------------------------------- /src/transformers_neuronx/mistral/model.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import decoder 16 | from transformers_neuronx import sampling 17 | from transformers_neuronx import bucket 18 | from transformers_neuronx import base 19 | from transformers_neuronx.constants import LAYOUT_HSB 20 | from transformers_neuronx.config import NeuronConfig 21 | from transformers_neuronx.mistral.config import MistralConfig 22 | from transformers_neuronx.mistral.modules import MistralForCausalLM 23 | from transformers_neuronx.mistral.hlo import MistralForSamplingNoEmbeddingHlo 24 | 25 | 26 | class MistralForSampling(base.NeuronModelBase): 27 | 28 | def __init__(self, config, *, n_positions=2048, batch_size=1, amp='f32', tp_degree=2, 29 | context_length_estimate=None, context_unroll=None, unroll=None, 30 | neuron_config=NeuronConfig(), **kwargs): 31 | config = MistralConfig(config, n_positions, batch_size, amp, tp_degree) 32 | super().__init__(MistralForCausalLM, config) 33 | self.context_pre_hook = None 34 | self.context_hook = None 35 | self.config = config 36 | self.neuron_config = neuron_config if neuron_config else NeuronConfig() 37 | if self.neuron_config.on_device_generation: 38 | self.neuron_config.on_device_generation.vocab_size = self.config.vocab_size 39 | if context_unroll is None: 40 | context_unroll = config.num_hidden_layers 41 | self.context_unroll = context_unroll 42 | 43 | if unroll is None: 44 | unroll = config.num_hidden_layers 45 | self.unroll=unroll 46 | 47 | self.token_buckets = bucket.token_sizes(n_positions) 48 | self.context_buckets = bucket.context_sizes(context_length_estimate, self.token_buckets) 49 | 50 | self.batch_sizes = bucket.batch_sizes(batch_size) 51 | self.context_batch_sizes = [1] if self.neuron_config and self.neuron_config.continuous_batching else self.batch_sizes 52 | hlo_builder = MistralForSamplingNoEmbeddingHlo(config, neuron_config=self.neuron_config) 53 | self.decoder_param_set = decoder.DecoderLmHeadForSamplingNoEmbedding( 54 | tp_degree=tp_degree, n_positions_list=self.token_buckets, n_active_tokens=1, batch_size=self.batch_sizes, 55 | attention_head_size=config.attention_head_size, amp=amp, 56 | num_layers=config.num_hidden_layers, n_head=config.num_attention_heads, n_kv_head=config.num_key_value_heads, 57 | unroll=unroll, neuron_config=self.neuron_config, allow_pad=True, 58 | builder=hlo_builder 59 | ) 60 | self.decoder_lm_head = self.decoder_param_set.init_token_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self) 61 | self.decoder_lm_head_for_context = self.decoder_param_set.init_context_decoder(unroll=self.context_unroll, buckets=self.context_buckets, model_obj=self) 62 | 63 | def load_weights(self): 64 | self.materialize_embeddings() 65 | 66 | for layer in self.chkpt_model.model.layers: 67 | layer.materialize() 68 | attn = layer.self_attn 69 | mlp = layer.mlp 70 | new_layer = self.decoder_lm_head.new_layer() 71 | new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) 72 | new_layer.add_attention_query(attn.q_proj.weight.detach().T, None) 73 | new_layer.add_attention_key(attn.k_proj.weight.detach().T, None) 74 | new_layer.add_attention_value(attn.v_proj.weight.detach().T, None) 75 | new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) 76 | new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) 77 | 78 | # Note: Automatic MLP padding is safe since zeros are *only* introduced to intermediary state 79 | new_layer.add_parameter(mlp.gate_proj.weight.T, sharding=1, allow_pad=True, 80 | allow_quantize=True, allow_transform=True) 81 | new_layer.add_parameter(mlp.up_proj.weight.T, sharding=1, allow_pad=True, 82 | allow_quantize=True, allow_transform=True) 83 | if self.neuron_config.weight_tiling: 84 | new_layer.add_parameter(mlp.down_proj.weight.T, sharding=0, allow_pad=True, 85 | allow_quantize=True, allow_transform=True) 86 | else: 87 | new_layer.add_parameter(mlp.down_proj.weight, sharding=1, allow_pad=True, 88 | allow_quantize=True, out_feature_dim=0) 89 | 90 | new_layer.to_neuron() 91 | layer.nullify() 92 | 93 | ln_f = self.chkpt_model.model.norm 94 | ln_f.materialize() 95 | self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) 96 | ln_f.nullify() 97 | 98 | lm_head = self.chkpt_model.lm_head 99 | lm_head.materialize() 100 | self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) 101 | if self.neuron_config.on_device_embedding: 102 | self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.model.embed_tokens.weight, sharding=1, allow_pad=True) 103 | lm_head.nullify() 104 | 105 | self.decoder_lm_head.to_neuron() 106 | self.init_rest_of_model() 107 | self.maybe_nullify_embeddings() 108 | 109 | def materialize_embeddings(self): 110 | # Materialize the embedding to CPU 111 | self.chkpt_model.model.embed_tokens.materialize() 112 | 113 | def maybe_nullify_embeddings(self): 114 | if self.neuron_config.on_device_embedding: 115 | self.chkpt_model.model.embed_tokens.nullify() 116 | 117 | def init_rest_of_model(self): 118 | self.decoder_lm_head.use_executor = True 119 | 120 | if self.context_buckets: 121 | for context_length_estimate in self.context_buckets: 122 | for batch_size in self.context_batch_sizes: 123 | model = self.decoder_lm_head.build_weight_shared(share_caches=True, 124 | new=self.decoder_lm_head_for_context[context_length_estimate, batch_size]) 125 | # PERF: No latency improvement seen in multi-layer models from executor 126 | if self.context_unroll == self.config.num_hidden_layers: 127 | model.use_executor = True 128 | self.decoder_lm_head_for_context[context_length_estimate,batch_size] = model 129 | 130 | def forward(self, input_ids, cache_ids=None, start_ids=None): 131 | inputs, *rst = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids) 132 | if not self.neuron_config.on_device_embedding: 133 | inputs = self.chkpt_model.model.embed_tokens(inputs) 134 | if self.neuron_config.attention_layout == LAYOUT_HSB: 135 | inputs = inputs.transpose(0, -1).contiguous() 136 | logits = self._forward(inputs, *rst) 137 | logits = self._postprocess(input_ids, logits, start_ids=start_ids) 138 | 139 | return logits 140 | 141 | def sample(self, input_ids, sequence_length, start_ids=None, 142 | top_k=50, top_p=1.0, eos_token_override=None, temperature=1.0, streamer=None, stopping_criteria_list=None): 143 | 144 | if self.neuron_config.on_device_generation: 145 | return sampling.sample_tokens(self, input_ids, start_ids, sequence_length=sequence_length, 146 | config=self.neuron_config.on_device_generation, streamer=streamer) 147 | 148 | if self.context_pre_hook is not None: 149 | self.context_pre_hook() 150 | batch_size, context_length = input_ids.shape 151 | if batch_size not in self.batch_sizes: 152 | raise ValueError(f"Model not compiled for batch_size : {batch_size}. Acceptable batch_size is one of the following {self.batch_sizes}") 153 | 154 | result = sampling.sample_llama( 155 | self, input_ids, start_ids, sequence_length, 156 | eos_token_id=self.config.eos_token_id if eos_token_override is None else eos_token_override, 157 | top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer, 158 | stopping_criteria_list=stopping_criteria_list 159 | ) 160 | 161 | return result 162 | -------------------------------------------------------------------------------- /src/transformers_neuronx/mistral/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import dtypes 16 | from transformers_neuronx import module 17 | from transformers_neuronx import utils 18 | 19 | 20 | class MistralForCausalLM(module.PretrainedModel): 21 | 22 | def __init__(self, config): 23 | super().__init__() 24 | dtype, _, _ = utils.parse_amp(config.amp) 25 | dtype = dtypes.to_torch_dtype(dtype) 26 | self.model = MistralModel(config) 27 | self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) 28 | 29 | def get_tied_parameters(self): 30 | return [(self.model.embed_tokens.weight, self.lm_head.weight)] 31 | 32 | def get_base_model(self): 33 | return self.model 34 | 35 | 36 | class MistralModel(module.LowMemoryModule): 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) 41 | self.layers = module.LowMemoryModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 42 | self.norm = MistralRMSNorm(config) 43 | 44 | 45 | class MistralRMSNorm(module.LowMemoryModule): 46 | 47 | def __init__(self, config) -> None: 48 | super().__init__() 49 | self.weight = module.UninitializedParameter() 50 | 51 | 52 | class MistralDecoderLayer(module.LowMemoryModule): 53 | 54 | def __init__(self, config): 55 | super().__init__() 56 | self.self_attn = MistralAttention(config) 57 | self.mlp = MistralMLP(config) 58 | self.input_layernorm = MistralRMSNorm(config) 59 | self.post_attention_layernorm = MistralRMSNorm(config) 60 | 61 | 62 | class MistralAttention(module.LowMemoryModule): 63 | 64 | def __init__(self, config): 65 | super().__init__() 66 | self.hidden_size = config.hidden_size 67 | self.num_heads = config.num_attention_heads 68 | self.head_dim = self.hidden_size // self.num_heads 69 | dtype, _, _ = utils.parse_amp(config.amp) 70 | dtype = dtypes.to_torch_dtype(dtype) 71 | self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 72 | self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 73 | self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 74 | self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) 75 | 76 | 77 | class MistralMLP(module.LowMemoryModule): 78 | 79 | def __init__(self, config): 80 | super().__init__() 81 | dtype, _, _ = utils.parse_amp(config.amp) 82 | dtype = dtypes.to_torch_dtype(dtype) 83 | self.gate_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) 84 | self.up_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) 85 | self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) 86 | -------------------------------------------------------------------------------- /src/transformers_neuronx/mixtral/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import utils 16 | 17 | class MixtralConfig: 18 | 19 | def __init__( 20 | self, 21 | config, 22 | n_positions, 23 | batch_size, 24 | amp, 25 | tp_degree, 26 | **kwargs 27 | ): 28 | 29 | # Extract configs used for building HLO 30 | self.intermediate_size = config.intermediate_size 31 | self.hidden_size = config.hidden_size 32 | self.attention_head_size = config.hidden_size // config.num_attention_heads 33 | self.num_attention_heads = config.num_attention_heads 34 | self.num_key_value_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads 35 | self.num_hidden_layers = config.num_hidden_layers 36 | self.vocab_size = config.vocab_size 37 | self.hidden_act = config.hidden_act 38 | self.bos_token_id = config.bos_token_id 39 | self.eos_token_id = config.eos_token_id 40 | self.max_position_embeddings = config.max_position_embeddings 41 | self.rms_norm_eps = config.rms_norm_eps 42 | self.rotary_percentage = getattr(config, "rotary_percentage", 1) 43 | self.rope_theta = getattr(config, "rope_theta", 1e6) 44 | self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None) 45 | self.sliding_window = getattr(config, "sliding_window", None) 46 | self.num_experts_per_tok = config.num_experts_per_tok 47 | self.num_local_experts = config.num_local_experts 48 | 49 | utils.maybe_override_attributes(self, kwargs) 50 | 51 | # Add required Neuron configs 52 | self.n_positions = n_positions 53 | self.batch_size = batch_size 54 | self.amp = amp 55 | self.tp_degree = tp_degree 56 | self.model_type = 'mixtral' 57 | 58 | # Check values of tp_degree 59 | # The MoE implementation supports 1) tp_degree is divisible by num_local_experts or 2) num_local_experts is divisible by tp_degree 60 | # However, due to memory limit, only tp_degree = {8, 16, 32} are supported. Note that tp_degree = 8 needs to use f16 or bf16 61 | if (self.num_local_experts % self.tp_degree != 0) and (self.tp_degree % self.num_local_experts != 0): 62 | raise ValueError(f"tp_degree needs to be 8, 16 or 32. Use f16 or bf16 or tp_degree = 8.") 63 | 64 | -------------------------------------------------------------------------------- /src/transformers_neuronx/mixtral/hlo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from typing import Optional 16 | 17 | from transformers_neuronx import hlo 18 | from transformers_neuronx import utils 19 | from transformers_neuronx.layers import transformer 20 | from transformers_neuronx.mistral.hlo import MistralForSamplingNoEmbeddingHlo 21 | from transformers_neuronx.mixtral.config import MixtralConfig 22 | from transformers_neuronx.config import NeuronConfig 23 | from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB 24 | from transformers_neuronx.sparse_attn_utils import build_sliding_window_mask 25 | 26 | class MixtralForSamplingNoEmbeddingHlo(MistralForSamplingNoEmbeddingHlo): 27 | 28 | def __init__(self, 29 | config: MixtralConfig, 30 | neuron_config: Optional[NeuronConfig] = None 31 | ): 32 | super().__init__(config, neuron_config) 33 | is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH 34 | assert not is_bsh, "BSH layout is currently not supported for moe_layer" 35 | assert str(MistralForSamplingNoEmbeddingHlo.attention) == str(self.attention.__func__), \ 36 | "The self.attention() function should be derived from MistralForSamplingNoEmbeddingHlo.attention()" 37 | 38 | def layer( 39 | self, 40 | # hidden in layer_builder (from decoder.py) 41 | hidden, 42 | # tensors in layer_builder (from decoder.py) 43 | last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, 44 | # in_caches in layer_builder (from decoder.py) 45 | attn_k_cache, attn_v_cache, 46 | # weights in layer_builder (from decoder.py) 47 | pre_attn_ln_weight, pre_attn_ln_bias, 48 | # attention 49 | attn_q_weight, attn_q_scales, attn_q_bias, 50 | attn_k_weight, attn_k_scales, attn_k_bias, 51 | attn_v_weight, attn_v_scales, attn_v_bias, 52 | attn_out_weight, attn_out_scales, attn_out_bias, 53 | # rms_norm 54 | post_attn_ln_weight, post_attn_ln_bias, 55 | pre_mlp_ln_weight, pre_mlp_ln_bias, 56 | # placeholder 57 | mlp_in_weight, mlp_in_scales, mlp_in_bias, 58 | mlp_out_weight, mlp_out_scales, mlp_out_bias, 59 | post_mlp_ln_weight, post_mlp_ln_bias, 60 | # gating network and experts for MoE 61 | expert_indices, gate_weight, 62 | w1_weight_tp, w1_scales, w2_weight_tp, w2_scales, w3_weight_tp, w3_scales 63 | ): 64 | 65 | eps = self.config.rms_norm_eps 66 | is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH 67 | ln_hidden = hlo.rms_norm(hidden, pre_attn_ln_weight, eps) if is_bsh else hlo.rms_norm(hidden, pre_attn_ln_weight, eps, dim=0) 68 | attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( 69 | ln_hidden, cache_ids, start_ids, pos_embed, mask, active_mask, 70 | attn_k_cache, attn_v_cache, 71 | attn_q_weight, attn_q_scales, attn_q_bias, 72 | attn_k_weight, attn_k_scales, attn_k_bias, 73 | attn_v_weight, attn_v_scales, attn_v_bias, 74 | attn_out_weight, attn_out_scales, attn_out_bias 75 | ) 76 | hidden = hlo.add(attn_output, hidden) 77 | rms_norm_dim = 2 if is_bsh else 0 78 | norm_hidden = hlo.rms_norm(hidden, pre_mlp_ln_weight, eps, dim=rms_norm_dim) 79 | 80 | final_hidden_states = self.moe_layer(norm_hidden, expert_indices, gate_weight, \ 81 | w1_weight_tp, w1_scales, w2_weight_tp, w2_scales, w3_weight_tp, w3_scales) 82 | 83 | res_hidden = hlo.add(final_hidden_states, hidden) 84 | return res_hidden, out_attn_k_cache, out_attn_v_cache 85 | 86 | def ln_lm_head(self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias, return_all_outputs=True): 87 | return transformer.rms_lm_head(self.config.tp_degree, hidden, last_token_id, rms_weight, lm_head_weight, lm_head_bias, return_all_outputs, eps=self.config.rms_norm_eps, neuron_config=self.neuron_config) 88 | 89 | def moe_layer( 90 | self, 91 | norm_hidden, expert_indices, gate_weight, 92 | w1_weight_tp, w1_scales, w2_weight_tp, w2_scales, w3_weight_tp, w3_scales 93 | ): 94 | is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH 95 | assert not is_bsh, "BSH layout is currently not supported for moe_layer" 96 | gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp 97 | 98 | # Gating network 99 | dot_dims = dict(lhs_contracting_dimensions=[0], rhs_contracting_dimensions=[0]) 100 | router_logits = hlo.dot_general(gate_weight, norm_hidden, dot_dims) 101 | routing_weights = hlo.softmax(router_logits, dim=0) 102 | routing_weights, selected_experts = hlo.topk(routing_weights, k=self.config.num_experts_per_tok, dim=0) 103 | 104 | # Normalize weights of activated experts 105 | routing_weights_sum = hlo.reduce_sum(routing_weights, dim=0, keepdim=True) 106 | routing_weights_sum_br = hlo.broadcast(routing_weights_sum, routing_weights.sizes, [0, 1, 2]) 107 | routing_weights = hlo.divide(routing_weights, routing_weights_sum_br) 108 | 109 | # Following expert parallelism implement in https://github.com/vllm-project/vllm/pull/2090 110 | num_experts_per_core = expert_indices.sizes[0] 111 | _, intermediate_size = w1_weight_tp.sizes 112 | slice_size = intermediate_size // num_experts_per_core 113 | slice_size_const = hlo.full(slice_size, dtype=expert_indices.dtype, sizes=[]) 114 | 115 | local_hidden_states = None 116 | for idx in range(num_experts_per_core): 117 | idx_const = hlo.full(idx, dtype=expert_indices.dtype, sizes=[]) 118 | 119 | # Slice weight for tp < num_local_experts 120 | slice_idx_const = hlo.multiply(idx_const, slice_size_const) 121 | w1_weight = hlo.dynamic_slice_along(w1_weight_tp, dim=1, start=slice_idx_const, size=slice_size) 122 | w3_weight = hlo.dynamic_slice_along(w3_weight_tp, dim=1, start=slice_idx_const, size=slice_size) 123 | w2_weight = hlo.dynamic_slice_along(w2_weight_tp, dim=1, start=slice_idx_const, size=slice_size) 124 | 125 | # Build expert mask 126 | expert_idx = hlo.dynamic_slice_along(expert_indices, dim=0, start=idx_const, size=1) 127 | expert_idx_br = hlo.broadcast(expert_idx, selected_experts.sizes, [0]) 128 | expert_idx_br = hlo.cast(expert_idx_br, selected_experts.dtype) 129 | expert_mask = hlo.equal(selected_experts, expert_idx_br) 130 | expert_mask = hlo.cast(expert_mask, routing_weights.dtype) 131 | expert_weights = hlo.multiply(routing_weights, expert_mask) 132 | expert_weights = hlo.reduce_sum(expert_weights, dim=0, keepdim=True) # all-reduce across selected experts 133 | 134 | mlp_hidden = gated_mlp( 135 | norm_hidden, 136 | w1_weight, w3_weight, w2_weight, 137 | in0_scales=w1_scales, 138 | in1_scales=w3_scales, 139 | out_scales=w2_scales, 140 | activation_function='silu', 141 | tp_degree=self.config.tp_degree, 142 | neuron_config=self.neuron_config, 143 | return_partial=True, 144 | ) 145 | # Apply expert weighting 146 | expert_weights_br = hlo.broadcast(expert_weights, mlp_hidden.sizes, [0, 1, 2]) 147 | current_hidden_states = hlo.multiply(mlp_hidden, expert_weights_br) 148 | 149 | if local_hidden_states is None: 150 | local_hidden_states = current_hidden_states 151 | else: 152 | local_hidden_states = hlo.add(local_hidden_states, current_hidden_states) 153 | 154 | dtype, replica_groups = utils.parse_dtype_replica_groups(self.neuron_config, self.config.tp_degree) 155 | final_hidden_states = hlo.all_reduce_sum(local_hidden_states, self.config.tp_degree, dtype=dtype, replica_groups=replica_groups) 156 | 157 | return final_hidden_states 158 | -------------------------------------------------------------------------------- /src/transformers_neuronx/mixtral/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import dtypes 16 | from transformers_neuronx import module 17 | from transformers_neuronx import utils 18 | 19 | 20 | class MixtralForCausalLM(module.PretrainedModel): 21 | 22 | def __init__(self, config): 23 | super().__init__() 24 | dtype, _, _ = utils.parse_amp(config.amp) 25 | dtype = dtypes.to_torch_dtype(dtype) 26 | self.model = MixtralModel(config) 27 | self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) 28 | 29 | def get_tied_parameters(self): 30 | return [(self.model.embed_tokens.weight, self.lm_head.weight)] 31 | 32 | def get_base_model(self): 33 | return self.model 34 | 35 | 36 | class MixtralModel(module.LowMemoryModule): 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) 41 | self.layers = module.LowMemoryModuleList([MixtralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 42 | self.norm = MixtralRMSNorm(config) 43 | 44 | 45 | class MixtralRMSNorm(module.LowMemoryModule): 46 | 47 | def __init__(self, config) -> None: 48 | super().__init__() 49 | self.weight = module.UninitializedParameter() 50 | 51 | 52 | class MixtralDecoderLayer(module.LowMemoryModule): 53 | 54 | def __init__(self, config): 55 | super().__init__() 56 | self.self_attn = MixtralAttention(config) 57 | self.block_sparse_moe = MixtralSparseMoeBlock(config) 58 | self.input_layernorm = MixtralRMSNorm(config) 59 | self.post_attention_layernorm = MixtralRMSNorm(config) 60 | 61 | 62 | class MixtralAttention(module.LowMemoryModule): 63 | 64 | def __init__(self, config): 65 | super().__init__() 66 | self.hidden_size = config.hidden_size 67 | self.num_heads = config.num_attention_heads 68 | self.head_dim = self.hidden_size // self.num_heads 69 | dtype, _, _ = utils.parse_amp(config.amp) 70 | dtype = dtypes.to_torch_dtype(dtype) 71 | self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 72 | self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 73 | self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) 74 | self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) 75 | 76 | 77 | class MixtralBLockSparseTop2MLP(module.LowMemoryModule): 78 | 79 | def __init__(self, config): 80 | super().__init__() 81 | dtype, _, _ = utils.parse_amp(config.amp) 82 | dtype = dtypes.to_torch_dtype(dtype) 83 | self.w1 = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) 84 | self.w2 = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) 85 | self.w3 = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) 86 | 87 | 88 | class MixtralSparseMoeBlock(module.LowMemoryModule): 89 | 90 | def __init__(self, config): 91 | super().__init__() 92 | dtype, _, _ = utils.parse_amp(config.amp) 93 | dtype = dtypes.to_torch_dtype(dtype) 94 | self.gate = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) 95 | self.experts = module.LowMemoryModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(config.num_local_experts)]) 96 | -------------------------------------------------------------------------------- /src/transformers_neuronx/modeling_auto.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import AutoConfig 3 | import transformers_neuronx 4 | 5 | NEURON_MODEL_FOR_CAUSAL_LM_MAPPING = { 6 | "bloom": transformers_neuronx.BloomForSampling, 7 | "code_llama": transformers_neuronx.LlamaForSampling, 8 | "gpt2": transformers_neuronx.GPT2ForSamplingWithContextBroadcasting, 9 | "gpt_neox": transformers_neuronx.GPTNeoXForSampling, 10 | "gptj": transformers_neuronx.GPTJForSampling, 11 | "llama": transformers_neuronx.LlamaForSampling, 12 | "mistral": transformers_neuronx.MistralForSampling, 13 | "mixtral": transformers_neuronx.MixtralForSampling, 14 | "opt": transformers_neuronx.OPTForSampling, 15 | } 16 | 17 | 18 | CONFIG_MAPPING = { 19 | transformers.BloomConfig: "bloom", 20 | transformers.LlamaConfig: "llama", 21 | transformers.GPT2Config: "gpt2", 22 | transformers.GPTNeoXConfig: "gpt_neox", 23 | transformers.GPTJConfig: "gptj", 24 | transformers.MistralConfig: "mistral", 25 | transformers.MixtralConfig: "mixtral", 26 | transformers.OPTConfig: "opt", 27 | } 28 | 29 | 30 | class NeuronAutoModelForCausalLM: 31 | 32 | @classmethod 33 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 34 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) 35 | 36 | if type(config) in CONFIG_MAPPING: 37 | model_type = CONFIG_MAPPING[type(config)] 38 | elif hasattr(config, "model_type"): # Fallback for custom/derived config objects 39 | model_type = config.model_type 40 | else: 41 | raise AssertionError(f"Models based on '{type(config)}' are not supported by Neuron") 42 | 43 | if model_type not in NEURON_MODEL_FOR_CAUSAL_LM_MAPPING: 44 | raise AssertionError(f"The configuration model type '{model_type}' is not supported by Neuron") 45 | 46 | model_class = NEURON_MODEL_FOR_CAUSAL_LM_MAPPING[model_type] 47 | return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 48 | -------------------------------------------------------------------------------- /src/transformers_neuronx/nki/compile.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import json 17 | import torch 18 | import base64 19 | import numpy as np 20 | import inspect 21 | from transformers_neuronx import compiler 22 | from neuronxcc.nki import trace 23 | import neuronxcc.nki.language as nl 24 | from neuronxcc.nki import FrameworkKernel 25 | from torch_neuronx.pyhlo.scribe import HloShape 26 | 27 | 28 | class PyTorchTracedKernel(FrameworkKernel): 29 | dtype_converter = compiler.DataTypeConverter() 30 | 31 | @staticmethod 32 | def get_shape(hloShape): 33 | return tuple([d for d in hloShape.shape_proto.dimensions]) 34 | 35 | def translate_to_neuron_dtype(self, _dtype): 36 | dtype = self.dtype_converter.hlo2torch(_dtype) 37 | if dtype == torch.bfloat16: 38 | dtype = nl.bfloat16 39 | else: 40 | dtype = torch.empty(1, dtype=dtype).numpy().dtype 41 | return dtype 42 | 43 | def is_framework_tensor(self, t): 44 | return isinstance(t, HloShape) 45 | 46 | def map_framework_tensor(self, hloShape): 47 | shape = self.get_shape(hloShape) 48 | dtype = hloShape.shape_proto.element_type 49 | return shape, dtype 50 | 51 | def nki_call(func, *args, **kwargs): 52 | """ 53 | This function applies NKI kernel function (func) to inputs (*args) in PyHLO. 54 | 55 | Args: 56 | func: NKI kernel function 57 | args: inputs of func 58 | kwargs: 59 | grid (grid used in NKI kernel function) 60 | output_HloShapes (HloShapes of outputs of NKI kernel) 61 | 62 | 63 | Example: 64 | def mixed_pyhlo_nki(x, y): 65 | h = x.dtype[x.sizes].Multiply(x, x) 66 | o = nki_call(add_kernel, h, y, grid=32, output_HloShapes=x.dtype[x.sizes]) 67 | return o 68 | """ 69 | 70 | grid = kwargs.pop("grid", None) 71 | return NkiHloKernel(func, grid=grid)(*args, **kwargs) 72 | 73 | 74 | class NkiHloKernel: 75 | """ 76 | This class lowers a user defined compiler kernel to PyHLO op. 77 | 78 | This is the FAL binding for the NKI API for compiler to program Neuron Device directly. 79 | 80 | Parameters: 81 | func: the function of the baremetal kernel defition 82 | grid: launch grid configuration 83 | kernel_attrs: List[str], string attribute to control code injection point in compiler 84 | 85 | There are 2 steps to use a NKI kernel: 86 | 1) Define NKI kernel 87 | 2) Use NKI kernel within PyHLO by nki_call 88 | 89 | Example: 90 | # 1) Define NKI Kernel 91 | 92 | def add_kernel(a_input, b_input, c_output): 93 | # Calculate tile offsets based on current 'program' 94 | offset_i_x = nl.program_id(0) * 128 95 | offset_i_y = nl.program_id(1) * 512 96 | 97 | # Generate tensor indices to index tensors a and b 98 | ix = offset_i_x + nl.arange(128)[:, None] 99 | iy = offset_i_y + nl.arange(512)[None, :] 100 | 101 | # Load input data from external memory to on-chip memory 102 | # We refer to an indexed portion of a tensor as an intermediate tensor 103 | a_tile = nl.load(a_input[ix, iy]) 104 | b_tile = nl.load(b_input[ix, iy]) 105 | 106 | # compute a + b 107 | c_tile = a_tile + b_tile 108 | 109 | # store the addition results back to external memory (c_output) 110 | nl.store(c_output[ix, iy], value=c_tile) 111 | 112 | # 2) Use NKI kernel by nki_call: 113 | 114 | def mixed_pyhlo_nki(x, y): 115 | grid_x = x.sizes[0] // 128 116 | grid_y = x.sizes[1] // 512 117 | h = x.dtype[x.sizes].Multiply(x, x) 118 | o = nki_call(nki_add, h, y, grid=(grid_x, grid_y), output_HloShapes=[y.dtype[y.sizes]]) 119 | return o 120 | 121 | 122 | """ 123 | 124 | def __init__(self, func, grid=None, **kwargs): 125 | self.func = func 126 | self.grid = () 127 | if grid is not None: 128 | self.set_grid(grid) 129 | self._kernel = PyTorchTracedKernel( 130 | func_name=func.__name__, 131 | func=self.func, 132 | grid=self.grid, 133 | **kwargs 134 | ) 135 | 136 | def set_grid(self, grid): 137 | if not isinstance(grid, (tuple, list)): 138 | grid = [grid] 139 | self.grid = grid 140 | 141 | def __call__(self, *args, output_HloShapes=None): 142 | if output_HloShapes is None: 143 | raise ValueError("output_shape should be set in NkiHloKernel !") 144 | 145 | if not isinstance(output_HloShapes, (list, tuple)): 146 | output_HloShapes = [output_HloShapes] 147 | 148 | input_output_HloShapes = (*args, *output_HloShapes) 149 | config_str, input_names, output_names = self._kernel.dump_config(*input_output_HloShapes) 150 | 151 | if len(output_HloShapes) > 1: 152 | output_HloShapes = args[0].scribe.tuple(*output_HloShapes) 153 | else: 154 | output_HloShapes, = output_HloShapes 155 | 156 | output = output_HloShapes.CustomCall(*args, 157 | backend_config=str.encode(config_str), 158 | custom_call_target='AwsNeuronCustomNativeKernel') 159 | 160 | return output -------------------------------------------------------------------------------- /src/transformers_neuronx/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import torch 16 | import torch_neuronx # registers torch.ops.neuron 17 | 18 | 19 | def init(): 20 | return torch.ops.neuron._init_neuron() 21 | 22 | 23 | def to_nc(tensor, ordinal=0): 24 | return torch.ops.neuron._to_neuron(tensor.contiguous(), ordinal) 25 | 26 | 27 | def cpu(tensor): 28 | return torch.ops.neuron._from_neuron(tensor) 29 | 30 | 31 | def slice(tensor, dim, start, end, step): 32 | return torch.ops.neuron._slice_neuron(tensor, dim, start, end, step) 33 | 34 | 35 | def load(model, nc_id, nc_count): 36 | model.set_neuron_devices(nc_id, nc_count) 37 | return torch.ops.neuron._load_neuron(model) 38 | 39 | 40 | def load_collectives(model, nc_id, nc_count, g_nc_id, g_nc_count): 41 | return torch.ops.neuron._load_collectives_neuron(model, nc_id, nc_count, g_nc_id, g_nc_count) 42 | 43 | 44 | def execute(model, inputs): 45 | return torch.ops.neuron._execute_neuron(model, inputs) 46 | 47 | 48 | def parallel_to_nc(tensors): 49 | return torch.ops.neuron._parallel_to_neuron(tensors) 50 | 51 | 52 | def parallel_cpu(tensor): 53 | return torch.ops.neuron._parallel_from_neuron(tensor) 54 | 55 | 56 | def parallel_write(tensor, tensors): 57 | return torch.ops.neuron._parallel_write_neuron(tensor, tensors) 58 | 59 | 60 | def parallel_slice(tensor, dim, start, end, step): 61 | return torch.ops.neuron._parallel_slice_neuron(tensor, dim, start, end, step) 62 | 63 | 64 | def parallel_run(parallel_model, parallel_inputs, parallel_outputs): 65 | return torch.ops.neuron._parallel_run_neuron( 66 | parallel_model, parallel_inputs, parallel_outputs) 67 | 68 | 69 | def profile_start(model, ntff): 70 | return torch.ops.neuron._profile_start_neuron(model, ntff) 71 | 72 | 73 | def profile_stop(ntff): 74 | return torch.ops.neuron._profile_stop_neuron(ntff) 75 | 76 | 77 | def parallel_profile_start(model, ntff_prefix, ntff_count_limit): 78 | return torch.ops.neuron._parallel_profile_start_neuron(model, ntff_prefix, ntff_count_limit) 79 | 80 | 81 | def parallel_profile_stop(ntff_files): 82 | return torch.ops.neuron._parallel_profile_stop_neuron(ntff_files) -------------------------------------------------------------------------------- /src/transformers_neuronx/opt/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers.configuration_utils import PretrainedConfig 16 | from transformers_neuronx import utils 17 | from transformers_neuronx.gpt2.config import GPT2Config 18 | 19 | 20 | class OPTConfig: 21 | 22 | def __init__(self, config, n_positions, batch_size, amp, tp_degree, **kwargs): 23 | if not config.do_layer_norm_before: 24 | raise NotImplementedError('do_layer_norm_before=False not implemented') 25 | self.activation_function = config.activation_function 26 | self.eos_token_id = config.eos_token_id 27 | self.pad_token_id = config.pad_token_id 28 | self.ffn_dim = config.ffn_dim 29 | self.hidden_size = config.hidden_size 30 | self.max_position_embeddings = config.max_position_embeddings 31 | self.num_attention_heads = config.num_attention_heads 32 | self.num_hidden_layers = config.num_hidden_layers 33 | self.vocab_size = config.vocab_size 34 | self.word_embed_proj_dim = config.word_embed_proj_dim 35 | utils.maybe_override_attributes(self, kwargs) 36 | self.n_positions = n_positions 37 | self.batch_size = batch_size 38 | self.amp = amp 39 | self.tp_degree = tp_degree 40 | self.model_type = 'opt' 41 | 42 | 43 | def opt_config_to_gpt2_config(config): 44 | if config.ffn_dim != 4 * config.hidden_size: 45 | raise NotImplementedError(f'ffn_dim={config.ffn_dim} and hidden_size={config.hidden_size}') 46 | gpt2_config = PretrainedConfig() 47 | gpt2_config.activation_function = config.activation_function 48 | gpt2_config.n_ctx = config.max_position_embeddings 49 | gpt2_config.n_embd = config.hidden_size 50 | gpt2_config.n_head = config.num_attention_heads 51 | gpt2_config.n_layer = config.num_hidden_layers 52 | gpt2_config.n_positions = config.n_positions 53 | gpt2_config.vocab_size = config.vocab_size 54 | gpt2_config.eos_token_id = config.eos_token_id 55 | batch_size = config.batch_size 56 | amp = config.amp 57 | tp_degree = config.tp_degree 58 | return GPT2Config(gpt2_config, batch_size, amp, tp_degree) 59 | -------------------------------------------------------------------------------- /src/transformers_neuronx/opt/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx.gpt_demo import demo 16 | from transformers_neuronx.opt.model import OPTForSampling 17 | 18 | 19 | def main(): 20 | demo('facebook/opt-125m', OPTForSampling, amp_callback) 21 | 22 | 23 | def amp_callback(model, dtype): 24 | # cast attention and mlp to low precisions only; layernorms stay as f32 25 | for block in model.model.decoder.layers: 26 | block.self_attn.to(dtype) 27 | block.fc1.to(dtype) 28 | block.fc2.to(dtype) 29 | model.lm_head.to(dtype) 30 | -------------------------------------------------------------------------------- /src/transformers_neuronx/opt/gen_random_pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import argparse 16 | import json 17 | import os 18 | import torch 19 | from transformers.models.opt import OPTConfig 20 | from transformers_neuronx.module import sanitize_file_name, _KEY_TO_FILENAME_JSON 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('name', help="OPT model name or path to config.json") 26 | parser.add_argument('save', help="target folder to save the model") 27 | parser.add_argument('--empty', action='store_true') 28 | args = parser.parse_args() 29 | gen_random_pretrained(args.name, args.save, args.empty) 30 | 31 | 32 | def gen_random_pretrained(model_name, save, empty=False): 33 | if 'json' in model_name: 34 | config = json.load(open(model_name)) 35 | elif model_name == 'facebook/opt-175b': 36 | config = opt_175b_config() 37 | else: 38 | config = OPTConfig.from_pretrained(model_name).to_dict() 39 | os.makedirs(save, exist_ok=True) 40 | with open(os.path.join(save, 'config.json'), 'w') as fp: 41 | json.dump(config, fp, indent=2) 42 | vocab_size = config['vocab_size'] 43 | hidden_size = config['hidden_size'] 44 | max_position_embeddings = config['max_position_embeddings'] 45 | ffn_dim = config['ffn_dim'] 46 | num_hidden_layers = config['num_hidden_layers'] 47 | init_std = config['init_std'] 48 | torch_dtype = config['torch_dtype'] 49 | name2shape = { 50 | 'model.decoder.embed_tokens.weight': [vocab_size, hidden_size], 51 | 'model.decoder.embed_positions.weight': [max_position_embeddings + 2, hidden_size], 52 | 'model.decoder.final_layer_norm.weight': [hidden_size], 53 | 'model.decoder.final_layer_norm.bias': [hidden_size], 54 | } 55 | layer_name2shape = { 56 | 'self_attn.k_proj.weight': [hidden_size, hidden_size], 57 | 'self_attn.k_proj.bias': [hidden_size], 58 | 'self_attn.v_proj.weight': [hidden_size, hidden_size], 59 | 'self_attn.v_proj.bias': [hidden_size], 60 | 'self_attn.q_proj.weight': [hidden_size, hidden_size], 61 | 'self_attn.q_proj.bias': [hidden_size], 62 | 'self_attn.out_proj.weight': [hidden_size, hidden_size], 63 | 'self_attn.out_proj.bias': [hidden_size], 64 | 'self_attn_layer_norm.weight': [hidden_size], 65 | 'self_attn_layer_norm.bias': [hidden_size], 66 | 'fc1.weight': [ffn_dim, hidden_size], 67 | 'fc1.bias': [ffn_dim], 68 | 'fc2.weight': [hidden_size, ffn_dim], 69 | 'fc2.bias': [hidden_size], 70 | 'final_layer_norm.weight': [hidden_size], 71 | 'final_layer_norm.bias': [hidden_size], 72 | } 73 | for idx in range(num_hidden_layers): 74 | for name, shape in layer_name2shape.items(): 75 | name2shape[f'model.decoder.layers.{idx}.{name}'] = shape 76 | name2shape['lm_head.weight'] = [vocab_size, hidden_size] 77 | key_to_filename = {} 78 | for idx, key in enumerate(name2shape.keys()): 79 | key_to_filename[key] = f'p{idx}.{sanitize_file_name(key)}' 80 | if empty: 81 | key_to_filename[key] = f'{key_to_filename[key]}.empty_json' 82 | split_param_dir = os.path.join(save, 'pytorch_model.bin') 83 | os.makedirs(split_param_dir, exist_ok=True) 84 | with open(os.path.join(split_param_dir, _KEY_TO_FILENAME_JSON), 'w') as fp: 85 | json.dump(key_to_filename, fp, indent=2) 86 | dtype = getattr(torch, torch_dtype) 87 | for name, shape in name2shape.items(): 88 | save_path = os.path.join(split_param_dir, key_to_filename[name]) 89 | factor = 0.0 if 'layer_norm' in name or 'bias' in name else init_std 90 | if empty: 91 | empty_json = { 92 | 'torch_dtype': torch_dtype, 93 | 'shape': shape, 94 | 'init_std': factor, 95 | } 96 | with open(save_path, 'w') as fp: 97 | json.dump(empty_json, fp, indent=2) 98 | continue 99 | init_param = factor * torch.randn(shape) 100 | init_param = init_param.to(dtype) 101 | torch.save(init_param, save_path) 102 | print(f'done saving {save_path}') 103 | 104 | 105 | def opt_175b_config(): 106 | vocab_size = 50272 107 | hidden_size = 12288 108 | max_position_embeddings = 2048 109 | ffn_dim = 49152 110 | num_hidden_layers = 96 111 | init_std = 0.02 112 | config = dict( 113 | _name_or_path='facebook/opt-175b', 114 | _remove_final_layer_norm=False, 115 | activation_dropout=0.0, 116 | activation_function='relu', 117 | architectures=['OPTForCausalLM'], 118 | attention_dropout=0.0, 119 | bos_token_id=2, 120 | do_layer_norm_before=True, 121 | dropout=0.1, 122 | eos_token_id=2, 123 | ffn_dim=ffn_dim, 124 | hidden_size=hidden_size, 125 | init_std=init_std, 126 | layerdrop=0.0, 127 | max_position_embeddings=max_position_embeddings, 128 | model_type='opt', 129 | num_attention_heads=96, 130 | num_hidden_layers=num_hidden_layers, 131 | output_projection=True, 132 | pad_token_id=1, 133 | prefix='', 134 | torch_dtype='float16', 135 | transformers_version='4.23.1', 136 | use_cache=True, 137 | vocab_size=vocab_size, 138 | word_embed_proj_dim=hidden_size, 139 | ) 140 | return config 141 | -------------------------------------------------------------------------------- /src/transformers_neuronx/opt/hlo.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from transformers_neuronx import compiler 16 | from transformers_neuronx.gpt2 import hlo as gpt2_hlo 17 | from transformers_neuronx.opt.config import opt_config_to_gpt2_config 18 | 19 | 20 | def build_opt_multi_layer_hlo_module(config, n_active_tokens, n_positions, n_layers): 21 | config = opt_config_to_gpt2_config(config) 22 | multi_layer = gpt2_hlo.gen_scribable_multi_block(config, n_active_tokens, n_positions, n_layers) 23 | return compiler.compile_py_func(multi_layer) 24 | 25 | 26 | def build_ln_lm_head_hlo_module(config, n_active_tokens): 27 | config = opt_config_to_gpt2_config(config) 28 | ln_lm_head = gpt2_hlo.gen_scribable_ln_lm_head(config, n_active_tokens) 29 | return compiler.compile_py_func(ln_lm_head) 30 | 31 | 32 | def build_opt_hlo_module(config, n_active_tokens, n_positions, blocks_u8_bounds=None): 33 | config = opt_config_to_gpt2_config(config) 34 | gpt2 = gpt2_hlo.gen_scribable_gpt2(config, n_active_tokens, n_positions, blocks_u8_bounds) 35 | return compiler.compile_py_func(gpt2) 36 | -------------------------------------------------------------------------------- /src/transformers_neuronx/pad/layernorm_padded_cpu.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import torch 17 | from torch import nn 18 | 19 | def LayerNormCPU(torch_ln, hidden_dim_ratio): 20 | """ Creates a modified layernorm im pytorch 21 | 22 | Args: 23 | torch_ln (nn.Module): original layernorm 24 | hidden_dim_ratio (float): ratio of target hidden dim to source one for correction 25 | 26 | Returns: 27 | nn.Module: modified layer norm 28 | """ 29 | shape = list(torch_ln.normalized_shape) 30 | eps = torch_ln.eps 31 | ln = LayerNorm_(shape, eps=eps, hidden_dim_ratio=hidden_dim_ratio) 32 | ln.weight = torch_ln.weight 33 | ln.bias = torch_ln.bias 34 | return ln 35 | 36 | class LayerNorm_(nn.Module): 37 | def __init__(self, normalized_shape, *, 38 | eps=1e-5, 39 | elementwise_affine=True, 40 | hidden_dim_ratio=1.0): 41 | super().__init__() 42 | 43 | if isinstance(normalized_shape, int): 44 | normalized_shape = torch.Size([normalized_shape]) 45 | elif isinstance(normalized_shape, list): 46 | normalized_shape = torch.Size(normalized_shape) 47 | assert isinstance(normalized_shape, torch.Size) 48 | 49 | self.hidden_dim_ratio = hidden_dim_ratio 50 | self.normalized_shape = normalized_shape 51 | self.eps = eps 52 | self.elementwise_affine = elementwise_affine 53 | if self.elementwise_affine: 54 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 55 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 56 | 57 | def forward(self, x: torch.Tensor): 58 | assert self.normalized_shape == x.shape[-len(self.normalized_shape):] 59 | dims = [-(i + 1) for i in range(len(self.normalized_shape))] 60 | mean = x.mean(dim=dims, keepdim=True) 61 | mean_x2 = (x * x).mean(dim=dims, keepdim=True) 62 | var = mean_x2 - mean * mean 63 | mean *= self.hidden_dim_ratio 64 | var *= self.hidden_dim_ratio 65 | x_norm = (x - mean) / torch.sqrt(var + self.eps) 66 | if self.elementwise_affine: 67 | x_norm = self.weight * x_norm + self.bias 68 | return x_norm -------------------------------------------------------------------------------- /src/transformers_neuronx/parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from concurrent.futures import ThreadPoolExecutor 16 | import torch 17 | from transformers_neuronx import ops 18 | 19 | 20 | def parallel_load(models): 21 | degree = len(models) 22 | with ThreadPoolExecutor(degree) as executor: 23 | futures = [] 24 | for ordinal, model in enumerate(models): 25 | args = model, ordinal, 1, ordinal, degree 26 | fut = executor.submit(ops.load_collectives, *args) 27 | futures.append(fut) 28 | [future.result() for future in futures] # wait for load_collectives calls 29 | 30 | 31 | class TensorManipulator: 32 | 33 | def __init__(self, tp_degree): 34 | self.tp_degree = tp_degree 35 | 36 | def duplicate(self, tensor): 37 | tensors = [tensor for ordinal in range(self.tp_degree)] 38 | return to_nc(tensors) 39 | 40 | def shard_along(self, tensor, dim): 41 | size = tensor.shape[dim] 42 | shard_size = size // self.tp_degree 43 | slices = [slice(None) for _ in tensor.shape] 44 | tensors = [] 45 | for start in range(0, size, shard_size): 46 | slices[dim] = slice(start, start+shard_size, 1) 47 | shard = tensor[tuple(slices)].contiguous() 48 | tensors.append(shard) 49 | return to_nc(tensors) 50 | 51 | def primary_only(self, tensor): 52 | tensors = [tensor] 53 | tensors.extend(torch.zeros_like(tensor) for _ in range(1, self.tp_degree)) 54 | return to_nc(tensors) 55 | 56 | def unshard_along(self, sharded_tensors, dim): 57 | return torch.cat(cpu(sharded_tensors), dim=dim) 58 | 59 | def slice_on_nc(self, tensors, dim, start, end, step): 60 | return [ops.slice(ts, dim, start, end, step) for ts in tensors] 61 | 62 | 63 | def to_nc(sharded_tensors): 64 | return [ops.to_nc(ts, ordinal) for ordinal, ts in enumerate(sharded_tensors)] 65 | 66 | 67 | def cpu(sharded_tensors): 68 | return [ops.cpu(ts) for ts in sharded_tensors] 69 | 70 | 71 | class Executor: 72 | 73 | def __init__(self, tp_degree): 74 | self.executor = ThreadPoolExecutor(tp_degree) 75 | 76 | def execute(self, models, *inputs_cores): 77 | futures = [] 78 | for model, *inputs in zip(models, *inputs_cores): 79 | fut = self.executor.submit(ops.execute, model, inputs) 80 | futures.append(fut) 81 | cores_outputs = [fut.result() for fut in futures] 82 | outputs_cores = [list(outputs) for outputs in zip(*cores_outputs)] 83 | return outputs_cores 84 | 85 | 86 | class ParallelTensorManipulator: 87 | 88 | def __init__(self, tp_degree, rank_id=0, local_tp_degree=None): 89 | self.tp_degree = tp_degree 90 | self.rank_id = rank_id 91 | if local_tp_degree is None: 92 | local_tp_degree = tp_degree 93 | self.local_tp_degree = local_tp_degree 94 | 95 | def duplicate_on_cpu(self, tensor): 96 | return [tensor for ordinal in range(self.local_tp_degree)] 97 | 98 | def duplicate(self, tensor): 99 | return ops.parallel_to_nc([tensor.contiguous() for ordinal in range(self.local_tp_degree)]) 100 | 101 | def shard_along_on_cpu(self, tensor, dim): 102 | size = tensor.shape[dim] 103 | shard_size = size // self.tp_degree 104 | slices = [slice(None) for _ in tensor.shape] 105 | tensors = [] 106 | if self.local_tp_degree != self.tp_degree: # for multi-instance tp 107 | slice_start = self.rank_id * self.local_tp_degree * shard_size 108 | slice_end = (self.rank_id + 1) * self.local_tp_degree * shard_size 109 | for start in range(slice_start, slice_end, shard_size): 110 | slices[dim] = slice(start, start+shard_size, 1) 111 | shard = tensor[tuple(slices)].contiguous() 112 | tensors.append(shard) 113 | else: 114 | slice_start = 0 115 | slice_end = size 116 | slice_range = range(slice_start, slice_end, shard_size) 117 | for start in slice_range: 118 | slices[dim] = slice(start, start+shard_size, 1) 119 | shard = tensor[tuple(slices)].contiguous() 120 | if len(slice_range) == 1: 121 | # edge case for save_presharded flow where something is "sharded" 122 | # but in reality is a no-op causing some tensors to share memory 123 | # safetensors cannot share memory so we make a copy 124 | shard = shard.clone() 125 | tensors.append(shard) 126 | if len(tensors) != self.local_tp_degree: 127 | raise ValueError( 128 | f'Weight with shape {tensor.shape} cannot be sharded along dimension {dim}. ' 129 | f'This results in {len(tensors)} weight partitions which cannot be distributed to {self.local_tp_degree} NeuronCores evenly. ' 130 | f'To fix this issue either the model parameters or the `tp_degree` must be changed to allow the weight to be evenly split' 131 | ) 132 | return tensors 133 | 134 | def shard_along(self, tensor, dim): 135 | return ops.parallel_to_nc(self.shard_along_on_cpu(tensor, dim)) 136 | 137 | def duplicate_or_shard_along(self, tensor, dim): 138 | if dim is None: 139 | return self.duplicate(tensor) 140 | return self.shard_along(tensor, dim) 141 | 142 | def primary_only(self, tensor): 143 | tensors = [tensor] 144 | tensors.extend(torch.zeros_like(tensor) for _ in range(1, self.local_tp_degree)) 145 | return ops.parallel_to_nc(tensors) 146 | 147 | def unshard_along(self, sharded_tensors, dim): 148 | return torch.cat(ops.parallel_cpu(sharded_tensors), dim=dim) 149 | 150 | def slice_on_nc(self, tensors, dim, start, end, step): 151 | return ops.parallel_slice(tensors, dim, start, end, step) 152 | 153 | class CPUTensorManipulator(ParallelTensorManipulator): 154 | 155 | def duplicate(self, tensor): 156 | return [tensor.contiguous() for _ in range(self.local_tp_degree)] 157 | 158 | def shard_along(self, tensor, dim): 159 | return self.shard_along_on_cpu(tensor, dim) 160 | 161 | def primary_only(self, tensor): 162 | tensors = [tensor] 163 | tensors.extend(torch.zeros_like(tensor) for _ in range(1, self.local_tp_degree)) 164 | return tensors 165 | 166 | def slice_on_nc(self, tensor, dim, start, end, step): 167 | index = [slice(None)] * tensor.dim() 168 | index[dim] = slice(start, end, step) 169 | return tensor[index] 170 | 171 | def layers_to_neuron(num_workers, layers, n_positions_list, to_neuron_hooks): 172 | with ThreadPoolExecutor(num_workers) as pool: 173 | futures = [pool.submit(layer.to_neuron, n_positions_list) for layer in layers] 174 | for idx, future in enumerate(futures): 175 | future.result() 176 | for hook in to_neuron_hooks: 177 | hook(idx) 178 | 179 | 180 | class CacheBroadcaster: 181 | 182 | def __init__(self, tp_degree, shard_dim, batch_dim, batch_size): 183 | self.manipulator = ParallelTensorManipulator(tp_degree) 184 | self.shard_dim = shard_dim 185 | self.batch_dim = batch_dim 186 | self.batch_size = batch_size 187 | 188 | def broadcast(self, source, target): 189 | source_batch_size = source.shape[self.batch_dim] 190 | source = self.manipulator.unshard_along(source, dim=self.shard_dim) 191 | repeats = [1 for _ in source.shape] 192 | repeats[self.batch_dim] = self.batch_size // source_batch_size 193 | source = source.repeat(repeats) 194 | source = self.manipulator.shard_along_on_cpu(source, dim=self.shard_dim) 195 | ops.parallel_write(target, source) -------------------------------------------------------------------------------- /src/transformers_neuronx/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import warnings 16 | from typing import Optional, List 17 | import torch 18 | from transformers_neuronx.config import QuantizationConfig 19 | 20 | from transformers_neuronx.constants import FpBounds 21 | 22 | 23 | def maybe_quantize_weights( 24 | tensor: torch.Tensor, 25 | quantize_config: QuantizationConfig, 26 | out_feature_dim: Optional[int] = 1, 27 | contract_dims: Optional[List] = None, 28 | is_unit_scale=False, 29 | ): 30 | """ 31 | Quantize tensors using the dtype and method specified in quantize_config. 32 | 33 | Arguments: 34 | tensor: The PyTorch tensor that will be quantized. 35 | quantize_config: Config that specifies the quantization dtype and method. 36 | out_feature_dim: Output feature dimension for the matrix when it's multiplied. 37 | contract_dims: Contraction dimension(s) for the tensor when it's multiplied. 38 | 39 | Returns: 40 | quantized_weights: The quantized tensor. 41 | scales: Scales to "rescale" the quantized tensor after it's multiplied, 42 | where W_f = W_q * scales. 43 | """ 44 | 45 | if tensor is None: 46 | return None, None 47 | 48 | tensor_rank = len(list(tensor.shape)) 49 | 50 | if contract_dims is None: 51 | assert ( 52 | tensor_rank == 2 53 | ), f"Contract dimensions must be specified for {tensor_rank}-dimensional tensors." 54 | # Preserve original quantization API behavior 55 | contract_dims = [0] if out_feature_dim == 1 else [1] 56 | else: 57 | assert tensor_rank - len(contract_dims) == 1, ( 58 | "Quantization is only supported when the number of contract " 59 | "dimensions is 1 less than the number of tensor dimensions. Number " 60 | f"of tensor dimensions: {tensor_rank}. Number of provided contract " 61 | f"dimensions: {len(contract_dims)}." 62 | ) 63 | 64 | assert out_feature_dim not in contract_dims, ( 65 | f"out_feature_dim ({out_feature_dim}) should not be included in " 66 | f"contract_dims ({contract_dims})" 67 | ) 68 | 69 | tensor = tensor.to(torch.float32) 70 | 71 | if quantize_config.quantize_method == "vector_dynamic": 72 | if quantize_config.quant_dtype == "s8": 73 | # Use W_f = W_q * 74 | int8_min = torch.iinfo(torch.int8).min 75 | int8_max = torch.iinfo(torch.int8).max 76 | max_values = torch.amax(torch.abs(tensor), dim=contract_dims, keepdim=True) 77 | scales = max_values / int8_max 78 | if scales.count_nonzero() == 0: 79 | # If the scales is all zeros, the weight tensor are zeros 80 | quantized_weights = tensor.to(torch.int8) 81 | else: 82 | quantized_weights = tensor / scales 83 | quantized_weights = torch.round(quantized_weights) 84 | quantized_weights = torch.clamp(quantized_weights, int8_min, int8_max) 85 | quantized_weights = quantized_weights.to(torch.int8) 86 | scales = scales.flatten() 87 | elif quantize_config.quant_dtype == "f8e4m3fn": 88 | if is_unit_scale: 89 | scales = torch.ones(tensor.shape[out_feature_dim], dtype=tensor.dtype) 90 | quantized_weights = tensor.to(torch.float8_e4m3fn) 91 | else: 92 | fp8_max = float(FpBounds.max) 93 | max_values = torch.amax(torch.abs(tensor), dim=contract_dims, keepdim=True) 94 | scales = max_values / fp8_max 95 | if scales.count_nonzero() == 0: 96 | # If the scales is all zeros, the weight tensor are zeros 97 | quantized_weights = tensor.to(torch.float8_e4m3fn) 98 | else: 99 | quantized_weights = tensor / scales 100 | quantized_weights = quantized_weights.to(torch.float8_e4m3fn) 101 | scales = scales.flatten() 102 | else: 103 | raise NotImplementedError( 104 | f"{quantize_config.quant_dtype} for {quantize_config.quantize_method} is not yet implemented." 105 | ) 106 | elif quantize_config.quantize_method == 'direct_cast': 107 | warnings.warn( 108 | f"direct casting is enabled with dtype {quantize_config.quant_dtype}, make sure the values are within " 109 | f"representable range of the datatype") 110 | return tensor.to(quantize_config.quant_dtype), None 111 | else: 112 | raise NotImplementedError( 113 | f"{quantize_config.quantize_method} is not yet implemented." 114 | ) 115 | return quantized_weights, scales 116 | -------------------------------------------------------------------------------- /src/transformers_neuronx/sparse_attn_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from dataclasses import dataclass 5 | 6 | def create_blk_mask(blks_q, blks_kv, num_global_blks=0, num_local_blks=1, num_random_blks=0, causal=False): 7 | """ 8 | Create a block mask given the configs, and the number of blocks in each dimension. 9 | Assume all heads use the same mask. 10 | """ 11 | blk_mask = torch.zeros((blks_q, blks_kv), dtype=torch.bool) 12 | # Add global blocks 13 | if num_global_blks > 0: 14 | blk_mask[:num_global_blks, :] = 1 15 | blk_mask[:, :num_global_blks] = 1 16 | # Add local blocks 17 | if num_local_blks > 0: 18 | width = num_local_blks // 2 19 | for row in range(blks_q): 20 | start = max(0, row - width) 21 | end = min(row + width + 1, blks_kv) 22 | blk_mask[row, start:end] = 1 23 | # Add random blocks 24 | if num_random_blks > 0: 25 | assert blks_kv > num_random_blks, "Number of random blocks must be smaller than total number of col blocks!" 26 | for row in range(blks_q): 27 | # If causal, only draw blocks from the lower-triangular part of the matrix 28 | pool = list(range(0, blks_kv)) if not causal else list(range(0, row+1)) 29 | selected = pool if len(pool) <= num_random_blks else random.sample(pool, num_random_blks) 30 | for col in selected: 31 | blk_mask[row, col] = 1 32 | 33 | if causal: 34 | blk_mask = torch.tril(blk_mask) 35 | 36 | return blk_mask 37 | 38 | def build_dense_mask(q_seq_len, k_seq_len, mask, blk_size=128, causal=False): 39 | row_blks = (q_seq_len + blk_size - 1) // blk_size 40 | col_blks = (k_seq_len + blk_size - 1) // blk_size 41 | assert tuple(mask.shape) == (row_blks, col_blks), f'Mask must have shape (q_seq_len // blk_size, k_seq_len // blk_size)' 42 | dense_mask = torch.zeros((q_seq_len, k_seq_len), dtype=torch.bool) 43 | 44 | for row_id in range(row_blks): 45 | for col_id in range(col_blks): 46 | if int(mask[row_id, col_id]) == 1: 47 | last_row = min(q_seq_len, (row_id+1)*blk_size) 48 | last_col = min(k_seq_len, (col_id+1)*blk_size) 49 | dense_mask[row_id*blk_size:last_row, col_id*blk_size:last_col] = 1 50 | if causal: 51 | dense_mask = torch.tril(dense_mask) 52 | return dense_mask 53 | 54 | 55 | def build_sliding_window_mask(q_seq_len, kv_seq_len, window_size, causal=True): 56 | """ 57 | Build a sliding window 2D mask shape {n_active_tokens, n_positions} 58 | This funciton is used in context encoding with sliding window attention, aka, mistral-7b-v0.1 59 | """ 60 | dense_mask = torch.ones((q_seq_len, kv_seq_len), dtype=torch.bool) 61 | 62 | # In causal mode we only attend to tokens on the left 63 | window_size_l, window_size_r = (window_size, 0) if causal \ 64 | else (window_size // 2, window_size // 2) 65 | 66 | dense_mask = torch.tril(dense_mask, diagonal=window_size_r) 67 | dense_mask = torch.triu(dense_mask, diagonal=-window_size_l) 68 | 69 | return dense_mask.detach() 70 | 71 | 72 | @dataclass 73 | class BlkSparseAttnConfig: 74 | """ Block-sparse attention specific config """ 75 | blk_size: int = 128 76 | num_global_blks: int = 0 77 | num_local_blks: int = 1 78 | num_random_blks: int = 0 79 | 80 | @dataclass 81 | class SlidingWindowAttnConfig: 82 | """ Sliding-window attention specific config """ 83 | window_size: int = 128 84 | 85 | 86 | class SparseAttnConfig: 87 | """ 88 | The config class that contains sparse attention related settings 89 | - attn_type: string, must be in the list defined by ATTN_TYPE_LIST 90 | - sparse_attn_config: must be either BlkSparseAttnConfig, SlidingWindowAttnConfig, 91 | or None, the config type must match the attention type specified by attn_type 92 | - same_mask_per_layer: bool, set to False if user want to use a different mask 93 | in each layer of the model 94 | - sparse_mask_dict: dict, override the default when attn_type == 'custom'. If 95 | same_mask_per_layer == True, the dict type is (q_seq_len, kv_seq_len) --> 96 | torch.Tensor. Otherwise, the key type is (layer_id, q_seq_len, kv_seq_len). 97 | Values must be boolean tensors with shape (q_seq_len, kv_seq_len). 98 | """ 99 | 100 | def __init__(self, attn_type='blk_sparse', causal=False, 101 | sparse_attn_config=None, 102 | # This flag controls whether we use the same mask in every layer 103 | same_mask_per_layer=True, 104 | # User can directly provide the masks if needed 105 | sparse_mask_dict=dict()): 106 | ATTN_TYPE_LIST = ['blk_sparse', 'window', 'custom'] 107 | assert attn_type in ATTN_TYPE_LIST, f'Supported attention types are: {ATTN_TYPE_LIST}' 108 | 109 | if attn_type == 'blk_sparse': 110 | assert sparse_attn_config and isinstance(sparse_attn_config, BlkSparseAttnConfig), \ 111 | "Must provide a valid block sparse attention config!" 112 | elif attn_type == 'window': 113 | assert sparse_attn_config and isinstance(sparse_attn_config, SlidingWindowAttnConfig), \ 114 | "Must provide a valid sliding window attention config!" 115 | 116 | self.sparse_mask_dict = sparse_mask_dict 117 | self.sparse_attn_config = sparse_attn_config 118 | self.attn_type = attn_type 119 | self.causal = causal 120 | self.same_mask_per_layer = same_mask_per_layer 121 | self.skip_masking_decode = (self.attn_type == 'window') 122 | 123 | def create_blk_sparse_mask(self, q_seq_len, kv_seq_len): 124 | blks_q = math.ceil(q_seq_len / self.sparse_attn_config.blk_size) 125 | blks_kv = math.ceil(kv_seq_len / self.sparse_attn_config.blk_size) 126 | blk_mask = create_blk_mask( 127 | blks_q, blks_kv, 128 | self.sparse_attn_config.num_global_blks, 129 | self.sparse_attn_config.num_local_blks, 130 | self.sparse_attn_config.num_random_blks, 131 | self.causal and (q_seq_len != 1) 132 | ) 133 | dense_mask = build_dense_mask( 134 | q_seq_len, kv_seq_len, 135 | blk_mask, self.sparse_attn_config.blk_size, 136 | self.causal and (q_seq_len != 1) 137 | ) 138 | return dense_mask.detach() 139 | 140 | def create_sliding_window_mask(self, q_seq_len, kv_seq_len): 141 | dense_mask = torch.zeros((q_seq_len, kv_seq_len), dtype=torch.bool) 142 | window_size_l, window_size_r = (self.sparse_attn_config.window_size, 0) if self.causal \ 143 | else (self.sparse_attn_config.window_size // 2, self.sparse_attn_config.window_size // 2) 144 | 145 | for row in range(q_seq_len): 146 | start = max(0, row-window_size_l) 147 | end = row+1 if self.causal else min(kv_seq_len, row+window_size_r) 148 | dense_mask[row, start:end] = 1 149 | 150 | return dense_mask.detach() 151 | 152 | def create_sparse_mask(self, q_seq_len, kv_seq_len, layer_id=0): 153 | """ Create a mask that defines how the new tokens attend to the old tokens """ 154 | assert ((q_seq_len == 1) or (q_seq_len == kv_seq_len)), \ 155 | "Only supporting decode mode (q_seq_len=1) or self-attention mode (q_seq_len=k_seq_len)!" 156 | key = (q_seq_len, kv_seq_len) if self.same_mask_per_layer else (layer_id, q_seq_len, kv_seq_len) 157 | if key in self.sparse_mask_dict: 158 | return self.sparse_mask_dict[key] 159 | 160 | # Don't generate mask if q_seq_len = 1 (decode) and user is using window attention 161 | skip_masking = (q_seq_len == 1) and self.skip_masking_decode 162 | if skip_masking: 163 | return None 164 | 165 | if self.attn_type == 'blk_sparse': 166 | mask = self.create_blk_sparse_mask(q_seq_len, kv_seq_len) 167 | elif self.attn_type == 'window': 168 | mask = self.create_sliding_window_mask(q_seq_len, kv_seq_len) 169 | self.sparse_mask_dict[key] = mask 170 | return mask 171 | -------------------------------------------------------------------------------- /src/transformers_neuronx/stopping_criteria.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import torch 3 | import time 4 | from typing import Optional 5 | 6 | 7 | class StoppingCriteria(ABC): 8 | """Abstract base class for all stopping criteria that can be applied during generation.""" 9 | 10 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 11 | 12 | raise NotImplementedError("StoppingCriteria needs to be subclassed") 13 | 14 | 15 | class MaxTimeCriteria(StoppingCriteria): 16 | """ 17 | This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the 18 | time will start being counted when you initialize this function. You can override this by passing an 19 | `initial_time`. 20 | 21 | Args: 22 | max_time (`float`): 23 | The maximum allowed time in seconds for the generation. 24 | initial_time (`float`, *optional*, defaults to `time.time()`): 25 | The start of the generation allowed time. 26 | """ 27 | 28 | def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): 29 | self.max_time = max_time 30 | self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp 31 | 32 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 33 | return time.time() - self.initial_timestamp > self.max_time 34 | 35 | class StoppingCriteriaList(list): 36 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 37 | return any(criteria(input_ids, scores) for criteria in self) 38 | 39 | -------------------------------------------------------------------------------- /src/transformers_neuronx/tensor_pool.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import torch 16 | from typing import List 17 | import multiprocessing.pool 18 | 19 | class TensorPool: 20 | """ A pool that helps manage the liveness of the torch tensors""" 21 | 22 | def __init__(self): 23 | self.tensor_pool: List[torch.Tensor] = list() 24 | self.thread_pool = multiprocessing.pool.ThreadPool(processes=1) 25 | 26 | def push(self, tensors): 27 | if isinstance(tensors, torch.Tensor): 28 | self.tensor_pool.append(tensors) 29 | elif isinstance(tensors, (List, tuple)): 30 | for t in tensors: 31 | self.tensor_pool.append(t) 32 | else: 33 | raise TypeError(f"Unsupported type {type(tensors)} to TensorPool") 34 | 35 | def clear(self): 36 | self.tensor_pool.clear() 37 | 38 | def async_clear(self): 39 | task = self.thread_pool.apply_async(self.clear) 40 | return task -------------------------------------------------------------------------------- /src/transformers_neuronx/testing/data.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | import torch 16 | 17 | 18 | def batch_varying_lengths(vocab_size, batch_size, context_length): 19 | """ 20 | Creates a batch with varying length context attention masks 21 | 22 | This is useful for testing that mixed/unordered context lengths produce the 23 | correct results. 24 | 25 | Example: 26 | attention_mask = [ 27 | [0, 0, 1], 28 | [0, 1, 1], 29 | [0, 0, 1], 30 | [1, 1, 1], 31 | ] 32 | """ 33 | inputs = torch.randint(0, vocab_size, [batch_size, context_length]) 34 | ones = torch.ones(batch_size, context_length) 35 | starts = torch.randint(0, context_length, [batch_size]) 36 | mask = (starts.view(-1, 1) < ones).int() # Ensure at least 1 item is set in mask 37 | return inputs, mask, starts 38 | 39 | 40 | def batch_all_lengths(vocab_size, context_length): 41 | """ 42 | Creates a batch with 1 of each context attention mask sizes 43 | 44 | This is useful for exhaustively checking if each context length is properly 45 | respected in the model. 46 | 47 | Example: 48 | attention_mask = [ 49 | [1, 1, 1], 50 | [0, 1, 1], 51 | [0, 0, 1], 52 | ] 53 | """ 54 | inputs = torch.randint(0, vocab_size, [context_length, context_length]) 55 | mask = torch.ones(context_length, context_length).triu().int() 56 | starts = torch.arange(context_length) 57 | inputs = inputs * mask 58 | return inputs, mask, starts 59 | 60 | 61 | def batch_full_lengths(vocab_size, batch_size, context_length): 62 | """ 63 | Creates a batch with identical full-length context attention masks 64 | 65 | This is only useful as a sanity check to see that batching works. 66 | 67 | Example: 68 | attention_mask = [ 69 | [1, 1, 1], 70 | [1, 1, 1], 71 | [1, 1, 1], 72 | ] 73 | """ 74 | inputs = torch.randint(0, vocab_size, [batch_size, context_length]) 75 | mask = torch.ones(batch_size, context_length) 76 | starts = torch.zeros(batch_size) 77 | return inputs, mask, starts 78 | -------------------------------------------------------------------------------- /src/transformers_neuronx/testing/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | import os 16 | from typing import Union, Tuple, List, Optional 17 | 18 | import torch 19 | from transformers import PretrainedConfig, PreTrainedModel 20 | from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter 21 | 22 | 23 | def validate( 24 | model: Union[torch.nn.Module, PreTrainedModel], 25 | neuron: torch.nn.Module, 26 | config: Union[dict, PretrainedConfig], 27 | inputs: Union[List, Tuple, torch.Tensor], 28 | sequence_length: int, 29 | compiler_args: Optional[Union[List, str]] = None, 30 | rtol: float = 1e-5, 31 | atol: float = 1e-5, 32 | ) -> None: 33 | """ 34 | Compare the outputs of a Neuron model to a CPU model. 35 | 36 | Args: 37 | model: The CPU model golden reference. 38 | neuron: The Neuron model to be validated. 39 | config: The config for the model to be validated. 40 | inputs: The inputs to the model. 41 | sequence_length: The sequence length of the input + output tokens. 42 | compiler_args: The compiler arguments to be passed to the Neuron model. 43 | rtol: Relative tolerance when checking result equality. 44 | rtol: Absolute tolerance when checking result equality. 45 | 46 | Raises: 47 | AssertionError: Error when the values differ. 48 | """ 49 | if compiler_args is None: 50 | os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer --auto-cast=none" 51 | 52 | neuron.to_neuron() 53 | 54 | print("Configuration:") 55 | print(config) 56 | 57 | input_ids, attention_mask, start_ids = inputs 58 | batch_size, context_length = input_ids.shape 59 | new_tokens = sequence_length - context_length 60 | 61 | print("Input Ids:") 62 | print(input_ids) 63 | print("Attention Mask:") 64 | print(attention_mask) 65 | print("Start Ids:") 66 | print(start_ids) 67 | 68 | expected = model.generate( 69 | input_ids=input_ids, 70 | attention_mask=attention_mask, 71 | max_new_tokens=new_tokens, 72 | min_new_tokens=new_tokens, 73 | top_k=1, 74 | ) 75 | 76 | wrapper = HuggingFaceGenerationModelAdapter(config, neuron) 77 | actual_generate = wrapper.generate( 78 | input_ids=input_ids, 79 | attention_mask=attention_mask, 80 | max_new_tokens=new_tokens, 81 | min_new_tokens=new_tokens, 82 | top_k=1, 83 | ) 84 | neuron.reset() 85 | actual_sample = neuron.sample( 86 | input_ids, sequence_length=sequence_length, start_ids=start_ids, top_k=1 87 | ) 88 | 89 | print("CPU Result:") 90 | print(expected) 91 | print("Neuron Result (Generate)") 92 | print(actual_generate) 93 | print("Neuron Result (Sample)") 94 | print(actual_sample) 95 | 96 | torch.testing.assert_close( 97 | actual=actual_generate, expected=expected, rtol=rtol, atol=atol 98 | ) 99 | torch.testing.assert_close( 100 | actual=actual_sample, expected=expected, rtol=rtol, atol=atol 101 | ) 102 | -------------------------------------------------------------------------------- /src/transformers_neuronx/tools/ckpt_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import shutil 18 | import json 19 | import torch 20 | import math 21 | import argparse 22 | import time 23 | 24 | 25 | HELP_DOC_STR = """ 26 | This is a ckpt converter to generate random weight ckpt for any N layer based on existed single layer ckpt. 27 | 28 | It utilizes the `empty` save and loading by only preserve the shape info of weights. 29 | For 400B Llama (around 100 layers) 30 | - Disk occupation: 13MB 31 | - convert time: 60s 32 | 33 | Usage of ckpt_converter: 34 | 35 | step1: generate a single layer ckpt and save it 36 | 37 | model = AutoModelForCausalLM.from_config(single_layer_config) 38 | save_pretrained_split(model, single_layer_directory) 39 | 40 | step2: use the converter to convert ckpt to arbiratry layer 41 | 42 | converter single_layer_directory n_layer_directory n 43 | 44 | Note: only validated for Llama architecture 45 | """ 46 | 47 | def convert(input_dir, output_dir, num_layers, init_std=1.0): 48 | os.makedirs(output_dir, exist_ok=True) 49 | 50 | with open(os.path.join(input_dir, "config.json")) as f: 51 | config_json = json.load(f) 52 | 53 | config_json["num_hidden_layers"] = num_layers 54 | 55 | with open(os.path.join(output_dir, "config.json"), 'w') as f: 56 | json.dump(config_json, f, indent=2) 57 | 58 | shutil.copy(os.path.join(input_dir, "generation_config.json"), os.path.join(output_dir, "generation_config.json")) 59 | 60 | 61 | src_bin_dir = os.path.join(input_dir, "pytorch_model.bin") 62 | 63 | dst_bin_dir = os.path.join(output_dir, "pytorch_model.bin") 64 | 65 | os.makedirs(dst_bin_dir, exist_ok=True) 66 | 67 | with open(os.path.join(src_bin_dir, "key_to_filename.json")) as f: 68 | src_key_to_filename_json = json.load(f) 69 | 70 | 71 | non_layer_weights = {} 72 | per_layer_weights = {} 73 | 74 | all_shape_infos = {} 75 | 76 | for k, v in src_key_to_filename_json.items(): 77 | src_file = os.path.join(src_bin_dir, v) 78 | 79 | w = torch.load(src_file) 80 | 81 | all_shape_infos[k] = w.shape 82 | 83 | if k.startswith("model.layers"): 84 | per_layer_weights[k] = v 85 | else: 86 | non_layer_weights[k] = v 87 | 88 | print(non_layer_weights, per_layer_weights) 89 | 90 | param_counter = 0 91 | 92 | dst_key_to_filename_json = {} 93 | 94 | 95 | total_param_bytes = 0 96 | 97 | def add_param(key, src_key): 98 | nonlocal param_counter 99 | nonlocal total_param_bytes 100 | 101 | dst_file = f"p{param_counter}.{key}.empty_json" 102 | dst_key_to_filename_json[key] = dst_file 103 | 104 | 105 | tensor_shape = all_shape_infos[src_key] 106 | empty_json = { 107 | 'torch_dtype': "float32", 108 | 'shape': tensor_shape, 109 | 'init_std': init_std, 110 | } 111 | 112 | with open(os.path.join(dst_bin_dir, dst_file), 'w') as f: 113 | json.dump(empty_json, f, indent=2) 114 | 115 | param_counter += 1 116 | 117 | total_param_bytes += math.prod(tensor_shape) 118 | 119 | add_param("model.embed_tokens.weight", "model.embed_tokens.weight") 120 | 121 | 122 | for i in range(num_layers): 123 | for w_k, w_f in per_layer_weights.items(): 124 | suffix = w_k.replace("model.layers.0.", "") 125 | k = f"model.layers.{i}.{suffix}" 126 | add_param(k, w_k) 127 | 128 | 129 | add_param("model.norm.weight", "model.norm.weight") 130 | add_param("lm_head.weight", "lm_head.weight") 131 | 132 | with open(os.path.join(dst_bin_dir, "key_to_filename.json"), 'w') as f: 133 | json.dump(dst_key_to_filename_json, f, indent=2) 134 | 135 | 136 | print(f"total params: {total_param_bytes} ({total_param_bytes / (10**9)}B), {total_param_bytes*2 / (10**9)} GB (for fp16)", ) 137 | 138 | 139 | def ckpt_converter(): 140 | 141 | parser = argparse.ArgumentParser(description=HELP_DOC_STR, 142 | formatter_class= argparse.RawTextHelpFormatter) 143 | parser.add_argument('input_dir', type=str) 144 | parser.add_argument('output_dir', type=str) 145 | parser.add_argument('num_layers', type=int) 146 | parser.add_argument('--init_std', type=float, default=1.0) 147 | 148 | 149 | args = parser.parse_args() 150 | 151 | start = time.time() 152 | convert(args.input_dir, args.output_dir, args.num_layers, args.init_std) 153 | 154 | print(f"convert done after {time.time() - start}s") 155 | 156 | 157 | if __name__ == "__main__": 158 | ckpt_converter() 159 | -------------------------------------------------------------------------------- /src/transformers_neuronx/tools/gen_hlo_snapshot.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import argparse 16 | import torch 17 | from torch_neuronx.pyhlo import hlo_pb2, xla_data_pb2 18 | from transformers_neuronx import compiler 19 | 20 | 21 | def main_randn(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('module') 24 | parser.add_argument('snapshot') 25 | parser.add_argument('--std', type=float, default=0.01) 26 | parser.add_argument('--int', default='zeros') 27 | parser.add_argument('--treat_as_int', nargs='*', type=int, default=None) 28 | args = parser.parse_args() 29 | hlo_module = hlo_pb2.HloModuleProto() 30 | with open(args.module, 'rb') as fp: 31 | hlo_module.ParseFromString(fp.read()) 32 | torch.manual_seed(15213) 33 | int_func = getattr(torch, args.int) 34 | randn_inputs = compiler.gen_randn_inputs(hlo_module, args.std, int_func, args.treat_as_int) 35 | hlo_snapshot = gen_hlo_snapshot(hlo_module, randn_inputs) 36 | with open(args.snapshot, 'wb') as fp: 37 | fp.write(hlo_snapshot.SerializeToString()) 38 | 39 | 40 | def gen_hlo_snapshot(hlo_module, inputs): 41 | hlo_snapshot = hlo_pb2.HloSnapshot() 42 | hlo_snapshot.hlo.hlo_module.CopyFrom(hlo_module) 43 | for tensor, param in zip(inputs, hlo_module.host_program_shape.parameters): 44 | argument = hlo_snapshot.arguments.add() 45 | argument.shape.CopyFrom(param) 46 | name = xla_data_pb2.PrimitiveType.Name(param.element_type) 47 | attrname = f'{name.lower()}s' 48 | attr = getattr(argument, attrname) 49 | array = tensor.numpy().ravel() 50 | if isinstance(attr, bytes): 51 | setattr(argument, attrname, array.tobytes()) 52 | else: 53 | attr.extend(array.tolist()) 54 | return hlo_snapshot 55 | -------------------------------------------------------------------------------- /src/transformers_neuronx/util/token_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from queue import Queue 3 | from typing import Dict, List, Set 4 | 5 | ROOT_NODE = 0 6 | 7 | def _validate_level_dict(level_dict: Dict[int, List[int]])-> (int, int): 8 | """ 9 | For a level, validate that all nodes are indexed from 10 | [all_nodes_till_precious_level, all_nodes_till_precious_level+nodes_in_curr_level) 11 | """ 12 | tree_depth = len(level_dict) 13 | nodes_counter = 0 14 | for lvl in range(0, tree_depth): 15 | nodes_in_level = len(level_dict[lvl]) 16 | # Validate nodes from [nodes_counter,nodes_counter + nodes_in_level) 17 | expected_nodes = {i for i in range(nodes_counter, nodes_counter + nodes_in_level)} 18 | for node in level_dict[lvl]: 19 | assert node in expected_nodes, f"Node {node} not indexed correctly in a leveled order for level {lvl}" 20 | expected_nodes.remove(node) 21 | nodes_counter = nodes_counter + nodes_in_level 22 | return nodes_counter, tree_depth 23 | 24 | 25 | def _validate_all_nodes_discovered(visited: Set[int], token_tree: Dict[int, List[int]])-> None: 26 | """ 27 | Checks if the set of nodes discovered while using level order traversal from root node 0 is 28 | the complete set of nodes defined in the input token tree. 29 | """ 30 | for k in token_tree.keys(): 31 | assert k in visited, f"Invalid token tree with node {k}." 32 | 33 | 34 | def validate_token_tree(token_tree: Dict[int, List[int]])-> (int, int): 35 | """ 36 | Assume index 0 to be the root node (no incoming edges) and start level order traversal from it. 37 | Validate tree structure (not graph) while doing level order traversal 38 | Validate all nodes are discovered at the end of the level order traversal. 39 | Also validate with level order traversal that all nodes are in required order. 40 | """ 41 | assert ROOT_NODE in token_tree, "Token tree does not have the root node indexed as 0" 42 | visited = set() 43 | level_dict = {} 44 | q = Queue(maxsize = 0) 45 | visited.add(ROOT_NODE) 46 | q.put((ROOT_NODE,0)) 47 | level_dict[0] = [ROOT_NODE] 48 | while not q.empty(): 49 | curr, lvl = q.get() 50 | if curr in token_tree: 51 | for child in token_tree[curr]: 52 | assert child not in visited, "Cycle/Graph found instead of a tree." 53 | visited.add(child) 54 | q.put((child, lvl+1)) 55 | if lvl+1 not in level_dict: 56 | level_dict[lvl+1] = [] 57 | level_dict[lvl+1].append(child) 58 | _validate_all_nodes_discovered(visited, token_tree) 59 | return _validate_level_dict(level_dict) 60 | 61 | 62 | def generate_attention_mask(token_tree: Dict[int, List[int]])->torch.Tensor: 63 | """ 64 | Generate attention mask based on the token tree. 65 | """ 66 | total_nodes, depth = validate_token_tree(token_tree) 67 | attn_mask = torch.zeros(total_nodes, total_nodes, dtype=torch.int32) 68 | buffer = [] 69 | def populate_mask(): 70 | top = buffer[-1] 71 | for node in buffer: 72 | attn_mask[top][node] = 1 73 | # DFS on a tree. 74 | def DFS(node: int): 75 | buffer.append(node) 76 | if node in token_tree: 77 | for child in token_tree[node]: 78 | DFS(child) 79 | populate_mask() 80 | buffer.pop() 81 | DFS(ROOT_NODE) 82 | return attn_mask 83 | -------------------------------------------------------------------------------- /src/transformers_neuronx/version.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # please make sure the major.minor version matches the interface version in Config 17 | __version__ = '0.13.x' 18 | --------------------------------------------------------------------------------