├── .github ├── CODEOWNERS ├── azure-gpu-tests.yml └── workflows │ └── cpu-tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── evaluate ├── adapter.py ├── adapter_v2.py ├── full.py └── lora.py ├── finetune ├── adapter.py ├── adapter_v2.py ├── full.py └── lora.py ├── generate.py ├── generate ├── adapter.py ├── adapter_v2.py ├── full.py └── lora.py ├── howto ├── convert_lora_weights.md ├── customize_paths.md ├── download_weights.md ├── finetune_adapter.md ├── finetune_adapter_v2.md ├── finetune_full.md ├── finetune_lora.md ├── inference.md ├── tpus.md ├── train_redpajama.md └── unstructured_dataset.md ├── lit_llama ├── __init__.py ├── adapter.py ├── adapter_v2.py ├── lora.py ├── model.py ├── packed_dataset.py ├── quantization.py ├── tokenizer.py └── utils.py ├── pretrain ├── redpajama.py └── shakespeare.py ├── pyproject.toml ├── quantize └── gptq.py ├── scripts ├── convert_checkpoint.py ├── convert_hf_checkpoint.py ├── convert_lora_weights.py ├── download.py ├── prepare_alpaca.py ├── prepare_any_text.py ├── prepare_dolly.py ├── prepare_redpajama.py └── prepare_shakespeare.py ├── setup.py └── tests ├── conftest.py ├── test_adapter.py ├── test_adapter_v2.py ├── test_generate.py ├── test_lora.py ├── test_model.py ├── test_packed_dataset.py ├── test_prepare_redpajama.py ├── test_prepare_shakespeare.py ├── test_rmsnorm.py ├── test_rope.py └── test_utils.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @awaelchli @carmocca @lantiga 2 | -------------------------------------------------------------------------------- /.github/azure-gpu-tests.yml: -------------------------------------------------------------------------------- 1 | # Python package 2 | # Create and test a Python package on multiple Python versions. 3 | # Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more: 4 | # https://docs.microsoft.com/azure/devops/pipelines/languages/python 5 | 6 | trigger: 7 | tags: 8 | include: 9 | - '*' 10 | branches: 11 | include: 12 | - "main" 13 | - "refs/tags/*" 14 | 15 | pr: 16 | branches: 17 | include: 18 | - "main" 19 | 20 | jobs: 21 | - job: testing 22 | # how long to run the job before automatically cancelling 23 | timeoutInMinutes: "20" 24 | # how much time to give 'run always even if cancelled tasks' before stopping them 25 | cancelTimeoutInMinutes: "2" 26 | pool: "lit-rtx-3090" 27 | variables: 28 | DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) 29 | container: 30 | image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.0-cuda11.7.1" 31 | options: "--gpus=all --shm-size=8gb" 32 | workspace: 33 | clean: all 34 | steps: 35 | 36 | - bash: | 37 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" 38 | cuda_ver=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))") 39 | echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$cuda_ver" 40 | displayName: 'set env. vars' 41 | 42 | - bash: | 43 | echo $CUDA_VISIBLE_DEVICES 44 | echo $CUDA_VERSION_MM 45 | lspci | egrep 'VGA|3D' 46 | whereis nvidia 47 | nvidia-smi 48 | which python && which pip 49 | python --version && pip --version && pip list 50 | python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'" 51 | displayName: 'Image info & NVIDIA' 52 | 53 | - script: pip install ".[all]" "pytest" 54 | displayName: 'Install dependencies' 55 | 56 | - bash: pytest -v --durations=10 --disable-pytest-warnings --strict-markers --color=yes 57 | displayName: 'Testing' 58 | -------------------------------------------------------------------------------- /.github/workflows/cpu-tests.yml: -------------------------------------------------------------------------------- 1 | name: CPU tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 12 | 13 | defaults: 14 | run: 15 | shell: bash 16 | 17 | jobs: 18 | pytester: 19 | runs-on: ${{ matrix.os }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: ["ubuntu-22.04", "macos-13", "windows-2022"] 24 | python-version: ["3.10"] 25 | pkg-install: ["no", "yes"] 26 | timeout-minutes: 15 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | cache: 'pip' 36 | cache-dependency-path: | 37 | pyproject.toml 38 | setup.py 39 | 40 | - name: Install package & dependencies 41 | run: | 42 | pip install ".[all]" pytest 43 | pip list 44 | 45 | - name: Drop package itself 46 | if: matrix.pkg-install == 'no' 47 | run: pip uninstall -y lit-llama 48 | 49 | - name: Run tests 50 | run: pytest -v --durations=10 51 | 52 | 53 | testing-guardian: 54 | runs-on: ubuntu-latest 55 | needs: pytester 56 | if: always() 57 | steps: 58 | - run: echo "${{ needs.pytester.result }}" 59 | - name: failing... 60 | if: needs.pytester.result == 'failure' 61 | run: exit 1 62 | - name: cancelled or skipped... 63 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) 64 | timeout-minutes: 1 65 | run: sleep 90 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | 7 | # data 8 | data 9 | checkpoints 10 | out 11 | !data/shakespeare/prepare.py 12 | wandb 13 | 14 | # downloaded by our tests 15 | original_model.py 16 | original_adapter.py 17 | 18 | .ruff_cache/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | Lit-LLaMA 3 | 4 | # ⚡ Lit-LLaMA ️ 5 | 6 | ![cpu-tests](https://github.com/lightning-AI/lit-llama/actions/workflows/cpu-tests.yml/badge.svg) [![Build Status](https://dev.azure.com/Lightning-AI/lit%20Models/_apis/build/status%2FLightning-AI.lit-LLaMA?branchName=main)](https://dev.azure.com/Lightning-AI/lit%20Models/_build/latest?definitionId=49&branchName=main) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lit-llama/blob/master/LICENSE) [![Discord](https://img.shields.io/discord/1077906959069626439?style=plastic)](https://discord.gg/VptPCZkGNa) 7 | 8 |
  9 | ⚠️ Warning: Not Actively Maintained
 10 | 
 11 | This repository is no longer actively maintained. For a more up-to-date alternative, please visit the LitGPT project:
 12 | https://github.com/Lightning-AI/litgpt, which serves as the successor to this repository.
 13 | 
 14 | Feel free to explore, reuse, or fork, but be aware that no further updates or support will be provided.
 15 | 
16 | 17 | Lit-LLaMA and pineapple pizza 18 | 19 |
20 | 21 | # ⚡ Lit-LLaMA ️ 22 | Independent implementation of [LLaMA]() pretraining, finetuning, and inference code that is fully open source under the **Apache 2.0 license.** 23 | 24 | This implementation builds on [nanoGPT](). 25 | 26 | The open-source code in this repository works with the original LLaMA weights that are distributed by Meta under a [research-only license](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md#model-details). 27 | 28 | ## Looking for LLaMA 2? 29 | 30 | Meta AI has since released LLaMA 2. Additionally, new Apache 2.0 licensed weights are being released as part of the [Open LLaMA project](https://github.com/openlm-research/open_llama). 31 | 32 | To run LLaMA 2 weights, Open LLaMA weights, or Vicuna weights (among other LLaMA-like checkpoints), **check out the [Lit-GPT repository](https://github.com/Lightning-AI/lit-gpt)**. 33 | 34 | ## Why? 35 | 36 | We believe that AI should be fully open source and part of the collective knowledge. 37 | 38 | The original [LLaMA code](https://github.com/facebookresearch/llama) is [GPL licensed](https://github.com/facebookresearch/llama/blob/main/LICENSE) which means any project using it must also be released under GPL. 39 | 40 | This "taints" any other code and prevents integration with the rest of the ecosystem. 41 | 42 | **Lit-LLaMA solves that for good.** 43 | 44 |   45 | 46 | ## Design principles 47 | **Lit-LLaMA** is: 48 | 49 | - **Simple:** Single-file implementation without boilerplate. 50 | - **Correct:** Numerically equivalent to the original model. 51 | - **Optimized:** Runs on consumer hardware or at scale. 52 | - **Open-source:** No strings attached. 53 | 54 | ## Get involved! 55 | [Join our Discord](https://discord.gg/VptPCZkGNa) to build high-performance, truly open-source models for the common benefit of the community. 56 | 57 |   58 | 59 | ## Setup 60 | 61 | Clone the repo 62 | 63 | ```bash 64 | git clone https://github.com/Lightning-AI/lit-llama 65 | cd lit-llama 66 | ``` 67 | 68 | install dependencies 69 | 70 | ```bash 71 | pip install -e ".[all]" 72 | ``` 73 | 74 | You are all set! 🎉 75 | 76 |   77 | 78 | ## Use the model 79 | 80 | To generate text predictions, you need to download the model weights. **If you don't have them, check out our [guide](howto/download_weights.md).** 81 | 82 | Run inference: 83 | 84 | ```bash 85 | python generate.py --prompt "Hello, my name is" 86 | ``` 87 | 88 | This will run the 7B model and require ~26 GB of GPU memory (A100 GPU). 89 | 90 | [Full guide for generating samples from the model](howto/inference.md). 91 | 92 | ### Run Lit-LLaMA on consumer devices 93 | 94 | On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB. 95 | For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`): 96 | 97 | ```bash 98 | python generate.py --quantize llm.int8 --prompt "Hello, my name is" 99 | ``` 100 | 101 | See `python generate.py --help` for more options. 102 | 103 | You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first: 104 | 105 | ```bash 106 | python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4 107 | ``` 108 | 109 | GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled. 110 | 111 | With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file: 112 | 113 | ```bash 114 | python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth 115 | ``` 116 | 117 | [Full guide for generating samples from the model](howto/inference.md). 118 | 119 | ## Finetune the model 120 | 121 | We provide a simple training scripts in `finetune/lora.py` and `finetune/adapter.py` that instruction-tunes a pretrained model on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset using the techniques of [LoRA](https://arxiv.org/abs/2106.09685) and [Adapter](https://arxiv.org/abs/2303.16199). 122 | 123 | 1. Download the data and generate a instruction tuning dataset: 124 | 125 | ```bash 126 | python scripts/prepare_alpaca.py 127 | ``` 128 | 129 | 2. Run the finetuning script 130 | 131 | ```bash 132 | python finetune/lora.py 133 | ``` 134 | or 135 | ```bash 136 | python finetune/adapter.py 137 | ``` 138 | 139 | It is expected that you have downloaded the pretrained weights as described above. 140 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). Follow the instructions in the script to efficiently fit your GPU memory. 141 | Note: For some GPU models you might need to set `torch.backends.cuda.enable_flash_sdp(False)` (see comments at the top of the script). 142 | 143 | More details about each finetuning method and how you can apply it to your own data can be found in our technical how-to guides. 144 | 145 | ### Finetuning How-To Guides 146 | 147 | These technical tutorials illustrate how to run the finetuning code. 148 | 149 | - [Finetune with LoRA](howto/finetune_lora.md) 150 | - [Finetune with Adapters](howto/finetune_adapter.md) 151 | 152 | ### Understanding Finetuning -- Conceptual Tutorials 153 | 154 | Looking for conceptual tutorials and explanations? We have some additional articles below: 155 | 156 | - [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) 157 | 158 | ## Pre-training 159 | 160 | We provide a simple training script based on Fabric if you want to venture into pre-training on RedPajama, a reproduction of the original LLaMA dataset. 161 | Conversion scripts for our optimized streaming `PackedDataset` are included. 162 | 163 | Follow this guide to start pre-training on the RedPajama dataset: 164 | 165 | - [Pretrain on RedPajama](howto/train_redpajama.md) 166 | 167 | ## Get involved! 168 | 169 | We are on a quest towards fully open source AI. 170 | 171 | Lit-LLaMA 172 | 173 | Join us and start contributing, especially on the following areas: 174 | 175 | - [ ] [Pre-training](https://github.com/Lightning-AI/lit-llama/labels/pre-training) 176 | - [ ] [Fine-tuning (full and LoRA)](https://github.com/Lightning-AI/lit-llama/labels/fine-tuning) 177 | - [ ] [Quantization](https://github.com/Lightning-AI/lit-llama/labels/quantization) 178 | - [ ] [Sparsification](https://github.com/Lightning-AI/lit-llama/labels/sparsification) 179 | 180 | Look at `train.py` for a starting point towards pre-training / fine-tuning using [Lightning Fabric](https://lightning.ai/docs/fabric/stable/). 181 | 182 | We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment. 183 | 184 | Unsure about contributing? Check out our [Contributing to Lit-LLaMA: A Hitchhiker’s Guide to the Quest for Fully Open-Source AI](https://lightning.ai/pages/community/tutorial/contributing-to-lit-llama-a-hitchhikers-guide-to-the-quest-for-fully-open-source-ai/) guide. 185 | 186 | Don't forget to [join our Discord](https://discord.gg/VptPCZkGNa)! 187 | 188 | ## Acknowledgements 189 | 190 | - [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT) 191 | - [@FacebookResearch](https://github.com/facebookresearch) for the original [LLaMA implementation](https://github.com/facebookresearch/llama) 192 | - [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 193 | - [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA) 194 | - [@IST-DASLab](https://github.com/IST-DASLab) for [GPTQ](https://github.com/IST-DASLab/gptq) 195 | 196 | ## License 197 | 198 | Lit-LLaMA is released under the [Apache 2.0](https://github.com/Lightning-AI/lightning-llama/blob/main/LICENSE) license. 199 | -------------------------------------------------------------------------------- /evaluate/adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/ 4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323 5 | import math 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import lightning as L 12 | import torch 13 | import tqdm 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from lit_llama import Tokenizer 20 | from lit_llama.adapter import LLaMA 21 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup 22 | from scripts.prepare_alpaca import generate_prompt 23 | 24 | from datasets import load_dataset 25 | 26 | instruction_tuning = True 27 | 28 | 29 | def load_eval_data(dataset_name: str) -> str: 30 | # this mimics gptq datautils 31 | if dataset_name == "wikitext": 32 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 33 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 34 | testdata = "\n\n".join(testdata["text"]) 35 | elif dataset_name == "ptb": 36 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") 37 | testdata = "\n\n".join(testdata["sentence"]) 38 | elif dataset_name == "c4": 39 | testdata = load_dataset( 40 | "allenai/c4", 41 | "allenai--c4", 42 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, 43 | split="validation", 44 | ) 45 | testdata = " ".join(testdata[:1100]["text"]) 46 | 47 | else: 48 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)") 49 | return testdata 50 | 51 | 52 | @torch.inference_mode() 53 | def main( 54 | datasets: str = "wikitext,ptb,c4", 55 | *, 56 | # compilation fails as it does not support torch.complex64 for RoPE 57 | # compile: bool = False, 58 | accelerator: str = "auto", 59 | adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"), 60 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 61 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 62 | dtype: str = "float32", 63 | quantize: Optional[str] = None, 64 | ) -> None: 65 | """Generates text samples based on a pre-trained LLaMA model and tokenizer. 66 | 67 | Args: 68 | datasets: The datasets to use as a comma separated string 69 | # compile: Whether to compile the model. 70 | accelerator: The hardware to run on. Possible choices are: 71 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. 72 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of 73 | `finetune_adapter.py`. 74 | checkpoint_path: The checkpoint path to load. 75 | tokenizer_path: The tokenizer path to load. 76 | dtype: The tensor dtype for choosing the floating-point precision 77 | quantize: Whether to quantize the model and using which method: 78 | ``"llm.int8"``: LLM.int8() mode, 79 | ``"gptq.int4"``: GPTQ 4-bit mode. 80 | """ 81 | assert adapter_path.is_file() 82 | assert checkpoint_path.is_file() 83 | assert tokenizer_path.is_file() 84 | 85 | fabric = L.Fabric(accelerator=accelerator, devices=1) 86 | 87 | dt = getattr(torch, dtype, None) 88 | if not isinstance(dt, torch.dtype): 89 | raise ValueError(f"{dtype} is not a valid dtype.") 90 | dtype = dt 91 | 92 | print("Loading model ...", file=sys.stderr) 93 | t0 = time.time() 94 | with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint: 95 | name = llama_model_lookup(pretrained_checkpoint) 96 | 97 | with EmptyInitOnDevice( 98 | device=fabric.device, dtype=dtype, quantization_mode=quantize 99 | ): 100 | model = LLaMA.from_name(name) 101 | 102 | # 1. Load the pretrained weights 103 | model.load_state_dict(pretrained_checkpoint, strict=False) 104 | # 2. Load the fine-tuned adapter weights 105 | model.load_state_dict(adapter_checkpoint, strict=False) 106 | 107 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 108 | 109 | model.eval() 110 | 111 | # if compile: 112 | # model = torch.compile(model) 113 | 114 | total_toks = 0 115 | model = fabric.setup_module(model) 116 | 117 | tokenizer = Tokenizer(tokenizer_path) 118 | 119 | for dsname in datasets.split(","): 120 | test_string = load_eval_data(dsname) 121 | 122 | if instruction_tuning: 123 | sample = {"instruction": test_string, "input": input} 124 | test_string = generate_prompt(sample) 125 | 126 | encoded_text = tokenizer.encode( 127 | test_string, bos=True, eos=False, device=fabric.device 128 | ) 129 | encoded_text = encoded_text[ 130 | None, : 256 * model.config.block_size 131 | ] # add batch dimension, trim like gptq implementation 132 | t0 = time.perf_counter() 133 | 134 | nlls = 0 135 | toks = 0 136 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30) 137 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)): 138 | inp = encoded_text[:, i : i + block_size] 139 | logits = model(inp)[0] 140 | nll = torch.nn.functional.cross_entropy( 141 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum" 142 | ) 143 | toks += inp.size(1) - 1 144 | nlls += nll.item() 145 | 146 | print(encoded_text.shape, logits.shape) 147 | ppl = math.exp(nlls / toks) 148 | print(f"Perplexity on {dsname}: {ppl:.2f}") 149 | total_toks += toks 150 | 151 | t = time.perf_counter() - t0 152 | print( 153 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec", 154 | file=sys.stderr, 155 | ) 156 | print( 157 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", 158 | file=sys.stderr, 159 | ) 160 | 161 | 162 | if __name__ == "__main__": 163 | from jsonargparse import CLI 164 | 165 | torch.set_float32_matmul_precision("high") 166 | CLI(main) 167 | -------------------------------------------------------------------------------- /evaluate/adapter_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/ 4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323 5 | import math 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import lightning as L 12 | import torch 13 | import tqdm 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from lit_llama import Tokenizer 20 | from lit_llama.adapter import LLaMA 21 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup 22 | from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers 23 | from scripts.prepare_alpaca import generate_prompt 24 | 25 | from datasets import load_dataset 26 | 27 | 28 | def load_eval_data(dataset_name: str) -> str: 29 | # this mimics gptq datautils 30 | if dataset_name == "wikitext": 31 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 32 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 33 | testdata = "\n\n".join(testdata["text"]) 34 | elif dataset_name == "ptb": 35 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") 36 | testdata = "\n\n".join(testdata["sentence"]) 37 | elif dataset_name == "c4": 38 | testdata = load_dataset( 39 | "allenai/c4", 40 | "allenai--c4", 41 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, 42 | split="validation", 43 | ) 44 | testdata = " ".join(testdata[:1100]["text"]) 45 | 46 | else: 47 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)") 48 | return testdata 49 | 50 | 51 | @torch.inference_mode() 52 | def main( 53 | datasets: str = "wikitext,ptb,c4", 54 | *, 55 | accelerator: str = "auto", 56 | adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"), 57 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 58 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 59 | dtype: str = "float32", 60 | quantize: Optional[str] = None, 61 | ) -> None: 62 | """Generates text samples based on a pre-trained LLaMA model and tokenizer. 63 | 64 | Args: 65 | datasets: The datasets to use as a comma separated string 66 | accelerator: The hardware to run on. Possible choices are: 67 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. 68 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of 69 | `finetune_adapter_v2.py`. 70 | checkpoint_path: The checkpoint path to load. 71 | tokenizer_path: The tokenizer path to load. 72 | dtype: The tensor dtype for choosing the floating-point precision 73 | quantize: Whether to quantize the model and using which method: 74 | ``"llm.int8"``: LLM.int8() mode, 75 | ``"gptq.int4"``: GPTQ 4-bit mode. 76 | """ 77 | assert adapter_path.is_file() 78 | assert checkpoint_path.is_file() 79 | assert tokenizer_path.is_file() 80 | 81 | fabric = L.Fabric(accelerator=accelerator, devices=1) 82 | 83 | dt = getattr(torch, dtype, None) 84 | if not isinstance(dt, torch.dtype): 85 | raise ValueError(f"{dtype} is not a valid dtype.") 86 | dtype = dt 87 | 88 | print("Loading model ...", file=sys.stderr) 89 | t0 = time.time() 90 | with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint: 91 | name = llama_model_lookup(pretrained_checkpoint) 92 | 93 | with EmptyInitOnDevice( 94 | device=fabric.device, dtype=dtype, quantization_mode=quantize 95 | ): 96 | model = LLaMA.from_name(name) 97 | add_adapter_v2_parameters_to_linear_layers(model) 98 | 99 | # 1. Load the pretrained weights 100 | model.load_state_dict(pretrained_checkpoint, strict=False) 101 | # 2. Load the fine-tuned adapter weights 102 | model.load_state_dict(adapter_checkpoint, strict=False) 103 | 104 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 105 | 106 | model.eval() 107 | 108 | # if compile: 109 | # model = torch.compile(model) 110 | 111 | total_toks = 0 112 | model = fabric.setup_module(model) 113 | 114 | tokenizer = Tokenizer(tokenizer_path) 115 | 116 | for dsname in datasets.split(","): 117 | test_string = load_eval_data(dsname) 118 | 119 | sample = {"instruction": test_string, "input": input} 120 | test_string = generate_prompt(sample) 121 | 122 | encoded_text = tokenizer.encode( 123 | test_string, bos=True, eos=False, device=fabric.device 124 | ) 125 | encoded_text = encoded_text[ 126 | None, : 256 * model.config.block_size 127 | ] # add batch dimension, trim like gptq implementation 128 | t0 = time.perf_counter() 129 | 130 | nlls = 0 131 | toks = 0 132 | 133 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30) 134 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)): 135 | inp = encoded_text[:, i : i + block_size] 136 | logits = model(inp)[0] 137 | nll = torch.nn.functional.cross_entropy( 138 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum" 139 | ) 140 | toks += inp.size(1) - 1 141 | nlls += nll.item() 142 | 143 | print(encoded_text.shape, logits.shape) 144 | ppl = math.exp(nlls / toks) 145 | print(f"Perplexity on {dsname}: {ppl:.2f}") 146 | total_toks += toks 147 | 148 | t = time.perf_counter() - t0 149 | print( 150 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec", 151 | file=sys.stderr, 152 | ) 153 | print( 154 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", 155 | file=sys.stderr, 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | from jsonargparse import CLI 161 | 162 | torch.set_float32_matmul_precision("high") 163 | CLI(main) 164 | -------------------------------------------------------------------------------- /evaluate/full.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/ 4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323 5 | import math 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import lightning as L 12 | import torch 13 | import tqdm 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from lit_llama import LLaMA, Tokenizer 20 | from lit_llama.utils import EmptyInitOnDevice 21 | 22 | from datasets import load_dataset 23 | 24 | 25 | def load_eval_data(dataset_name: str) -> str: 26 | # this mimics gptq datautils 27 | if dataset_name == "wikitext": 28 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 29 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 30 | testdata = "\n\n".join(testdata["text"]) 31 | elif dataset_name == "ptb": 32 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") 33 | testdata = "\n\n".join(testdata["sentence"]) 34 | elif dataset_name == "c4": 35 | testdata = load_dataset( 36 | "allenai/c4", 37 | "allenai--c4", 38 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, 39 | split="validation", 40 | ) 41 | testdata = " ".join(testdata[:1100]["text"]) 42 | 43 | else: 44 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)") 45 | return testdata 46 | 47 | 48 | def main( 49 | datasets: str = "wikitext,ptb,c4", 50 | *, 51 | # compilation fails as it does not support torch.complex64 for RoPE 52 | # compile: bool = False, 53 | accelerator: str = "auto", 54 | checkpoint_path: Optional[Path] = None, 55 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 56 | model_size: str = "7B", 57 | dtype: str = "float32", 58 | quantize: Optional[str] = None, 59 | ) -> None: 60 | """Generates text samples based on a pre-trained LLaMA model and tokenizer. 61 | 62 | Args: 63 | datasets: The datasets to use as a comma separated string 64 | # compile: Whether to compile the model. 65 | accelerator: The hardware to run on. Possible choices are: 66 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. 67 | checkpoint_path: The checkpoint path to load. 68 | tokenizer_path: The tokenizer path to load. 69 | dtype: The tensor dtype for choosing the floating-point precision 70 | quantize: Whether to quantize the model and using which method: 71 | ``"llm.int8"``: LLM.int8() mode, 72 | ``"gptq.int4"``: GPTQ 4-bit mode. 73 | """ 74 | if not checkpoint_path: 75 | checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth") 76 | assert checkpoint_path.is_file() 77 | assert tokenizer_path.is_file() 78 | 79 | fabric = L.Fabric(accelerator=accelerator, devices=1) 80 | 81 | dt = getattr(torch, dtype, None) 82 | if not isinstance(dt, torch.dtype): 83 | raise ValueError(f"{dtype} is not a valid dtype.") 84 | dtype = dt 85 | 86 | with EmptyInitOnDevice( 87 | device=fabric.device, dtype=dtype, quantization_mode=quantize 88 | ): 89 | print("Loading model ...", file=sys.stderr) 90 | t0 = time.time() 91 | model = LLaMA.from_name(model_size) 92 | checkpoint = torch.load(checkpoint_path) 93 | model.load_state_dict(checkpoint) 94 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 95 | 96 | model.eval() 97 | 98 | # if compile: 99 | # model = torch.compile(model) 100 | 101 | total_toks = 0 102 | model = fabric.setup_module(model) 103 | 104 | tokenizer = Tokenizer(tokenizer_path) 105 | 106 | for dsname in datasets.split(","): 107 | test_string = load_eval_data(dsname) 108 | encoded_text = tokenizer.encode( 109 | test_string, bos=True, eos=False, device=fabric.device 110 | ) 111 | encoded_text = encoded_text[ 112 | None, : 256 * model.config.block_size 113 | ] # add batch dimension, trim like gptq implementation 114 | t0 = time.perf_counter() 115 | 116 | nlls = 0 117 | toks = 0 118 | with torch.inference_mode(): 119 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30) 120 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)): 121 | inp = encoded_text[:, i : i + block_size] 122 | logits = model(inp)[0] 123 | nll = torch.nn.functional.cross_entropy( 124 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum" 125 | ) 126 | toks += inp.size(1) - 1 127 | nlls += nll.item() 128 | 129 | print(encoded_text.shape, logits.shape) 130 | ppl = math.exp(nlls / toks) 131 | print(f"Perplexity on {dsname}: {ppl:.2f}") 132 | total_toks += toks 133 | 134 | t = time.perf_counter() - t0 135 | print( 136 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec", 137 | file=sys.stderr, 138 | ) 139 | print( 140 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", 141 | file=sys.stderr, 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | from jsonargparse import CLI 147 | 148 | torch.set_float32_matmul_precision("high") 149 | CLI(main) 150 | -------------------------------------------------------------------------------- /evaluate/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/ 4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323 5 | import math 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import lightning as L 12 | import torch 13 | import tqdm 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from lit_llama import LLaMA, Tokenizer 20 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup 21 | from lit_llama.lora import lora 22 | from scripts.prepare_alpaca import generate_prompt 23 | 24 | from datasets import load_dataset 25 | 26 | instruction_tuning = True 27 | lora_r = 8 28 | lora_alpha = 16 29 | lora_dropout = 0.05 30 | 31 | 32 | def load_eval_data(dataset_name: str) -> str: 33 | # this mimics gptq datautils 34 | if dataset_name == "wikitext": 35 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 36 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 37 | testdata = "\n\n".join(testdata["text"]) 38 | elif dataset_name == "ptb": 39 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") 40 | testdata = "\n\n".join(testdata["sentence"]) 41 | elif dataset_name == "c4": 42 | testdata = load_dataset( 43 | "allenai/c4", 44 | "allenai--c4", 45 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, 46 | split="validation", 47 | ) 48 | testdata = " ".join(testdata[:1100]["text"]) 49 | 50 | else: 51 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)") 52 | return testdata 53 | 54 | 55 | def main( 56 | datasets: str = "wikitext,ptb,c4", 57 | *, 58 | # compilation fails as it does not support torch.complex64 for RoPE 59 | # compile: bool = False, 60 | accelerator: str = "auto", 61 | lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"), 62 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 63 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 64 | dtype: str = "float32", 65 | quantize: Optional[str] = None, 66 | ) -> None: 67 | """Generates text samples based on a pre-trained LLaMA model and tokenizer 68 | finetuned with LoRA. 69 | 70 | Args: 71 | datasets: The datasets to use as a comma separated string 72 | # compile: Whether to compile the model. 73 | accelerator: The hardware to run on. Possible choices are: 74 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. 75 | lora_path: Path to the checkpoint with trained LoRA weights, which are the output of 76 | `finetune_lora.py`. 77 | checkpoint_path: The checkpoint path to load. 78 | tokenizer_path: The tokenizer path to load. 79 | dtype: The tensor dtype for choosing the floating-point precision 80 | quantize: Whether to quantize the model and using which method: 81 | ``"llm.int8"``: LLM.int8() mode, 82 | ``"gptq.int4"``: GPTQ 4-bit mode. 83 | """ 84 | assert lora_path.is_file() 85 | assert checkpoint_path.is_file() 86 | assert tokenizer_path.is_file() 87 | 88 | if quantize is not None: 89 | raise NotImplementedError("Quantization in LoRA is not supported yet") 90 | 91 | fabric = L.Fabric(accelerator=accelerator, devices=1) 92 | 93 | dt = getattr(torch, dtype, None) 94 | if not isinstance(dt, torch.dtype): 95 | raise ValueError(f"{dtype} is not a valid dtype.") 96 | dtype = dt 97 | 98 | print("Loading model ...", file=sys.stderr) 99 | t0 = time.time() 100 | 101 | with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint: 102 | name = llama_model_lookup(pretrained_checkpoint) 103 | 104 | with EmptyInitOnDevice( 105 | device=fabric.device, dtype=dtype, quantization_mode=quantize 106 | ), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True): 107 | model = LLaMA.from_name(name) 108 | 109 | # 1. Load the pretrained weights 110 | model.load_state_dict(pretrained_checkpoint, strict=False) 111 | # 2. Load the fine-tuned lora weights 112 | model.load_state_dict(lora_checkpoint, strict=False) 113 | 114 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 115 | 116 | model.eval() 117 | 118 | # if compile: 119 | # model = torch.compile(model) 120 | 121 | total_toks = 0 122 | model = fabric.setup_module(model) 123 | 124 | tokenizer = Tokenizer(tokenizer_path) 125 | 126 | for dsname in datasets.split(","): 127 | test_string = load_eval_data(dsname) 128 | 129 | if instruction_tuning: 130 | sample = {"instruction": test_string, "input": input} 131 | test_string = generate_prompt(sample) 132 | 133 | encoded_text = tokenizer.encode( 134 | test_string, bos=True, eos=False, device=fabric.device 135 | ) 136 | encoded_text = encoded_text[ 137 | None, : 256 * model.config.block_size 138 | ] # add batch dimension, trim like gptq implementation 139 | t0 = time.perf_counter() 140 | 141 | nlls = 0 142 | toks = 0 143 | with torch.inference_mode(): 144 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30) 145 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)): 146 | inp = encoded_text[:, i : i + block_size] 147 | logits = model(inp)[0] 148 | nll = torch.nn.functional.cross_entropy( 149 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum" 150 | ) 151 | toks += inp.size(1) - 1 152 | nlls += nll.item() 153 | 154 | print(encoded_text.shape, logits.shape) 155 | ppl = math.exp(nlls / toks) 156 | print(f"Perplexity on {dsname}: {ppl:.2f}") 157 | total_toks += toks 158 | 159 | t = time.perf_counter() - t0 160 | print( 161 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec", 162 | file=sys.stderr, 163 | ) 164 | print( 165 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", 166 | file=sys.stderr, 167 | ) 168 | 169 | 170 | if __name__ == "__main__": 171 | from jsonargparse import CLI 172 | 173 | torch.set_float32_matmul_precision("high") 174 | CLI(main) 175 | -------------------------------------------------------------------------------- /finetune/full.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | """ 4 | Instruction-tuning on the Alpaca dataset using a regular finetuning procedure (updating all layers). 5 | 6 | Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line 7 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). 8 | """ 9 | import sys 10 | from pathlib import Path 11 | import os 12 | import time 13 | from functools import partial 14 | 15 | import lightning as L 16 | from lightning.fabric.strategies import FSDPStrategy 17 | import numpy as np 18 | import torch 19 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 20 | 21 | # support running without installing as a package 22 | wd = Path(__file__).parent.parent.resolve() 23 | sys.path.append(str(wd)) 24 | 25 | from generate import generate 26 | from lit_llama.model import Block, LLaMA, LLaMAConfig 27 | from lit_llama.tokenizer import Tokenizer 28 | from lit_llama.utils import save_model_checkpoint 29 | from scripts.prepare_alpaca import generate_prompt 30 | 31 | 32 | instruction_tuning = True 33 | eval_interval = 1000 34 | save_interval = 1000 35 | eval_iters = 100 36 | log_interval = 100 37 | devices = 4 38 | 39 | # Hyperparameters 40 | learning_rate = 3e-5 41 | batch_size = 128 / devices 42 | micro_batch_size = 4 43 | gradient_accumulation_iters = batch_size // micro_batch_size 44 | assert gradient_accumulation_iters > 0 45 | epoch_size = 50000 # train dataset size 46 | num_epochs = 5 47 | max_iters = num_epochs * (epoch_size // micro_batch_size) // devices 48 | weight_decay = 0.0 49 | block_size = 512 50 | warmup_iters = 100 51 | 52 | 53 | def main( 54 | data_dir: str = "data/alpaca", 55 | pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth", 56 | out_dir: str = "out/full/alpaca", 57 | ): 58 | 59 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 60 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True) 61 | 62 | fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy) 63 | fabric.launch() 64 | fabric.seed_everything(1337 + fabric.global_rank) 65 | 66 | if fabric.global_rank == 0: 67 | os.makedirs(out_dir, exist_ok=True) 68 | 69 | train_data, val_data = load_datasets(data_dir=data_dir) 70 | 71 | config = LLaMAConfig.from_name("7B") 72 | config.block_size = block_size 73 | 74 | checkpoint = torch.load(pretrained_path) 75 | 76 | with fabric.device: 77 | torch.set_default_tensor_type(torch.HalfTensor) 78 | model = LLaMA(config).bfloat16() 79 | torch.set_default_tensor_type(torch.FloatTensor) 80 | model.load_state_dict(checkpoint, strict=False) 81 | 82 | model = fabric.setup_module(model) 83 | 84 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False) 85 | optimizer = fabric.setup_optimizers(optimizer) 86 | 87 | train(fabric, model, optimizer, train_data, val_data, out_dir) 88 | 89 | # Save the final checkpoint at the end of training 90 | save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth")) 91 | 92 | 93 | def train( 94 | fabric: L.Fabric, 95 | model: torch.nn.Module, 96 | optimizer: torch.optim.Optimizer, 97 | train_data: np.ndarray, 98 | val_data: np.ndarray, 99 | out_dir: str, 100 | ) -> None: 101 | """The training loop. 102 | 103 | Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. 104 | """ 105 | step_count = 0 106 | model.train() 107 | 108 | for iter_num in range(max_iters): 109 | 110 | is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0 111 | 112 | if step_count <= warmup_iters: 113 | # linear warmup 114 | lr = learning_rate * step_count / warmup_iters 115 | for param_group in optimizer.param_groups: 116 | param_group['lr'] = lr 117 | 118 | t0 = time.time() 119 | 120 | input_ids, targets = get_batch(fabric, train_data) 121 | with fabric.no_backward_sync(model, enabled=is_accumulating): 122 | logits = model(input_ids) 123 | loss = loss_fn(logits, targets) 124 | fabric.backward(loss / gradient_accumulation_iters) 125 | 126 | if not is_accumulating: 127 | optimizer.step() 128 | optimizer.zero_grad() 129 | step_count += 1 130 | 131 | if step_count % eval_interval == 0: 132 | val_loss = validate(fabric, model, val_data) 133 | fabric.print(f"step {iter_num}: val loss {val_loss:.4f}") 134 | fabric.barrier() 135 | 136 | if step_count % save_interval == 0: 137 | print(f"Saving weights to {out_dir}") 138 | save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")) 139 | 140 | dt = time.time() - t0 141 | if iter_num % log_interval == 0: 142 | fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms") 143 | 144 | 145 | def generate_response(model, instruction): 146 | tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model") 147 | sample = {"instruction": instruction, "input": ""} 148 | prompt = instruction 149 | if instruction_tuning: 150 | prompt = generate_prompt(sample) 151 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) 152 | 153 | output = generate( 154 | model, 155 | idx=encoded, 156 | max_seq_length=block_size, 157 | max_new_tokens=100, 158 | ) 159 | output = tokenizer.decode(output) 160 | return output # output.split("### Response:")[1].strip() 161 | 162 | 163 | @torch.no_grad() 164 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor: 165 | fabric.print("Validating ...") 166 | model.eval() 167 | losses = torch.zeros(eval_iters) 168 | for k in range(eval_iters): 169 | input_ids, targets = get_batch(fabric, val_data) 170 | logits = model(input_ids) 171 | loss = loss_fn(logits, targets) 172 | losses[k] = loss.item() 173 | out = losses.mean() 174 | 175 | # produce an example: 176 | instruction = "Recommend a movie for me to watch during the weekend and explain the reason." 177 | 178 | output = generate_response(model, instruction) 179 | fabric.print(instruction) 180 | fabric.print(output) 181 | 182 | model.train() 183 | return out.item() 184 | 185 | 186 | def loss_fn(logits, targets): 187 | # shift the targets such that output n predicts token n+1 188 | logits = logits[..., :-1, :].contiguous() 189 | targets = targets[..., 1:].contiguous() 190 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 191 | return loss 192 | 193 | 194 | def get_batch(fabric: L.Fabric, data: list): 195 | ix = torch.randint(len(data), (micro_batch_size,)) 196 | 197 | input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] 198 | labels = [data[i]["labels"].type(torch.int64) for i in ix] 199 | 200 | max_len = max(len(s) for s in input_ids) 201 | 202 | def pad_right(x, pad_id): 203 | # pad right based on the longest sequence 204 | n = max_len - len(x) 205 | return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) 206 | 207 | x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) 208 | y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) 209 | x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) 210 | return x, y 211 | 212 | 213 | def load_datasets(data_dir): 214 | train_data = torch.load(os.path.join(data_dir, "train.pt")) 215 | val_data = torch.load(os.path.join(data_dir, "test.pt")) 216 | return train_data, val_data 217 | 218 | 219 | if __name__ == "__main__": 220 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 221 | # torch.backends.cuda.enable_flash_sdp(False) 222 | torch.set_float32_matmul_precision("high") 223 | 224 | from jsonargparse.cli import CLI 225 | 226 | CLI(main) 227 | -------------------------------------------------------------------------------- /finetune/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | """ 4 | Instruction-tuning with LoRA on the Alpaca dataset. 5 | 6 | Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line 7 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). 8 | """ 9 | import sys 10 | from pathlib import Path 11 | import os 12 | import time 13 | 14 | import lightning as L 15 | import numpy as np 16 | import torch 17 | 18 | # support running without installing as a package 19 | wd = Path(__file__).parent.parent.resolve() 20 | sys.path.append(str(wd)) 21 | 22 | from generate import generate 23 | from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict 24 | from lit_llama.model import LLaMA, LLaMAConfig 25 | from lit_llama.tokenizer import Tokenizer 26 | from scripts.prepare_alpaca import generate_prompt 27 | 28 | 29 | instruction_tuning = True 30 | eval_interval = 100 31 | save_interval = 100 32 | eval_iters = 100 33 | log_interval = 1 34 | 35 | # Hyperparameters 36 | learning_rate = 3e-4 37 | batch_size = 128 38 | micro_batch_size = 4 39 | gradient_accumulation_iters = batch_size // micro_batch_size 40 | assert gradient_accumulation_iters > 0 41 | max_iters = 50000 * 3 // micro_batch_size 42 | weight_decay = 0.0 43 | max_seq_length = 256 # see scripts/prepare_alpaca.py 44 | lora_r = 8 45 | lora_alpha = 16 46 | lora_dropout = 0.05 47 | warmup_iters = 100 48 | 49 | 50 | def main( 51 | data_dir: str = "data/alpaca", 52 | pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth", 53 | tokenizer_path: str = "checkpoints/lit-llama/tokenizer.model", 54 | out_dir: str = "out/lora/alpaca", 55 | ): 56 | 57 | fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true") 58 | fabric.launch() 59 | fabric.seed_everything(1337 + fabric.global_rank) 60 | 61 | if fabric.global_rank == 0: 62 | os.makedirs(out_dir, exist_ok=True) 63 | 64 | train_data, val_data = load_datasets(data_dir=data_dir) 65 | 66 | config = LLaMAConfig.from_name("7B") 67 | config.block_size = max_seq_length 68 | 69 | checkpoint = torch.load(pretrained_path) 70 | 71 | with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True): 72 | model = LLaMA(config) 73 | # strict=False because missing keys due to LoRA weights not contained in checkpoint state 74 | model.load_state_dict(checkpoint, strict=False) 75 | 76 | mark_only_lora_as_trainable(model) 77 | 78 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 79 | model, optimizer = fabric.setup(model, optimizer) 80 | train(fabric, model, optimizer, train_data, val_data, tokenizer_path, out_dir) 81 | 82 | # Save the final LoRA checkpoint at the end of training 83 | checkpoint = lora_state_dict(model) 84 | fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint) 85 | 86 | 87 | def train( 88 | fabric: L.Fabric, 89 | model: torch.nn.Module, 90 | optimizer: torch.optim.Optimizer, 91 | train_data: np.ndarray, 92 | val_data: np.ndarray, 93 | tokenizer_path: str, 94 | out_dir: str, 95 | ) -> None: 96 | """The training loop. 97 | 98 | Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. 99 | """ 100 | step_count = 0 101 | 102 | for iter_num in range(max_iters): 103 | 104 | if step_count <= warmup_iters: 105 | # linear warmup 106 | lr = learning_rate * step_count / warmup_iters 107 | for param_group in optimizer.param_groups: 108 | param_group['lr'] = lr 109 | 110 | t0 = time.time() 111 | 112 | input_ids, targets = get_batch(fabric, train_data) 113 | with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)): 114 | logits = model(input_ids) 115 | loss = loss_fn(logits, targets) 116 | fabric.backward(loss / gradient_accumulation_iters) 117 | 118 | if (iter_num + 1) % gradient_accumulation_iters == 0: 119 | optimizer.step() 120 | optimizer.zero_grad() 121 | step_count += 1 122 | 123 | if step_count % eval_interval == 0: 124 | val_loss = validate(fabric, model, val_data, tokenizer_path) 125 | fabric.print(f"step {iter_num}: val loss {val_loss:.4f}") 126 | fabric.barrier() 127 | 128 | if step_count % save_interval == 0: 129 | print(f"Saving LoRA weights to {out_dir}") 130 | # We are only saving the LoRA weights 131 | # TODO: Provide a function/script to merge the LoRA weights with pretrained weights 132 | checkpoint = lora_state_dict(model) 133 | fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint) 134 | 135 | dt = time.time() - t0 136 | if iter_num % log_interval == 0: 137 | fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms") 138 | 139 | 140 | def generate_response(model, instruction, tokenizer_path): 141 | tokenizer = Tokenizer(tokenizer_path) 142 | sample = {"instruction": instruction, "input": ""} 143 | prompt = instruction 144 | if instruction_tuning: 145 | prompt = generate_prompt(sample) 146 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) 147 | 148 | output = generate( 149 | model, 150 | idx=encoded, 151 | max_seq_length=max_seq_length, 152 | max_new_tokens=100, 153 | ) 154 | output = tokenizer.decode(output) 155 | return output # output.split("### Response:")[1].strip() 156 | 157 | 158 | @torch.no_grad() 159 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray, tokenizer_path: str) -> torch.Tensor: 160 | fabric.print("Validating ...") 161 | model.eval() 162 | losses = torch.zeros(eval_iters) 163 | for k in range(eval_iters): 164 | input_ids, targets = get_batch(fabric, val_data) 165 | logits = model(input_ids) 166 | loss = loss_fn(logits, targets) 167 | losses[k] = loss.item() 168 | out = losses.mean() 169 | 170 | # produce an example: 171 | instruction = "Recommend a movie for me to watch during the weekend and explain the reason." 172 | 173 | output = generate_response(model, instruction, tokenizer_path) 174 | fabric.print(instruction) 175 | fabric.print(output) 176 | 177 | model.train() 178 | return out.item() 179 | 180 | def loss_fn(logits, targets): 181 | # shift the targets such that output n predicts token n+1 182 | logits = logits[..., :-1, :].contiguous() 183 | targets = targets[..., 1:].contiguous() 184 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 185 | return loss 186 | 187 | 188 | def get_batch(fabric: L.Fabric, data: list): 189 | ix = torch.randint(len(data), (micro_batch_size,)) 190 | 191 | input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] 192 | labels = [data[i]["labels"].type(torch.int64) for i in ix] 193 | 194 | max_len = max(len(s) for s in input_ids) 195 | 196 | def pad_right(x, pad_id): 197 | # pad right based on the longest sequence 198 | n = max_len - len(x) 199 | return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) 200 | 201 | x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) 202 | y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) 203 | x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) 204 | return x, y 205 | 206 | 207 | def load_datasets(data_dir): 208 | train_data = torch.load(os.path.join(data_dir, "train.pt")) 209 | val_data = torch.load(os.path.join(data_dir, "test.pt")) 210 | return train_data, val_data 211 | 212 | 213 | if __name__ == "__main__": 214 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 215 | # torch.backends.cuda.enable_flash_sdp(False) 216 | torch.set_float32_matmul_precision("high") 217 | 218 | from jsonargparse.cli import CLI 219 | 220 | CLI(main) 221 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | import time 5 | import warnings 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import lightning as L 10 | import torch 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).parent.parent.resolve() 14 | sys.path.append(str(wd)) 15 | 16 | from lit_llama import LLaMA, Tokenizer 17 | from lit_llama.utils import lazy_load, llama_model_lookup, quantization 18 | 19 | 20 | @torch.no_grad() 21 | def generate( 22 | model: LLaMA, 23 | idx: torch.Tensor, 24 | max_new_tokens: int, 25 | *, 26 | max_seq_length: Optional[int] = None, 27 | temperature: float = 1.0, 28 | top_k: Optional[int] = None, 29 | eos_id: Optional[int] = None, 30 | ) -> torch.Tensor: 31 | """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 32 | 33 | The implementation of this function is modified from A. Karpathy's nanoGPT. 34 | 35 | Args: 36 | model: The model to use. 37 | idx: Tensor of shape (T) with indices of the prompt sequence. 38 | max_new_tokens: The number of new tokens to generate. 39 | max_seq_length: The maximum sequence length allowed. 40 | temperature: Scales the predicted logits by 1 / temperature 41 | top_k: If specified, only sample among the tokens with the k highest probabilities 42 | eos_id: If specified, stop generating any more token once the token is triggered 43 | """ 44 | # create an empty tensor of the expected final shape and fill in the current tokens 45 | T = idx.size(0) 46 | T_new = T + max_new_tokens 47 | if max_seq_length is None: 48 | max_seq_length = min(T_new, model.config.block_size) 49 | 50 | device, dtype = idx.device, idx.dtype 51 | # create an empty tensor of the expected final shape and fill in the current tokens 52 | empty = torch.empty(T_new, dtype=dtype, device=device) 53 | empty[:T] = idx 54 | idx = empty 55 | input_pos = torch.arange(0, T, device=device) 56 | 57 | if idx.device.type == "xla": 58 | import torch_xla.core.xla_model as xm 59 | 60 | xm.mark_step() 61 | 62 | # generate max_new_tokens tokens 63 | for _ in range(max_new_tokens): 64 | x = idx.index_select(0, input_pos).view(1, -1) 65 | 66 | # forward 67 | logits = model(x, max_seq_length, input_pos) 68 | logits = logits[0, -1] / temperature 69 | 70 | # optionally crop the logits to only the top k options 71 | if top_k is not None: 72 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 73 | logits = torch.where(logits < v[[-1]], -float("Inf"), logits) 74 | 75 | probs = torch.nn.functional.softmax(logits, dim=-1) 76 | idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) 77 | 78 | # advance 79 | input_pos = input_pos[-1:] + 1 80 | 81 | if idx.device.type == "xla": 82 | xm.mark_step() 83 | 84 | # concatenate the new generation 85 | idx = idx.index_copy(0, input_pos, idx_next) 86 | 87 | # if token is triggered, return the output (stop generation) 88 | if idx_next == eos_id: 89 | return idx[:input_pos] # include the EOS token 90 | 91 | return idx 92 | 93 | 94 | def main( 95 | prompt: str = "Hello, my name is", 96 | *, 97 | num_samples: int = 1, 98 | max_new_tokens: int = 50, 99 | top_k: int = 200, 100 | temperature: float = 0.8, 101 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 102 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 103 | quantize: Optional[str] = None, 104 | ) -> None: 105 | """Generates text samples based on a pre-trained LLaMA model and tokenizer. 106 | 107 | Args: 108 | prompt: The prompt string to use for generating the samples. 109 | num_samples: The number of text samples to generate. 110 | max_new_tokens: The number of generation steps to take. 111 | top_k: The number of top most probable tokens to consider in the sampling process. 112 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random 113 | samples. 114 | checkpoint_path: The checkpoint path to load. 115 | tokenizer_path: The tokenizer path to load. 116 | quantize: Whether to quantize the model and using which method: 117 | ``"llm.int8"``: LLM.int8() mode, 118 | ``"gptq.int4"``: GPTQ 4-bit mode. 119 | """ 120 | assert checkpoint_path.is_file(), checkpoint_path 121 | assert tokenizer_path.is_file(), tokenizer_path 122 | 123 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" 124 | fabric = L.Fabric(devices=1, precision=precision) 125 | 126 | print("Loading model ...", file=sys.stderr) 127 | t0 = time.time() 128 | with lazy_load(checkpoint_path) as checkpoint: 129 | name = llama_model_lookup(checkpoint) 130 | 131 | with fabric.init_module(empty_init=True), quantization(mode=quantize): 132 | model = LLaMA.from_name(name) 133 | 134 | model.load_state_dict(checkpoint) 135 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 136 | 137 | model.eval() 138 | model = fabric.setup(model) 139 | 140 | tokenizer = Tokenizer(tokenizer_path) 141 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) 142 | prompt_length = encoded.size(0) 143 | 144 | L.seed_everything(1234) 145 | for i in range(num_samples): 146 | t0 = time.perf_counter() 147 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k) 148 | t = time.perf_counter() - t0 149 | 150 | model.reset_cache() 151 | print(tokenizer.decode(y)) 152 | tokens_generated = y.size(0) - prompt_length 153 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) 154 | if fabric.device.type == "cuda": 155 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) 156 | 157 | 158 | if __name__ == "__main__": 159 | from jsonargparse import CLI 160 | 161 | torch.set_float32_matmul_precision("high") 162 | warnings.filterwarnings( 163 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 164 | "ignore", 165 | message="ComplexHalf support is experimental and many operators don't support it yet" 166 | ) 167 | warnings.filterwarnings( 168 | # Triggered in bitsandbytes/autograd/_functions.py:298 169 | "ignore", 170 | message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", 171 | ) 172 | CLI(main) 173 | -------------------------------------------------------------------------------- /generate/adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | import time 5 | import warnings 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import lightning as L 10 | import torch 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).parent.parent.resolve() 14 | sys.path.append(str(wd)) 15 | 16 | from generate import generate 17 | from lit_llama import Tokenizer 18 | from lit_llama.adapter import LLaMA 19 | from lit_llama.utils import lazy_load, llama_model_lookup, quantization 20 | from scripts.prepare_alpaca import generate_prompt 21 | 22 | 23 | def main( 24 | prompt: str = "What food do lamas eat?", 25 | input: str = "", 26 | adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"), 27 | pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 28 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 29 | quantize: Optional[str] = None, 30 | max_new_tokens: int = 100, 31 | top_k: int = 200, 32 | temperature: float = 0.8, 33 | ) -> None: 34 | """Generates a response based on a given instruction and an optional input. 35 | This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model. 36 | See `finetune_adapter.py`. 37 | 38 | Args: 39 | prompt: The prompt/instruction (Alpaca style). 40 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of 41 | `finetune_adapter.py`. 42 | input: Optional input (Alpaca style). 43 | pretrained_path: The path to the checkpoint with pretrained LLaMA weights. 44 | tokenizer_path: The tokenizer path to load. 45 | quantize: Whether to quantize the model and using which method: 46 | ``"llm.int8"``: LLM.int8() mode, 47 | ``"gptq.int4"``: GPTQ 4-bit mode. 48 | max_new_tokens: The number of generation steps to take. 49 | top_k: The number of top most probable tokens to consider in the sampling process. 50 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random 51 | samples. 52 | """ 53 | assert adapter_path.is_file() 54 | assert pretrained_path.is_file() 55 | assert tokenizer_path.is_file() 56 | 57 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" 58 | fabric = L.Fabric(devices=1, precision=precision) 59 | 60 | print("Loading model ...", file=sys.stderr) 61 | t0 = time.time() 62 | with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint: 63 | name = llama_model_lookup(pretrained_checkpoint) 64 | 65 | with fabric.init_module(empty_init=True), quantization(mode=quantize): 66 | model = LLaMA.from_name(name) 67 | 68 | # 1. Load the pretrained weights 69 | model.load_state_dict(pretrained_checkpoint, strict=False) 70 | # 2. Load the fine-tuned adapter weights 71 | model.load_state_dict(adapter_checkpoint, strict=False) 72 | 73 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 74 | 75 | model.eval() 76 | model = fabric.setup(model) 77 | 78 | tokenizer = Tokenizer(tokenizer_path) 79 | sample = {"instruction": prompt, "input": input} 80 | prompt = generate_prompt(sample) 81 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) 82 | prompt_length = encoded.size(0) 83 | 84 | t0 = time.perf_counter() 85 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) 86 | t = time.perf_counter() - t0 87 | 88 | model.reset_cache() 89 | output = tokenizer.decode(y) 90 | output = output.split("### Response:")[1].strip() 91 | print(output) 92 | 93 | tokens_generated = y.size(0) - prompt_length 94 | print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) 95 | if fabric.device.type == "cuda": 96 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) 97 | 98 | 99 | if __name__ == "__main__": 100 | from jsonargparse import CLI 101 | 102 | torch.set_float32_matmul_precision("high") 103 | warnings.filterwarnings( 104 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 105 | "ignore", 106 | message="ComplexHalf support is experimental and many operators don't support it yet" 107 | ) 108 | CLI(main) 109 | -------------------------------------------------------------------------------- /generate/adapter_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | import time 5 | import warnings 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import lightning as L 10 | import torch 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).parent.parent.resolve() 14 | sys.path.append(str(wd)) 15 | 16 | from generate import generate 17 | from lit_llama import Tokenizer 18 | from lit_llama.adapter import LLaMA 19 | from lit_llama.utils import lazy_load, llama_model_lookup, quantization 20 | from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers 21 | from scripts.prepare_alpaca import generate_prompt 22 | 23 | 24 | def main( 25 | prompt: str = "What food do lamas eat?", 26 | input: str = "", 27 | adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"), 28 | pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 29 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 30 | quantize: Optional[str] = None, 31 | max_new_tokens: int = 100, 32 | top_k: int = 200, 33 | temperature: float = 0.8, 34 | ) -> None: 35 | """Generates a response based on a given instruction and an optional input. 36 | This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model. 37 | See `finetune_adapter_v2.py`. 38 | 39 | Args: 40 | prompt: The prompt/instruction (Alpaca style). 41 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of 42 | `finetune_adapter_v2.py`. 43 | input: Optional input (Alpaca style). 44 | pretrained_path: The path to the checkpoint with pretrained LLaMA weights. 45 | tokenizer_path: The tokenizer path to load. 46 | quantize: Whether to quantize the model and using which method: 47 | ``"llm.int8"``: LLM.int8() mode, 48 | ``"gptq.int4"``: GPTQ 4-bit mode. 49 | max_new_tokens: The number of generation steps to take. 50 | top_k: The number of top most probable tokens to consider in the sampling process. 51 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random 52 | samples. 53 | """ 54 | assert adapter_path.is_file() 55 | assert pretrained_path.is_file() 56 | assert tokenizer_path.is_file() 57 | 58 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" 59 | fabric = L.Fabric(devices=1, precision=precision) 60 | 61 | print("Loading model ...", file=sys.stderr) 62 | t0 = time.time() 63 | with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint: 64 | name = llama_model_lookup(pretrained_checkpoint) 65 | 66 | with fabric.init_module(empty_init=True), quantization(mode=quantize): 67 | model = LLaMA.from_name(name) 68 | add_adapter_v2_parameters_to_linear_layers(model) 69 | 70 | # 1. Load the pretrained weights 71 | model.load_state_dict(pretrained_checkpoint, strict=False) 72 | # 2. Load the fine-tuned adapter weights 73 | model.load_state_dict(adapter_checkpoint, strict=False) 74 | 75 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 76 | 77 | model.eval() 78 | model = fabric.setup(model) 79 | 80 | tokenizer = Tokenizer(tokenizer_path) 81 | sample = {"instruction": prompt, "input": input} 82 | prompt = generate_prompt(sample) 83 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) 84 | prompt_length = encoded.size(0) 85 | 86 | t0 = time.perf_counter() 87 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) 88 | t = time.perf_counter() - t0 89 | 90 | model.reset_cache() 91 | output = tokenizer.decode(y) 92 | output = output.split("### Response:")[1].strip() 93 | print(output) 94 | 95 | tokens_generated = y.size(0) - prompt_length 96 | print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) 97 | if fabric.device.type == "cuda": 98 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) 99 | 100 | 101 | if __name__ == "__main__": 102 | from jsonargparse import CLI 103 | 104 | torch.set_float32_matmul_precision("high") 105 | warnings.filterwarnings( 106 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 107 | "ignore", 108 | message="ComplexHalf support is experimental and many operators don't support it yet" 109 | ) 110 | CLI(main) 111 | -------------------------------------------------------------------------------- /generate/full.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | import time 5 | import warnings 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import lightning as L 10 | import torch 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).absolute().parent.parent 14 | sys.path.append(str(wd)) 15 | 16 | from lit_llama import LLaMA, Tokenizer 17 | from lit_llama.utils import quantization 18 | from scripts.prepare_alpaca import generate_prompt 19 | from generate import generate 20 | 21 | 22 | def main( 23 | prompt: str = "Hello, my name is", 24 | *, 25 | num_samples: int = 1, 26 | max_new_tokens: int = 50, 27 | top_k: int = 200, 28 | temperature: float = 0.8, 29 | checkpoint_path: Optional[Path] = None, 30 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 31 | model_size: str = "7B", 32 | quantize: Optional[str] = None, 33 | ) -> None: 34 | """Generates text samples based on a pre-trained LLaMA model and tokenizer. 35 | 36 | Args: 37 | prompt: The prompt string to use for generating the samples. 38 | num_samples: The number of text samples to generate. 39 | max_new_tokens: The number of generation steps to take. 40 | top_k: The number of top most probable tokens to consider in the sampling process. 41 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random 42 | samples. 43 | checkpoint_path: The checkpoint path to load. 44 | tokenizer_path: The tokenizer path to load. 45 | model_size: The model size to load. 46 | quantize: Whether to quantize the model and using which method: 47 | ``"llm.int8"``: LLM.int8() mode, 48 | ``"gptq.int4"``: GPTQ 4-bit mode. 49 | """ 50 | if not checkpoint_path: 51 | checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth") 52 | assert checkpoint_path.is_file(), checkpoint_path 53 | assert tokenizer_path.is_file(), tokenizer_path 54 | 55 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" 56 | fabric = L.Fabric(devices=1, precision=precision) 57 | 58 | print("Loading model ...", file=sys.stderr) 59 | t0 = time.time() 60 | 61 | with fabric.init_module(empty_init=True), quantization(mode=quantize): 62 | model = LLaMA.from_name(model_size) 63 | 64 | checkpoint = torch.load(checkpoint_path) 65 | model.load_state_dict(checkpoint) 66 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 67 | 68 | model.eval() 69 | model = fabric.setup(model) 70 | 71 | tokenizer = Tokenizer(tokenizer_path) 72 | sample = {"instruction": prompt, "input": input} 73 | prompt = generate_prompt(sample) 74 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) 75 | prompt_length = encoded.size(0) 76 | 77 | L.seed_everything(1234) 78 | for i in range(num_samples): 79 | t0 = time.perf_counter() 80 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k) 81 | t = time.perf_counter() - t0 82 | 83 | model.reset_cache() 84 | print(tokenizer.decode(y)) 85 | tokens_generated = y.size(0) - prompt_length 86 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) 87 | if fabric.device.type == "cuda": 88 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) 89 | 90 | 91 | if __name__ == "__main__": 92 | from jsonargparse import CLI 93 | 94 | torch.set_float32_matmul_precision("high") 95 | warnings.filterwarnings( 96 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 97 | "ignore", 98 | message="ComplexHalf support is experimental and many operators don't support it yet" 99 | ) 100 | warnings.filterwarnings( 101 | # Triggered in bitsandbytes/autograd/_functions.py:298 102 | "ignore", 103 | message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", 104 | ) 105 | CLI(main) 106 | -------------------------------------------------------------------------------- /generate/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | import time 5 | import warnings 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import lightning as L 10 | import torch 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).parent.parent.resolve() 14 | sys.path.append(str(wd)) 15 | 16 | from generate import generate 17 | from lit_llama import Tokenizer, LLaMA 18 | from lit_llama.lora import lora 19 | from lit_llama.utils import lazy_load, llama_model_lookup 20 | from scripts.prepare_alpaca import generate_prompt 21 | 22 | lora_r = 8 23 | lora_alpha = 16 24 | lora_dropout = 0.05 25 | 26 | 27 | def main( 28 | prompt: str = "What food do lamas eat?", 29 | input: str = "", 30 | lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"), 31 | pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 32 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 33 | quantize: Optional[str] = None, 34 | max_new_tokens: int = 100, 35 | top_k: int = 200, 36 | temperature: float = 0.8, 37 | ) -> None: 38 | """Generates a response based on a given instruction and an optional input. 39 | This script will only work with checkpoints from the instruction-tuned LoRA model. 40 | See `finetune_lora.py`. 41 | 42 | Args: 43 | prompt: The prompt/instruction (Alpaca style). 44 | lora_path: Path to the checkpoint with trained LoRA weights, which are the output of 45 | `finetune_lora.py`. 46 | input: Optional input (Alpaca style). 47 | pretrained_path: The path to the checkpoint with pretrained LLaMA weights. 48 | tokenizer_path: The tokenizer path to load. 49 | quantize: Whether to quantize the model and using which method: 50 | ``"llm.int8"``: LLM.int8() mode, 51 | ``"gptq.int4"``: GPTQ 4-bit mode. 52 | max_new_tokens: The number of generation steps to take. 53 | top_k: The number of top most probable tokens to consider in the sampling process. 54 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random 55 | samples. 56 | """ 57 | assert lora_path.is_file() 58 | assert pretrained_path.is_file() 59 | assert tokenizer_path.is_file() 60 | 61 | if quantize is not None: 62 | raise NotImplementedError("Quantization in LoRA is not supported yet") 63 | 64 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true" 65 | fabric = L.Fabric(devices=1, precision=precision) 66 | 67 | print("Loading model ...", file=sys.stderr) 68 | t0 = time.time() 69 | 70 | with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint: 71 | name = llama_model_lookup(pretrained_checkpoint) 72 | 73 | with fabric.init_module(empty_init=True), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True): 74 | model = LLaMA.from_name(name) 75 | 76 | # 1. Load the pretrained weights 77 | model.load_state_dict(pretrained_checkpoint, strict=False) 78 | # 2. Load the fine-tuned lora weights 79 | model.load_state_dict(lora_checkpoint, strict=False) 80 | 81 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 82 | 83 | model.eval() 84 | model = fabric.setup(model) 85 | 86 | tokenizer = Tokenizer(tokenizer_path) 87 | sample = {"instruction": prompt, "input": input} 88 | prompt = generate_prompt(sample) 89 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) 90 | 91 | t0 = time.perf_counter() 92 | output = generate( 93 | model, 94 | idx=encoded, 95 | max_new_tokens=max_new_tokens, 96 | temperature=temperature, 97 | top_k=top_k, 98 | eos_id=tokenizer.eos_id 99 | ) 100 | t = time.perf_counter() - t0 101 | 102 | output = tokenizer.decode(output) 103 | output = output.split("### Response:")[1].strip() 104 | print(output) 105 | 106 | print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr) 107 | if fabric.device.type == "cuda": 108 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) 109 | 110 | 111 | if __name__ == "__main__": 112 | from jsonargparse import CLI 113 | 114 | torch.set_float32_matmul_precision("high") 115 | warnings.filterwarnings( 116 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 117 | "ignore", 118 | message="ComplexHalf support is experimental and many operators don't support it yet" 119 | ) 120 | CLI(main) 121 | -------------------------------------------------------------------------------- /howto/convert_lora_weights.md: -------------------------------------------------------------------------------- 1 | # Merging LoRA weights into base model weights 2 | 3 | Purpose: By merging our selected LoRA weights into the base model weights, we can benefit from all base model optimisation such as quantisation (available in this repo), pruning, caching, etc. 4 | 5 | 6 | ## How to run? 7 | 8 | After you have finish finetuning using LoRA, select your weight and run the converter script: 9 | 10 | ```bash 11 | python scripts/convert_lora_weights.py --lora_path out/lora/your-folder/your-weight-name.pth 12 | ``` 13 | 14 | The converted base weight file will be saved into the same folder with the name `{your-weight-name}-lora-merged-weights.pth`. Now you can run `generate.py` with the merged weights and apply quantisation: 15 | 16 | ```bash 17 | python generate.py --checkpoint_path out/lora/your-folder/your-weight-name-lora-merged-weights.pth --quantize llm.int8 18 | ``` 19 | 20 | -------------------------------------------------------------------------------- /howto/customize_paths.md: -------------------------------------------------------------------------------- 1 | ## Customize paths 2 | 3 | The project is setup to use specific paths to read the original weights and save checkpoints etc. 4 | 5 | For all scripts, you can run 6 | 7 | ```shell 8 | python script.py -h 9 | ``` 10 | 11 | to get a list of available options. For instance, here's how you would modify the checkpoint dir: 12 | 13 | ```shell 14 | python scripts/convert_checkpoint.py --checkpoint_dir "data/checkpoints/foo" 15 | ``` 16 | 17 | Note that this change will need to be passed along to subsequent steps, for example: 18 | 19 | ```shell 20 | python generate.py \ 21 | --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \ 22 | --tokenizer_path "data/checkpoints/foo/tokenizer.model" 23 | ``` 24 | 25 | and 26 | 27 | ```shell 28 | python quantize/gptq.py \ 29 | --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \ 30 | --tokenizer_path "data/checkpoints/foo/tokenizer.model" 31 | ``` 32 | 33 | To avoid this, you can use symbolic links to create shortcuts and avoid passing different paths. 34 | -------------------------------------------------------------------------------- /howto/download_weights.md: -------------------------------------------------------------------------------- 1 | ## Downloading pretrained weights 2 | 3 | Except for when you are training from scratch, you will need the pretrained weights from Meta. 4 | 5 | ### Original Meta weights 6 | 7 | Download the model weights following the instructions on the official [LLaMA repository](https://github.com/facebookresearch/llama). 8 | 9 | Once downloaded, you should have a folder like this: 10 | 11 | ```text 12 | checkpoints/llama 13 | ├── 7B 14 | │ ├── ... 15 | │ └── consolidated.00.pth 16 | ├── 13B 17 | │ ... 18 | └── tokenizer.model 19 | ``` 20 | 21 | Convert the weights to the Lit-LLaMA format: 22 | 23 | ```bash 24 | python scripts/convert_checkpoint.py --model_size 7B 25 | ``` 26 | 27 | > **Note** 28 | > All scripts support argument [customization](customize_paths.md) 29 | 30 | ### OpenLLaMA 31 | 32 | OpenLM Research has released **Apache 2.0 licensed** weights obtained by training LLaMA on the 1.2 trillion token open-source [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset. 33 | 34 | Weights were released in preview on intermediate number of tokens (1T at the time of writing). In order to get them do: 35 | 36 | ```bash 37 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install 38 | git clone https://huggingface.co/openlm-research/open_llama_7b checkpoints/open-llama/7B 39 | ``` 40 | 41 | Or if you don't have `git-lfs` installed: 42 | 43 | ```bash 44 | python scripts/download.py --repo_id openlm-research/open_llama_7b --local_dir checkpoints/open-llama/7B 45 | ``` 46 | 47 | Once downloaded, you should have a folder like this: 48 | 49 | ```text 50 | checkpoints/open-llama/ 51 | └── 7B 52 | ├── ... 53 | ├── pytorch_model-00001-of-00002.bin 54 | ├── pytorch_model-00002-of-00002.bin 55 | ├── pytorch_model.bin.index.json 56 | └── tokenizer.model 57 | ``` 58 | 59 | Convert the weights to the Lit-LLaMA format: 60 | 61 | ```bash 62 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B 63 | ``` 64 | 65 | > **Note** 66 | > All scripts support argument [customization](customize_paths.md) 67 | 68 | Once converted, you should have a folder like this: 69 | 70 | ```text 71 | checkpoints/lit-llama/ 72 | ├── 7B 73 | │ └── lit-llama.pth 74 | └── tokenizer.model 75 | ``` 76 | 77 | You are all set. Now you can continue with inference or finetuning. 78 | 79 | Try running [`generate.py` to test the imported weights](inference.md). 80 | 81 | 82 | ### Alternative sources 83 | 84 | You might find LLaMA weights hosted online in the HuggingFace hub. Beware that this infringes the original weight's license. 85 | You could try downloading them by running the following command with a specific repo id: 86 | 87 | ```bash 88 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install 89 | git clone REPO_ID checkpoints/hf-llama/7B 90 | ``` 91 | 92 | Or if you don't have `git-lfs` installed: 93 | 94 | ```bash 95 | python scripts/download.py --repo_id REPO_ID --local_dir checkpoints/hf-llama/7B 96 | ``` 97 | 98 | Once downloaded, you should have a folder like this: 99 | 100 | ```text 101 | checkpoints/hf-llama/ 102 | └── 7B 103 | ├── ... 104 | ├── pytorch_model-00001-of-00002.bin 105 | ├── pytorch_model-00002-of-00002.bin 106 | ├── pytorch_model.bin.index.json 107 | └── tokenizer.model 108 | ``` 109 | 110 | Convert the weights to the Lit-LLaMA format: 111 | 112 | ```bash 113 | python scripts/convert_hf_checkpoint.py --model_size 7B 114 | ``` 115 | 116 | > **Note** 117 | > All scripts support argument [customization](customize_paths.md) 118 | 119 | Once converted, you should have a folder like this: 120 | 121 | ```text 122 | checkpoints/lit-llama/ 123 | ├── 7B 124 | │ └── lit-llama.pth 125 | └── tokenizer.model 126 | ``` 127 | 128 | You are all set. Now you can continue with inference or finetuning. 129 | 130 | Try running [`generate.py` to test the imported weights](inference.md). 131 | -------------------------------------------------------------------------------- /howto/finetune_adapter.md: -------------------------------------------------------------------------------- 1 | # Finetuning with Adapter 2 | 3 | [LLaMA-Adapter](https://arxiv.org/abs/2303.16199) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only 1.2M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training. 4 | 5 | We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour. 6 | 7 | If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful. 8 | 9 | ## LLaMA-Adapter v2 10 | 11 | The LLaMA-Adapter authors developed a newer adapter method called LLaMA-Adapter v2, which is related to this LLaMA-Adapter method but includes more trainable parameters. LLaMA-Adapter v2 is also available via Lit-LLaMA; you can read more about it in [the related how-to doc here](./finetune_adapter_v2.md). 12 | 13 | ## Preparation 14 | 15 | The steps here only need to be done once: 16 | 17 | 1. Follow the instructions in the [README](README.md) to install the dependencies. 18 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md). 19 | 3. If you want to utilize more than one GPU, you should `pip install deepspeed`. 20 | 4. Download the data and generate the Alpaca instruction tuning dataset: 21 | 22 | ```bash 23 | python scripts/prepare_alpaca.py 24 | ``` 25 | 26 | or [prepare your own dataset](#tune-on-your-dataset). 27 | 28 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md) 29 | 30 | ## Running the finetuning 31 | 32 | ```bash 33 | python finetune/adapter.py 34 | ``` 35 | 36 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). 37 | You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available. 38 | Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently. 39 | 40 | For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2: 41 | 42 | ```python 43 | devices = 8 44 | micro_batch_size = 8 45 | ``` 46 | 47 | This script will save checkpoints periodically to the folder `out/`. 48 | 49 | > **Note** 50 | > All scripts support argument [customization](customize_paths.md) 51 | 52 | ## Test the model 53 | 54 | You can test the finetuned model with your own instructions by running: 55 | 56 | ```bash 57 | python generate/adapter.py \ 58 | --prompt "Recommend a movie to watch on the weekend." \ 59 | --quantize llm.int8 60 | ``` 61 | Output: 62 | ``` 63 | A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy... 64 | ``` 65 | If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB. 66 | 67 | ## Tune on your dataset 68 | 69 | With only a few modifications, you can prepare and train on your own instruction dataset. 70 | 71 | 1. Create a json file in which each row holds one instruction-response pair. 72 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be 73 | the empty string if the instruction doesn't require a context. Below is an example json file: 74 | 75 | ``` 76 | [ 77 | { 78 | "instruction": "Arrange the given numbers in ascending order.", 79 | "input": "2, 4, 0, 8, 3", 80 | "output": "0, 2, 3, 4, 8" 81 | }, 82 | ... 83 | ] 84 | ``` 85 | 86 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want: 87 | 88 | ```bash 89 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py 90 | ``` 91 | 92 | 3. Modify `scripts/prepare_mydata.py` to read the json data file. 93 | 4. Run the script to generate the preprocessed, tokenized train-val split: 94 | 95 | ```bash 96 | python scripts/prepare_mydata.py --destination_path data/mydata/ 97 | ``` 98 | 99 | 5. Run `finetune/adapter.py` by passing in the location of your data (and optionally other parameters): 100 | 101 | ```bash 102 | python finetune/adapter.py --data_dir data/mydata/ --out_dir out/myexperiment 103 | ``` 104 | 105 | 106 | ## Troubleshooting 107 | 108 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line 109 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). 110 | -------------------------------------------------------------------------------- /howto/finetune_adapter_v2.md: -------------------------------------------------------------------------------- 1 | # Finetuning with Adapter v2 2 | 3 | [LLaMA-Adapter v2](https://arxiv.org/abs/2304.15010) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only ~4 M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training. 4 | 5 | We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour. 6 | 7 | If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful. 8 | 9 | ## LLaMA-Adapter v1 versus LLaMA-Adapter v2 10 | 11 | LLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the 7B LLaMA model has 1.2M trainable parameters with LLaMA v1, LLaMA-Adapter v2 adds 2.8 M trainable parameters for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~4.3 M trainable parameters in total. 12 | 13 | If you are interested in using the more lightweight LLaMA-Adapter v1 approach, see [the related LLaMA Adapter how-to doc here](./finetune_adapter.md). 14 | 15 | While LLaMA-Adapter v2 increases the number of trainable parameters from 1.2 M (from LLaMA-Apdapter v1) to 4.3 M, the inference cost is not significantly impacted. This is because the additional bias and scale parameters are cheap to compute in the forward pass, and the RMSNorm parameters are already included in the base model. In LLaMA-Adapter v1, the RMSNorm parameters are not trainable. 16 | 17 | 18 | ## Preparation 19 | 20 | The steps here only need to be done once: 21 | 22 | 1. Follow the instructions in the [README](README.md) to install the dependencies. 23 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md). 24 | 3. If you want to utilize more than one GPU, you should `pip install deepspeed`. 25 | 4. Download the data and generate the Alpaca instruction tuning dataset: 26 | 27 | ```bash 28 | python scripts/prepare_alpaca.py 29 | ``` 30 | 31 | or [prepare your own dataset](#tune-on-your-dataset). 32 | 33 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md) 34 | 35 | ## Running the finetuning 36 | 37 | ```bash 38 | python finetune/adapter_v2.py 39 | ``` 40 | 41 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). 42 | You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available. 43 | Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently. 44 | 45 | For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2: 46 | 47 | ```python 48 | devices = 8 49 | micro_batch_size = 8 50 | ``` 51 | 52 | This script will save checkpoints periodically to the folder `out/`. 53 | 54 | > **Note** 55 | > All scripts support argument [customization](customize_paths.md) 56 | 57 | ## Test the model 58 | 59 | You can test the finetuned model with your own instructions by running: 60 | 61 | ```bash 62 | python generate/adapter_v2.py \ 63 | --prompt "Recommend a movie to watch on the weekend." \ 64 | --quantize llm.int8 65 | ``` 66 | Output: 67 | ``` 68 | A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy... 69 | ``` 70 | If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB. 71 | 72 | ## Tune on your dataset 73 | 74 | With only a few modifications, you can prepare and train on your own instruction dataset. 75 | 76 | 1. Create a json file in which each row holds one instruction-response pair. 77 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be 78 | the empty string if the instruction doesn't require a context. Below is an example json file: 79 | 80 | ``` 81 | [ 82 | { 83 | "instruction": "Arrange the given numbers in ascending order.", 84 | "input": "2, 4, 0, 8, 3", 85 | "output": "0, 2, 3, 4, 8" 86 | }, 87 | ... 88 | ] 89 | ``` 90 | 91 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want: 92 | 93 | ```bash 94 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py 95 | ``` 96 | 97 | 3. Modify `scripts/prepare_mydata.py` to read the json data file. 98 | 4. Run the script to generate the preprocessed, tokenized train-val split: 99 | 100 | ```bash 101 | python scripts/prepare_mydata.py --destination_path data/mydata/ 102 | ``` 103 | 104 | 5. Run `finetune/adapter_v2.py` by passing in the location of your data (and optionally other parameters): 105 | 106 | ```bash 107 | python finetune/adapter_v2.py --data_dir data/mydata/ --out_dir out/myexperiment 108 | ``` 109 | 110 | 111 | ## Troubleshooting 112 | 113 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line 114 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). 115 | -------------------------------------------------------------------------------- /howto/finetune_full.md: -------------------------------------------------------------------------------- 1 | # Full Finetuning 2 | 3 | Full finetuning updates all layers in the pretrained LLaMA model. This *regular* finetuning procedure is typically considered as the baseline for parameter-efficient alternatives such as Low-Rank Adaptation (LoRA) or LLaMA-Adapter. 4 | 5 | The current [finetune/full.py](../finetune/full.py) we provide uses 4 A100 GPUs with a fully-sharded data parallel strategy to finetune Lit-LLaMA 7B on [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. The A100 GPUs have 40 GB each, but it may require less memory to finetune this model. 6 | 7 | 8 | 9 | ## Preparation 10 | 11 | The steps here only need to be done once: 12 | 13 | 1. Follow the instructions in the [README](README.md) to install the dependencies. 14 | 15 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md). 16 | 17 | 4. Download the data and generate the Alpaca instruction tuning dataset: 18 | 19 | ```bash 20 | python scripts/prepare_alpaca.py 21 | ``` 22 | 23 | or [prepare your own dataset](#tune-on-your-own-dataset). 24 | 25 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md) 26 | 27 | ## Running the finetuning 28 | 29 | ```bash 30 | python finetune/full.py 31 | ``` 32 | 33 | 34 | You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available or increase the `batch_size`. 35 | Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently. 36 | 37 | For example, the following settings will let you finetune the model in 32 hours using a fully-sharded data parallel strategy: 38 | ```python 39 | devices = 4 40 | batch_size = 128 // devices 41 | micro_batch_size = 4 42 | ``` 43 | 44 | This script will save checkpoints periodically to the folder `out/`. 45 | 46 | > **Note** 47 | > All scripts support argument [customization](customize_paths.md) 48 | 49 | ## Test the model 50 | 51 | You can test the finetuned model with your own instructions by running: 52 | 53 | ```bash 54 | python generate/full.py \ 55 | --prompt "Recommend a movie to watch on the weekend." \ 56 | --quantize llm.int8 57 | ``` 58 | Output: 59 | ``` 60 | A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy... 61 | ``` 62 | If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB. 63 | 64 | ## Tune on your dataset 65 | 66 | With only a few modifications, you can prepare and train on your own instruction dataset. 67 | 68 | 1. Create a json file in which each row holds one instruction-response pair. 69 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be 70 | the empty string if the instruction doesn't require a context. Below is an example json file: 71 | 72 | ``` 73 | [ 74 | { 75 | "instruction": "Arrange the given numbers in ascending order.", 76 | "input": "2, 4, 0, 8, 3", 77 | "output": "0, 2, 3, 4, 8" 78 | }, 79 | ... 80 | ] 81 | ``` 82 | 83 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want: 84 | 85 | ```bash 86 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py 87 | ``` 88 | 89 | 3. Modify `scripts/prepare_mydata.py` to read the json data file. 90 | 4. Run the script to generate the preprocessed, tokenized train-val split: 91 | 92 | ```bash 93 | python scripts/prepare_mydata.py --destination_path data/mydata/ 94 | ``` 95 | 96 | 5. Run `finetune/full.py` by passing in the location of your data (and optionally other parameters): 97 | 98 | ```bash 99 | python finetune/full.py --data_dir data/mydata/ --out_dir out/myexperiment 100 | ``` 101 | 102 | 103 | ## Troubleshooting 104 | 105 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line 106 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). 107 | -------------------------------------------------------------------------------- /howto/finetune_lora.md: -------------------------------------------------------------------------------- 1 | # Finetuning with LoRA 2 | 3 | [Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model. 4 | We demonstrate this method by instruction-finetuning LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. 5 | 6 | ## Preparation 7 | 8 | The steps here only need to be done once: 9 | 10 | 1. Follow the instructions in the [README](../README.md) to install the dependencies. 11 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md). 12 | 3. Download the data and generate the instruction tuning dataset: 13 | 14 | ```bash 15 | python scripts/prepare_alpaca.py 16 | ``` 17 | 18 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md) 19 | 20 | ## Running the finetuning 21 | 22 | ```bash 23 | python finetune/lora.py 24 | ``` 25 | 26 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). 27 | 28 | This script will save checkpoints periodically to the folder `out/`. 29 | 30 | > **Note** 31 | > All scripts support argument [customization](customize_paths.md) 32 | 33 | 34 | ## Test the model 35 | 36 | You can test the finetuned model with your own instructions by running: 37 | 38 | ```bash 39 | python generate/lora.py --prompt "Recommend a movie to watch on the weekend." 40 | ``` 41 | Output: 42 | ``` 43 | I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of... 44 | ``` 45 | 46 | If your GPU supports `bfloat16`, you can additionally pass `--dtype bfloat16` to bring the memory consumption down to ~14 GB. 47 | 48 | ## Tune on your dataset 49 | 50 | With only a few modifications, you can prepare and train on your own instruction dataset. 51 | 52 | 1. Create a json file in which each row holds one instruction-response pair. 53 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be 54 | the empty string if the instruction doesn't require a context. Below is an example json file: 55 | 56 | ``` 57 | [ 58 | { 59 | "instruction": "Arrange the given numbers in ascending order.", 60 | "input": "2, 4, 0, 8, 3", 61 | "output": "0, 2, 3, 4, 8" 62 | }, 63 | ... 64 | ] 65 | ``` 66 | 67 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want: 68 | 69 | ```bash 70 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py 71 | ``` 72 | 73 | 3. Modify `scripts/prepare_mydata.py` to read the json data file. 74 | 4. Run the script to generate the preprocessed, tokenized train-val split: 75 | 76 | ```bash 77 | python scripts/prepare_mydata.py --destination_path data/mydata/ 78 | ``` 79 | 80 | 5. Run `finetune/lora.py` by passing in the location of your data (and optionally other parameters): 81 | 82 | ```bash 83 | python finetune/lora.py --data_dir data/mydata/ --out_dir out/myexperiment 84 | ``` 85 | 86 | 87 | ## Troubleshooting 88 | 89 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line 90 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). 91 | -------------------------------------------------------------------------------- /howto/inference.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | We demonstrate how to run inference (next token prediction) with the LLaMA base model in the [`generate.py`](generate.py) script: 4 | 5 | ```bash 6 | python generate.py --prompt "Hello, my name is" 7 | ``` 8 | Output: 9 | ``` 10 | Hello my name is TJ. I have a passion for the outdoors, love hiking and exploring. I also enjoy traveling and learning new things. I especially enjoy long walks, good conversation and a friendly smile. 11 | ``` 12 | 13 | The script assumes you have downloaded and converted the weights and saved them in the `./checkpoints` folder as described [here](download_weights.md). 14 | 15 | > **Note** 16 | > All scripts support argument [customization](customize_paths.md) 17 | 18 | With the default settings, this will run the 7B model and require ~26 GB of GPU memory (A100 GPU). 19 | 20 | ## Run Lit-LLaMA on consumer devices 21 | 22 | On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB. 23 | For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`): 24 | 25 | ```bash 26 | python generate.py --quantize llm.int8 --prompt "Hello, my name is" 27 | ``` 28 | This will consume about ~10 GB of GPU memory or ~8 GB if also using `bfloat16`. 29 | See `python generate.py --help` for more options. 30 | 31 | You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first: 32 | 33 | ```bash 34 | python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4 35 | ``` 36 | 37 | GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled. 38 | 39 | With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file: 40 | 41 | ```bash 42 | python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth 43 | ``` 44 | -------------------------------------------------------------------------------- /howto/tpus.md: -------------------------------------------------------------------------------- 1 | # TPU support 2 | 3 | Lit-LLaMA used `lightning.Fabric` under the hood, which itself supports TPUs (via [PyTorch XLA](https://github.com/pytorch/xla)). 4 | 5 | The following commands will allow you to set up a `Google Cloud` instance with a [TPU v4](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) VM: 6 | 7 | ```shell 8 | gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b 9 | gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b 10 | ``` 11 | 12 | Now that you are in the machine, let's clone the repository and install the dependencies 13 | 14 | ```shell 15 | git clone https://github.com/Lightning-AI/lit-llama 16 | cd lit-llama 17 | pip install -e ".[all]" 18 | ``` 19 | 20 | By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables 21 | 22 | ```shell 23 | export PJRT_DEVICE=TPU 24 | export ALLOW_MULTIPLE_LIBTPU_LOAD=1 25 | ``` 26 | 27 | > **Note** 28 | > You can find an extensive guide on how to get set-up and all the available options [here](https://cloud.google.com/tpu/docs/v4-users-guide). 29 | 30 | Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with `gcloud compute tpus tpu-vm scp` or you can follow the steps described in our [downloading guide](download_weights.md). 31 | 32 | ## Inference 33 | 34 | Generation works out-of-the-box with TPUs: 35 | 36 | ```shell 37 | python3 generate.py --prompt "Hello, my name is" --num_samples 3 38 | ``` 39 | 40 | This command will take take ~20s for the first generation time as XLA needs to compile the graph. 41 | You'll notice that afterwards, generation times drop to ~5s. 42 | 43 | ## Finetuning 44 | 45 | Coming soon. 46 | 47 | > **Warning** 48 | > When you are done, remember to delete your instance 49 | > ```shell 50 | > gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b 51 | > ``` -------------------------------------------------------------------------------- /howto/train_redpajama.md: -------------------------------------------------------------------------------- 1 | # Pre-train LLaMA on RedPajama 2 | 3 | This howto will walk you through setting up the RedPajama dataset and launching the pre-training script. 4 | 5 | ## What's RedPajama 6 | 7 | [RedPajama](https://github.com/togethercomputer/RedPajama-Data) is an open-source reproduction of the original LLaMA training dataset. 8 | 9 | It contains a total of 1.2 trillion tokens, divided into 10 | 11 | ```text 12 | Commoncrawl 878B 13 | C4 175B 14 | GitHub 59B 15 | Books 26B 16 | ArXiv 28B 17 | Wikipedia 24B 18 | StackExchange 20B 19 | ``` 20 | 21 | The [RedPajama repo](https://github.com/togethercomputer/RedPajama-Data) contains the source code for collecting and preparing 22 | the dataset, and it is Apache 2.0 licensed. 23 | 24 | The data itself is licensed according to the original licenses with which its invidivdual parts were released. 25 | The GitHub datasets are limited to MIT, BSD, or Apache 2.0 repositories. 26 | 27 | Along with the full [RedPajama-1T dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T), 28 | the [RedPajama-1T-Sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) 1B sample dataset 29 | is also available for development. 30 | 31 | You can download the data using git lfs: 32 | 33 | ```bash 34 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install 35 | git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T data/RedPajama-Data-1T 36 | ``` 37 | 38 | ```bash 39 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install 40 | git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample data/RedPajama-Data-1T-Sample 41 | ``` 42 | 43 | ## Prepare RedPajama for training 44 | 45 | The dataset consists of 2084 `jsonl` files (the sample dataset contains 11). In order to start pre-training lit-llama 46 | on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `PackedDataset` 47 | streaming dataset that comes with lit-llama. 48 | 49 | Do to so, run 50 | 51 | ```bash 52 | python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama 53 | ``` 54 | 55 | or 56 | 57 | ```bash 58 | python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T-Sample --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama-sample --sample True 59 | ``` 60 | 61 | for the sample dataset. 62 | 63 | In the above we are assuming that you will be using the same tokenizer as used in LLaMA, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here. 64 | 65 | The script will take a while to run, so time for :tea: 66 | 67 | ## Pre-training 68 | 69 | Running the pre-training script requires at least 4 GPUs with 40GB+ each (A100). 70 | 71 | ```bash 72 | python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama 73 | ``` 74 | 75 | For running on the sample dataset: 76 | 77 | ```bash 78 | python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama-sample 79 | ``` 80 | 81 | The script will save checkpoints periodically to the folder `out/`. 82 | 83 | The `train_redpajama.py` script will pre-train the LLaMA 7B model with FSDP in 84 | `bfloat16` precision and gradient accumulation. 85 | 86 | You can easily change the size of the model by passing a different string to 87 | 88 | ```python 89 | config = LLaMAConfig.from_name("7B") 90 | ``` 91 | 92 | in the `main` function. 93 | 94 | Keep in mind that the original LLaMA training for the 7B model required 83k A100 80GB 95 | hours, so you'll need access to a cluster. 96 | 97 | Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html) 98 | to launch the script across machines: 99 | 100 | - [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) 101 | - [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html) 102 | - [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html) 103 | 104 | The script contains several configurations and hyperparameters you can tweak: 105 | 106 | ```python 107 | out_dir = "out/training" 108 | save_interval = 1000 109 | eval_interval = 1000 110 | eval_iters = 100 111 | log_interval = 1 112 | 113 | # Hyperparameters 114 | learning_rate = 6e-4 115 | batch_size = 125 116 | micro_batch_size = 5 117 | max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices 118 | weight_decay = 1e-1 119 | beta1 = 0.9 120 | beta2 = 0.95 121 | grad_clip = 1.0 122 | decay_lr = True 123 | warmup_iters = 2000 124 | lr_decay_iters = max_iters 125 | min_lr = 6e-5 126 | ``` 127 | 128 | In particular, `micro_batch_size` should be adjusted so the process will use the available 129 | GPU memory. 130 | 131 | Last, logging is kept minimal in the script. In order to use a particular logger 132 | please refer to or 133 | call a logging client library like `wandb` directly. 134 | -------------------------------------------------------------------------------- /howto/unstructured_dataset.md: -------------------------------------------------------------------------------- 1 | # Finetuning on an unstructured dataset 2 | 3 | While most scripts were made to finetune on instruction datasets, it is possible to finetune on any dataset. This is useful for experimentation while not being as expensive as training a full model. 4 | 5 | This guide is only to prepare the finetuning, as either LoRA or Adapter-v1 methods support this dataset type! 6 | 7 | ## Preparation 8 | 9 | 1. Gather your text into an input file named `input.txt` 10 | 2. Divide the data into training and validation sets using the following script: 11 | 12 | ```bash 13 | python scripts/prepare_any_text.py 14 | ``` 15 | 16 | 3. Modify relevant scripts for your finetuning method under `finetune/` and `evaluate/`, setting the `instruction_tuning` variable to `False` 17 | 18 | And then you're set! Proceed to run the [LoRA guide](./finetune_lora.md) or [Adapter v1 guide](./finetune_adapter.md). -------------------------------------------------------------------------------- /lit_llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope 4 | from lit_llama.tokenizer import Tokenizer 5 | -------------------------------------------------------------------------------- /lit_llama/adapter_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import torch 4 | from torch import Tensor 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from lit_llama.adapter import LLaMA 9 | 10 | 11 | def get_adapter_substrings(): 12 | substrings = ["adapter_wte", "gating_factor"] # regular adapter v1 parameters 13 | substrings.extend(["adapter_scale", "adapter_bias"]) # adapter v2: new bias and scale used in Linear 14 | substrings.extend(["rms_1", "rms_2", "ln_f"]) # adapter v2: RMSNorm parameters are now trainable 15 | return substrings 16 | 17 | 18 | def mark_only_adapter_v2_as_trainable(model: LLaMA) -> None: 19 | """Sets `requires_grad=False` for all non-adapter weights.""" 20 | for name, param in model.named_parameters(): 21 | param.requires_grad = any(s in name for s in get_adapter_substrings()) 22 | 23 | 24 | def adapter_v2_state_from_state_dict(state_dict: dict) -> dict: 25 | """Returns the model state dict with only the adapter weights for saving.""" 26 | return {name: param for name, param in state_dict.items() 27 | if any(s in name for s in get_adapter_substrings())} 28 | 29 | 30 | def adapter_v2_new_forward(self, input: Tensor) -> Tensor: 31 | return self.adapter_scale * ( 32 | F.linear(input, self.weight, self.bias) + self.adapter_bias 33 | ) 34 | 35 | 36 | def adapter_v2_linear_with_bias_and_scale(layer): 37 | layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True) 38 | layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True) 39 | bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__) 40 | setattr(layer, 'forward', bound_method) 41 | return layer 42 | 43 | 44 | def add_adapter_v2_parameters_to_linear_layers(model): 45 | for module in model.modules(): 46 | if isinstance(module, nn.Linear): 47 | adapter_v2_linear_with_bias_and_scale(module) 48 | -------------------------------------------------------------------------------- /lit_llama/packed_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # Very loosely inspired by indexed_dataset in Fairseq, Megatron 4 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py 5 | 6 | 7 | import os 8 | import struct 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import IterableDataset, get_worker_info 14 | 15 | 16 | dtypes = { 17 | 1: np.uint8, 18 | 2: np.int8, 19 | 3: np.int16, 20 | 4: np.int32, 21 | 5: np.int64, 22 | 6: np.float32, 23 | 7: np.float64, 24 | 8: np.uint16, 25 | } 26 | 27 | 28 | def code(dtype): 29 | for k in dtypes.keys(): 30 | if dtypes[k] == dtype: 31 | return k 32 | raise ValueError(dtype) 33 | 34 | 35 | HDR_MAGIC = b"LITPKDS" 36 | HDR_SIZE = 24 # bytes 37 | 38 | 39 | class PackedDataset(IterableDataset): 40 | def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0): 41 | self._filenames = filenames 42 | self._n_chunks = n_chunks 43 | self._block_size = block_size 44 | self._seed = seed 45 | self._shuffle = shuffle 46 | self._wrap = wrap 47 | self._num_processes = num_processes 48 | self._process_rank = process_rank 49 | 50 | def __iter__(self): 51 | worker_info = get_worker_info() 52 | num_workers = worker_info.num_workers if worker_info is not None else 1 53 | worker_id = worker_info.id if worker_info is not None else 0 54 | num_shards = num_workers * self._num_processes 55 | shard_id = self._process_rank * num_workers + worker_id 56 | 57 | max_num_files = len(self._filenames) // num_shards * num_shards 58 | filenames = self._filenames[shard_id : max_num_files : num_shards] 59 | 60 | return PackedDatasetIterator( 61 | filenames=filenames, 62 | n_chunks=self._n_chunks, 63 | block_size=self._block_size, 64 | seed=self._seed, 65 | shuffle=self._shuffle, 66 | wrap=self._wrap, 67 | ) 68 | 69 | 70 | class PackedDatasetBuilder(object): 71 | def __init__( 72 | self, 73 | outdir, 74 | prefix, 75 | chunk_size, 76 | sep_token, 77 | dtype="auto", 78 | vocab_size=None, 79 | ): 80 | if dtype == "auto": 81 | if vocab_size is None: 82 | raise ValueError("vocab_size cannot be None when dtype='auto'") 83 | if vocab_size is not None and vocab_size < 65500: 84 | self._dtype = np.uint16 85 | else: 86 | self._dtype = np.int32 87 | else: 88 | self._dtype = dtype 89 | self._counter = 0 90 | self._chunk_size = chunk_size 91 | self._outdir = outdir 92 | self._prefix = prefix 93 | self._sep_token = sep_token 94 | self._arr = np.zeros(self._chunk_size, dtype=self._dtype) 95 | self._arr.fill(self._sep_token) 96 | self._idx = 0 97 | self._version = 1 98 | self._filenames = [] 99 | 100 | def _write_chunk(self): 101 | filename = f"{self._prefix}_{self._counter:010d}.bin" 102 | filename = os.path.join(self._outdir, filename) 103 | 104 | with open(filename, "wb") as f: 105 | f.write(HDR_MAGIC) 106 | f.write(struct.pack(" self._chunk_size: 126 | part_len = self._chunk_size - self._idx 127 | self._arr[self._idx : self._idx + part_len] = arr[:part_len] 128 | self._write_chunk() 129 | arr = arr[part_len:] 130 | 131 | arr_len = arr.shape[0] 132 | self._arr[self._idx : self._idx + arr_len] = arr 133 | self._idx += arr_len 134 | 135 | def write_reminder(self): 136 | self._write_chunk() 137 | 138 | 139 | class PackedDatasetIterator: 140 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): 141 | self._seed = seed 142 | self._shuffle = shuffle 143 | self._rng = np.random.default_rng(seed) if shuffle else None 144 | self._block_idxs = None 145 | 146 | self._wrap = wrap 147 | 148 | # TODO: instead of filenames, we could have a single text stream 149 | # (or text file) with the sequence of all files to be 150 | # fetched/loaded. 151 | self._filenames = filenames 152 | self._file_idx = 0 153 | 154 | self._n_chunks = n_chunks 155 | 156 | self._dtype = None 157 | self._block_size = block_size 158 | self._n_blocks = None 159 | 160 | self._mmaps = [] 161 | self._buffers = [] 162 | 163 | self._block_idxs = [] 164 | self._curr_idx = 0 165 | 166 | self._load_n_chunks() 167 | 168 | def _read_header(self, path): 169 | with open(path, "rb") as f: 170 | magic = f.read(len(HDR_MAGIC)) 171 | assert magic == HDR_MAGIC, "File doesn't match expected format." 172 | version = struct.unpack(" len(self._filenames[self._file_idx:]): 189 | if not self._wrap: 190 | raise StopIteration 191 | else: 192 | self._file_idx = 0 193 | 194 | for i in range(self._n_chunks): 195 | filename = self._filenames[self._file_idx + i] 196 | if self._dtype is None: 197 | self._dtype, self._chunk_size = self._read_header( 198 | filename 199 | ) 200 | self._n_blocks = self._chunk_size // self._block_size 201 | # TODO: check header matches with previous files 202 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 203 | self._mmaps.append(mmap) 204 | self._buffers.append(memoryview(mmap)) 205 | 206 | self._file_idx += self._n_chunks 207 | n_all_blocks = self._n_chunks * self._n_blocks 208 | 209 | self._block_idxs = ( 210 | self._rng.permutation(n_all_blocks) 211 | if self._shuffle 212 | else range(n_all_blocks) 213 | ) 214 | 215 | self._curr_idx = 0 216 | 217 | def __del__(self): 218 | self._close_mmaps() 219 | del self._mmaps 220 | del self._buffers 221 | 222 | def __iter__(self): 223 | return self 224 | 225 | def __next__(self): 226 | if self._curr_idx >= len(self._block_idxs): 227 | self._load_n_chunks() 228 | # TODO: trigger fetching next next n_chunks if remote 229 | block_idx = self._block_idxs[self._curr_idx] 230 | chunk_id = block_idx // self._n_blocks 231 | buffer = self._buffers[chunk_id] 232 | elem_id = (block_idx % self._n_blocks) * self._block_size 233 | offset = np.dtype(self._dtype).itemsize * elem_id 234 | arr = np.frombuffer( 235 | buffer, dtype=self._dtype, count=self._block_size, offset=offset 236 | ) 237 | self._curr_idx += 1 238 | return torch.from_numpy(arr.astype(np.int64)) 239 | 240 | 241 | class CombinedDataset(IterableDataset): 242 | def __init__(self, datasets, seed, weights=None): 243 | self._seed = seed 244 | self._datasets = datasets 245 | self._weights = weights 246 | n_datasets = len(datasets) 247 | if weights is None: 248 | self._weights = [1 / n_datasets] * n_datasets 249 | 250 | def __iter__(self): 251 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights) 252 | 253 | 254 | class CombinedDatasetIterator: 255 | def __init__(self, datasets, seed, weights): 256 | self._datasets = [iter(el) for el in datasets] 257 | self._weights = weights 258 | self._rng = random.Random(seed) 259 | 260 | def __next__(self): 261 | dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1) 262 | return next(dataset) 263 | -------------------------------------------------------------------------------- /lit_llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import os 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import torch 8 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer 9 | 10 | 11 | class Tokenizer: 12 | """Tokenizer for LLaMA.""" 13 | 14 | def __init__(self, model_path: Path) -> None: 15 | self.processor = SentencePieceProcessor(model_file=str(model_path)) 16 | self.bos_id = self.processor.bos_id() 17 | self.eos_id = self.processor.eos_id() 18 | self.pad_id = self.processor.pad_id() 19 | 20 | @property 21 | def vocab_size(self) -> int: 22 | return self.processor.vocab_size() 23 | 24 | def encode( 25 | self, 26 | string: str, 27 | bos: bool = True, 28 | eos: bool = False, 29 | max_length: int = -1, 30 | pad: bool = False, 31 | device: Optional[torch.device] = None 32 | ) -> torch.Tensor: 33 | tokens = self.processor.encode(string) 34 | if bos: 35 | tokens = [self.bos_id] + tokens 36 | if eos: 37 | tokens = tokens + [self.eos_id] 38 | if max_length > 0: 39 | tokens = tokens[:max_length] 40 | if pad and len(tokens) < max_length: 41 | tokens += [self.pad_id] * (max_length - len(tokens)) 42 | 43 | return torch.tensor(tokens, dtype=torch.int, device=device) 44 | 45 | def decode(self, tokens: torch.Tensor) -> str: 46 | return self.processor.decode(tokens.tolist()) 47 | 48 | @staticmethod 49 | def train(input: str, destination: str, vocab_size=32000) -> None: 50 | model_prefix = os.path.join(destination, "tokenizer") 51 | SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size) 52 | -------------------------------------------------------------------------------- /pretrain/shakespeare.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | """ 4 | This script is a placeholder for training LLaMA from scratch. 5 | Currently, it just trains on the Shakespeare dataset. 6 | """ 7 | from pathlib import Path 8 | import sys 9 | import os 10 | import time 11 | from functools import partial 12 | from typing import Tuple 13 | 14 | import lightning as L 15 | from lightning.fabric.strategies import FSDPStrategy 16 | 17 | import torch 18 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 19 | 20 | import numpy as np 21 | 22 | # support running without installing as a package 23 | wd = Path(__file__).parent.parent.resolve() 24 | sys.path.append(str(wd)) 25 | 26 | from lit_llama.model import Block, LLaMA, LLaMAConfig 27 | from lit_llama.utils import save_model_checkpoint 28 | 29 | 30 | out_dir = "out/training" 31 | eval_interval = 2000 32 | eval_iters = 200 33 | log_interval = 1 34 | # compilation fails as it does not support torch.complex64 for RoPE 35 | # compile = False 36 | 37 | # Hyperparameters 38 | learning_rate = 6e-4 39 | batch_size = 2 40 | max_iters = 600000 41 | weight_decay = 1e-1 42 | beta1 = 0.9 43 | beta2 = 0.95 44 | grad_clip = 1.0 45 | 46 | # For shakespeare, choose smaller block size than vanilla LLaMA 47 | block_size = 1024 48 | 49 | 50 | def main() -> None: 51 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 52 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True) 53 | 54 | fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy) 55 | fabric.launch() 56 | fabric.seed_everything(1337 + fabric.global_rank) 57 | 58 | if fabric.global_rank == 0: 59 | os.makedirs(out_dir, exist_ok=True) 60 | 61 | train_data, val_data = load_datasets() 62 | 63 | config = LLaMAConfig.from_name("7B") 64 | config.block_size = block_size 65 | config.vocab_size = 100 # from prepare_shakespeare.py 66 | 67 | with fabric.device: 68 | model = LLaMA(config) 69 | 70 | # if compile: 71 | # model = torch.compile(model) 72 | 73 | model = fabric.setup_module(model) 74 | 75 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False) 76 | optimizer = fabric.setup_optimizers(optimizer) 77 | 78 | train(fabric, model, optimizer, train_data, val_data) 79 | 80 | 81 | def train( 82 | fabric: L.Fabric, 83 | model: torch.nn.Module, 84 | optimizer: torch.optim.Optimizer, 85 | train_data: np.ndarray, 86 | val_data: np.ndarray, 87 | ) -> None: 88 | """The training loop. 89 | 90 | Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. 91 | """ 92 | 93 | iter_num = 0 94 | 95 | while True: 96 | # TODO: add learning rate scheduling 97 | 98 | # evaluate the loss on train/val sets and write checkpoints 99 | if iter_num > 0 and iter_num % eval_interval == 0: 100 | val_loss = validate(fabric, model, val_data) 101 | fabric.print(f"step {iter_num}: val loss {val_loss:.4f}") 102 | fabric.print(f"Saving checkpoint to {out_dir}") 103 | save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")) 104 | 105 | t0 = time.time() 106 | 107 | input_ids, targets = get_batch( 108 | fabric, 109 | train_data, 110 | block_size=model.config.block_size, # type: ignore[union-attr,arg-type] 111 | ) 112 | logits = model(input_ids) 113 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 114 | 115 | fabric.backward(loss) 116 | 117 | # TODO: Gradient clipping 118 | # if grad_clip != 0.0: 119 | # fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 120 | 121 | optimizer.step() 122 | optimizer.zero_grad() 123 | 124 | dt = time.time() - t0 125 | if iter_num % log_interval == 0: 126 | fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms") 127 | iter_num += 1 128 | 129 | if iter_num > max_iters: 130 | break 131 | 132 | 133 | @torch.no_grad() 134 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor: 135 | fabric.print("Validating ...") 136 | model.eval() 137 | losses = torch.zeros(eval_iters) 138 | for k in range(eval_iters): 139 | input_ids, targets = get_batch( 140 | fabric, 141 | val_data, 142 | block_size=model.config.block_size, # type: ignore[union-attr,arg-type] 143 | ) 144 | logits = model(input_ids) 145 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 146 | losses[k] = loss.item() 147 | out = losses.mean() 148 | model.train() 149 | return out 150 | 151 | 152 | def get_batch(fabric: L.Fabric, data: np.ndarray, block_size: int) -> Tuple[torch.Tensor, torch.Tensor]: 153 | ix = torch.randint(len(data) - block_size, (batch_size,)) 154 | x = torch.stack([torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix]) 155 | y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) for i in ix]) 156 | x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) 157 | return x, y 158 | 159 | 160 | def load_datasets(data_dir: str = "data/shakespeare") -> Tuple[np.ndarray, np.ndarray]: 161 | train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r") 162 | val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r") 163 | return train_data, val_data 164 | 165 | 166 | if __name__ == "__main__": 167 | torch.set_float32_matmul_precision("high") 168 | main() 169 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "lit-llama" 7 | version = "0.1.0" 8 | description = "Implementation of the LLaMA language model" 9 | license = {text = "Apache-2.0"} 10 | authors = [ 11 | { name = "Lightning AI", email = "community@lightning.ai" } 12 | ] 13 | readme = "README.md" 14 | requires-python = ">=3.10" 15 | dependencies = [ 16 | "torch>=2.0.0", 17 | "lightning @ git+https://github.com/Lightning-AI/lightning@master", 18 | "sentencepiece", 19 | "bitsandbytes", 20 | ] 21 | classifiers = [ 22 | "Topic :: Text Processing" 23 | ] 24 | 25 | [project.optional-dependencies] 26 | all = [ 27 | "tqdm", # convert_checkpoint.py 28 | "numpy <2.0", # train.py dataset memmap 29 | "jsonargparse[signatures]", # generate.py, convert_checkpoint.py CLI 30 | "datasets", # evaluate.py 31 | "zstandard", # prepare_redpajama.py" 32 | ] 33 | 34 | [tool.setuptools.packages.find] 35 | where = ["."] # list of folders that contain the packages (["."] by default) 36 | include = ["lit_llama"] # package names should match these glob patterns (["*"] by default) 37 | exclude = [] # exclude packages matching these glob patterns (empty by default) 38 | namespaces = false # to disable scanning PEP 420 namespaces (true by default) 39 | 40 | 41 | [tool.pytest.ini_options] 42 | addopts = [ 43 | "--strict-markers", 44 | "--color=yes", 45 | "--disable-pytest-warnings", 46 | ] 47 | -------------------------------------------------------------------------------- /quantize/gptq.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # This adapts GPTQ's quantization process: https://github.com/IST-DASLab/gptq/ 4 | # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323 5 | # portions copyright by the authors licensed under the Apache License 2.0 6 | import gc 7 | import sys 8 | import time 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | from datasets import load_dataset 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from lit_llama import LLaMA, Tokenizer 20 | from lit_llama.quantization import GPTQQuantizer 21 | from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup 22 | 23 | 24 | def get_sample_data(): 25 | traindata = load_dataset( 26 | "allenai/c4", 27 | "allenai--c4", 28 | data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, 29 | split="train", 30 | ) 31 | # heuristic for the data size? 32 | txt = "\n".join( 33 | traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist() 34 | ) 35 | return txt 36 | 37 | 38 | @torch.no_grad() 39 | def llama_blockwise_quantization( 40 | model, sample_inputs, working_device, *, bits=4, groupsize=-1 41 | ): 42 | """ 43 | This is the classic post-training quantization of all linear layers. 44 | We quantize in order, i.e. when observing the inputs, we use the outputs of the previously quantized layers rather 45 | than doing them all at once. 46 | """ 47 | print(model) 48 | print(model.config) 49 | 50 | print("Getting inputs for first block") 51 | model.transformer.wte.to(working_device) 52 | sample_inputs = sample_inputs.to(working_device) 53 | inps = model.transformer.wte(sample_inputs) 54 | model.transformer.wte.to("cpu") 55 | torch.cuda.empty_cache() 56 | 57 | rope_cache = model.build_rope_cache(sample_inputs) 58 | mask_cache = model.build_mask_cache(sample_inputs) 59 | 60 | print("Starting to quantize blocks") 61 | outs = torch.zeros_like(inps) 62 | 63 | # better than relying on enumeration? originally the code bundled 64 | # the two mlp fc layers 65 | # we could automate this with a lot of hooks and another iteration 66 | submodules_to_process = [ 67 | "attn.c_attn", 68 | "attn.c_proj", 69 | "mlp.c_fc1", 70 | "mlp.c_fc2", 71 | "mlp.c_proj", 72 | ] 73 | 74 | for i, block in enumerate(model.transformer.h): 75 | block.to(working_device) 76 | 77 | for name in submodules_to_process: 78 | print(i, name, end=" ") 79 | t0 = time.perf_counter() 80 | print("collecting stats", end=" ") 81 | sys.stdout.flush() 82 | module = block.get_submodule(name) 83 | 84 | gptq = GPTQQuantizer( 85 | module, 86 | bits=bits, 87 | groupsize=groupsize, 88 | actorder=(groupsize == -1), 89 | ) 90 | handle = module.register_forward_hook(gptq.collect_input_stats) 91 | for j in range(inps.size(0)): 92 | outs[j : j + 1], _ = block( 93 | inps[j : j + 1], 94 | rope=rope_cache, 95 | mask=mask_cache, 96 | max_seq_length=model.config.block_size 97 | ) 98 | 99 | handle.remove() 100 | 101 | print("quantizing", end=" ") 102 | sys.stdout.flush() 103 | q_module, error = gptq.quantize() 104 | 105 | # replace the linear module with the quantized module 106 | pname, dname = name.rsplit(".", 1) 107 | setattr(block.get_submodule(pname), dname, q_module) 108 | 109 | # cleanup in an attempt to not run out of memory 110 | del gptq 111 | gc.collect() 112 | torch.cuda.empty_cache() 113 | t1 = time.perf_counter() 114 | print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}") 115 | 116 | for j in range(inps.size(0)): 117 | outs[j : j + 1], _ = block( 118 | inps[j : j + 1], 119 | rope=rope_cache, 120 | mask=mask_cache, 121 | max_seq_length=model.config.block_size 122 | ) 123 | 124 | block.cpu() 125 | gc.collect() 126 | torch.cuda.empty_cache() 127 | 128 | # the outputs are the next block's inputs and we'll reuse the old inputs 129 | inps, outs = outs, inps 130 | 131 | model.transformer.ln_f.to(working_device) 132 | for j in range(inps.size(0)): 133 | outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1]) 134 | model.transformer.ln_f.to("cpu") 135 | inps, outs = outs, inps 136 | 137 | model.lm_head.to(working_device) 138 | gptq = GPTQQuantizer( 139 | model.lm_head, 140 | bits=bits, 141 | groupsize=groupsize, 142 | actorder=(groupsize == -1), 143 | ) 144 | handle = model.lm_head.register_forward_hook(gptq.collect_input_stats) 145 | for j in range(inps.size(0)): 146 | model.lm_head(inps[j : j + 1]) 147 | handle.remove() 148 | q_module, error = gptq.quantize() 149 | model.lm_head = q_module 150 | model.lm_head.to("cpu") 151 | 152 | 153 | def main( 154 | *, 155 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), 156 | output_path: Optional[Path] = None, 157 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 158 | n_samples: int = 128, 159 | dtype: str = "float32", 160 | quantize: Optional[str] = None, 161 | ) -> None: 162 | """Generates text samples based on a pre-trained LLaMA model and tokenizer. 163 | 164 | Args: 165 | checkpoint_path: The checkpoint path to load. 166 | output_path: Path to write the quantized model's state dict to. 167 | tokenizer_path: The tokenizer path to load. 168 | n_samples: Number of example inputs to use for statistics (default: 128) 169 | dtype: The dtype to use to load the model. 170 | quantize: Mode to quantize the model to: 171 | ``"gptq.int4"``: GPTQ 4-bit mode. 172 | Note that ``"llm.int8"```does not need a quantization step. 173 | """ 174 | assert checkpoint_path.is_file() 175 | assert tokenizer_path.is_file() 176 | if output_path is None: 177 | output_path = checkpoint_path.parent / "llama-gptq.4bit.pth" 178 | assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file()) 179 | 180 | device = "cuda" 181 | 182 | dt = getattr(torch, dtype, None) 183 | if not isinstance(dt, torch.dtype): 184 | raise ValueError(f"{dtype} is not a valid dtype.") 185 | dtype = dt 186 | 187 | if quantize == "gptq.int4": 188 | bits = 4 189 | elif quantize == "gptq.int8": 190 | bits = 8 191 | else: 192 | raise RuntimeError(f"unknown/unsupported quantization mode {quantize}") 193 | 194 | # we avoid loading the entire model on the GPU and do this block by block 195 | with EmptyInitOnDevice( 196 | device="cpu", 197 | dtype=dtype, 198 | ): 199 | print("Loading model ...", file=sys.stderr) 200 | t0 = time.time() 201 | checkpoint = torch.load(checkpoint_path) 202 | name = llama_model_lookup(checkpoint) 203 | model = LLaMA.from_name(name) 204 | model.load_state_dict(checkpoint) 205 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 206 | 207 | model.eval() 208 | 209 | tokenizer = Tokenizer(tokenizer_path) 210 | 211 | test_string = get_sample_data() 212 | encoded_text = tokenizer.encode( 213 | test_string, 214 | bos=True, 215 | eos=False, 216 | ) 217 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30) 218 | encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size) 219 | 220 | t0 = time.perf_counter() 221 | llama_blockwise_quantization(model, encoded_text, device, bits=bits) 222 | t = time.perf_counter() - t0 223 | 224 | print( 225 | f"\n\nTime for quantization: {t:.02f} sec total", 226 | file=sys.stderr, 227 | ) 228 | print( 229 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", 230 | file=sys.stderr, 231 | ) 232 | 233 | torch.save(model.state_dict(), output_path) 234 | 235 | 236 | if __name__ == "__main__": 237 | from jsonargparse import CLI 238 | 239 | torch.set_float32_matmul_precision("high") 240 | CLI(main) 241 | -------------------------------------------------------------------------------- /scripts/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import gc 4 | import shutil 5 | from pathlib import Path 6 | from typing import Dict 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | """ 12 | Sample usage: 13 | 14 | ```bash 15 | python -m scripts.convert_checkpoint -h 16 | 17 | python -m scripts.convert_checkpoint converted 18 | ``` 19 | """ 20 | 21 | 22 | def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: 23 | converted = {} 24 | converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype) 25 | converted["lm_head.weight"] = state_dict["output.weight"].to(dtype) 26 | converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype) 27 | 28 | for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): 29 | # attention 30 | # the wq, wk, wv from the FB model are stacked in our model as c_attn 31 | converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat( 32 | ( 33 | state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype), 34 | state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype), 35 | state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype), 36 | ) 37 | ) 38 | converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[ 39 | f"layers.{layer_idx}.attention.wo.weight" 40 | ].to(dtype) 41 | # mlp 42 | converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[ 43 | f"layers.{layer_idx}.feed_forward.w1.weight" 44 | ].to(dtype) 45 | converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[ 46 | f"layers.{layer_idx}.feed_forward.w2.weight" 47 | ].to(dtype) 48 | converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[ 49 | f"layers.{layer_idx}.feed_forward.w3.weight" 50 | ].to(dtype) 51 | # rms norm 52 | converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype) 53 | converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype) 54 | return converted 55 | 56 | 57 | shard_dims = { 58 | "lm_head.weight": 0, 59 | "wte.weight": 1, 60 | "attn.c_attn.weight": 0, 61 | "attn.c_proj.weight": 1, 62 | "mlp.c_fc1.weight": 0, 63 | "mlp.c_fc2.weight": 0, 64 | "mlp.c_proj.weight": 1 65 | } 66 | 67 | 68 | def meta_weights_for_nano_model( 69 | *, 70 | output_dir: Path = Path("checkpoints/lit-llama"), 71 | checkpoint_dir: Path = Path("checkpoints/llama/"), 72 | model_size: str = "7B", 73 | dtype: str = "float32", 74 | ) -> None: 75 | output_dir = output_dir / model_size 76 | checkpoint_dir = checkpoint_dir / model_size 77 | output_dir.mkdir(parents=True, exist_ok=True) 78 | 79 | # the tokenizer is the same for all model sizes, so we store it in the parent dir 80 | shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent) 81 | 82 | dt = getattr(torch, dtype, None) 83 | if not isinstance(dt, torch.dtype): 84 | raise ValueError(f"{dtype} is not a valid dtype.") 85 | dtype = dt 86 | 87 | checkpoint_files = sorted(checkpoint_dir.glob("*.pth")) 88 | checkpoint_files.sort() 89 | n_checkpoints = len(checkpoint_files) 90 | 91 | if n_checkpoints == 0: 92 | raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.") 93 | 94 | # for the bigger models, there are multiple model-parallel checkpoints 95 | # and we combine them into one single file 96 | combined = None 97 | for file in tqdm(checkpoint_files, total=n_checkpoints): 98 | checkpoint = torch.load(file, map_location="cpu") 99 | converted = convert_state_dict(checkpoint, dtype=dtype) 100 | if combined is None: 101 | combined = converted 102 | continue 103 | for name, param in converted.items(): 104 | dim = None 105 | for k, d in shard_dims.items(): 106 | if k in name: 107 | dim = d 108 | break 109 | if dim is None: 110 | # Extra check: assert that tensors are the same if not sharded 111 | # assert torch.allclose(combined[name], param) 112 | continue 113 | combined[name] = torch.cat((combined[name], param), dim=dim) 114 | 115 | del checkpoint 116 | del converted 117 | gc.collect() 118 | 119 | for name, param in combined.items(): 120 | if "c_attn" not in name: 121 | continue 122 | 123 | # Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...] 124 | 125 | src_chunk_len = param.shape[0] // n_checkpoints 126 | mat_len = src_chunk_len // 3 127 | dst_chunk_len = mat_len * n_checkpoints 128 | attn = torch.clone(param) 129 | for i in range(n_checkpoints): 130 | for j in range(3): 131 | param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \ 132 | attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len] 133 | 134 | del attn 135 | gc.collect() 136 | 137 | torch.save(combined, output_dir / "lit-llama.pth") 138 | 139 | 140 | if __name__ == "__main__": 141 | from jsonargparse import CLI 142 | 143 | CLI(meta_weights_for_nano_model) 144 | -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import collections 4 | import contextlib 5 | import gc 6 | import json 7 | import shutil 8 | import sys 9 | from pathlib import Path 10 | 11 | import torch 12 | 13 | # support running without installing as a package 14 | wd = Path(__file__).parent.parent.resolve() 15 | sys.path.append(str(wd)) 16 | 17 | from lit_llama.model import LLaMA, LLaMAConfig 18 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, incremental_save 19 | 20 | 21 | @torch.no_grad() 22 | def convert_hf_checkpoint( 23 | *, 24 | output_dir: Path = Path("checkpoints/lit-llama/7B"), 25 | checkpoint_dir: Path = Path("checkpoints/hf-llama/7B"), 26 | model_size: str = "7B", 27 | dtype: str = "float32", 28 | verify: bool = False, 29 | ) -> None: 30 | """ 31 | Perform the reverse operation of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py 32 | """ 33 | output_dir.mkdir(parents=True, exist_ok=True) 34 | 35 | # the tokenizer is the same for all model sizes, so we store it in the parent dir 36 | shutil.copy(checkpoint_dir / "tokenizer.model", output_dir.parent) 37 | 38 | dt = getattr(torch, dtype, None) 39 | if not isinstance(dt, torch.dtype): 40 | raise ValueError(f"{dtype} is not a valid dtype.") 41 | dtype = dt 42 | 43 | print("Initializing lit-llama") 44 | config = LLaMAConfig.from_name(model_size) 45 | 46 | with EmptyInitOnDevice(device="meta", dtype=dtype): 47 | model = LLaMA(config) 48 | 49 | qkv_size = model.transformer.h[0].attn.c_attn.weight.shape[0] // 3 50 | 51 | # initialize a new empty state dict to hold our new weights 52 | sd_meta = model.state_dict() 53 | sd = {} 54 | 55 | # Load the json file containing weight mapping 56 | pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" 57 | with open(pytorch_bin_map_json_path) as json_map: 58 | bin_index = json.load(json_map) 59 | bin_files = set(checkpoint_dir / bin for bin in bin_index["weight_map"].values()) 60 | if not bin_files: 61 | raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") 62 | 63 | def permute(w): 64 | dim = config.n_embd 65 | w = w._load_tensor().to(dtype) 66 | return ( 67 | w.view(config.n_head, 2, dim // config.n_head // 2, dim) 68 | .transpose(1, 2) 69 | .reshape(dim, dim) 70 | ) 71 | 72 | weight_map = { 73 | "self_attn.o_proj.weight": "attn.c_proj.weight", 74 | "self_attn.q_proj.weight": "attn.c_attn.weight", 75 | "self_attn.k_proj.weight": "attn.c_attn.weight", 76 | "self_attn.v_proj.weight": "attn.c_attn.weight", 77 | "mlp.gate_proj.weight": "mlp.c_fc1.weight", 78 | "mlp.up_proj.weight": "mlp.c_fc2.weight", 79 | "mlp.down_proj.weight": "mlp.c_proj.weight", 80 | "input_layernorm.weight": "rms_1.scale", 81 | "post_attention_layernorm.weight": "rms_2.scale", 82 | "model.embed_tokens.weight": "transformer.wte.weight", 83 | "model.norm.weight": "transformer.ln_f.scale", 84 | "lm_head.weight": "lm_head.weight", 85 | } 86 | 87 | print(f"Saving to disk at {output_dir}") 88 | unprocessed_weights = collections.defaultdict(dict) 89 | 90 | with incremental_save(output_dir / "lit-llama.pth") as saver: 91 | # for checkpoints that split the QKV across several files, we need to keep all the bin files 92 | # open, so we use `ExitStack` to close them all together at the end 93 | with contextlib.ExitStack() as stack: 94 | for bin_file in bin_files: 95 | print("Processing", bin_file) 96 | hf_weights = stack.enter_context(lazy_load(bin_file)) 97 | for name, param in hf_weights.items(): 98 | skip = False 99 | if "rotary_emb.inv_freq" in name: 100 | continue 101 | if "model.layers" in name: 102 | block_id = int(name.split(".")[2]) 103 | from_name = ".".join(name.split(".")[3:]) 104 | to_name = weight_map[from_name] 105 | sd_key = f"transformer.h.{block_id}.{to_name}" 106 | 107 | if "q_proj" in name: 108 | unprocessed_weights[sd_key]["q_proj"] = param 109 | skip = True 110 | elif "k_proj" in name: 111 | unprocessed_weights[sd_key]["k_proj"] = param 112 | skip = True 113 | elif "v_proj" in name: 114 | unprocessed_weights[sd_key]["v_proj"] = param 115 | skip = True 116 | if skip and len(unprocessed_weights[sd_key]) == 3: 117 | w = torch.empty( 118 | sd_meta[sd_key].shape, dtype=sd_meta[sd_key].dtype 119 | ) 120 | w[:qkv_size] = permute(unprocessed_weights[sd_key]["q_proj"]) 121 | w[qkv_size:-qkv_size] = permute( 122 | unprocessed_weights[sd_key]["k_proj"] 123 | ) 124 | w[-qkv_size:] = ( 125 | unprocessed_weights[sd_key]["v_proj"] 126 | ._load_tensor() 127 | .to(dtype) 128 | ) 129 | sd[sd_key] = w 130 | del unprocessed_weights[sd_key] 131 | skip = False 132 | else: 133 | sd[sd_key] = param._load_tensor().to(dtype) 134 | else: 135 | sd_key = weight_map[name] 136 | sd[sd_key] = param._load_tensor().to(dtype) 137 | if not skip: 138 | sd[sd_key] = saver.store_early(sd[sd_key]) 139 | gc.collect() 140 | saver.save(sd) 141 | 142 | assert len(unprocessed_weights) == 0, f"unexpected partial weights {list(unprocessed_weights)}" 143 | if verify: 144 | try: 145 | from transformers import LlamaForCausalLM 146 | except ImportError: 147 | raise ImportError("verify=True requires transformers to be installed, please `pip install transformers`") 148 | print("Verifying...") 149 | 150 | token_sample = torch.randint(0, config.vocab_size, size=(1, config.block_size), dtype=torch.int64) 151 | out = model(token_sample) 152 | del model 153 | gc.collect() 154 | 155 | print("Loading original model for comparison") 156 | model_hf = LlamaForCausalLM.from_pretrained(checkpoint_dir) 157 | out_hf = model_hf(token_sample)["logits"] 158 | 159 | print("Comparing outputs") 160 | assert out.device.type == out_hf.device.type 161 | assert out.dtype == out_hf.dtype 162 | assert torch.testing.assert_close(out, out_hf) 163 | 164 | 165 | if __name__ == "__main__": 166 | from jsonargparse import CLI 167 | 168 | CLI(convert_hf_checkpoint) 169 | 170 | -------------------------------------------------------------------------------- /scripts/convert_lora_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import lightning as L 9 | import torch 10 | import torch.nn as nn 11 | 12 | # support running without installing as a package 13 | wd = Path(__file__).parent.parent.resolve() 14 | sys.path.append(str(wd)) 15 | 16 | from lit_llama import LLaMA 17 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup 18 | from lit_llama.lora import lora 19 | 20 | def del_lora_state_dict(model: nn.Module): 21 | base_model_dict = model.state_dict() 22 | key_to_delete = [k for k in base_model_dict if "lora_" in k] 23 | for del_key in key_to_delete: 24 | del base_model_dict[del_key] 25 | return base_model_dict 26 | 27 | 28 | def lora_model_lookup(checkpoint: dict) -> int: 29 | """Returns the LoRA rank from the adapter checkpoint. 30 | 31 | """ 32 | return checkpoint["transformer.h.0.attn.c_attn.lora_B"].shape[1] 33 | 34 | 35 | def main( 36 | accelerator: str = "auto", 37 | lora_path: Optional[Path] = None, 38 | checkpoint_path: Optional[Path] = None, 39 | dtype: str = "bfloat16", 40 | ) -> None: 41 | """Merges lora weights to base model. 42 | 43 | Args: 44 | accelerator: The hardware to run on. Possible choices are: 45 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. 46 | lora_path: Path to the checkpoint with trained LoRA weights, which are the output of 47 | `finetune_lora.py`. 48 | checkpoint_path: The checkpoint path to load. 49 | dtype: `torch.dtype` to work with 50 | """ 51 | if not lora_path: 52 | lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth") 53 | if not checkpoint_path: 54 | checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth") 55 | 56 | assert lora_path.is_file() 57 | assert checkpoint_path.is_file() 58 | 59 | fabric = L.Fabric(accelerator=accelerator, devices=1) 60 | 61 | dt = getattr(torch, dtype, None) 62 | if not isinstance(dt, torch.dtype): 63 | raise ValueError(f"{dtype} is not a valid dtype.") 64 | dtype = dt 65 | 66 | print("Loading model ...", file=sys.stderr) 67 | t0 = time.time() 68 | 69 | with (lazy_load(checkpoint_path) as pretrained_checkpoint, 70 | lazy_load(lora_path) as lora_checkpoint): 71 | name = llama_model_lookup(pretrained_checkpoint) 72 | rank = lora_model_lookup(lora_checkpoint) 73 | 74 | with EmptyInitOnDevice( 75 | device=fabric.device, dtype=dtype 76 | ), lora(r=rank, alpha=16, dropout=0.05, enabled=True): 77 | model = LLaMA.from_name(name) 78 | 79 | # 1. Load the pretrained weights 80 | model.load_state_dict(pretrained_checkpoint, strict=False) 81 | # 2. Load the fine-tuned lora weights 82 | model.load_state_dict(lora_checkpoint, strict=False) 83 | 84 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) 85 | 86 | model.eval() 87 | base_model_dict = del_lora_state_dict(model) 88 | save_path = lora_path.with_stem(f"{lora_path.stem}-lora-merged-weights") 89 | print("Saving LoRA to base model weights ...") 90 | torch.save(base_model_dict, save_path) 91 | print(f"Model saved at {save_path}") 92 | 93 | 94 | if __name__ == "__main__": 95 | from jsonargparse import CLI 96 | 97 | CLI(main) -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import os 4 | from typing import Optional 5 | from urllib.request import urlretrieve 6 | 7 | files = { 8 | "original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py", 9 | "original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py", 10 | } 11 | 12 | 13 | def download_original(wd: str) -> None: 14 | for file, url in files.items(): 15 | filepath = os.path.join(wd, file) 16 | if not os.path.isfile(filepath): 17 | print(f"Downloading original implementation to {filepath!r}") 18 | urlretrieve(url=url, filename=file) 19 | print("Done") 20 | else: 21 | print("Original implementation found. Skipping download.") 22 | 23 | 24 | def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None: 25 | if repo_id is None: 26 | raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.") 27 | 28 | from huggingface_hub import snapshot_download 29 | 30 | snapshot_download(repo_id, local_dir=local_dir, local_dir_use_symlinks=False) 31 | 32 | 33 | if __name__ == "__main__": 34 | from jsonargparse import CLI 35 | 36 | CLI(download_from_hub) 37 | -------------------------------------------------------------------------------- /scripts/prepare_alpaca.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | """Implementation derived from https://github.com/tloen/alpaca-lora""" 4 | import sys 5 | from pathlib import Path 6 | 7 | # support running without installing as a package 8 | wd = Path(__file__).parent.parent.resolve() 9 | sys.path.append(str(wd)) 10 | 11 | import torch 12 | import requests 13 | import json 14 | from torch.utils.data import random_split 15 | from lit_llama.tokenizer import Tokenizer 16 | from tqdm import tqdm 17 | 18 | 19 | DATA_FILE = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json" 20 | DATA_FILE_NAME = "alpaca_data_cleaned_archive.json" 21 | IGNORE_INDEX = -1 22 | 23 | 24 | def prepare( 25 | destination_path: Path = Path("data/alpaca"), 26 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 27 | test_split_size: int = 2000, 28 | max_seq_length: int = 256, 29 | seed: int = 42, 30 | mask_inputs: bool = False, # as in alpaca-lora 31 | data_file_name: str = DATA_FILE_NAME 32 | ) -> None: 33 | """Prepare the Alpaca dataset for instruction tuning. 34 | 35 | The output is a training and validation dataset saved as `train.pt` and `val.pt`, 36 | which stores the preprocessed and tokenized prompts and labels. 37 | """ 38 | 39 | destination_path.mkdir(parents=True, exist_ok=True) 40 | file_path = destination_path / data_file_name 41 | download(file_path) 42 | 43 | # TODO: If we don't have the Meta weights, where do we get the tokenizer from? 44 | tokenizer = Tokenizer(tokenizer_path) 45 | 46 | with open(file_path, "r") as file: 47 | data = json.load(file) 48 | 49 | # Partition the dataset into train and test 50 | train_split_size = len(data) - test_split_size 51 | train_set, test_set = random_split( 52 | data, 53 | lengths=(train_split_size, test_split_size), 54 | generator=torch.Generator().manual_seed(seed), 55 | ) 56 | train_set, test_set = list(train_set), list(test_set) 57 | 58 | print(f"train has {len(train_set):,} samples") 59 | print(f"val has {len(test_set):,} samples") 60 | 61 | print("Processing train split ...") 62 | train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)] 63 | torch.save(train_set, file_path.parent / "train.pt") 64 | 65 | print("Processing test split ...") 66 | test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)] 67 | torch.save(test_set, file_path.parent / "test.pt") 68 | 69 | 70 | def download(file_path: Path): 71 | """Downloads the raw json data file and saves it in the given destination.""" 72 | if file_path.exists(): 73 | return 74 | with open(file_path, "w") as f: 75 | f.write(requests.get(DATA_FILE).text) 76 | 77 | 78 | def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True): 79 | """Processes a single sample. 80 | 81 | Each sample in the dataset consists of: 82 | - instruction: A string describing the task 83 | - input: A string holding a special input value for the instruction. 84 | This only applies to some samples, and in others this is empty. 85 | - output: The response string 86 | 87 | This function processes this data to produce a prompt text and a label for 88 | supervised training. The input text is formed as a single message including all 89 | the instruction, the input (optional) and the response. 90 | The label/target is the same message but can optionally have the instruction + input text 91 | masked out (mask_inputs=True). 92 | 93 | Finally, both the prompt and the label get tokenized. If desired, all tokens 94 | in the label that correspond to the original input prompt get masked out (default). 95 | """ 96 | full_prompt = generate_prompt(example) 97 | full_prompt_and_response = full_prompt + example["output"] 98 | encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False) 99 | encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length) 100 | 101 | # The labels are the full prompt with response, but with the prompt masked out 102 | labels = encoded_full_prompt_and_response.clone() 103 | if mask_inputs: 104 | labels[:len(encoded_full_prompt)] = IGNORE_INDEX 105 | 106 | return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels} 107 | 108 | 109 | def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor: 110 | return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length) 111 | 112 | 113 | def generate_prompt(example): 114 | """Generates a standardized message to prompt the model with an instruction, optional input and a 115 | 'response' field.""" 116 | 117 | if example["input"]: 118 | return ( 119 | "Below is an instruction that describes a task, paired with an input that provides further context. " 120 | "Write a response that appropriately completes the request.\n\n" 121 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" 122 | ) 123 | return ( 124 | "Below is an instruction that describes a task. " 125 | "Write a response that appropriately completes the request.\n\n" 126 | f"### Instruction:\n{example['instruction']}\n\n### Response:" 127 | ) 128 | 129 | 130 | if __name__ == "__main__": 131 | from jsonargparse import CLI 132 | 133 | CLI(prepare) 134 | -------------------------------------------------------------------------------- /scripts/prepare_any_text.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | """Implementation derived from https://github.com/tloen/alpaca-lora""" 4 | import sys 5 | from pathlib import Path 6 | 7 | # support running without installing as a package 8 | wd = Path(__file__).parent.parent.resolve() 9 | sys.path.append(str(wd)) 10 | 11 | import torch 12 | import requests 13 | import json 14 | from torch.utils.data import random_split 15 | from lit_llama.tokenizer import Tokenizer 16 | from tqdm import tqdm 17 | 18 | 19 | IGNORE_INDEX = -1 20 | 21 | DATA_FILE_NAME = "input.txt" 22 | 23 | 24 | def prepare( 25 | destination_path: Path = Path("data/any"), 26 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 27 | test_split_ratio: float = 0.9, # default 90% train, 10% validation 28 | max_seq_length: int = 256, 29 | seed: int = 42, 30 | data_file_name: str = DATA_FILE_NAME, 31 | ) -> None: 32 | """Prepare any dataset for finetuning (akin to Shakespheare full tuning). 33 | 34 | The output is a training and validation dataset saved as `train.pt` and `val.pt`, 35 | which stores the preprocessed and tokenized prompts and labels. 36 | """ 37 | 38 | destination_path.mkdir(parents=True, exist_ok=True) 39 | file_path = destination_path / data_file_name 40 | if not file_path.exists(): 41 | raise AssertionError(f"{data_file_name} is provided by the user") 42 | 43 | # TODO: If we don't have the Meta weights, where do we get the tokenizer from? 44 | tokenizer = Tokenizer(tokenizer_path) 45 | 46 | data = [] 47 | 48 | with open(file_path, "r") as input_file: 49 | for line in input_file.readlines(): 50 | data.append(line) 51 | 52 | # Partition the dataset into train and test 53 | train_split_size = int(len(data) * test_split_ratio) 54 | test_split_size = len(data) - train_split_size 55 | train_set, test_set = random_split( 56 | data, 57 | lengths=(train_split_size, test_split_size), 58 | generator=torch.Generator().manual_seed(seed), 59 | ) 60 | train_set, test_set = list(train_set), list(test_set) 61 | 62 | print(f"train has {len(train_set):,} samples") 63 | print(f"val has {len(test_set):,} samples") 64 | 65 | print("Processing train split ...") 66 | train_set = [ 67 | prepare_line(line, tokenizer, max_seq_length) for line in tqdm(train_set) 68 | ] 69 | torch.save(train_set, file_path.parent / "train.pt") 70 | 71 | print("Processing test split ...") 72 | test_set = [ 73 | prepare_line(line, tokenizer, max_seq_length) for line in tqdm(test_set) 74 | ] 75 | torch.save(test_set, file_path.parent / "test.pt") 76 | 77 | 78 | def prepare_line(line: str, tokenizer: Tokenizer, max_length: int): 79 | """Processes a single sample. 80 | 81 | This function processes the line to produce the tokenized version of it. 82 | """ 83 | encoded_full_prompt = tokenize(tokenizer, line, max_length=max_length, eos=False) 84 | return { 85 | "input_ids": encoded_full_prompt, 86 | "labels": encoded_full_prompt, 87 | } 88 | 89 | 90 | def tokenize( 91 | tokenizer: Tokenizer, string: str, max_length: int, eos=True 92 | ) -> torch.Tensor: 93 | return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length) 94 | 95 | 96 | if __name__ == "__main__": 97 | from jsonargparse import CLI 98 | 99 | CLI(prepare) 100 | -------------------------------------------------------------------------------- /scripts/prepare_dolly.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | """Implementation derived from https://github.com/tloen/alpaca-lora""" 4 | import sys 5 | from pathlib import Path 6 | 7 | # support running without installing as a package 8 | wd = Path(__file__).parent.parent.resolve() 9 | sys.path.append(str(wd)) 10 | 11 | import torch 12 | import requests 13 | import json 14 | from torch.utils.data import random_split 15 | from lit_llama.tokenizer import Tokenizer 16 | from tqdm import tqdm 17 | 18 | 19 | DATA_FILE = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" 20 | DATA_FILE_NAME = "dolly_data_cleaned.json" 21 | IGNORE_INDEX = -1 22 | 23 | 24 | def prepare( 25 | destination_path: Path = Path("data/dolly"), 26 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 27 | test_split_size: int = 2000, 28 | max_seq_length: int = 1024, 29 | seed: int = 42, 30 | mask_inputs: bool = False, # as in alpaca-lora 31 | ) -> None: 32 | """Prepare the Dolly dataset for instruction tuning. 33 | 34 | The output is a training and validation dataset saved as `train.pt` and `val.pt`, 35 | which stores the preprocessed and tokenized prompts and labels. 36 | """ 37 | 38 | destination_path.mkdir(parents=True, exist_ok=True) 39 | file_path = destination_path / DATA_FILE_NAME 40 | download(file_path) 41 | 42 | # TODO: If we don't have the Meta weights, where do we get the tokenizer from? 43 | tokenizer = Tokenizer(tokenizer_path) 44 | 45 | with open(file_path, "r") as file: 46 | data = file.readlines() 47 | data = [json.loads(line) for line in data] 48 | for item in data: 49 | item["input"] = item.pop("context") 50 | item["output"] = item.pop("response") 51 | 52 | # Partition the dataset into train and test 53 | train_split_size = len(data) - test_split_size 54 | train_set, test_set = random_split( 55 | data, 56 | lengths=(train_split_size, test_split_size), 57 | generator=torch.Generator().manual_seed(seed), 58 | ) 59 | train_set, test_set = list(train_set), list(test_set) 60 | 61 | print(f"train has {len(train_set):,} samples") 62 | print(f"val has {len(test_set):,} samples") 63 | 64 | print("Processing train split ...") 65 | train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)] 66 | torch.save(train_set, file_path.parent / "train.pt") 67 | 68 | print("Processing test split ...") 69 | test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)] 70 | torch.save(test_set, file_path.parent / "test.pt") 71 | 72 | 73 | def download(file_path: Path): 74 | """Downloads the raw json data file and saves it in the given destination.""" 75 | if file_path.exists(): 76 | return 77 | with open(file_path, "w") as f: 78 | f.write(requests.get(DATA_FILE).text) 79 | 80 | 81 | def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True): 82 | """Processes a single sample. 83 | 84 | Each sample in the dataset consists of: 85 | - instruction: A string describing the task 86 | - input: A string holding a special input value for the instruction. 87 | This only applies to some samples, and in others this is empty. 88 | - output: The response string 89 | 90 | This function processes this data to produce a prompt text and a label for 91 | supervised training. The prompt text is formed as a single message including both 92 | the instruction and the input. The label/target is the same message but with the 93 | response attached. 94 | 95 | Finally, both the prompt and the label get tokenized. If desired, all tokens 96 | in the label that correspond to the original input prompt get masked out (default). 97 | """ 98 | full_prompt = generate_prompt(example) 99 | full_prompt_and_response = full_prompt + example["output"] 100 | encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False) 101 | encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length) 102 | 103 | # The labels are the full prompt with response, but with the prompt masked out 104 | labels = encoded_full_prompt_and_response.clone() 105 | if mask_inputs: 106 | labels[:len(encoded_full_prompt)] = IGNORE_INDEX 107 | 108 | return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels} 109 | 110 | 111 | def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor: 112 | return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length) 113 | 114 | 115 | def generate_prompt(example): 116 | """Generates a standardized message to prompt the model with an instruction, optional input and a 117 | 'response' field.""" 118 | 119 | if example["input"]: 120 | return ( 121 | f"Below is an instruction that describes a task, paired with an input that provides further context. " 122 | "Write a response that appropriately completes the request.\n\n" 123 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" 124 | ) 125 | return ( 126 | f"Below is an instruction that describes a task. " 127 | "Write a response that appropriately completes the request.\n\n" 128 | f"### Instruction:\n{example['instruction']}\n\n### Response:" 129 | ) 130 | 131 | 132 | if __name__ == "__main__": 133 | from jsonargparse import CLI 134 | 135 | CLI(prepare) 136 | -------------------------------------------------------------------------------- /scripts/prepare_redpajama.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import json 4 | import glob 5 | import os 6 | from pathlib import Path 7 | import sys 8 | 9 | # support running without installing as a package 10 | wd = Path(__file__).parent.parent.resolve() 11 | sys.path.append(str(wd)) 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | from lit_llama import Tokenizer 17 | import lit_llama.packed_dataset as packed_dataset 18 | 19 | 20 | filenames_sample = [ 21 | "arxiv_sample.jsonl", 22 | "book_sample.jsonl", 23 | "c4_sample.jsonl", 24 | "cc_2019-30_sample.jsonl", 25 | "cc_2020-05_sample.jsonl", 26 | "cc_2021-04_sample.jsonl", 27 | "cc_2022-05_sample.jsonl", 28 | "cc_2023-06_sample.jsonl", 29 | "github_sample.jsonl", 30 | "stackexchange_sample.jsonl", 31 | "wikipedia_sample.jsonl", 32 | ] 33 | 34 | filename_sets = { 35 | "arxiv": "arxiv/arxiv*", 36 | "book": "book/book*", 37 | "c4": "c4/c4-train*", 38 | "common_crawl": "common_crawl/*", 39 | "github": "github/filtered*", 40 | "stackexchange": "stackexchange/stackexchange*", 41 | "wikipedia": "wikipedia/wiki*", 42 | } 43 | 44 | 45 | def prepare_sample( 46 | source_path: Path, 47 | tokenizer_path: Path, 48 | destination_path: Path, 49 | chunk_size: int, 50 | match = "" 51 | ) -> None: 52 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model).""" 53 | destination_path.mkdir(parents=True, exist_ok=True) 54 | 55 | tokenizer = Tokenizer(tokenizer_path) 56 | 57 | for name in filenames_sample: 58 | if match and match not in name: 59 | continue 60 | 61 | filepath = source_path / name 62 | 63 | if not filepath.is_file(): 64 | raise RuntimeError( 65 | f"Input file not found at {filepath}. \n" 66 | "Make sure you download the data, e.g. wget -i https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through \n" 67 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T \n" 68 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" 69 | ) 70 | 71 | prefix, _ = os.path.splitext(name) 72 | 73 | builder = packed_dataset.PackedDatasetBuilder( 74 | outdir=destination_path, 75 | prefix=prefix, 76 | chunk_size=chunk_size, 77 | sep_token=tokenizer.bos_id, 78 | dtype="auto", 79 | vocab_size=tokenizer.vocab_size, 80 | ) 81 | 82 | print(f"Processing {name}") 83 | 84 | with open(filepath, encoding="utf-8") as f: 85 | for row in tqdm(f): 86 | text = json.loads(row)["text"] 87 | text_ids = tokenizer.encode(text) 88 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 89 | 90 | builder.write_reminder() 91 | 92 | 93 | def prepare_full( 94 | source_path: Path, 95 | tokenizer_path: Path, 96 | destination_path: Path, 97 | chunk_size: int, 98 | match: str = "" 99 | ) -> None: 100 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model).""" 101 | import zstandard as zstd 102 | 103 | destination_path.mkdir(parents=True, exist_ok=True) 104 | 105 | tokenizer = Tokenizer(tokenizer_path) 106 | 107 | for set_name, pattern in filename_sets.items(): 108 | if match and match not in set_name: 109 | continue 110 | 111 | is_cc = set_name == "common_crawl" 112 | 113 | filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) 114 | 115 | if not filenames: 116 | raise RuntimeError( 117 | f"No files matching {pattern} found at {source_path}. \n" 118 | "Make sure you download the data, e.g. wget -i https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through \n" 119 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T \n" 120 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" 121 | ) 122 | 123 | builder = packed_dataset.PackedDatasetBuilder( 124 | outdir=destination_path, 125 | prefix=set_name, 126 | chunk_size=chunk_size, 127 | sep_token=tokenizer.bos_id, 128 | dtype="auto", 129 | vocab_size=tokenizer.vocab_size, 130 | ) 131 | 132 | for name in filenames: 133 | filepath = source_path / name 134 | 135 | print(f"Processing {name}") 136 | 137 | if is_cc: 138 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: 139 | for row in tqdm(f): 140 | text = json.loads(row)["text"] 141 | text_ids = tokenizer.encode(text) 142 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 143 | else: 144 | with open(filepath, encoding="utf-8") as f: 145 | for row in tqdm(f): 146 | text = json.loads(row)["text"] 147 | text_ids = tokenizer.encode(text) 148 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 149 | 150 | builder.write_reminder() 151 | 152 | 153 | def prepare( 154 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 155 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 156 | destination_path: Path = Path("data/red_pajama_sample"), 157 | chunk_size: int = 2049 * 1024, # 2048 block size + 1 for causal (from LLama), 1024 blocks 158 | sample: bool = False, 159 | match: str = "", 160 | ) -> None: 161 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model).""" 162 | if sample: 163 | prepare_sample( 164 | source_path=source_path, 165 | tokenizer_path=tokenizer_path, 166 | destination_path=destination_path, 167 | chunk_size=chunk_size, 168 | match=match, 169 | ) 170 | else: 171 | prepare_full( 172 | source_path=source_path, 173 | tokenizer_path=tokenizer_path, 174 | destination_path=destination_path, 175 | chunk_size=chunk_size, 176 | match=match, 177 | ) 178 | 179 | 180 | if __name__ == "__main__": 181 | from jsonargparse import CLI 182 | 183 | CLI(prepare) 184 | -------------------------------------------------------------------------------- /scripts/prepare_shakespeare.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Andrej Karpathy 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | import sys 25 | from pathlib import Path 26 | 27 | # support running without installing as a package 28 | wd = Path(__file__).parent.parent.resolve() 29 | sys.path.append(str(wd)) 30 | 31 | import numpy as np 32 | import requests 33 | 34 | 35 | def prepare(destination_path: Path = Path("data/shakespeare")) -> None: 36 | """Prepare the "Tiny Shakespeare" dataset.""" 37 | destination_path.mkdir(parents=True, exist_ok=True) 38 | 39 | # download the tiny shakespeare dataset 40 | input_file_path = destination_path / "input.txt" 41 | if not input_file_path.exists(): 42 | data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 43 | with open(input_file_path, "w") as f: 44 | f.write(requests.get(data_url).text) 45 | 46 | with open(input_file_path) as f: 47 | data = f.read() 48 | n = len(data) 49 | train_data = data[: int(n * 0.9)] 50 | val_data = data[int(n * 0.9) :] 51 | 52 | from lit_llama import Tokenizer 53 | 54 | Tokenizer.train(input=input_file_path, destination=destination_path, vocab_size=100) 55 | tokenizer = Tokenizer(destination_path / "tokenizer.model") 56 | train_ids = tokenizer.encode(train_data) 57 | val_ids = tokenizer.encode(val_data) 58 | print(f"train has {len(train_ids):,} tokens") 59 | print(f"val has {len(val_ids):,} tokens") 60 | 61 | # export to bin files 62 | train_ids = np.array(train_ids, dtype=np.uint16) 63 | val_ids = np.array(val_ids, dtype=np.uint16) 64 | train_ids.tofile(destination_path / "train.bin") 65 | val_ids.tofile(destination_path / "val.bin") 66 | 67 | 68 | if __name__ == "__main__": 69 | from jsonargparse import CLI 70 | 71 | CLI(prepare) 72 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | from setuptools import setup 4 | 5 | setup() 6 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | wd = Path(__file__).parent.parent.absolute() 9 | 10 | 11 | @pytest.fixture() 12 | def orig_llama(): 13 | sys.path.append(str(wd)) 14 | 15 | from scripts.download import download_original 16 | 17 | download_original(wd) 18 | 19 | import original_model 20 | 21 | return original_model 22 | 23 | 24 | @pytest.fixture() 25 | def orig_llama_adapter(): 26 | sys.path.append(str(wd)) 27 | 28 | from scripts.download import download_original 29 | 30 | download_original(wd) 31 | 32 | import original_adapter 33 | 34 | return original_adapter 35 | 36 | 37 | @pytest.fixture() 38 | def lit_llama(): 39 | # this adds support for running tests without the package installed 40 | sys.path.append(str(wd)) 41 | 42 | import lit_llama 43 | 44 | return lit_llama 45 | -------------------------------------------------------------------------------- /tests/test_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | from dataclasses import asdict 4 | import pytest 5 | import sys 6 | import torch 7 | 8 | 9 | @pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.") 10 | @pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"]) 11 | def test_config_identical(model_size, lit_llama): 12 | import lit_llama.adapter as llama_adapter 13 | import lit_llama.model as llama 14 | from lit_llama.utils import EmptyInitOnDevice 15 | 16 | llama_config = asdict(llama.LLaMAConfig.from_name(model_size)) 17 | adapter_config = asdict(llama_adapter.LLaMAConfig.from_name(model_size)) 18 | 19 | del adapter_config["adapter_prompt_length"] 20 | del adapter_config["adapter_start_layer"] 21 | assert adapter_config == llama_config 22 | 23 | with EmptyInitOnDevice(): 24 | llama_model = llama.LLaMA.from_name(model_size) 25 | adapter_model = llama_adapter.LLaMA.from_name(model_size) 26 | assert llama_model.lm_head.weight.shape == adapter_model.lm_head.weight.shape 27 | 28 | 29 | def test_adapter_load_gating_factor(lit_llama): 30 | """Tests backward-compatible loading of checkpoints after the `gating_factor` was extended per-head 31 | in PR #297. 32 | """ 33 | import lit_llama.adapter as llama_adapter 34 | from lit_llama.utils import lazy_load 35 | 36 | config = llama_adapter.LLaMAConfig(n_head=4, block_size=100, n_embd=16) 37 | attn = llama_adapter.CausalSelfAttention(config=config, block_idx=3) 38 | 39 | # Old checkpoint format 40 | state_dict={ 41 | "gating_factor": torch.tensor(0.42), # in old checkpoints, this was a scalar 42 | "c_attn.weight": torch.zeros(3 * 16, 16), 43 | "c_proj.weight": torch.zeros(16, 16), 44 | "adapter_wte.weight": torch.zeros(10, 16), 45 | } 46 | attn.load_state_dict(state_dict=state_dict) 47 | assert torch.equal(attn.gating_factor, torch.full((1, 4, 1, 1), 0.42)) 48 | 49 | # New checkpoint format 50 | state_dict={ 51 | "gating_factor": torch.tensor([0.42, 0.42, 0.42, 0.42]).reshape(1, 4, 1, 1), 52 | "c_attn.weight": torch.zeros(3 * 16, 16), 53 | "c_proj.weight": torch.zeros(16, 16), 54 | "adapter_wte.weight": torch.zeros(10, 16), 55 | } 56 | attn.load_state_dict(state_dict=state_dict) 57 | assert torch.equal(attn.gating_factor, torch.full((1, 4, 1, 1), 0.42)) 58 | -------------------------------------------------------------------------------- /tests/test_adapter_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import pytest 4 | import sys 5 | 6 | 7 | @pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.") 8 | @pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"]) 9 | def test_config_identical(model_size, lit_llama): 10 | import torch.nn as nn 11 | import lit_llama.adapter as llama_adapter 12 | from lit_llama.adapter_v2 import adapter_v2_linear_with_bias_and_scale 13 | import lit_llama.model as llama 14 | from lit_llama.utils import EmptyInitOnDevice 15 | 16 | with EmptyInitOnDevice(): 17 | llama_model = llama.LLaMA.from_name(model_size) 18 | adapter_model = llama_adapter.LLaMA.from_name(model_size) 19 | 20 | for module in adapter_model.modules(): 21 | if isinstance(module, nn.Linear): 22 | adapter_v2_linear_with_bias_and_scale(module) 23 | 24 | print(adapter_model.transformer.h[2].attn.c_attn.adapter_bias) 25 | assert not hasattr(llama_model.transformer.h[2].attn.c_attn, 'adapter_bias') 26 | assert not hasattr(llama_model.transformer.h[2].attn.c_attn, 'adapter_scale') 27 | assert hasattr(adapter_model.transformer.h[2].attn.c_attn, 'adapter_bias') 28 | assert hasattr(adapter_model.transformer.h[2].attn.c_attn, 'adapter_scale') -------------------------------------------------------------------------------- /tests/test_generate.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import functools 4 | import subprocess 5 | import sys 6 | from contextlib import contextmanager, redirect_stdout 7 | from io import StringIO 8 | from pathlib import Path 9 | from unittest import mock 10 | from unittest.mock import Mock, call, ANY 11 | 12 | import torch 13 | 14 | wd = Path(__file__).parent.parent.absolute() 15 | 16 | 17 | @functools.lru_cache(maxsize=1) 18 | def load_generate_script(): 19 | sys.path.append(str(wd)) 20 | 21 | import generate as generate 22 | 23 | return generate 24 | 25 | 26 | def test_generate(): 27 | generate = load_generate_script() 28 | 29 | from lit_llama.model import LLaMA, LLaMAConfig 30 | 31 | T, C = 5, 3 32 | logits = torch.randn(T, C) 33 | input_idx = torch.randint(10, size=(T,)) 34 | 35 | config = LLaMAConfig(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) 36 | model = LLaMA(config) 37 | max_new_tokens = 20 38 | 39 | multinomial_results = [] 40 | original_multinomial = torch.multinomial 41 | 42 | def multinomial(*args, **kwargs): 43 | out = original_multinomial(*args, **kwargs) 44 | multinomial_results.append(out) 45 | return out 46 | 47 | with mock.patch("torch.multinomial", multinomial): 48 | out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10, top_k=4) 49 | 50 | assert out.size(0) == T + max_new_tokens 51 | multinomial_results = torch.hstack(multinomial_results) 52 | expected = torch.cat((input_idx, multinomial_results)) 53 | assert out.shape == expected.shape 54 | torch.testing.assert_close(out, expected) 55 | 56 | 57 | @mock.patch("torch.cuda.is_bf16_supported", return_value=False) 58 | def test_main(tmp_path, monkeypatch): 59 | generate = load_generate_script() 60 | 61 | checkpoint_path = tmp_path / "ckpt" 62 | checkpoint_path.touch() 63 | tokenizer_path = tmp_path / "tokenizer" 64 | tokenizer_path.touch() 65 | 66 | class FabricMock(Mock): 67 | @property 68 | def device(self): 69 | return torch.device("cpu") 70 | 71 | @contextmanager 72 | def init_module(self, empty_init): 73 | yield 74 | 75 | monkeypatch.setattr(generate.L, "Fabric", FabricMock) 76 | model_mock = Mock() 77 | monkeypatch.setattr(generate.LLaMA, "from_name", model_mock) 78 | lookup_mock = Mock(return_value="1T") 79 | monkeypatch.setattr(generate, "llama_model_lookup", lookup_mock) 80 | load_mock = Mock() 81 | load_mock.return_value = load_mock 82 | load_mock.__enter__ = Mock() 83 | load_mock.__exit__ = Mock() 84 | monkeypatch.setattr(generate.torch, "load", load_mock) 85 | monkeypatch.setattr(generate, "lazy_load", load_mock) 86 | tokenizer_mock = Mock() 87 | tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]]) 88 | tokenizer_mock.return_value.decode.return_value = "foo bar baz" 89 | monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) 90 | generate_mock = Mock() 91 | generate_mock.return_value = torch.tensor([[3, 2, 1]]) 92 | monkeypatch.setattr(generate, "generate", generate_mock) 93 | 94 | num_samples = 2 95 | out = StringIO() 96 | with redirect_stdout(out): 97 | generate.main( 98 | checkpoint_path=checkpoint_path, 99 | tokenizer_path=tokenizer_path, 100 | temperature=2.0, 101 | top_k=2, 102 | num_samples=num_samples, 103 | ) 104 | 105 | model_mock.assert_called_once_with("1T") 106 | load_mock.assert_called_once_with(checkpoint_path) 107 | tokenizer_mock.assert_called_once_with(tokenizer_path) 108 | assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples 109 | assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) 110 | assert generate_mock.mock_calls == [call(ANY, ANY, 50, temperature=2.0, top_k=2)] * num_samples 111 | # only the generated result is printed to stdout 112 | assert out.getvalue() == "foo bar baz\n" * num_samples 113 | 114 | 115 | def test_cli(): 116 | cli_path = wd / "generate.py" 117 | output = subprocess.check_output([sys.executable, cli_path, "-h"]) 118 | output = str(output.decode()) 119 | assert "Generates text samples" in output 120 | -------------------------------------------------------------------------------- /tests/test_lora.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import torch 4 | 5 | 6 | def test_lora_layer_replacement(lit_llama): 7 | from lit_llama.lora import lora, CausalSelfAttention as LoRACausalSelfAttention 8 | from lit_llama.model import LLaMA, LLaMAConfig 9 | 10 | config = LLaMAConfig() 11 | config.n_layer = 2 12 | config.n_head = 4 13 | config.n_embd = 8 14 | config.block_size = 8 15 | config.vocab_size = 8 16 | 17 | with lora(r=8, alpha=8, dropout=0.1): 18 | model = LLaMA(config) 19 | 20 | assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention) 21 | assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention) 22 | 23 | 24 | def test_lora_merge_unmerge(lit_llama): 25 | from lit_llama.lora import lora, mark_only_lora_as_trainable 26 | from lit_llama.model import LLaMA, LLaMAConfig 27 | 28 | config = LLaMAConfig(n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8) 29 | 30 | with lora(r=8, alpha=8, dropout=0.1): 31 | model = LLaMA(config) 32 | 33 | initial_weight = model.transformer.h[0].attn.c_attn.weight.clone() 34 | model.train() 35 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight) 36 | 37 | # perform an update to the LoRA weights 38 | mark_only_lora_as_trainable(model) 39 | optimizer = torch.optim.SGD(model.parameters(), lr=1.0) 40 | model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64)).sum().backward() 41 | optimizer.step() 42 | optimizer.zero_grad() 43 | # the weight remains unchanged (only lora A and B change) 44 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight) 45 | 46 | # 'merge' and then 'unmerge' should neutralize themselves 47 | weight_before = model.transformer.h[0].attn.c_attn.weight.clone() 48 | model.eval() 49 | assert not torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_before) 50 | model.train() 51 | # note: numerically, `W + (A * B) - (A * B) == W` does not hold exactly 52 | assert torch.allclose(model.transformer.h[0].attn.c_attn.weight, weight_before) 53 | 54 | # calling eval/train multiple times in a row should not merge/unmerge multiple times 55 | model.eval() 56 | assert model.transformer.h[0].attn.c_attn.merged 57 | weight_after = model.transformer.h[0].attn.c_attn.weight.clone() 58 | model.eval() 59 | model.eval() 60 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after) 61 | model.train() 62 | assert not model.transformer.h[0].attn.c_attn.merged 63 | weight_after = model.transformer.h[0].attn.c_attn.weight.clone() 64 | model.train() 65 | model.train() 66 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after) 67 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import torch 4 | import pytest 5 | import sys 6 | 7 | 8 | def copy_mlp(llama_mlp, orig_llama_mlp) -> None: 9 | orig_llama_mlp.w1.weight.copy_(llama_mlp.c_fc1.weight) 10 | orig_llama_mlp.w3.weight.copy_(llama_mlp.c_fc2.weight) 11 | orig_llama_mlp.w2.weight.copy_(llama_mlp.c_proj.weight) 12 | 13 | 14 | def copy_attention(llama_attn, orig_llama_attn) -> None: 15 | n_embd = llama_attn.c_attn.weight.shape[1] 16 | orig_llama_attn.wq.weight.copy_(llama_attn.c_attn.weight[:n_embd]) 17 | orig_llama_attn.wk.weight.copy_(llama_attn.c_attn.weight[n_embd:-n_embd]) 18 | orig_llama_attn.wv.weight.copy_(llama_attn.c_attn.weight[-n_embd:]) 19 | orig_llama_attn.wo.weight.copy_(llama_attn.c_proj.weight) 20 | 21 | 22 | def copy_block(llama_block, orig_llama_block) -> None: 23 | orig_llama_block.attention_norm.weight.copy_(llama_block.rms_1.scale) 24 | copy_attention(llama_block.attn, orig_llama_block.attention) 25 | orig_llama_block.ffn_norm.weight.copy_(llama_block.rms_2.scale) 26 | copy_mlp(llama_block.mlp, orig_llama_block.feed_forward) 27 | 28 | 29 | def copy_weights(llama_model, orig_llama_model) -> None: 30 | orig_llama_model.tok_embeddings.weight.copy_(llama_model.transformer.wte.weight) 31 | for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers): 32 | copy_block(llama_block, orig_llama_block) 33 | orig_llama_model.norm.weight.copy_(llama_model.transformer.ln_f.scale) 34 | orig_llama_model.output.weight.copy_(llama_model.lm_head.weight) 35 | 36 | 37 | @torch.no_grad() 38 | @pytest.mark.parametrize("kv_cache", (False, True)) 39 | def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None: 40 | block_size = 64 41 | vocab_size = 32000 42 | n_layer = 16 43 | n_head = 16 44 | n_embd = 32 45 | batch_size = 3 46 | 47 | llama_config = lit_llama.LLaMAConfig( 48 | block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd 49 | ) 50 | orig_llama_config = orig_llama.ModelArgs( 51 | dim=n_embd, 52 | n_layers=n_layer, 53 | n_heads=n_head, 54 | vocab_size=vocab_size, 55 | norm_eps=1e-5, 56 | max_seq_len=block_size, 57 | max_batch_size=batch_size, 58 | ) 59 | 60 | seq_len = orig_llama_config.max_seq_len 61 | token_sample = torch.randint(0, orig_llama_config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) 62 | 63 | llama_model = lit_llama.LLaMA(llama_config) 64 | llama_model.apply(llama_model._init_weights) 65 | orig_llama_model = orig_llama.Transformer(orig_llama_config) 66 | 67 | copy_weights(llama_model, orig_llama_model) 68 | 69 | orig_llama_embed = orig_llama_model.tok_embeddings(token_sample) 70 | llama_embed = llama_model.transformer.wte(token_sample) 71 | assert torch.allclose(orig_llama_embed, llama_embed) 72 | 73 | llama_rope = llama_model.build_rope_cache(token_sample) 74 | llama_mask = llama_model.build_mask_cache(token_sample) 75 | orig_llama_mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) 76 | orig_llama_mask = torch.triu(orig_llama_mask, diagonal=1) 77 | if kv_cache: 78 | orig_llama_block_out = orig_llama_model.layers[0]( 79 | orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask 80 | ) 81 | theirs_k_cache = orig_llama_model.layers[0].attention.cache_k 82 | theirs_v_cache = orig_llama_model.layers[0].attention.cache_v 83 | head_size = n_embd // n_head 84 | kv_cache_shape = (batch_size, n_head, block_size, head_size) 85 | ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape) 86 | (llama_block_out, ours_kv_cache) = llama_model.transformer.h[0]( 87 | llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache 88 | ) 89 | ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3) 90 | ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3) 91 | torch.testing.assert_close(ours_k_cache, theirs_k_cache) 92 | torch.testing.assert_close(ours_v_cache, theirs_v_cache) 93 | else: 94 | orig_llama_block_out = orig_llama_model.layers[0]( 95 | orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask 96 | ) 97 | (llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len) 98 | assert torch.allclose(orig_llama_block_out, llama_block_out) 99 | 100 | expected = orig_llama_model(token_sample, 0) 101 | out = llama_model(token_sample) 102 | assert torch.allclose(out, expected) 103 | 104 | 105 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 106 | @torch.no_grad() 107 | def test_bfloat16_llama_init(lit_llama, orig_llama) -> None: 108 | from lit_llama.utils import EmptyInitOnDevice 109 | 110 | block_size = 64 111 | vocab_size = 32000 112 | n_layer = 16 113 | n_head = 16 114 | n_embd = 32 115 | 116 | llama_config = lit_llama.LLaMAConfig( 117 | block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd 118 | ) 119 | llama_model = lit_llama.LLaMA(llama_config) 120 | llama_model.apply(llama_model._init_weights) 121 | 122 | batch_size = 3 123 | 124 | token_sample = torch.randint(0, vocab_size, size=(batch_size, block_size), dtype=torch.int64) 125 | 126 | expected = llama_model(token_sample) 127 | 128 | with EmptyInitOnDevice(device="cuda", dtype=torch.bfloat16): 129 | llama_model2 = lit_llama.LLaMA(llama_config) 130 | llama_model2.load_state_dict(llama_model.state_dict(keep_vars=True)) 131 | 132 | out = llama_model2(token_sample.cuda()).float().cpu() 133 | torch.testing.assert_close(out, expected, atol=5e-3, rtol=1e-3) 134 | 135 | 136 | def copy_adapter_weights(llama_model, orig_llama_model) -> None: 137 | # copy the gating parameter 138 | for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers): 139 | if hasattr(llama_block.attn, "gating_factor"): 140 | llama_block.attn.gating_factor.copy_(orig_llama_block.attention.gate) 141 | 142 | # In the original model, there is one embedding layer for all blocks combined 143 | orig_adapter_wte = orig_llama_model.adapter_query.weight.reshape( 144 | orig_llama_model.params.adapter_layer, orig_llama_model.params.adapter_len, orig_llama_model.params.dim 145 | ) 146 | 147 | # In ours, the embedding layer is split across the individual attention layers 148 | index = 0 149 | for llama_block in llama_model.transformer.h: 150 | if hasattr(llama_block.attn, "adapter_wte"): 151 | llama_block.attn.adapter_wte.weight.copy_(orig_adapter_wte[index]) 152 | index += 1 153 | 154 | 155 | def enable_gate(model): 156 | for name, param in model.named_parameters(): 157 | if "gating_factor" in name or "gate" in name: 158 | param.fill_(1) 159 | 160 | 161 | @torch.no_grad() 162 | def test_adapter_parity(orig_llama_adapter): 163 | """Test parity between our implementation of LLaMA-Adapter and the reference code.""" 164 | import lit_llama.adapter as lit_llama 165 | 166 | orig_llama = orig_llama_adapter 167 | 168 | block_size = 32 169 | vocab_size = 100 170 | n_layer = 2 171 | n_head = 4 172 | n_embd = 16 173 | adapter_prompt_length: int = 10 174 | adapter_start_layer: int = 0 175 | 176 | llama_config = lit_llama.LLaMAConfig( 177 | block_size=block_size, 178 | vocab_size=vocab_size, 179 | n_layer=n_layer, 180 | n_head=n_head, 181 | n_embd=n_embd, 182 | adapter_prompt_length=adapter_prompt_length, 183 | adapter_start_layer=adapter_start_layer, 184 | ) 185 | orig_llama_config = orig_llama.ModelArgs( 186 | dim=n_embd, 187 | n_layers=n_layer, 188 | n_heads=n_head, 189 | vocab_size=vocab_size, 190 | norm_eps=1e-5, 191 | max_seq_len=block_size, 192 | adapter_len=adapter_prompt_length, 193 | adapter_layer=(n_layer - adapter_start_layer), 194 | ) 195 | 196 | batch_size = 3 197 | token_sample = torch.randint( 198 | 0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64 199 | ) 200 | 201 | llama_model = lit_llama.LLaMA(llama_config) 202 | llama_model.apply(llama_model._init_weights) 203 | orig_llama_model = orig_llama.Transformer(orig_llama_config) 204 | 205 | copy_weights(llama_model, orig_llama_model) 206 | copy_adapter_weights(llama_model, orig_llama_model) 207 | 208 | # make the gate non-zero, otherwise the adapter is disabled and the model 209 | # identical to regular LLaMA 210 | enable_gate(llama_model) 211 | enable_gate(orig_llama_model) 212 | 213 | expected = orig_llama_model(token_sample, 0) 214 | out = llama_model(token_sample) 215 | assert torch.allclose(out, expected, atol=1e-5, rtol=1e-5) 216 | 217 | 218 | @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="torch.compile not supported on this platform") 219 | def test_model_compile(lit_llama): 220 | llama_config = lit_llama.LLaMAConfig(block_size=8, vocab_size=8, n_layer=2, n_head=2, n_embd=4) 221 | model = lit_llama.LLaMA(llama_config) 222 | model.apply(model._init_weights) 223 | 224 | model = torch.compile(model) 225 | 226 | sample = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) 227 | for _ in range(3): 228 | _ = model(sample) 229 | -------------------------------------------------------------------------------- /tests/test_packed_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import os 4 | from unittest.mock import MagicMock 5 | import requests 6 | 7 | from torch.utils.data import IterableDataset 8 | 9 | 10 | def train_tokenizer(destination_path): 11 | destination_path.mkdir(parents=True, exist_ok=True) 12 | 13 | # download the tiny shakespeare dataset 14 | input_file_path = destination_path / "input.txt" 15 | if not input_file_path.exists(): 16 | data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 17 | with open(input_file_path, "w") as f: 18 | f.write(requests.get(data_url).text) 19 | 20 | from lit_llama import Tokenizer 21 | Tokenizer.train( 22 | input=input_file_path, 23 | destination=destination_path, 24 | vocab_size=100, 25 | ) 26 | 27 | return destination_path / "tokenizer.model" 28 | 29 | 30 | def test_packed_dataset(tmp_path): 31 | tokenizer_path = train_tokenizer(tmp_path) 32 | 33 | from lit_llama import Tokenizer 34 | tokenizer = Tokenizer(tokenizer_path) 35 | 36 | texts = [ 37 | "The moment of truth is upon us.", 38 | "Time to open the fridge." 39 | ] 40 | 41 | from lit_llama.packed_dataset import PackedDatasetBuilder, PackedDataset, HDR_SIZE 42 | 43 | block_size = 10 44 | n_blocks = 2 45 | chunk_size = block_size * n_blocks 46 | 47 | builder = PackedDatasetBuilder( 48 | outdir=tmp_path, 49 | prefix="packed_dataset", 50 | chunk_size=chunk_size, 51 | sep_token=tokenizer.bos_id, 52 | dtype="auto", 53 | vocab_size=100, 54 | ) 55 | 56 | text_ids = [] 57 | 58 | for text in texts: 59 | text_ids = tokenizer.encode(text) 60 | assert text_ids[0] == tokenizer.bos_id 61 | builder.add_array(text_ids) 62 | 63 | filenames = builder.filenames 64 | 65 | assert len(filenames) == 2 66 | assert os.path.basename(filenames[0]) == "packed_dataset_0000000000.bin" 67 | assert os.path.basename(filenames[1]) == "packed_dataset_0000000001.bin" 68 | 69 | import numpy as np 70 | 71 | ex_tokenized = [ 72 | tokenizer.encode(text).numpy().astype(builder.dtype) 73 | for text in texts 74 | ] 75 | ex_tokenized = np.concatenate(ex_tokenized) 76 | ex_tokenized = ex_tokenized[:2 * chunk_size] 77 | 78 | for filename, el in zip(filenames, np.array_split(ex_tokenized, 2)): 79 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 80 | count = len(mmap) // np.dtype(builder.dtype).itemsize 81 | arr = np.frombuffer( 82 | mmap, dtype=builder.dtype, count=count, offset=0 83 | ) 84 | where_bos = np.where(arr == tokenizer.bos_id) 85 | # we expect two BOS tokens, one per file 86 | assert len(where_bos) == 1 87 | assert np.array_equal(arr, el) 88 | 89 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, shuffle=False) 90 | 91 | ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size) 92 | 93 | for item, el in zip(dataset, ex_split): 94 | assert np.array_equal(item, el) 95 | 96 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345) 97 | 98 | for i, item in enumerate(dataset): 99 | block_idxs = iter(dataset)._block_idxs 100 | assert np.array_equal(item, ex_split[block_idxs[i]]) 101 | 102 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345, wrap=True) 103 | 104 | for i, item in enumerate(dataset): 105 | if i > 24: 106 | break 107 | 108 | dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, seed=12345) 109 | 110 | for i, item in enumerate(dataset): 111 | block_idxs = iter(dataset)._block_idxs 112 | chunk_idx = i // n_blocks * n_blocks 113 | assert np.array_equal(item, ex_split[chunk_idx + block_idxs[i % n_blocks]]) 114 | 115 | block_size_ = block_size // 2 116 | ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size_) 117 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size_, seed=12345) 118 | 119 | for i, item in enumerate(dataset): 120 | block_idxs = iter(dataset)._block_idxs 121 | assert np.array_equal(item, ex_split[block_idxs[i]]) 122 | 123 | block_size_ = block_size // 3 124 | n_chunks = 2 125 | ex_chunks = np.split(ex_tokenized, n_chunks) 126 | n_splits = ex_tokenized.shape[0] // n_chunks // block_size_ 127 | ex_splits = [np.split(el[:n_splits * block_size_], n_splits) for el in ex_chunks] 128 | ex_split = sum(ex_splits, []) 129 | 130 | dataset = PackedDataset(filenames=filenames, n_chunks=n_chunks, block_size=block_size_, seed=12345) 131 | 132 | for i, item in enumerate(dataset): 133 | block_idxs = iter(dataset)._block_idxs 134 | assert np.array_equal(item, ex_split[block_idxs[i]]) 135 | 136 | 137 | class SimpleDataset(IterableDataset): 138 | def __init__(self, start, end): 139 | super().__init__() 140 | self._start = start 141 | self._end = end 142 | 143 | def __iter__(self): 144 | return iter(range(self._start, self._end)) 145 | 146 | 147 | def test_combined_dataset(tmp_path): 148 | from lit_llama.packed_dataset import CombinedDataset 149 | 150 | dataset1 = SimpleDataset(0, 10) 151 | dataset2 = SimpleDataset(10, 20) 152 | dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) 153 | 154 | res = [el for el in dataset] 155 | assert res == list(range(0, 10)) 156 | 157 | dataset1 = SimpleDataset(0, 10) 158 | dataset2 = SimpleDataset(10, 20) 159 | dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) 160 | 161 | res = [el for el in dataset] 162 | assert res == list(range(10, 20)) 163 | 164 | dataset1 = SimpleDataset(0, 10) 165 | dataset2 = SimpleDataset(10, 20) 166 | dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) 167 | 168 | res = [el for el in dataset] 169 | assert 9 in res or 19 in res 170 | if len(res) > 10: 171 | assert 0 in res and 10 in res 172 | 173 | 174 | def test_sharded_packed_dataset(monkeypatch): 175 | import lit_llama.packed_dataset 176 | from lit_llama.packed_dataset import PackedDataset 177 | 178 | dataset_iterator_mock = MagicMock() 179 | monkeypatch.setattr(lit_llama.packed_dataset, "PackedDatasetIterator", dataset_iterator_mock) 180 | filenames = [str(i) for i in range(10)] 181 | 182 | # world_size = 1, rank = 0 183 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2)) 184 | assert dataset_iterator_mock.call_args[1]["filenames"] == filenames 185 | dataset_iterator_mock.reset_mock() 186 | # world_size = 2, rank = 0 187 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=0)) 188 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "2", "4", "6", "8"] 189 | dataset_iterator_mock.reset_mock() 190 | # world_size = 2, rank = 1 191 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=1)) 192 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "3", "5", "7", "9"] 193 | dataset_iterator_mock.reset_mock() 194 | 195 | # world_size = 3, rank = 0 (dataset size not cleanly divisible by world size) 196 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=0)) 197 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "3", "6"] 198 | dataset_iterator_mock.reset_mock() 199 | # world_size = 3, rank = 1 (dataset size not cleanly divisible by world size) 200 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=1)) 201 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "4", "7"] 202 | dataset_iterator_mock.reset_mock() 203 | # world_size = 3, rank = 2 (dataset size not cleanly divisible by world size) 204 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=2)) 205 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["2", "5", "8"] 206 | -------------------------------------------------------------------------------- /tests/test_prepare_redpajama.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import json 4 | import os 5 | import subprocess 6 | import sys 7 | from pathlib import Path 8 | from unittest import mock 9 | from unittest.mock import Mock, call, ANY 10 | 11 | wd = (Path(__file__).parent.parent / "scripts").absolute() 12 | 13 | import requests 14 | 15 | 16 | def train_tokenizer(destination_path): 17 | destination_path.mkdir(parents=True, exist_ok=True) 18 | 19 | # download the tiny shakespeare dataset 20 | input_file_path = destination_path / "input.txt" 21 | if not input_file_path.exists(): 22 | data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 23 | with open(input_file_path, "w") as f: 24 | f.write(requests.get(data_url).text) 25 | 26 | from lit_llama import Tokenizer 27 | Tokenizer.train(input=input_file_path, destination=destination_path, vocab_size=100) 28 | 29 | return destination_path / "tokenizer.model" 30 | 31 | 32 | def test_prepare_sample(tmp_path): 33 | sys.path.append(str(wd)) 34 | 35 | tokenizer_path = train_tokenizer(tmp_path) 36 | 37 | sample_path = tmp_path / "sample" 38 | source_path = sample_path / "source" 39 | dest_path = sample_path / "dest" 40 | 41 | source_path.mkdir(parents=True, exist_ok=True) 42 | 43 | sample = { 44 | "meta": {"some": "info"}, 45 | "text": "some text" 46 | } 47 | 48 | jsonl_sample = "\n".join([json.dumps(el) for el in [sample] * 2]) 49 | 50 | import prepare_redpajama 51 | 52 | for filename in prepare_redpajama.filenames_sample: 53 | with open(source_path / filename, "w") as f: 54 | f.write(jsonl_sample) 55 | 56 | prepare_redpajama.prepare(source_path=source_path, tokenizer_path=tokenizer_path, destination_path=dest_path, sample=True) 57 | 58 | bin_files = [el.replace(".jsonl", "_0000000000.bin") for el in prepare_redpajama.filenames_sample] 59 | 60 | assert set(os.listdir(dest_path)) == set(bin_files) 61 | 62 | from lit_llama import Tokenizer 63 | from lit_llama.packed_dataset import PackedDataset 64 | 65 | tokenizer = Tokenizer(tokenizer_path) 66 | 67 | # artificially set block_size to fit the text 68 | block_size = len(tokenizer.encode("some text")) 69 | 70 | for filename in bin_files: 71 | filenames = [os.path.join(dest_path, filename)] 72 | dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, shuffle=False) 73 | dataset_iter = iter(dataset) 74 | assert tokenizer.decode(next(dataset_iter)) == "some text" 75 | assert tokenizer.decode(next(dataset_iter)) == "some text" 76 | 77 | 78 | def test_prepare_full(tmp_path): 79 | sys.path.append(str(wd)) 80 | 81 | tokenizer_path = train_tokenizer(tmp_path) 82 | 83 | full_path = tmp_path / "full" 84 | source_path = full_path / "source" 85 | dest_path = full_path / "dest" 86 | 87 | source_path.mkdir(parents=True, exist_ok=True) 88 | 89 | sample = { 90 | "meta": {"some": "info"}, 91 | "text": "some text" 92 | } 93 | 94 | jsonl_sample = "\n".join([json.dumps(el) for el in [sample] * 2]) 95 | 96 | import prepare_redpajama 97 | 98 | arxiv_file = source_path / "arxiv" / "arxiv_0.jsonl" 99 | arxiv_file.parent.mkdir(parents=True, exist_ok=True) 100 | with open(arxiv_file, "w") as f: 101 | f.write(jsonl_sample) 102 | 103 | import zstandard as zstd 104 | 105 | cc_file = source_path / "common_crawl" / "cc_0.jsonl" 106 | cc_file.parent.mkdir(parents=True, exist_ok=True) 107 | with zstd.open(cc_file, "wt", encoding="utf-8") as f: 108 | f.write(jsonl_sample) 109 | 110 | filename_sets = { 111 | "arxiv": "arxiv/arxiv*", 112 | "common_crawl": "common_crawl/*", 113 | } 114 | 115 | with mock.patch("prepare_redpajama.filename_sets", filename_sets): 116 | prepare_redpajama.prepare(source_path=source_path, tokenizer_path=tokenizer_path, destination_path=dest_path, sample=False) 117 | 118 | all_names = prepare_redpajama.filename_sets.keys() 119 | bin_files = [el + "_0000000000.bin" for el in all_names] 120 | 121 | assert set(os.listdir(dest_path)) == set(bin_files) 122 | 123 | from lit_llama import Tokenizer 124 | from lit_llama.packed_dataset import PackedDataset 125 | 126 | tokenizer = Tokenizer(tokenizer_path) 127 | 128 | # artificially set block_size to fit the text 129 | block_size = len(tokenizer.encode("some text")) 130 | 131 | filenames = [os.path.join(dest_path, el) for el in bin_files] 132 | 133 | for filename in filenames: 134 | dataset = PackedDataset(filenames=[filename], n_chunks=1, block_size=block_size, shuffle=False) 135 | dataset_iter = iter(dataset) 136 | assert tokenizer.decode(next(dataset_iter)) == "some text" 137 | assert tokenizer.decode(next(dataset_iter)) == "some text" 138 | 139 | 140 | def test_cli(): 141 | cli_path = wd / "prepare_redpajama.py" 142 | output = subprocess.check_output([sys.executable, cli_path, "-h"]) 143 | output = str(output.decode()) 144 | assert 'Prepare the "Red Pajama"' in output 145 | -------------------------------------------------------------------------------- /tests/test_prepare_shakespeare.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import os 4 | import subprocess 5 | import sys 6 | from pathlib import Path 7 | 8 | wd = (Path(__file__).parent.parent / "scripts").absolute() 9 | 10 | 11 | def test_prepare(tmp_path): 12 | sys.path.append(str(wd)) 13 | 14 | import prepare_shakespeare 15 | 16 | prepare_shakespeare.prepare(tmp_path) 17 | 18 | assert set(os.listdir(tmp_path)) == {"train.bin", "tokenizer.model", "tokenizer.vocab", "input.txt", "val.bin"} 19 | 20 | 21 | def test_cli(): 22 | cli_path = wd / "prepare_shakespeare.py" 23 | output = subprocess.check_output([sys.executable, cli_path, "-h"]) 24 | output = str(output.decode()) 25 | assert 'Prepare the "Tiny Shakespeare"' in output 26 | -------------------------------------------------------------------------------- /tests/test_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import torch 4 | 5 | 6 | @torch.no_grad() 7 | def test_rmsnorm(lit_llama, orig_llama) -> None: 8 | block_size = 16 9 | vocab_size = 16 10 | 11 | sample = torch.rand(size=(2, block_size, vocab_size), dtype=torch.float32) 12 | 13 | eps = 1e-6 14 | orig_llama_rmsnorm = orig_llama.RMSNorm(vocab_size, eps=eps)(sample) 15 | llama_rmsnorm = lit_llama.RMSNorm(vocab_size, eps=eps)(sample) 16 | 17 | assert torch.allclose(orig_llama_rmsnorm, llama_rmsnorm) 18 | -------------------------------------------------------------------------------- /tests/test_rope.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import torch 4 | 5 | 6 | @torch.no_grad() 7 | def test_rope(lit_llama, orig_llama) -> None: 8 | torch.manual_seed(1) 9 | 10 | bs, seq_len, n_head, n_embed = 1, 6, 2, 8 11 | x = torch.randint(0, 10000, size=(bs, seq_len, n_head, n_embed // n_head)).float() 12 | 13 | freqs_cis = orig_llama.precompute_freqs_cis(n_embed // n_head, seq_len) 14 | llama_rope_cache = lit_llama.build_rope_cache(seq_len, n_embed // n_head, dtype=x.dtype, device=x.device) 15 | torch.testing.assert_close(freqs_cis, torch.view_as_complex(llama_rope_cache)) 16 | 17 | llama_x_rope = lit_llama.apply_rope(x, llama_rope_cache) 18 | orig_llama_x_rope, _ = orig_llama.apply_rotary_emb(x, x, freqs_cis) 19 | torch.testing.assert_close(llama_x_rope, orig_llama_x_rope) 20 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import tempfile 4 | import pathlib 5 | 6 | import torch 7 | 8 | 9 | class ATensor(torch.Tensor): 10 | pass 11 | 12 | 13 | def test_lazy_load_basic(lit_llama): 14 | import lit_llama.utils 15 | 16 | with tempfile.TemporaryDirectory() as tmpdirname: 17 | m = torch.nn.Linear(5, 3) 18 | path = pathlib.Path(tmpdirname) 19 | fn = str(path / "test.pt") 20 | torch.save(m.state_dict(), fn) 21 | with lit_llama.utils.lazy_load(fn) as sd_lazy: 22 | assert "NotYetLoadedTensor" in str(next(iter(sd_lazy.values()))) 23 | m2 = torch.nn.Linear(5, 3) 24 | m2.load_state_dict(sd_lazy) 25 | 26 | x = torch.randn(2, 5) 27 | actual = m2(x) 28 | expected = m(x) 29 | torch.testing.assert_close(actual, expected) 30 | 31 | 32 | def test_lazy_load_subclass(lit_llama): 33 | import lit_llama.utils 34 | 35 | with tempfile.TemporaryDirectory() as tmpdirname: 36 | path = pathlib.Path(tmpdirname) 37 | fn = str(path / "test.pt") 38 | t = torch.randn(2, 3)[:, 1:] 39 | sd = { 40 | 1: t, 41 | 2: torch.nn.Parameter(t), 42 | 3: torch.Tensor._make_subclass(ATensor, t), 43 | } 44 | torch.save(sd, fn) 45 | with lit_llama.utils.lazy_load(fn) as sd_lazy: 46 | for k in sd.keys(): 47 | actual = sd_lazy[k] 48 | expected = sd[k] 49 | torch.testing.assert_close(actual._load_tensor(), expected) 50 | 51 | 52 | def test_incremental_write(tmp_path, lit_llama): 53 | import lit_llama.utils 54 | 55 | sd = {str(k): torch.randn(5, 10) for k in range(3)} 56 | sd_expected = {k: v.clone() for k, v in sd.items()} 57 | fn = str(tmp_path / "test.pt") 58 | with lit_llama.utils.incremental_save(fn) as f: 59 | sd["0"] = f.store_early(sd["0"]) 60 | sd["2"] = f.store_early(sd["2"]) 61 | f.save(sd) 62 | sd_actual = torch.load(fn) 63 | assert sd_actual.keys() == sd_expected.keys() 64 | for k, v_expected in sd_expected.items(): 65 | v_actual = sd_actual[k] 66 | torch.testing.assert_close(v_expected, v_actual) 67 | 68 | 69 | def test_find_multiple(lit_llama): 70 | from lit_llama.utils import find_multiple 71 | 72 | assert find_multiple(17, 5) == 20 73 | assert find_multiple(30, 7) == 35 74 | assert find_multiple(10, 2) == 10 75 | assert find_multiple(5, 10) == 10 76 | --------------------------------------------------------------------------------