├── .github
├── CODEOWNERS
├── azure-gpu-tests.yml
└── workflows
│ └── cpu-tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── evaluate
├── adapter.py
├── adapter_v2.py
├── full.py
└── lora.py
├── finetune
├── adapter.py
├── adapter_v2.py
├── full.py
└── lora.py
├── generate.py
├── generate
├── adapter.py
├── adapter_v2.py
├── full.py
└── lora.py
├── howto
├── convert_lora_weights.md
├── customize_paths.md
├── download_weights.md
├── finetune_adapter.md
├── finetune_adapter_v2.md
├── finetune_full.md
├── finetune_lora.md
├── inference.md
├── tpus.md
├── train_redpajama.md
└── unstructured_dataset.md
├── lit_llama
├── __init__.py
├── adapter.py
├── adapter_v2.py
├── lora.py
├── model.py
├── packed_dataset.py
├── quantization.py
├── tokenizer.py
└── utils.py
├── pretrain
├── redpajama.py
└── shakespeare.py
├── pyproject.toml
├── quantize
└── gptq.py
├── scripts
├── convert_checkpoint.py
├── convert_hf_checkpoint.py
├── convert_lora_weights.py
├── download.py
├── prepare_alpaca.py
├── prepare_any_text.py
├── prepare_dolly.py
├── prepare_redpajama.py
└── prepare_shakespeare.py
├── setup.py
└── tests
├── conftest.py
├── test_adapter.py
├── test_adapter_v2.py
├── test_generate.py
├── test_lora.py
├── test_model.py
├── test_packed_dataset.py
├── test_prepare_redpajama.py
├── test_prepare_shakespeare.py
├── test_rmsnorm.py
├── test_rope.py
└── test_utils.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @awaelchli @carmocca @lantiga
2 |
--------------------------------------------------------------------------------
/.github/azure-gpu-tests.yml:
--------------------------------------------------------------------------------
1 | # Python package
2 | # Create and test a Python package on multiple Python versions.
3 | # Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
4 | # https://docs.microsoft.com/azure/devops/pipelines/languages/python
5 |
6 | trigger:
7 | tags:
8 | include:
9 | - '*'
10 | branches:
11 | include:
12 | - "main"
13 | - "refs/tags/*"
14 |
15 | pr:
16 | branches:
17 | include:
18 | - "main"
19 |
20 | jobs:
21 | - job: testing
22 | # how long to run the job before automatically cancelling
23 | timeoutInMinutes: "20"
24 | # how much time to give 'run always even if cancelled tasks' before stopping them
25 | cancelTimeoutInMinutes: "2"
26 | pool: "lit-rtx-3090"
27 | variables:
28 | DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
29 | container:
30 | image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.0-cuda11.7.1"
31 | options: "--gpus=all --shm-size=8gb"
32 | workspace:
33 | clean: all
34 | steps:
35 |
36 | - bash: |
37 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
38 | cuda_ver=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
39 | echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$cuda_ver"
40 | displayName: 'set env. vars'
41 |
42 | - bash: |
43 | echo $CUDA_VISIBLE_DEVICES
44 | echo $CUDA_VERSION_MM
45 | lspci | egrep 'VGA|3D'
46 | whereis nvidia
47 | nvidia-smi
48 | which python && which pip
49 | python --version && pip --version && pip list
50 | python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
51 | displayName: 'Image info & NVIDIA'
52 |
53 | - script: pip install ".[all]" "pytest"
54 | displayName: 'Install dependencies'
55 |
56 | - bash: pytest -v --durations=10 --disable-pytest-warnings --strict-markers --color=yes
57 | displayName: 'Testing'
58 |
--------------------------------------------------------------------------------
/.github/workflows/cpu-tests.yml:
--------------------------------------------------------------------------------
1 | name: CPU tests
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
12 |
13 | defaults:
14 | run:
15 | shell: bash
16 |
17 | jobs:
18 | pytester:
19 | runs-on: ${{ matrix.os }}
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | os: ["ubuntu-22.04", "macos-13", "windows-2022"]
24 | python-version: ["3.10"]
25 | pkg-install: ["no", "yes"]
26 | timeout-minutes: 15
27 |
28 | steps:
29 | - uses: actions/checkout@v3
30 |
31 | - name: Set up Python ${{ matrix.python-version }}
32 | uses: actions/setup-python@v4
33 | with:
34 | python-version: ${{ matrix.python-version }}
35 | cache: 'pip'
36 | cache-dependency-path: |
37 | pyproject.toml
38 | setup.py
39 |
40 | - name: Install package & dependencies
41 | run: |
42 | pip install ".[all]" pytest
43 | pip list
44 |
45 | - name: Drop package itself
46 | if: matrix.pkg-install == 'no'
47 | run: pip uninstall -y lit-llama
48 |
49 | - name: Run tests
50 | run: pytest -v --durations=10
51 |
52 |
53 | testing-guardian:
54 | runs-on: ubuntu-latest
55 | needs: pytester
56 | if: always()
57 | steps:
58 | - run: echo "${{ needs.pytester.result }}"
59 | - name: failing...
60 | if: needs.pytester.result == 'failure'
61 | run: exit 1
62 | - name: cancelled or skipped...
63 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result)
64 | timeout-minutes: 1
65 | run: sleep 90
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | .DS_Store
4 | *.egg-info
5 | build
6 |
7 | # data
8 | data
9 | checkpoints
10 | out
11 | !data/shakespeare/prepare.py
12 | wandb
13 |
14 | # downloaded by our tests
15 | original_model.py
16 | original_adapter.py
17 |
18 | .ruff_cache/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 | # ⚡ Lit-LLaMA ️
5 |
6 |  [](https://dev.azure.com/Lightning-AI/lit%20Models/_build/latest?definitionId=49&branchName=main) [](https://github.com/Lightning-AI/lit-llama/blob/master/LICENSE) [](https://discord.gg/VptPCZkGNa)
7 |
8 |
9 | ⚠️ Warning: Not Actively Maintained
10 |
11 | This repository is no longer actively maintained. For a more up-to-date alternative, please visit the LitGPT project:
12 | https://github.com/Lightning-AI/litgpt, which serves as the successor to this repository.
13 |
14 | Feel free to explore, reuse, or fork, but be aware that no further updates or support will be provided.
15 |
16 |
17 |

18 |
19 |
20 |
21 | # ⚡ Lit-LLaMA ️
22 | Independent implementation of [LLaMA]() pretraining, finetuning, and inference code that is fully open source under the **Apache 2.0 license.**
23 |
24 | This implementation builds on [nanoGPT]().
25 |
26 | The open-source code in this repository works with the original LLaMA weights that are distributed by Meta under a [research-only license](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md#model-details).
27 |
28 | ## Looking for LLaMA 2?
29 |
30 | Meta AI has since released LLaMA 2. Additionally, new Apache 2.0 licensed weights are being released as part of the [Open LLaMA project](https://github.com/openlm-research/open_llama).
31 |
32 | To run LLaMA 2 weights, Open LLaMA weights, or Vicuna weights (among other LLaMA-like checkpoints), **check out the [Lit-GPT repository](https://github.com/Lightning-AI/lit-gpt)**.
33 |
34 | ## Why?
35 |
36 | We believe that AI should be fully open source and part of the collective knowledge.
37 |
38 | The original [LLaMA code](https://github.com/facebookresearch/llama) is [GPL licensed](https://github.com/facebookresearch/llama/blob/main/LICENSE) which means any project using it must also be released under GPL.
39 |
40 | This "taints" any other code and prevents integration with the rest of the ecosystem.
41 |
42 | **Lit-LLaMA solves that for good.**
43 |
44 |
45 |
46 | ## Design principles
47 | **Lit-LLaMA** is:
48 |
49 | - **Simple:** Single-file implementation without boilerplate.
50 | - **Correct:** Numerically equivalent to the original model.
51 | - **Optimized:** Runs on consumer hardware or at scale.
52 | - **Open-source:** No strings attached.
53 |
54 | ## Get involved!
55 | [Join our Discord](https://discord.gg/VptPCZkGNa) to build high-performance, truly open-source models for the common benefit of the community.
56 |
57 |
58 |
59 | ## Setup
60 |
61 | Clone the repo
62 |
63 | ```bash
64 | git clone https://github.com/Lightning-AI/lit-llama
65 | cd lit-llama
66 | ```
67 |
68 | install dependencies
69 |
70 | ```bash
71 | pip install -e ".[all]"
72 | ```
73 |
74 | You are all set! 🎉
75 |
76 |
77 |
78 | ## Use the model
79 |
80 | To generate text predictions, you need to download the model weights. **If you don't have them, check out our [guide](howto/download_weights.md).**
81 |
82 | Run inference:
83 |
84 | ```bash
85 | python generate.py --prompt "Hello, my name is"
86 | ```
87 |
88 | This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
89 |
90 | [Full guide for generating samples from the model](howto/inference.md).
91 |
92 | ### Run Lit-LLaMA on consumer devices
93 |
94 | On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
95 | For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
96 |
97 | ```bash
98 | python generate.py --quantize llm.int8 --prompt "Hello, my name is"
99 | ```
100 |
101 | See `python generate.py --help` for more options.
102 |
103 | You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
104 |
105 | ```bash
106 | python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
107 | ```
108 |
109 | GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled.
110 |
111 | With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file:
112 |
113 | ```bash
114 | python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth
115 | ```
116 |
117 | [Full guide for generating samples from the model](howto/inference.md).
118 |
119 | ## Finetune the model
120 |
121 | We provide a simple training scripts in `finetune/lora.py` and `finetune/adapter.py` that instruction-tunes a pretrained model on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset using the techniques of [LoRA](https://arxiv.org/abs/2106.09685) and [Adapter](https://arxiv.org/abs/2303.16199).
122 |
123 | 1. Download the data and generate a instruction tuning dataset:
124 |
125 | ```bash
126 | python scripts/prepare_alpaca.py
127 | ```
128 |
129 | 2. Run the finetuning script
130 |
131 | ```bash
132 | python finetune/lora.py
133 | ```
134 | or
135 | ```bash
136 | python finetune/adapter.py
137 | ```
138 |
139 | It is expected that you have downloaded the pretrained weights as described above.
140 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). Follow the instructions in the script to efficiently fit your GPU memory.
141 | Note: For some GPU models you might need to set `torch.backends.cuda.enable_flash_sdp(False)` (see comments at the top of the script).
142 |
143 | More details about each finetuning method and how you can apply it to your own data can be found in our technical how-to guides.
144 |
145 | ### Finetuning How-To Guides
146 |
147 | These technical tutorials illustrate how to run the finetuning code.
148 |
149 | - [Finetune with LoRA](howto/finetune_lora.md)
150 | - [Finetune with Adapters](howto/finetune_adapter.md)
151 |
152 | ### Understanding Finetuning -- Conceptual Tutorials
153 |
154 | Looking for conceptual tutorials and explanations? We have some additional articles below:
155 |
156 | - [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/)
157 |
158 | ## Pre-training
159 |
160 | We provide a simple training script based on Fabric if you want to venture into pre-training on RedPajama, a reproduction of the original LLaMA dataset.
161 | Conversion scripts for our optimized streaming `PackedDataset` are included.
162 |
163 | Follow this guide to start pre-training on the RedPajama dataset:
164 |
165 | - [Pretrain on RedPajama](howto/train_redpajama.md)
166 |
167 | ## Get involved!
168 |
169 | We are on a quest towards fully open source AI.
170 |
171 |
172 |
173 | Join us and start contributing, especially on the following areas:
174 |
175 | - [ ] [Pre-training](https://github.com/Lightning-AI/lit-llama/labels/pre-training)
176 | - [ ] [Fine-tuning (full and LoRA)](https://github.com/Lightning-AI/lit-llama/labels/fine-tuning)
177 | - [ ] [Quantization](https://github.com/Lightning-AI/lit-llama/labels/quantization)
178 | - [ ] [Sparsification](https://github.com/Lightning-AI/lit-llama/labels/sparsification)
179 |
180 | Look at `train.py` for a starting point towards pre-training / fine-tuning using [Lightning Fabric](https://lightning.ai/docs/fabric/stable/).
181 |
182 | We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.
183 |
184 | Unsure about contributing? Check out our [Contributing to Lit-LLaMA: A Hitchhiker’s Guide to the Quest for Fully Open-Source AI](https://lightning.ai/pages/community/tutorial/contributing-to-lit-llama-a-hitchhikers-guide-to-the-quest-for-fully-open-source-ai/) guide.
185 |
186 | Don't forget to [join our Discord](https://discord.gg/VptPCZkGNa)!
187 |
188 | ## Acknowledgements
189 |
190 | - [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT)
191 | - [@FacebookResearch](https://github.com/facebookresearch) for the original [LLaMA implementation](https://github.com/facebookresearch/llama)
192 | - [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
193 | - [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA)
194 | - [@IST-DASLab](https://github.com/IST-DASLab) for [GPTQ](https://github.com/IST-DASLab/gptq)
195 |
196 | ## License
197 |
198 | Lit-LLaMA is released under the [Apache 2.0](https://github.com/Lightning-AI/lightning-llama/blob/main/LICENSE) license.
199 |
--------------------------------------------------------------------------------
/evaluate/adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
5 | import math
6 | import sys
7 | import time
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import lightning as L
12 | import torch
13 | import tqdm
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from lit_llama import Tokenizer
20 | from lit_llama.adapter import LLaMA
21 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
22 | from scripts.prepare_alpaca import generate_prompt
23 |
24 | from datasets import load_dataset
25 |
26 | instruction_tuning = True
27 |
28 |
29 | def load_eval_data(dataset_name: str) -> str:
30 | # this mimics gptq datautils
31 | if dataset_name == "wikitext":
32 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
33 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
34 | testdata = "\n\n".join(testdata["text"])
35 | elif dataset_name == "ptb":
36 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
37 | testdata = "\n\n".join(testdata["sentence"])
38 | elif dataset_name == "c4":
39 | testdata = load_dataset(
40 | "allenai/c4",
41 | "allenai--c4",
42 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
43 | split="validation",
44 | )
45 | testdata = " ".join(testdata[:1100]["text"])
46 |
47 | else:
48 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
49 | return testdata
50 |
51 |
52 | @torch.inference_mode()
53 | def main(
54 | datasets: str = "wikitext,ptb,c4",
55 | *,
56 | # compilation fails as it does not support torch.complex64 for RoPE
57 | # compile: bool = False,
58 | accelerator: str = "auto",
59 | adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"),
60 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
61 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
62 | dtype: str = "float32",
63 | quantize: Optional[str] = None,
64 | ) -> None:
65 | """Generates text samples based on a pre-trained LLaMA model and tokenizer.
66 |
67 | Args:
68 | datasets: The datasets to use as a comma separated string
69 | # compile: Whether to compile the model.
70 | accelerator: The hardware to run on. Possible choices are:
71 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
72 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
73 | `finetune_adapter.py`.
74 | checkpoint_path: The checkpoint path to load.
75 | tokenizer_path: The tokenizer path to load.
76 | dtype: The tensor dtype for choosing the floating-point precision
77 | quantize: Whether to quantize the model and using which method:
78 | ``"llm.int8"``: LLM.int8() mode,
79 | ``"gptq.int4"``: GPTQ 4-bit mode.
80 | """
81 | assert adapter_path.is_file()
82 | assert checkpoint_path.is_file()
83 | assert tokenizer_path.is_file()
84 |
85 | fabric = L.Fabric(accelerator=accelerator, devices=1)
86 |
87 | dt = getattr(torch, dtype, None)
88 | if not isinstance(dt, torch.dtype):
89 | raise ValueError(f"{dtype} is not a valid dtype.")
90 | dtype = dt
91 |
92 | print("Loading model ...", file=sys.stderr)
93 | t0 = time.time()
94 | with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
95 | name = llama_model_lookup(pretrained_checkpoint)
96 |
97 | with EmptyInitOnDevice(
98 | device=fabric.device, dtype=dtype, quantization_mode=quantize
99 | ):
100 | model = LLaMA.from_name(name)
101 |
102 | # 1. Load the pretrained weights
103 | model.load_state_dict(pretrained_checkpoint, strict=False)
104 | # 2. Load the fine-tuned adapter weights
105 | model.load_state_dict(adapter_checkpoint, strict=False)
106 |
107 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
108 |
109 | model.eval()
110 |
111 | # if compile:
112 | # model = torch.compile(model)
113 |
114 | total_toks = 0
115 | model = fabric.setup_module(model)
116 |
117 | tokenizer = Tokenizer(tokenizer_path)
118 |
119 | for dsname in datasets.split(","):
120 | test_string = load_eval_data(dsname)
121 |
122 | if instruction_tuning:
123 | sample = {"instruction": test_string, "input": input}
124 | test_string = generate_prompt(sample)
125 |
126 | encoded_text = tokenizer.encode(
127 | test_string, bos=True, eos=False, device=fabric.device
128 | )
129 | encoded_text = encoded_text[
130 | None, : 256 * model.config.block_size
131 | ] # add batch dimension, trim like gptq implementation
132 | t0 = time.perf_counter()
133 |
134 | nlls = 0
135 | toks = 0
136 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
137 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
138 | inp = encoded_text[:, i : i + block_size]
139 | logits = model(inp)[0]
140 | nll = torch.nn.functional.cross_entropy(
141 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
142 | )
143 | toks += inp.size(1) - 1
144 | nlls += nll.item()
145 |
146 | print(encoded_text.shape, logits.shape)
147 | ppl = math.exp(nlls / toks)
148 | print(f"Perplexity on {dsname}: {ppl:.2f}")
149 | total_toks += toks
150 |
151 | t = time.perf_counter() - t0
152 | print(
153 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
154 | file=sys.stderr,
155 | )
156 | print(
157 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
158 | file=sys.stderr,
159 | )
160 |
161 |
162 | if __name__ == "__main__":
163 | from jsonargparse import CLI
164 |
165 | torch.set_float32_matmul_precision("high")
166 | CLI(main)
167 |
--------------------------------------------------------------------------------
/evaluate/adapter_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
5 | import math
6 | import sys
7 | import time
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import lightning as L
12 | import torch
13 | import tqdm
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from lit_llama import Tokenizer
20 | from lit_llama.adapter import LLaMA
21 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
22 | from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
23 | from scripts.prepare_alpaca import generate_prompt
24 |
25 | from datasets import load_dataset
26 |
27 |
28 | def load_eval_data(dataset_name: str) -> str:
29 | # this mimics gptq datautils
30 | if dataset_name == "wikitext":
31 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
32 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
33 | testdata = "\n\n".join(testdata["text"])
34 | elif dataset_name == "ptb":
35 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
36 | testdata = "\n\n".join(testdata["sentence"])
37 | elif dataset_name == "c4":
38 | testdata = load_dataset(
39 | "allenai/c4",
40 | "allenai--c4",
41 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
42 | split="validation",
43 | )
44 | testdata = " ".join(testdata[:1100]["text"])
45 |
46 | else:
47 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
48 | return testdata
49 |
50 |
51 | @torch.inference_mode()
52 | def main(
53 | datasets: str = "wikitext,ptb,c4",
54 | *,
55 | accelerator: str = "auto",
56 | adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"),
57 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
58 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
59 | dtype: str = "float32",
60 | quantize: Optional[str] = None,
61 | ) -> None:
62 | """Generates text samples based on a pre-trained LLaMA model and tokenizer.
63 |
64 | Args:
65 | datasets: The datasets to use as a comma separated string
66 | accelerator: The hardware to run on. Possible choices are:
67 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
68 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
69 | `finetune_adapter_v2.py`.
70 | checkpoint_path: The checkpoint path to load.
71 | tokenizer_path: The tokenizer path to load.
72 | dtype: The tensor dtype for choosing the floating-point precision
73 | quantize: Whether to quantize the model and using which method:
74 | ``"llm.int8"``: LLM.int8() mode,
75 | ``"gptq.int4"``: GPTQ 4-bit mode.
76 | """
77 | assert adapter_path.is_file()
78 | assert checkpoint_path.is_file()
79 | assert tokenizer_path.is_file()
80 |
81 | fabric = L.Fabric(accelerator=accelerator, devices=1)
82 |
83 | dt = getattr(torch, dtype, None)
84 | if not isinstance(dt, torch.dtype):
85 | raise ValueError(f"{dtype} is not a valid dtype.")
86 | dtype = dt
87 |
88 | print("Loading model ...", file=sys.stderr)
89 | t0 = time.time()
90 | with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
91 | name = llama_model_lookup(pretrained_checkpoint)
92 |
93 | with EmptyInitOnDevice(
94 | device=fabric.device, dtype=dtype, quantization_mode=quantize
95 | ):
96 | model = LLaMA.from_name(name)
97 | add_adapter_v2_parameters_to_linear_layers(model)
98 |
99 | # 1. Load the pretrained weights
100 | model.load_state_dict(pretrained_checkpoint, strict=False)
101 | # 2. Load the fine-tuned adapter weights
102 | model.load_state_dict(adapter_checkpoint, strict=False)
103 |
104 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
105 |
106 | model.eval()
107 |
108 | # if compile:
109 | # model = torch.compile(model)
110 |
111 | total_toks = 0
112 | model = fabric.setup_module(model)
113 |
114 | tokenizer = Tokenizer(tokenizer_path)
115 |
116 | for dsname in datasets.split(","):
117 | test_string = load_eval_data(dsname)
118 |
119 | sample = {"instruction": test_string, "input": input}
120 | test_string = generate_prompt(sample)
121 |
122 | encoded_text = tokenizer.encode(
123 | test_string, bos=True, eos=False, device=fabric.device
124 | )
125 | encoded_text = encoded_text[
126 | None, : 256 * model.config.block_size
127 | ] # add batch dimension, trim like gptq implementation
128 | t0 = time.perf_counter()
129 |
130 | nlls = 0
131 | toks = 0
132 |
133 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
134 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
135 | inp = encoded_text[:, i : i + block_size]
136 | logits = model(inp)[0]
137 | nll = torch.nn.functional.cross_entropy(
138 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
139 | )
140 | toks += inp.size(1) - 1
141 | nlls += nll.item()
142 |
143 | print(encoded_text.shape, logits.shape)
144 | ppl = math.exp(nlls / toks)
145 | print(f"Perplexity on {dsname}: {ppl:.2f}")
146 | total_toks += toks
147 |
148 | t = time.perf_counter() - t0
149 | print(
150 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
151 | file=sys.stderr,
152 | )
153 | print(
154 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
155 | file=sys.stderr,
156 | )
157 |
158 |
159 | if __name__ == "__main__":
160 | from jsonargparse import CLI
161 |
162 | torch.set_float32_matmul_precision("high")
163 | CLI(main)
164 |
--------------------------------------------------------------------------------
/evaluate/full.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
5 | import math
6 | import sys
7 | import time
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import lightning as L
12 | import torch
13 | import tqdm
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from lit_llama import LLaMA, Tokenizer
20 | from lit_llama.utils import EmptyInitOnDevice
21 |
22 | from datasets import load_dataset
23 |
24 |
25 | def load_eval_data(dataset_name: str) -> str:
26 | # this mimics gptq datautils
27 | if dataset_name == "wikitext":
28 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
29 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
30 | testdata = "\n\n".join(testdata["text"])
31 | elif dataset_name == "ptb":
32 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
33 | testdata = "\n\n".join(testdata["sentence"])
34 | elif dataset_name == "c4":
35 | testdata = load_dataset(
36 | "allenai/c4",
37 | "allenai--c4",
38 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
39 | split="validation",
40 | )
41 | testdata = " ".join(testdata[:1100]["text"])
42 |
43 | else:
44 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
45 | return testdata
46 |
47 |
48 | def main(
49 | datasets: str = "wikitext,ptb,c4",
50 | *,
51 | # compilation fails as it does not support torch.complex64 for RoPE
52 | # compile: bool = False,
53 | accelerator: str = "auto",
54 | checkpoint_path: Optional[Path] = None,
55 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
56 | model_size: str = "7B",
57 | dtype: str = "float32",
58 | quantize: Optional[str] = None,
59 | ) -> None:
60 | """Generates text samples based on a pre-trained LLaMA model and tokenizer.
61 |
62 | Args:
63 | datasets: The datasets to use as a comma separated string
64 | # compile: Whether to compile the model.
65 | accelerator: The hardware to run on. Possible choices are:
66 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
67 | checkpoint_path: The checkpoint path to load.
68 | tokenizer_path: The tokenizer path to load.
69 | dtype: The tensor dtype for choosing the floating-point precision
70 | quantize: Whether to quantize the model and using which method:
71 | ``"llm.int8"``: LLM.int8() mode,
72 | ``"gptq.int4"``: GPTQ 4-bit mode.
73 | """
74 | if not checkpoint_path:
75 | checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
76 | assert checkpoint_path.is_file()
77 | assert tokenizer_path.is_file()
78 |
79 | fabric = L.Fabric(accelerator=accelerator, devices=1)
80 |
81 | dt = getattr(torch, dtype, None)
82 | if not isinstance(dt, torch.dtype):
83 | raise ValueError(f"{dtype} is not a valid dtype.")
84 | dtype = dt
85 |
86 | with EmptyInitOnDevice(
87 | device=fabric.device, dtype=dtype, quantization_mode=quantize
88 | ):
89 | print("Loading model ...", file=sys.stderr)
90 | t0 = time.time()
91 | model = LLaMA.from_name(model_size)
92 | checkpoint = torch.load(checkpoint_path)
93 | model.load_state_dict(checkpoint)
94 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
95 |
96 | model.eval()
97 |
98 | # if compile:
99 | # model = torch.compile(model)
100 |
101 | total_toks = 0
102 | model = fabric.setup_module(model)
103 |
104 | tokenizer = Tokenizer(tokenizer_path)
105 |
106 | for dsname in datasets.split(","):
107 | test_string = load_eval_data(dsname)
108 | encoded_text = tokenizer.encode(
109 | test_string, bos=True, eos=False, device=fabric.device
110 | )
111 | encoded_text = encoded_text[
112 | None, : 256 * model.config.block_size
113 | ] # add batch dimension, trim like gptq implementation
114 | t0 = time.perf_counter()
115 |
116 | nlls = 0
117 | toks = 0
118 | with torch.inference_mode():
119 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
120 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
121 | inp = encoded_text[:, i : i + block_size]
122 | logits = model(inp)[0]
123 | nll = torch.nn.functional.cross_entropy(
124 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
125 | )
126 | toks += inp.size(1) - 1
127 | nlls += nll.item()
128 |
129 | print(encoded_text.shape, logits.shape)
130 | ppl = math.exp(nlls / toks)
131 | print(f"Perplexity on {dsname}: {ppl:.2f}")
132 | total_toks += toks
133 |
134 | t = time.perf_counter() - t0
135 | print(
136 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
137 | file=sys.stderr,
138 | )
139 | print(
140 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
141 | file=sys.stderr,
142 | )
143 |
144 |
145 | if __name__ == "__main__":
146 | from jsonargparse import CLI
147 |
148 | torch.set_float32_matmul_precision("high")
149 | CLI(main)
150 |
--------------------------------------------------------------------------------
/evaluate/lora.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
4 | # Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
5 | import math
6 | import sys
7 | import time
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import lightning as L
12 | import torch
13 | import tqdm
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from lit_llama import LLaMA, Tokenizer
20 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
21 | from lit_llama.lora import lora
22 | from scripts.prepare_alpaca import generate_prompt
23 |
24 | from datasets import load_dataset
25 |
26 | instruction_tuning = True
27 | lora_r = 8
28 | lora_alpha = 16
29 | lora_dropout = 0.05
30 |
31 |
32 | def load_eval_data(dataset_name: str) -> str:
33 | # this mimics gptq datautils
34 | if dataset_name == "wikitext":
35 | # traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
36 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
37 | testdata = "\n\n".join(testdata["text"])
38 | elif dataset_name == "ptb":
39 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
40 | testdata = "\n\n".join(testdata["sentence"])
41 | elif dataset_name == "c4":
42 | testdata = load_dataset(
43 | "allenai/c4",
44 | "allenai--c4",
45 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
46 | split="validation",
47 | )
48 | testdata = " ".join(testdata[:1100]["text"])
49 |
50 | else:
51 | raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
52 | return testdata
53 |
54 |
55 | def main(
56 | datasets: str = "wikitext,ptb,c4",
57 | *,
58 | # compilation fails as it does not support torch.complex64 for RoPE
59 | # compile: bool = False,
60 | accelerator: str = "auto",
61 | lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"),
62 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
63 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
64 | dtype: str = "float32",
65 | quantize: Optional[str] = None,
66 | ) -> None:
67 | """Generates text samples based on a pre-trained LLaMA model and tokenizer
68 | finetuned with LoRA.
69 |
70 | Args:
71 | datasets: The datasets to use as a comma separated string
72 | # compile: Whether to compile the model.
73 | accelerator: The hardware to run on. Possible choices are:
74 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
75 | lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
76 | `finetune_lora.py`.
77 | checkpoint_path: The checkpoint path to load.
78 | tokenizer_path: The tokenizer path to load.
79 | dtype: The tensor dtype for choosing the floating-point precision
80 | quantize: Whether to quantize the model and using which method:
81 | ``"llm.int8"``: LLM.int8() mode,
82 | ``"gptq.int4"``: GPTQ 4-bit mode.
83 | """
84 | assert lora_path.is_file()
85 | assert checkpoint_path.is_file()
86 | assert tokenizer_path.is_file()
87 |
88 | if quantize is not None:
89 | raise NotImplementedError("Quantization in LoRA is not supported yet")
90 |
91 | fabric = L.Fabric(accelerator=accelerator, devices=1)
92 |
93 | dt = getattr(torch, dtype, None)
94 | if not isinstance(dt, torch.dtype):
95 | raise ValueError(f"{dtype} is not a valid dtype.")
96 | dtype = dt
97 |
98 | print("Loading model ...", file=sys.stderr)
99 | t0 = time.time()
100 |
101 | with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
102 | name = llama_model_lookup(pretrained_checkpoint)
103 |
104 | with EmptyInitOnDevice(
105 | device=fabric.device, dtype=dtype, quantization_mode=quantize
106 | ), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
107 | model = LLaMA.from_name(name)
108 |
109 | # 1. Load the pretrained weights
110 | model.load_state_dict(pretrained_checkpoint, strict=False)
111 | # 2. Load the fine-tuned lora weights
112 | model.load_state_dict(lora_checkpoint, strict=False)
113 |
114 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
115 |
116 | model.eval()
117 |
118 | # if compile:
119 | # model = torch.compile(model)
120 |
121 | total_toks = 0
122 | model = fabric.setup_module(model)
123 |
124 | tokenizer = Tokenizer(tokenizer_path)
125 |
126 | for dsname in datasets.split(","):
127 | test_string = load_eval_data(dsname)
128 |
129 | if instruction_tuning:
130 | sample = {"instruction": test_string, "input": input}
131 | test_string = generate_prompt(sample)
132 |
133 | encoded_text = tokenizer.encode(
134 | test_string, bos=True, eos=False, device=fabric.device
135 | )
136 | encoded_text = encoded_text[
137 | None, : 256 * model.config.block_size
138 | ] # add batch dimension, trim like gptq implementation
139 | t0 = time.perf_counter()
140 |
141 | nlls = 0
142 | toks = 0
143 | with torch.inference_mode():
144 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
145 | for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
146 | inp = encoded_text[:, i : i + block_size]
147 | logits = model(inp)[0]
148 | nll = torch.nn.functional.cross_entropy(
149 | logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
150 | )
151 | toks += inp.size(1) - 1
152 | nlls += nll.item()
153 |
154 | print(encoded_text.shape, logits.shape)
155 | ppl = math.exp(nlls / toks)
156 | print(f"Perplexity on {dsname}: {ppl:.2f}")
157 | total_toks += toks
158 |
159 | t = time.perf_counter() - t0
160 | print(
161 | f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
162 | file=sys.stderr,
163 | )
164 | print(
165 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
166 | file=sys.stderr,
167 | )
168 |
169 |
170 | if __name__ == "__main__":
171 | from jsonargparse import CLI
172 |
173 | torch.set_float32_matmul_precision("high")
174 | CLI(main)
175 |
--------------------------------------------------------------------------------
/finetune/full.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | """
4 | Instruction-tuning on the Alpaca dataset using a regular finetuning procedure (updating all layers).
5 |
6 | Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
7 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
8 | """
9 | import sys
10 | from pathlib import Path
11 | import os
12 | import time
13 | from functools import partial
14 |
15 | import lightning as L
16 | from lightning.fabric.strategies import FSDPStrategy
17 | import numpy as np
18 | import torch
19 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
20 |
21 | # support running without installing as a package
22 | wd = Path(__file__).parent.parent.resolve()
23 | sys.path.append(str(wd))
24 |
25 | from generate import generate
26 | from lit_llama.model import Block, LLaMA, LLaMAConfig
27 | from lit_llama.tokenizer import Tokenizer
28 | from lit_llama.utils import save_model_checkpoint
29 | from scripts.prepare_alpaca import generate_prompt
30 |
31 |
32 | instruction_tuning = True
33 | eval_interval = 1000
34 | save_interval = 1000
35 | eval_iters = 100
36 | log_interval = 100
37 | devices = 4
38 |
39 | # Hyperparameters
40 | learning_rate = 3e-5
41 | batch_size = 128 / devices
42 | micro_batch_size = 4
43 | gradient_accumulation_iters = batch_size // micro_batch_size
44 | assert gradient_accumulation_iters > 0
45 | epoch_size = 50000 # train dataset size
46 | num_epochs = 5
47 | max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
48 | weight_decay = 0.0
49 | block_size = 512
50 | warmup_iters = 100
51 |
52 |
53 | def main(
54 | data_dir: str = "data/alpaca",
55 | pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
56 | out_dir: str = "out/full/alpaca",
57 | ):
58 |
59 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
60 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
61 |
62 | fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy)
63 | fabric.launch()
64 | fabric.seed_everything(1337 + fabric.global_rank)
65 |
66 | if fabric.global_rank == 0:
67 | os.makedirs(out_dir, exist_ok=True)
68 |
69 | train_data, val_data = load_datasets(data_dir=data_dir)
70 |
71 | config = LLaMAConfig.from_name("7B")
72 | config.block_size = block_size
73 |
74 | checkpoint = torch.load(pretrained_path)
75 |
76 | with fabric.device:
77 | torch.set_default_tensor_type(torch.HalfTensor)
78 | model = LLaMA(config).bfloat16()
79 | torch.set_default_tensor_type(torch.FloatTensor)
80 | model.load_state_dict(checkpoint, strict=False)
81 |
82 | model = fabric.setup_module(model)
83 |
84 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
85 | optimizer = fabric.setup_optimizers(optimizer)
86 |
87 | train(fabric, model, optimizer, train_data, val_data, out_dir)
88 |
89 | # Save the final checkpoint at the end of training
90 | save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth"))
91 |
92 |
93 | def train(
94 | fabric: L.Fabric,
95 | model: torch.nn.Module,
96 | optimizer: torch.optim.Optimizer,
97 | train_data: np.ndarray,
98 | val_data: np.ndarray,
99 | out_dir: str,
100 | ) -> None:
101 | """The training loop.
102 |
103 | Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
104 | """
105 | step_count = 0
106 | model.train()
107 |
108 | for iter_num in range(max_iters):
109 |
110 | is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0
111 |
112 | if step_count <= warmup_iters:
113 | # linear warmup
114 | lr = learning_rate * step_count / warmup_iters
115 | for param_group in optimizer.param_groups:
116 | param_group['lr'] = lr
117 |
118 | t0 = time.time()
119 |
120 | input_ids, targets = get_batch(fabric, train_data)
121 | with fabric.no_backward_sync(model, enabled=is_accumulating):
122 | logits = model(input_ids)
123 | loss = loss_fn(logits, targets)
124 | fabric.backward(loss / gradient_accumulation_iters)
125 |
126 | if not is_accumulating:
127 | optimizer.step()
128 | optimizer.zero_grad()
129 | step_count += 1
130 |
131 | if step_count % eval_interval == 0:
132 | val_loss = validate(fabric, model, val_data)
133 | fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
134 | fabric.barrier()
135 |
136 | if step_count % save_interval == 0:
137 | print(f"Saving weights to {out_dir}")
138 | save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
139 |
140 | dt = time.time() - t0
141 | if iter_num % log_interval == 0:
142 | fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
143 |
144 |
145 | def generate_response(model, instruction):
146 | tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
147 | sample = {"instruction": instruction, "input": ""}
148 | prompt = instruction
149 | if instruction_tuning:
150 | prompt = generate_prompt(sample)
151 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
152 |
153 | output = generate(
154 | model,
155 | idx=encoded,
156 | max_seq_length=block_size,
157 | max_new_tokens=100,
158 | )
159 | output = tokenizer.decode(output)
160 | return output # output.split("### Response:")[1].strip()
161 |
162 |
163 | @torch.no_grad()
164 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
165 | fabric.print("Validating ...")
166 | model.eval()
167 | losses = torch.zeros(eval_iters)
168 | for k in range(eval_iters):
169 | input_ids, targets = get_batch(fabric, val_data)
170 | logits = model(input_ids)
171 | loss = loss_fn(logits, targets)
172 | losses[k] = loss.item()
173 | out = losses.mean()
174 |
175 | # produce an example:
176 | instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
177 |
178 | output = generate_response(model, instruction)
179 | fabric.print(instruction)
180 | fabric.print(output)
181 |
182 | model.train()
183 | return out.item()
184 |
185 |
186 | def loss_fn(logits, targets):
187 | # shift the targets such that output n predicts token n+1
188 | logits = logits[..., :-1, :].contiguous()
189 | targets = targets[..., 1:].contiguous()
190 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
191 | return loss
192 |
193 |
194 | def get_batch(fabric: L.Fabric, data: list):
195 | ix = torch.randint(len(data), (micro_batch_size,))
196 |
197 | input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
198 | labels = [data[i]["labels"].type(torch.int64) for i in ix]
199 |
200 | max_len = max(len(s) for s in input_ids)
201 |
202 | def pad_right(x, pad_id):
203 | # pad right based on the longest sequence
204 | n = max_len - len(x)
205 | return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
206 |
207 | x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
208 | y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
209 | x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
210 | return x, y
211 |
212 |
213 | def load_datasets(data_dir):
214 | train_data = torch.load(os.path.join(data_dir, "train.pt"))
215 | val_data = torch.load(os.path.join(data_dir, "test.pt"))
216 | return train_data, val_data
217 |
218 |
219 | if __name__ == "__main__":
220 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
221 | # torch.backends.cuda.enable_flash_sdp(False)
222 | torch.set_float32_matmul_precision("high")
223 |
224 | from jsonargparse.cli import CLI
225 |
226 | CLI(main)
227 |
--------------------------------------------------------------------------------
/finetune/lora.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | """
4 | Instruction-tuning with LoRA on the Alpaca dataset.
5 |
6 | Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
7 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
8 | """
9 | import sys
10 | from pathlib import Path
11 | import os
12 | import time
13 |
14 | import lightning as L
15 | import numpy as np
16 | import torch
17 |
18 | # support running without installing as a package
19 | wd = Path(__file__).parent.parent.resolve()
20 | sys.path.append(str(wd))
21 |
22 | from generate import generate
23 | from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
24 | from lit_llama.model import LLaMA, LLaMAConfig
25 | from lit_llama.tokenizer import Tokenizer
26 | from scripts.prepare_alpaca import generate_prompt
27 |
28 |
29 | instruction_tuning = True
30 | eval_interval = 100
31 | save_interval = 100
32 | eval_iters = 100
33 | log_interval = 1
34 |
35 | # Hyperparameters
36 | learning_rate = 3e-4
37 | batch_size = 128
38 | micro_batch_size = 4
39 | gradient_accumulation_iters = batch_size // micro_batch_size
40 | assert gradient_accumulation_iters > 0
41 | max_iters = 50000 * 3 // micro_batch_size
42 | weight_decay = 0.0
43 | max_seq_length = 256 # see scripts/prepare_alpaca.py
44 | lora_r = 8
45 | lora_alpha = 16
46 | lora_dropout = 0.05
47 | warmup_iters = 100
48 |
49 |
50 | def main(
51 | data_dir: str = "data/alpaca",
52 | pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
53 | tokenizer_path: str = "checkpoints/lit-llama/tokenizer.model",
54 | out_dir: str = "out/lora/alpaca",
55 | ):
56 |
57 | fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true")
58 | fabric.launch()
59 | fabric.seed_everything(1337 + fabric.global_rank)
60 |
61 | if fabric.global_rank == 0:
62 | os.makedirs(out_dir, exist_ok=True)
63 |
64 | train_data, val_data = load_datasets(data_dir=data_dir)
65 |
66 | config = LLaMAConfig.from_name("7B")
67 | config.block_size = max_seq_length
68 |
69 | checkpoint = torch.load(pretrained_path)
70 |
71 | with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
72 | model = LLaMA(config)
73 | # strict=False because missing keys due to LoRA weights not contained in checkpoint state
74 | model.load_state_dict(checkpoint, strict=False)
75 |
76 | mark_only_lora_as_trainable(model)
77 |
78 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
79 | model, optimizer = fabric.setup(model, optimizer)
80 | train(fabric, model, optimizer, train_data, val_data, tokenizer_path, out_dir)
81 |
82 | # Save the final LoRA checkpoint at the end of training
83 | checkpoint = lora_state_dict(model)
84 | fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
85 |
86 |
87 | def train(
88 | fabric: L.Fabric,
89 | model: torch.nn.Module,
90 | optimizer: torch.optim.Optimizer,
91 | train_data: np.ndarray,
92 | val_data: np.ndarray,
93 | tokenizer_path: str,
94 | out_dir: str,
95 | ) -> None:
96 | """The training loop.
97 |
98 | Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
99 | """
100 | step_count = 0
101 |
102 | for iter_num in range(max_iters):
103 |
104 | if step_count <= warmup_iters:
105 | # linear warmup
106 | lr = learning_rate * step_count / warmup_iters
107 | for param_group in optimizer.param_groups:
108 | param_group['lr'] = lr
109 |
110 | t0 = time.time()
111 |
112 | input_ids, targets = get_batch(fabric, train_data)
113 | with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
114 | logits = model(input_ids)
115 | loss = loss_fn(logits, targets)
116 | fabric.backward(loss / gradient_accumulation_iters)
117 |
118 | if (iter_num + 1) % gradient_accumulation_iters == 0:
119 | optimizer.step()
120 | optimizer.zero_grad()
121 | step_count += 1
122 |
123 | if step_count % eval_interval == 0:
124 | val_loss = validate(fabric, model, val_data, tokenizer_path)
125 | fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
126 | fabric.barrier()
127 |
128 | if step_count % save_interval == 0:
129 | print(f"Saving LoRA weights to {out_dir}")
130 | # We are only saving the LoRA weights
131 | # TODO: Provide a function/script to merge the LoRA weights with pretrained weights
132 | checkpoint = lora_state_dict(model)
133 | fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
134 |
135 | dt = time.time() - t0
136 | if iter_num % log_interval == 0:
137 | fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
138 |
139 |
140 | def generate_response(model, instruction, tokenizer_path):
141 | tokenizer = Tokenizer(tokenizer_path)
142 | sample = {"instruction": instruction, "input": ""}
143 | prompt = instruction
144 | if instruction_tuning:
145 | prompt = generate_prompt(sample)
146 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
147 |
148 | output = generate(
149 | model,
150 | idx=encoded,
151 | max_seq_length=max_seq_length,
152 | max_new_tokens=100,
153 | )
154 | output = tokenizer.decode(output)
155 | return output # output.split("### Response:")[1].strip()
156 |
157 |
158 | @torch.no_grad()
159 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray, tokenizer_path: str) -> torch.Tensor:
160 | fabric.print("Validating ...")
161 | model.eval()
162 | losses = torch.zeros(eval_iters)
163 | for k in range(eval_iters):
164 | input_ids, targets = get_batch(fabric, val_data)
165 | logits = model(input_ids)
166 | loss = loss_fn(logits, targets)
167 | losses[k] = loss.item()
168 | out = losses.mean()
169 |
170 | # produce an example:
171 | instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
172 |
173 | output = generate_response(model, instruction, tokenizer_path)
174 | fabric.print(instruction)
175 | fabric.print(output)
176 |
177 | model.train()
178 | return out.item()
179 |
180 | def loss_fn(logits, targets):
181 | # shift the targets such that output n predicts token n+1
182 | logits = logits[..., :-1, :].contiguous()
183 | targets = targets[..., 1:].contiguous()
184 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
185 | return loss
186 |
187 |
188 | def get_batch(fabric: L.Fabric, data: list):
189 | ix = torch.randint(len(data), (micro_batch_size,))
190 |
191 | input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
192 | labels = [data[i]["labels"].type(torch.int64) for i in ix]
193 |
194 | max_len = max(len(s) for s in input_ids)
195 |
196 | def pad_right(x, pad_id):
197 | # pad right based on the longest sequence
198 | n = max_len - len(x)
199 | return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
200 |
201 | x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
202 | y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
203 | x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
204 | return x, y
205 |
206 |
207 | def load_datasets(data_dir):
208 | train_data = torch.load(os.path.join(data_dir, "train.pt"))
209 | val_data = torch.load(os.path.join(data_dir, "test.pt"))
210 | return train_data, val_data
211 |
212 |
213 | if __name__ == "__main__":
214 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
215 | # torch.backends.cuda.enable_flash_sdp(False)
216 | torch.set_float32_matmul_precision("high")
217 |
218 | from jsonargparse.cli import CLI
219 |
220 | CLI(main)
221 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | import time
5 | import warnings
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import lightning as L
10 | import torch
11 |
12 | # support running without installing as a package
13 | wd = Path(__file__).parent.parent.resolve()
14 | sys.path.append(str(wd))
15 |
16 | from lit_llama import LLaMA, Tokenizer
17 | from lit_llama.utils import lazy_load, llama_model_lookup, quantization
18 |
19 |
20 | @torch.no_grad()
21 | def generate(
22 | model: LLaMA,
23 | idx: torch.Tensor,
24 | max_new_tokens: int,
25 | *,
26 | max_seq_length: Optional[int] = None,
27 | temperature: float = 1.0,
28 | top_k: Optional[int] = None,
29 | eos_id: Optional[int] = None,
30 | ) -> torch.Tensor:
31 | """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
32 |
33 | The implementation of this function is modified from A. Karpathy's nanoGPT.
34 |
35 | Args:
36 | model: The model to use.
37 | idx: Tensor of shape (T) with indices of the prompt sequence.
38 | max_new_tokens: The number of new tokens to generate.
39 | max_seq_length: The maximum sequence length allowed.
40 | temperature: Scales the predicted logits by 1 / temperature
41 | top_k: If specified, only sample among the tokens with the k highest probabilities
42 | eos_id: If specified, stop generating any more token once the token is triggered
43 | """
44 | # create an empty tensor of the expected final shape and fill in the current tokens
45 | T = idx.size(0)
46 | T_new = T + max_new_tokens
47 | if max_seq_length is None:
48 | max_seq_length = min(T_new, model.config.block_size)
49 |
50 | device, dtype = idx.device, idx.dtype
51 | # create an empty tensor of the expected final shape and fill in the current tokens
52 | empty = torch.empty(T_new, dtype=dtype, device=device)
53 | empty[:T] = idx
54 | idx = empty
55 | input_pos = torch.arange(0, T, device=device)
56 |
57 | if idx.device.type == "xla":
58 | import torch_xla.core.xla_model as xm
59 |
60 | xm.mark_step()
61 |
62 | # generate max_new_tokens tokens
63 | for _ in range(max_new_tokens):
64 | x = idx.index_select(0, input_pos).view(1, -1)
65 |
66 | # forward
67 | logits = model(x, max_seq_length, input_pos)
68 | logits = logits[0, -1] / temperature
69 |
70 | # optionally crop the logits to only the top k options
71 | if top_k is not None:
72 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
73 | logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
74 |
75 | probs = torch.nn.functional.softmax(logits, dim=-1)
76 | idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
77 |
78 | # advance
79 | input_pos = input_pos[-1:] + 1
80 |
81 | if idx.device.type == "xla":
82 | xm.mark_step()
83 |
84 | # concatenate the new generation
85 | idx = idx.index_copy(0, input_pos, idx_next)
86 |
87 | # if token is triggered, return the output (stop generation)
88 | if idx_next == eos_id:
89 | return idx[:input_pos] # include the EOS token
90 |
91 | return idx
92 |
93 |
94 | def main(
95 | prompt: str = "Hello, my name is",
96 | *,
97 | num_samples: int = 1,
98 | max_new_tokens: int = 50,
99 | top_k: int = 200,
100 | temperature: float = 0.8,
101 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
102 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
103 | quantize: Optional[str] = None,
104 | ) -> None:
105 | """Generates text samples based on a pre-trained LLaMA model and tokenizer.
106 |
107 | Args:
108 | prompt: The prompt string to use for generating the samples.
109 | num_samples: The number of text samples to generate.
110 | max_new_tokens: The number of generation steps to take.
111 | top_k: The number of top most probable tokens to consider in the sampling process.
112 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random
113 | samples.
114 | checkpoint_path: The checkpoint path to load.
115 | tokenizer_path: The tokenizer path to load.
116 | quantize: Whether to quantize the model and using which method:
117 | ``"llm.int8"``: LLM.int8() mode,
118 | ``"gptq.int4"``: GPTQ 4-bit mode.
119 | """
120 | assert checkpoint_path.is_file(), checkpoint_path
121 | assert tokenizer_path.is_file(), tokenizer_path
122 |
123 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
124 | fabric = L.Fabric(devices=1, precision=precision)
125 |
126 | print("Loading model ...", file=sys.stderr)
127 | t0 = time.time()
128 | with lazy_load(checkpoint_path) as checkpoint:
129 | name = llama_model_lookup(checkpoint)
130 |
131 | with fabric.init_module(empty_init=True), quantization(mode=quantize):
132 | model = LLaMA.from_name(name)
133 |
134 | model.load_state_dict(checkpoint)
135 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
136 |
137 | model.eval()
138 | model = fabric.setup(model)
139 |
140 | tokenizer = Tokenizer(tokenizer_path)
141 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
142 | prompt_length = encoded.size(0)
143 |
144 | L.seed_everything(1234)
145 | for i in range(num_samples):
146 | t0 = time.perf_counter()
147 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
148 | t = time.perf_counter() - t0
149 |
150 | model.reset_cache()
151 | print(tokenizer.decode(y))
152 | tokens_generated = y.size(0) - prompt_length
153 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
154 | if fabric.device.type == "cuda":
155 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
156 |
157 |
158 | if __name__ == "__main__":
159 | from jsonargparse import CLI
160 |
161 | torch.set_float32_matmul_precision("high")
162 | warnings.filterwarnings(
163 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
164 | "ignore",
165 | message="ComplexHalf support is experimental and many operators don't support it yet"
166 | )
167 | warnings.filterwarnings(
168 | # Triggered in bitsandbytes/autograd/_functions.py:298
169 | "ignore",
170 | message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
171 | )
172 | CLI(main)
173 |
--------------------------------------------------------------------------------
/generate/adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | import time
5 | import warnings
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import lightning as L
10 | import torch
11 |
12 | # support running without installing as a package
13 | wd = Path(__file__).parent.parent.resolve()
14 | sys.path.append(str(wd))
15 |
16 | from generate import generate
17 | from lit_llama import Tokenizer
18 | from lit_llama.adapter import LLaMA
19 | from lit_llama.utils import lazy_load, llama_model_lookup, quantization
20 | from scripts.prepare_alpaca import generate_prompt
21 |
22 |
23 | def main(
24 | prompt: str = "What food do lamas eat?",
25 | input: str = "",
26 | adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"),
27 | pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
28 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
29 | quantize: Optional[str] = None,
30 | max_new_tokens: int = 100,
31 | top_k: int = 200,
32 | temperature: float = 0.8,
33 | ) -> None:
34 | """Generates a response based on a given instruction and an optional input.
35 | This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
36 | See `finetune_adapter.py`.
37 |
38 | Args:
39 | prompt: The prompt/instruction (Alpaca style).
40 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
41 | `finetune_adapter.py`.
42 | input: Optional input (Alpaca style).
43 | pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
44 | tokenizer_path: The tokenizer path to load.
45 | quantize: Whether to quantize the model and using which method:
46 | ``"llm.int8"``: LLM.int8() mode,
47 | ``"gptq.int4"``: GPTQ 4-bit mode.
48 | max_new_tokens: The number of generation steps to take.
49 | top_k: The number of top most probable tokens to consider in the sampling process.
50 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random
51 | samples.
52 | """
53 | assert adapter_path.is_file()
54 | assert pretrained_path.is_file()
55 | assert tokenizer_path.is_file()
56 |
57 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
58 | fabric = L.Fabric(devices=1, precision=precision)
59 |
60 | print("Loading model ...", file=sys.stderr)
61 | t0 = time.time()
62 | with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
63 | name = llama_model_lookup(pretrained_checkpoint)
64 |
65 | with fabric.init_module(empty_init=True), quantization(mode=quantize):
66 | model = LLaMA.from_name(name)
67 |
68 | # 1. Load the pretrained weights
69 | model.load_state_dict(pretrained_checkpoint, strict=False)
70 | # 2. Load the fine-tuned adapter weights
71 | model.load_state_dict(adapter_checkpoint, strict=False)
72 |
73 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
74 |
75 | model.eval()
76 | model = fabric.setup(model)
77 |
78 | tokenizer = Tokenizer(tokenizer_path)
79 | sample = {"instruction": prompt, "input": input}
80 | prompt = generate_prompt(sample)
81 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
82 | prompt_length = encoded.size(0)
83 |
84 | t0 = time.perf_counter()
85 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
86 | t = time.perf_counter() - t0
87 |
88 | model.reset_cache()
89 | output = tokenizer.decode(y)
90 | output = output.split("### Response:")[1].strip()
91 | print(output)
92 |
93 | tokens_generated = y.size(0) - prompt_length
94 | print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
95 | if fabric.device.type == "cuda":
96 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
97 |
98 |
99 | if __name__ == "__main__":
100 | from jsonargparse import CLI
101 |
102 | torch.set_float32_matmul_precision("high")
103 | warnings.filterwarnings(
104 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
105 | "ignore",
106 | message="ComplexHalf support is experimental and many operators don't support it yet"
107 | )
108 | CLI(main)
109 |
--------------------------------------------------------------------------------
/generate/adapter_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | import time
5 | import warnings
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import lightning as L
10 | import torch
11 |
12 | # support running without installing as a package
13 | wd = Path(__file__).parent.parent.resolve()
14 | sys.path.append(str(wd))
15 |
16 | from generate import generate
17 | from lit_llama import Tokenizer
18 | from lit_llama.adapter import LLaMA
19 | from lit_llama.utils import lazy_load, llama_model_lookup, quantization
20 | from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
21 | from scripts.prepare_alpaca import generate_prompt
22 |
23 |
24 | def main(
25 | prompt: str = "What food do lamas eat?",
26 | input: str = "",
27 | adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"),
28 | pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
29 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
30 | quantize: Optional[str] = None,
31 | max_new_tokens: int = 100,
32 | top_k: int = 200,
33 | temperature: float = 0.8,
34 | ) -> None:
35 | """Generates a response based on a given instruction and an optional input.
36 | This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
37 | See `finetune_adapter_v2.py`.
38 |
39 | Args:
40 | prompt: The prompt/instruction (Alpaca style).
41 | adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
42 | `finetune_adapter_v2.py`.
43 | input: Optional input (Alpaca style).
44 | pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
45 | tokenizer_path: The tokenizer path to load.
46 | quantize: Whether to quantize the model and using which method:
47 | ``"llm.int8"``: LLM.int8() mode,
48 | ``"gptq.int4"``: GPTQ 4-bit mode.
49 | max_new_tokens: The number of generation steps to take.
50 | top_k: The number of top most probable tokens to consider in the sampling process.
51 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random
52 | samples.
53 | """
54 | assert adapter_path.is_file()
55 | assert pretrained_path.is_file()
56 | assert tokenizer_path.is_file()
57 |
58 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
59 | fabric = L.Fabric(devices=1, precision=precision)
60 |
61 | print("Loading model ...", file=sys.stderr)
62 | t0 = time.time()
63 | with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
64 | name = llama_model_lookup(pretrained_checkpoint)
65 |
66 | with fabric.init_module(empty_init=True), quantization(mode=quantize):
67 | model = LLaMA.from_name(name)
68 | add_adapter_v2_parameters_to_linear_layers(model)
69 |
70 | # 1. Load the pretrained weights
71 | model.load_state_dict(pretrained_checkpoint, strict=False)
72 | # 2. Load the fine-tuned adapter weights
73 | model.load_state_dict(adapter_checkpoint, strict=False)
74 |
75 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
76 |
77 | model.eval()
78 | model = fabric.setup(model)
79 |
80 | tokenizer = Tokenizer(tokenizer_path)
81 | sample = {"instruction": prompt, "input": input}
82 | prompt = generate_prompt(sample)
83 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
84 | prompt_length = encoded.size(0)
85 |
86 | t0 = time.perf_counter()
87 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
88 | t = time.perf_counter() - t0
89 |
90 | model.reset_cache()
91 | output = tokenizer.decode(y)
92 | output = output.split("### Response:")[1].strip()
93 | print(output)
94 |
95 | tokens_generated = y.size(0) - prompt_length
96 | print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
97 | if fabric.device.type == "cuda":
98 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
99 |
100 |
101 | if __name__ == "__main__":
102 | from jsonargparse import CLI
103 |
104 | torch.set_float32_matmul_precision("high")
105 | warnings.filterwarnings(
106 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
107 | "ignore",
108 | message="ComplexHalf support is experimental and many operators don't support it yet"
109 | )
110 | CLI(main)
111 |
--------------------------------------------------------------------------------
/generate/full.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | import time
5 | import warnings
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import lightning as L
10 | import torch
11 |
12 | # support running without installing as a package
13 | wd = Path(__file__).absolute().parent.parent
14 | sys.path.append(str(wd))
15 |
16 | from lit_llama import LLaMA, Tokenizer
17 | from lit_llama.utils import quantization
18 | from scripts.prepare_alpaca import generate_prompt
19 | from generate import generate
20 |
21 |
22 | def main(
23 | prompt: str = "Hello, my name is",
24 | *,
25 | num_samples: int = 1,
26 | max_new_tokens: int = 50,
27 | top_k: int = 200,
28 | temperature: float = 0.8,
29 | checkpoint_path: Optional[Path] = None,
30 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
31 | model_size: str = "7B",
32 | quantize: Optional[str] = None,
33 | ) -> None:
34 | """Generates text samples based on a pre-trained LLaMA model and tokenizer.
35 |
36 | Args:
37 | prompt: The prompt string to use for generating the samples.
38 | num_samples: The number of text samples to generate.
39 | max_new_tokens: The number of generation steps to take.
40 | top_k: The number of top most probable tokens to consider in the sampling process.
41 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random
42 | samples.
43 | checkpoint_path: The checkpoint path to load.
44 | tokenizer_path: The tokenizer path to load.
45 | model_size: The model size to load.
46 | quantize: Whether to quantize the model and using which method:
47 | ``"llm.int8"``: LLM.int8() mode,
48 | ``"gptq.int4"``: GPTQ 4-bit mode.
49 | """
50 | if not checkpoint_path:
51 | checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
52 | assert checkpoint_path.is_file(), checkpoint_path
53 | assert tokenizer_path.is_file(), tokenizer_path
54 |
55 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
56 | fabric = L.Fabric(devices=1, precision=precision)
57 |
58 | print("Loading model ...", file=sys.stderr)
59 | t0 = time.time()
60 |
61 | with fabric.init_module(empty_init=True), quantization(mode=quantize):
62 | model = LLaMA.from_name(model_size)
63 |
64 | checkpoint = torch.load(checkpoint_path)
65 | model.load_state_dict(checkpoint)
66 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
67 |
68 | model.eval()
69 | model = fabric.setup(model)
70 |
71 | tokenizer = Tokenizer(tokenizer_path)
72 | sample = {"instruction": prompt, "input": input}
73 | prompt = generate_prompt(sample)
74 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
75 | prompt_length = encoded.size(0)
76 |
77 | L.seed_everything(1234)
78 | for i in range(num_samples):
79 | t0 = time.perf_counter()
80 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
81 | t = time.perf_counter() - t0
82 |
83 | model.reset_cache()
84 | print(tokenizer.decode(y))
85 | tokens_generated = y.size(0) - prompt_length
86 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
87 | if fabric.device.type == "cuda":
88 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
89 |
90 |
91 | if __name__ == "__main__":
92 | from jsonargparse import CLI
93 |
94 | torch.set_float32_matmul_precision("high")
95 | warnings.filterwarnings(
96 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
97 | "ignore",
98 | message="ComplexHalf support is experimental and many operators don't support it yet"
99 | )
100 | warnings.filterwarnings(
101 | # Triggered in bitsandbytes/autograd/_functions.py:298
102 | "ignore",
103 | message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
104 | )
105 | CLI(main)
106 |
--------------------------------------------------------------------------------
/generate/lora.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | import time
5 | import warnings
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import lightning as L
10 | import torch
11 |
12 | # support running without installing as a package
13 | wd = Path(__file__).parent.parent.resolve()
14 | sys.path.append(str(wd))
15 |
16 | from generate import generate
17 | from lit_llama import Tokenizer, LLaMA
18 | from lit_llama.lora import lora
19 | from lit_llama.utils import lazy_load, llama_model_lookup
20 | from scripts.prepare_alpaca import generate_prompt
21 |
22 | lora_r = 8
23 | lora_alpha = 16
24 | lora_dropout = 0.05
25 |
26 |
27 | def main(
28 | prompt: str = "What food do lamas eat?",
29 | input: str = "",
30 | lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"),
31 | pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
32 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
33 | quantize: Optional[str] = None,
34 | max_new_tokens: int = 100,
35 | top_k: int = 200,
36 | temperature: float = 0.8,
37 | ) -> None:
38 | """Generates a response based on a given instruction and an optional input.
39 | This script will only work with checkpoints from the instruction-tuned LoRA model.
40 | See `finetune_lora.py`.
41 |
42 | Args:
43 | prompt: The prompt/instruction (Alpaca style).
44 | lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
45 | `finetune_lora.py`.
46 | input: Optional input (Alpaca style).
47 | pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
48 | tokenizer_path: The tokenizer path to load.
49 | quantize: Whether to quantize the model and using which method:
50 | ``"llm.int8"``: LLM.int8() mode,
51 | ``"gptq.int4"``: GPTQ 4-bit mode.
52 | max_new_tokens: The number of generation steps to take.
53 | top_k: The number of top most probable tokens to consider in the sampling process.
54 | temperature: A value controlling the randomness of the sampling process. Higher values result in more random
55 | samples.
56 | """
57 | assert lora_path.is_file()
58 | assert pretrained_path.is_file()
59 | assert tokenizer_path.is_file()
60 |
61 | if quantize is not None:
62 | raise NotImplementedError("Quantization in LoRA is not supported yet")
63 |
64 | precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
65 | fabric = L.Fabric(devices=1, precision=precision)
66 |
67 | print("Loading model ...", file=sys.stderr)
68 | t0 = time.time()
69 |
70 | with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
71 | name = llama_model_lookup(pretrained_checkpoint)
72 |
73 | with fabric.init_module(empty_init=True), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
74 | model = LLaMA.from_name(name)
75 |
76 | # 1. Load the pretrained weights
77 | model.load_state_dict(pretrained_checkpoint, strict=False)
78 | # 2. Load the fine-tuned lora weights
79 | model.load_state_dict(lora_checkpoint, strict=False)
80 |
81 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
82 |
83 | model.eval()
84 | model = fabric.setup(model)
85 |
86 | tokenizer = Tokenizer(tokenizer_path)
87 | sample = {"instruction": prompt, "input": input}
88 | prompt = generate_prompt(sample)
89 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
90 |
91 | t0 = time.perf_counter()
92 | output = generate(
93 | model,
94 | idx=encoded,
95 | max_new_tokens=max_new_tokens,
96 | temperature=temperature,
97 | top_k=top_k,
98 | eos_id=tokenizer.eos_id
99 | )
100 | t = time.perf_counter() - t0
101 |
102 | output = tokenizer.decode(output)
103 | output = output.split("### Response:")[1].strip()
104 | print(output)
105 |
106 | print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
107 | if fabric.device.type == "cuda":
108 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
109 |
110 |
111 | if __name__ == "__main__":
112 | from jsonargparse import CLI
113 |
114 | torch.set_float32_matmul_precision("high")
115 | warnings.filterwarnings(
116 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
117 | "ignore",
118 | message="ComplexHalf support is experimental and many operators don't support it yet"
119 | )
120 | CLI(main)
121 |
--------------------------------------------------------------------------------
/howto/convert_lora_weights.md:
--------------------------------------------------------------------------------
1 | # Merging LoRA weights into base model weights
2 |
3 | Purpose: By merging our selected LoRA weights into the base model weights, we can benefit from all base model optimisation such as quantisation (available in this repo), pruning, caching, etc.
4 |
5 |
6 | ## How to run?
7 |
8 | After you have finish finetuning using LoRA, select your weight and run the converter script:
9 |
10 | ```bash
11 | python scripts/convert_lora_weights.py --lora_path out/lora/your-folder/your-weight-name.pth
12 | ```
13 |
14 | The converted base weight file will be saved into the same folder with the name `{your-weight-name}-lora-merged-weights.pth`. Now you can run `generate.py` with the merged weights and apply quantisation:
15 |
16 | ```bash
17 | python generate.py --checkpoint_path out/lora/your-folder/your-weight-name-lora-merged-weights.pth --quantize llm.int8
18 | ```
19 |
20 |
--------------------------------------------------------------------------------
/howto/customize_paths.md:
--------------------------------------------------------------------------------
1 | ## Customize paths
2 |
3 | The project is setup to use specific paths to read the original weights and save checkpoints etc.
4 |
5 | For all scripts, you can run
6 |
7 | ```shell
8 | python script.py -h
9 | ```
10 |
11 | to get a list of available options. For instance, here's how you would modify the checkpoint dir:
12 |
13 | ```shell
14 | python scripts/convert_checkpoint.py --checkpoint_dir "data/checkpoints/foo"
15 | ```
16 |
17 | Note that this change will need to be passed along to subsequent steps, for example:
18 |
19 | ```shell
20 | python generate.py \
21 | --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
22 | --tokenizer_path "data/checkpoints/foo/tokenizer.model"
23 | ```
24 |
25 | and
26 |
27 | ```shell
28 | python quantize/gptq.py \
29 | --checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
30 | --tokenizer_path "data/checkpoints/foo/tokenizer.model"
31 | ```
32 |
33 | To avoid this, you can use symbolic links to create shortcuts and avoid passing different paths.
34 |
--------------------------------------------------------------------------------
/howto/download_weights.md:
--------------------------------------------------------------------------------
1 | ## Downloading pretrained weights
2 |
3 | Except for when you are training from scratch, you will need the pretrained weights from Meta.
4 |
5 | ### Original Meta weights
6 |
7 | Download the model weights following the instructions on the official [LLaMA repository](https://github.com/facebookresearch/llama).
8 |
9 | Once downloaded, you should have a folder like this:
10 |
11 | ```text
12 | checkpoints/llama
13 | ├── 7B
14 | │ ├── ...
15 | │ └── consolidated.00.pth
16 | ├── 13B
17 | │ ...
18 | └── tokenizer.model
19 | ```
20 |
21 | Convert the weights to the Lit-LLaMA format:
22 |
23 | ```bash
24 | python scripts/convert_checkpoint.py --model_size 7B
25 | ```
26 |
27 | > **Note**
28 | > All scripts support argument [customization](customize_paths.md)
29 |
30 | ### OpenLLaMA
31 |
32 | OpenLM Research has released **Apache 2.0 licensed** weights obtained by training LLaMA on the 1.2 trillion token open-source [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset.
33 |
34 | Weights were released in preview on intermediate number of tokens (1T at the time of writing). In order to get them do:
35 |
36 | ```bash
37 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
38 | git clone https://huggingface.co/openlm-research/open_llama_7b checkpoints/open-llama/7B
39 | ```
40 |
41 | Or if you don't have `git-lfs` installed:
42 |
43 | ```bash
44 | python scripts/download.py --repo_id openlm-research/open_llama_7b --local_dir checkpoints/open-llama/7B
45 | ```
46 |
47 | Once downloaded, you should have a folder like this:
48 |
49 | ```text
50 | checkpoints/open-llama/
51 | └── 7B
52 | ├── ...
53 | ├── pytorch_model-00001-of-00002.bin
54 | ├── pytorch_model-00002-of-00002.bin
55 | ├── pytorch_model.bin.index.json
56 | └── tokenizer.model
57 | ```
58 |
59 | Convert the weights to the Lit-LLaMA format:
60 |
61 | ```bash
62 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B
63 | ```
64 |
65 | > **Note**
66 | > All scripts support argument [customization](customize_paths.md)
67 |
68 | Once converted, you should have a folder like this:
69 |
70 | ```text
71 | checkpoints/lit-llama/
72 | ├── 7B
73 | │ └── lit-llama.pth
74 | └── tokenizer.model
75 | ```
76 |
77 | You are all set. Now you can continue with inference or finetuning.
78 |
79 | Try running [`generate.py` to test the imported weights](inference.md).
80 |
81 |
82 | ### Alternative sources
83 |
84 | You might find LLaMA weights hosted online in the HuggingFace hub. Beware that this infringes the original weight's license.
85 | You could try downloading them by running the following command with a specific repo id:
86 |
87 | ```bash
88 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
89 | git clone REPO_ID checkpoints/hf-llama/7B
90 | ```
91 |
92 | Or if you don't have `git-lfs` installed:
93 |
94 | ```bash
95 | python scripts/download.py --repo_id REPO_ID --local_dir checkpoints/hf-llama/7B
96 | ```
97 |
98 | Once downloaded, you should have a folder like this:
99 |
100 | ```text
101 | checkpoints/hf-llama/
102 | └── 7B
103 | ├── ...
104 | ├── pytorch_model-00001-of-00002.bin
105 | ├── pytorch_model-00002-of-00002.bin
106 | ├── pytorch_model.bin.index.json
107 | └── tokenizer.model
108 | ```
109 |
110 | Convert the weights to the Lit-LLaMA format:
111 |
112 | ```bash
113 | python scripts/convert_hf_checkpoint.py --model_size 7B
114 | ```
115 |
116 | > **Note**
117 | > All scripts support argument [customization](customize_paths.md)
118 |
119 | Once converted, you should have a folder like this:
120 |
121 | ```text
122 | checkpoints/lit-llama/
123 | ├── 7B
124 | │ └── lit-llama.pth
125 | └── tokenizer.model
126 | ```
127 |
128 | You are all set. Now you can continue with inference or finetuning.
129 |
130 | Try running [`generate.py` to test the imported weights](inference.md).
131 |
--------------------------------------------------------------------------------
/howto/finetune_adapter.md:
--------------------------------------------------------------------------------
1 | # Finetuning with Adapter
2 |
3 | [LLaMA-Adapter](https://arxiv.org/abs/2303.16199) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only 1.2M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
4 |
5 | We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
6 |
7 | If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
8 |
9 | ## LLaMA-Adapter v2
10 |
11 | The LLaMA-Adapter authors developed a newer adapter method called LLaMA-Adapter v2, which is related to this LLaMA-Adapter method but includes more trainable parameters. LLaMA-Adapter v2 is also available via Lit-LLaMA; you can read more about it in [the related how-to doc here](./finetune_adapter_v2.md).
12 |
13 | ## Preparation
14 |
15 | The steps here only need to be done once:
16 |
17 | 1. Follow the instructions in the [README](README.md) to install the dependencies.
18 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
19 | 3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
20 | 4. Download the data and generate the Alpaca instruction tuning dataset:
21 |
22 | ```bash
23 | python scripts/prepare_alpaca.py
24 | ```
25 |
26 | or [prepare your own dataset](#tune-on-your-dataset).
27 |
28 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
29 |
30 | ## Running the finetuning
31 |
32 | ```bash
33 | python finetune/adapter.py
34 | ```
35 |
36 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
37 | You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
38 | Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
39 |
40 | For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
41 |
42 | ```python
43 | devices = 8
44 | micro_batch_size = 8
45 | ```
46 |
47 | This script will save checkpoints periodically to the folder `out/`.
48 |
49 | > **Note**
50 | > All scripts support argument [customization](customize_paths.md)
51 |
52 | ## Test the model
53 |
54 | You can test the finetuned model with your own instructions by running:
55 |
56 | ```bash
57 | python generate/adapter.py \
58 | --prompt "Recommend a movie to watch on the weekend." \
59 | --quantize llm.int8
60 | ```
61 | Output:
62 | ```
63 | A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
64 | ```
65 | If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
66 |
67 | ## Tune on your dataset
68 |
69 | With only a few modifications, you can prepare and train on your own instruction dataset.
70 |
71 | 1. Create a json file in which each row holds one instruction-response pair.
72 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
73 | the empty string if the instruction doesn't require a context. Below is an example json file:
74 |
75 | ```
76 | [
77 | {
78 | "instruction": "Arrange the given numbers in ascending order.",
79 | "input": "2, 4, 0, 8, 3",
80 | "output": "0, 2, 3, 4, 8"
81 | },
82 | ...
83 | ]
84 | ```
85 |
86 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
87 |
88 | ```bash
89 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
90 | ```
91 |
92 | 3. Modify `scripts/prepare_mydata.py` to read the json data file.
93 | 4. Run the script to generate the preprocessed, tokenized train-val split:
94 |
95 | ```bash
96 | python scripts/prepare_mydata.py --destination_path data/mydata/
97 | ```
98 |
99 | 5. Run `finetune/adapter.py` by passing in the location of your data (and optionally other parameters):
100 |
101 | ```bash
102 | python finetune/adapter.py --data_dir data/mydata/ --out_dir out/myexperiment
103 | ```
104 |
105 |
106 | ## Troubleshooting
107 |
108 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
109 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
110 |
--------------------------------------------------------------------------------
/howto/finetune_adapter_v2.md:
--------------------------------------------------------------------------------
1 | # Finetuning with Adapter v2
2 |
3 | [LLaMA-Adapter v2](https://arxiv.org/abs/2304.15010) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only ~4 M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
4 |
5 | We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
6 |
7 | If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
8 |
9 | ## LLaMA-Adapter v1 versus LLaMA-Adapter v2
10 |
11 | LLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the 7B LLaMA model has 1.2M trainable parameters with LLaMA v1, LLaMA-Adapter v2 adds 2.8 M trainable parameters for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~4.3 M trainable parameters in total.
12 |
13 | If you are interested in using the more lightweight LLaMA-Adapter v1 approach, see [the related LLaMA Adapter how-to doc here](./finetune_adapter.md).
14 |
15 | While LLaMA-Adapter v2 increases the number of trainable parameters from 1.2 M (from LLaMA-Apdapter v1) to 4.3 M, the inference cost is not significantly impacted. This is because the additional bias and scale parameters are cheap to compute in the forward pass, and the RMSNorm parameters are already included in the base model. In LLaMA-Adapter v1, the RMSNorm parameters are not trainable.
16 |
17 |
18 | ## Preparation
19 |
20 | The steps here only need to be done once:
21 |
22 | 1. Follow the instructions in the [README](README.md) to install the dependencies.
23 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
24 | 3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
25 | 4. Download the data and generate the Alpaca instruction tuning dataset:
26 |
27 | ```bash
28 | python scripts/prepare_alpaca.py
29 | ```
30 |
31 | or [prepare your own dataset](#tune-on-your-dataset).
32 |
33 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
34 |
35 | ## Running the finetuning
36 |
37 | ```bash
38 | python finetune/adapter_v2.py
39 | ```
40 |
41 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
42 | You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
43 | Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
44 |
45 | For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
46 |
47 | ```python
48 | devices = 8
49 | micro_batch_size = 8
50 | ```
51 |
52 | This script will save checkpoints periodically to the folder `out/`.
53 |
54 | > **Note**
55 | > All scripts support argument [customization](customize_paths.md)
56 |
57 | ## Test the model
58 |
59 | You can test the finetuned model with your own instructions by running:
60 |
61 | ```bash
62 | python generate/adapter_v2.py \
63 | --prompt "Recommend a movie to watch on the weekend." \
64 | --quantize llm.int8
65 | ```
66 | Output:
67 | ```
68 | A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
69 | ```
70 | If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
71 |
72 | ## Tune on your dataset
73 |
74 | With only a few modifications, you can prepare and train on your own instruction dataset.
75 |
76 | 1. Create a json file in which each row holds one instruction-response pair.
77 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
78 | the empty string if the instruction doesn't require a context. Below is an example json file:
79 |
80 | ```
81 | [
82 | {
83 | "instruction": "Arrange the given numbers in ascending order.",
84 | "input": "2, 4, 0, 8, 3",
85 | "output": "0, 2, 3, 4, 8"
86 | },
87 | ...
88 | ]
89 | ```
90 |
91 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
92 |
93 | ```bash
94 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
95 | ```
96 |
97 | 3. Modify `scripts/prepare_mydata.py` to read the json data file.
98 | 4. Run the script to generate the preprocessed, tokenized train-val split:
99 |
100 | ```bash
101 | python scripts/prepare_mydata.py --destination_path data/mydata/
102 | ```
103 |
104 | 5. Run `finetune/adapter_v2.py` by passing in the location of your data (and optionally other parameters):
105 |
106 | ```bash
107 | python finetune/adapter_v2.py --data_dir data/mydata/ --out_dir out/myexperiment
108 | ```
109 |
110 |
111 | ## Troubleshooting
112 |
113 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
114 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
115 |
--------------------------------------------------------------------------------
/howto/finetune_full.md:
--------------------------------------------------------------------------------
1 | # Full Finetuning
2 |
3 | Full finetuning updates all layers in the pretrained LLaMA model. This *regular* finetuning procedure is typically considered as the baseline for parameter-efficient alternatives such as Low-Rank Adaptation (LoRA) or LLaMA-Adapter.
4 |
5 | The current [finetune/full.py](../finetune/full.py) we provide uses 4 A100 GPUs with a fully-sharded data parallel strategy to finetune Lit-LLaMA 7B on [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. The A100 GPUs have 40 GB each, but it may require less memory to finetune this model.
6 |
7 |
8 |
9 | ## Preparation
10 |
11 | The steps here only need to be done once:
12 |
13 | 1. Follow the instructions in the [README](README.md) to install the dependencies.
14 |
15 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
16 |
17 | 4. Download the data and generate the Alpaca instruction tuning dataset:
18 |
19 | ```bash
20 | python scripts/prepare_alpaca.py
21 | ```
22 |
23 | or [prepare your own dataset](#tune-on-your-own-dataset).
24 |
25 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
26 |
27 | ## Running the finetuning
28 |
29 | ```bash
30 | python finetune/full.py
31 | ```
32 |
33 |
34 | You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available or increase the `batch_size`.
35 | Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
36 |
37 | For example, the following settings will let you finetune the model in 32 hours using a fully-sharded data parallel strategy:
38 | ```python
39 | devices = 4
40 | batch_size = 128 // devices
41 | micro_batch_size = 4
42 | ```
43 |
44 | This script will save checkpoints periodically to the folder `out/`.
45 |
46 | > **Note**
47 | > All scripts support argument [customization](customize_paths.md)
48 |
49 | ## Test the model
50 |
51 | You can test the finetuned model with your own instructions by running:
52 |
53 | ```bash
54 | python generate/full.py \
55 | --prompt "Recommend a movie to watch on the weekend." \
56 | --quantize llm.int8
57 | ```
58 | Output:
59 | ```
60 | A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
61 | ```
62 | If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
63 |
64 | ## Tune on your dataset
65 |
66 | With only a few modifications, you can prepare and train on your own instruction dataset.
67 |
68 | 1. Create a json file in which each row holds one instruction-response pair.
69 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
70 | the empty string if the instruction doesn't require a context. Below is an example json file:
71 |
72 | ```
73 | [
74 | {
75 | "instruction": "Arrange the given numbers in ascending order.",
76 | "input": "2, 4, 0, 8, 3",
77 | "output": "0, 2, 3, 4, 8"
78 | },
79 | ...
80 | ]
81 | ```
82 |
83 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
84 |
85 | ```bash
86 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
87 | ```
88 |
89 | 3. Modify `scripts/prepare_mydata.py` to read the json data file.
90 | 4. Run the script to generate the preprocessed, tokenized train-val split:
91 |
92 | ```bash
93 | python scripts/prepare_mydata.py --destination_path data/mydata/
94 | ```
95 |
96 | 5. Run `finetune/full.py` by passing in the location of your data (and optionally other parameters):
97 |
98 | ```bash
99 | python finetune/full.py --data_dir data/mydata/ --out_dir out/myexperiment
100 | ```
101 |
102 |
103 | ## Troubleshooting
104 |
105 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
106 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
107 |
--------------------------------------------------------------------------------
/howto/finetune_lora.md:
--------------------------------------------------------------------------------
1 | # Finetuning with LoRA
2 |
3 | [Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
4 | We demonstrate this method by instruction-finetuning LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**.
5 |
6 | ## Preparation
7 |
8 | The steps here only need to be done once:
9 |
10 | 1. Follow the instructions in the [README](../README.md) to install the dependencies.
11 | 2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
12 | 3. Download the data and generate the instruction tuning dataset:
13 |
14 | ```bash
15 | python scripts/prepare_alpaca.py
16 | ```
17 |
18 | See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
19 |
20 | ## Running the finetuning
21 |
22 | ```bash
23 | python finetune/lora.py
24 | ```
25 |
26 | The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
27 |
28 | This script will save checkpoints periodically to the folder `out/`.
29 |
30 | > **Note**
31 | > All scripts support argument [customization](customize_paths.md)
32 |
33 |
34 | ## Test the model
35 |
36 | You can test the finetuned model with your own instructions by running:
37 |
38 | ```bash
39 | python generate/lora.py --prompt "Recommend a movie to watch on the weekend."
40 | ```
41 | Output:
42 | ```
43 | I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
44 | ```
45 |
46 | If your GPU supports `bfloat16`, you can additionally pass `--dtype bfloat16` to bring the memory consumption down to ~14 GB.
47 |
48 | ## Tune on your dataset
49 |
50 | With only a few modifications, you can prepare and train on your own instruction dataset.
51 |
52 | 1. Create a json file in which each row holds one instruction-response pair.
53 | A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
54 | the empty string if the instruction doesn't require a context. Below is an example json file:
55 |
56 | ```
57 | [
58 | {
59 | "instruction": "Arrange the given numbers in ascending order.",
60 | "input": "2, 4, 0, 8, 3",
61 | "output": "0, 2, 3, 4, 8"
62 | },
63 | ...
64 | ]
65 | ```
66 |
67 | 2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
68 |
69 | ```bash
70 | cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
71 | ```
72 |
73 | 3. Modify `scripts/prepare_mydata.py` to read the json data file.
74 | 4. Run the script to generate the preprocessed, tokenized train-val split:
75 |
76 | ```bash
77 | python scripts/prepare_mydata.py --destination_path data/mydata/
78 | ```
79 |
80 | 5. Run `finetune/lora.py` by passing in the location of your data (and optionally other parameters):
81 |
82 | ```bash
83 | python finetune/lora.py --data_dir data/mydata/ --out_dir out/myexperiment
84 | ```
85 |
86 |
87 | ## Troubleshooting
88 |
89 | If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
90 | `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
91 |
--------------------------------------------------------------------------------
/howto/inference.md:
--------------------------------------------------------------------------------
1 | # Inference
2 |
3 | We demonstrate how to run inference (next token prediction) with the LLaMA base model in the [`generate.py`](generate.py) script:
4 |
5 | ```bash
6 | python generate.py --prompt "Hello, my name is"
7 | ```
8 | Output:
9 | ```
10 | Hello my name is TJ. I have a passion for the outdoors, love hiking and exploring. I also enjoy traveling and learning new things. I especially enjoy long walks, good conversation and a friendly smile.
11 | ```
12 |
13 | The script assumes you have downloaded and converted the weights and saved them in the `./checkpoints` folder as described [here](download_weights.md).
14 |
15 | > **Note**
16 | > All scripts support argument [customization](customize_paths.md)
17 |
18 | With the default settings, this will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
19 |
20 | ## Run Lit-LLaMA on consumer devices
21 |
22 | On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
23 | For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
24 |
25 | ```bash
26 | python generate.py --quantize llm.int8 --prompt "Hello, my name is"
27 | ```
28 | This will consume about ~10 GB of GPU memory or ~8 GB if also using `bfloat16`.
29 | See `python generate.py --help` for more options.
30 |
31 | You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
32 |
33 | ```bash
34 | python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
35 | ```
36 |
37 | GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled.
38 |
39 | With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file:
40 |
41 | ```bash
42 | python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth
43 | ```
44 |
--------------------------------------------------------------------------------
/howto/tpus.md:
--------------------------------------------------------------------------------
1 | # TPU support
2 |
3 | Lit-LLaMA used `lightning.Fabric` under the hood, which itself supports TPUs (via [PyTorch XLA](https://github.com/pytorch/xla)).
4 |
5 | The following commands will allow you to set up a `Google Cloud` instance with a [TPU v4](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) VM:
6 |
7 | ```shell
8 | gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b
9 | gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b
10 | ```
11 |
12 | Now that you are in the machine, let's clone the repository and install the dependencies
13 |
14 | ```shell
15 | git clone https://github.com/Lightning-AI/lit-llama
16 | cd lit-llama
17 | pip install -e ".[all]"
18 | ```
19 |
20 | By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables
21 |
22 | ```shell
23 | export PJRT_DEVICE=TPU
24 | export ALLOW_MULTIPLE_LIBTPU_LOAD=1
25 | ```
26 |
27 | > **Note**
28 | > You can find an extensive guide on how to get set-up and all the available options [here](https://cloud.google.com/tpu/docs/v4-users-guide).
29 |
30 | Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with `gcloud compute tpus tpu-vm scp` or you can follow the steps described in our [downloading guide](download_weights.md).
31 |
32 | ## Inference
33 |
34 | Generation works out-of-the-box with TPUs:
35 |
36 | ```shell
37 | python3 generate.py --prompt "Hello, my name is" --num_samples 3
38 | ```
39 |
40 | This command will take take ~20s for the first generation time as XLA needs to compile the graph.
41 | You'll notice that afterwards, generation times drop to ~5s.
42 |
43 | ## Finetuning
44 |
45 | Coming soon.
46 |
47 | > **Warning**
48 | > When you are done, remember to delete your instance
49 | > ```shell
50 | > gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b
51 | > ```
--------------------------------------------------------------------------------
/howto/train_redpajama.md:
--------------------------------------------------------------------------------
1 | # Pre-train LLaMA on RedPajama
2 |
3 | This howto will walk you through setting up the RedPajama dataset and launching the pre-training script.
4 |
5 | ## What's RedPajama
6 |
7 | [RedPajama](https://github.com/togethercomputer/RedPajama-Data) is an open-source reproduction of the original LLaMA training dataset.
8 |
9 | It contains a total of 1.2 trillion tokens, divided into
10 |
11 | ```text
12 | Commoncrawl 878B
13 | C4 175B
14 | GitHub 59B
15 | Books 26B
16 | ArXiv 28B
17 | Wikipedia 24B
18 | StackExchange 20B
19 | ```
20 |
21 | The [RedPajama repo](https://github.com/togethercomputer/RedPajama-Data) contains the source code for collecting and preparing
22 | the dataset, and it is Apache 2.0 licensed.
23 |
24 | The data itself is licensed according to the original licenses with which its invidivdual parts were released.
25 | The GitHub datasets are limited to MIT, BSD, or Apache 2.0 repositories.
26 |
27 | Along with the full [RedPajama-1T dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T),
28 | the [RedPajama-1T-Sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) 1B sample dataset
29 | is also available for development.
30 |
31 | You can download the data using git lfs:
32 |
33 | ```bash
34 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
35 | git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T data/RedPajama-Data-1T
36 | ```
37 |
38 | ```bash
39 | # Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
40 | git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample data/RedPajama-Data-1T-Sample
41 | ```
42 |
43 | ## Prepare RedPajama for training
44 |
45 | The dataset consists of 2084 `jsonl` files (the sample dataset contains 11). In order to start pre-training lit-llama
46 | on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `PackedDataset`
47 | streaming dataset that comes with lit-llama.
48 |
49 | Do to so, run
50 |
51 | ```bash
52 | python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama
53 | ```
54 |
55 | or
56 |
57 | ```bash
58 | python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T-Sample --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama-sample --sample True
59 | ```
60 |
61 | for the sample dataset.
62 |
63 | In the above we are assuming that you will be using the same tokenizer as used in LLaMA, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.
64 |
65 | The script will take a while to run, so time for :tea:
66 |
67 | ## Pre-training
68 |
69 | Running the pre-training script requires at least 4 GPUs with 40GB+ each (A100).
70 |
71 | ```bash
72 | python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama
73 | ```
74 |
75 | For running on the sample dataset:
76 |
77 | ```bash
78 | python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama-sample
79 | ```
80 |
81 | The script will save checkpoints periodically to the folder `out/`.
82 |
83 | The `train_redpajama.py` script will pre-train the LLaMA 7B model with FSDP in
84 | `bfloat16` precision and gradient accumulation.
85 |
86 | You can easily change the size of the model by passing a different string to
87 |
88 | ```python
89 | config = LLaMAConfig.from_name("7B")
90 | ```
91 |
92 | in the `main` function.
93 |
94 | Keep in mind that the original LLaMA training for the 7B model required 83k A100 80GB
95 | hours, so you'll need access to a cluster.
96 |
97 | Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
98 | to launch the script across machines:
99 |
100 | - [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)
101 | - [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)
102 | - [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
103 |
104 | The script contains several configurations and hyperparameters you can tweak:
105 |
106 | ```python
107 | out_dir = "out/training"
108 | save_interval = 1000
109 | eval_interval = 1000
110 | eval_iters = 100
111 | log_interval = 1
112 |
113 | # Hyperparameters
114 | learning_rate = 6e-4
115 | batch_size = 125
116 | micro_batch_size = 5
117 | max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices
118 | weight_decay = 1e-1
119 | beta1 = 0.9
120 | beta2 = 0.95
121 | grad_clip = 1.0
122 | decay_lr = True
123 | warmup_iters = 2000
124 | lr_decay_iters = max_iters
125 | min_lr = 6e-5
126 | ```
127 |
128 | In particular, `micro_batch_size` should be adjusted so the process will use the available
129 | GPU memory.
130 |
131 | Last, logging is kept minimal in the script. In order to use a particular logger
132 | please refer to or
133 | call a logging client library like `wandb` directly.
134 |
--------------------------------------------------------------------------------
/howto/unstructured_dataset.md:
--------------------------------------------------------------------------------
1 | # Finetuning on an unstructured dataset
2 |
3 | While most scripts were made to finetune on instruction datasets, it is possible to finetune on any dataset. This is useful for experimentation while not being as expensive as training a full model.
4 |
5 | This guide is only to prepare the finetuning, as either LoRA or Adapter-v1 methods support this dataset type!
6 |
7 | ## Preparation
8 |
9 | 1. Gather your text into an input file named `input.txt`
10 | 2. Divide the data into training and validation sets using the following script:
11 |
12 | ```bash
13 | python scripts/prepare_any_text.py
14 | ```
15 |
16 | 3. Modify relevant scripts for your finetuning method under `finetune/` and `evaluate/`, setting the `instruction_tuning` variable to `False`
17 |
18 | And then you're set! Proceed to run the [LoRA guide](./finetune_lora.md) or [Adapter v1 guide](./finetune_adapter.md).
--------------------------------------------------------------------------------
/lit_llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
4 | from lit_llama.tokenizer import Tokenizer
5 |
--------------------------------------------------------------------------------
/lit_llama/adapter_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import torch
4 | from torch import Tensor
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 |
8 | from lit_llama.adapter import LLaMA
9 |
10 |
11 | def get_adapter_substrings():
12 | substrings = ["adapter_wte", "gating_factor"] # regular adapter v1 parameters
13 | substrings.extend(["adapter_scale", "adapter_bias"]) # adapter v2: new bias and scale used in Linear
14 | substrings.extend(["rms_1", "rms_2", "ln_f"]) # adapter v2: RMSNorm parameters are now trainable
15 | return substrings
16 |
17 |
18 | def mark_only_adapter_v2_as_trainable(model: LLaMA) -> None:
19 | """Sets `requires_grad=False` for all non-adapter weights."""
20 | for name, param in model.named_parameters():
21 | param.requires_grad = any(s in name for s in get_adapter_substrings())
22 |
23 |
24 | def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:
25 | """Returns the model state dict with only the adapter weights for saving."""
26 | return {name: param for name, param in state_dict.items()
27 | if any(s in name for s in get_adapter_substrings())}
28 |
29 |
30 | def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
31 | return self.adapter_scale * (
32 | F.linear(input, self.weight, self.bias) + self.adapter_bias
33 | )
34 |
35 |
36 | def adapter_v2_linear_with_bias_and_scale(layer):
37 | layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
38 | layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
39 | bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
40 | setattr(layer, 'forward', bound_method)
41 | return layer
42 |
43 |
44 | def add_adapter_v2_parameters_to_linear_layers(model):
45 | for module in model.modules():
46 | if isinstance(module, nn.Linear):
47 | adapter_v2_linear_with_bias_and_scale(module)
48 |
--------------------------------------------------------------------------------
/lit_llama/packed_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # Very loosely inspired by indexed_dataset in Fairseq, Megatron
4 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
5 |
6 |
7 | import os
8 | import struct
9 | import random
10 |
11 | import numpy as np
12 | import torch
13 | from torch.utils.data import IterableDataset, get_worker_info
14 |
15 |
16 | dtypes = {
17 | 1: np.uint8,
18 | 2: np.int8,
19 | 3: np.int16,
20 | 4: np.int32,
21 | 5: np.int64,
22 | 6: np.float32,
23 | 7: np.float64,
24 | 8: np.uint16,
25 | }
26 |
27 |
28 | def code(dtype):
29 | for k in dtypes.keys():
30 | if dtypes[k] == dtype:
31 | return k
32 | raise ValueError(dtype)
33 |
34 |
35 | HDR_MAGIC = b"LITPKDS"
36 | HDR_SIZE = 24 # bytes
37 |
38 |
39 | class PackedDataset(IterableDataset):
40 | def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0):
41 | self._filenames = filenames
42 | self._n_chunks = n_chunks
43 | self._block_size = block_size
44 | self._seed = seed
45 | self._shuffle = shuffle
46 | self._wrap = wrap
47 | self._num_processes = num_processes
48 | self._process_rank = process_rank
49 |
50 | def __iter__(self):
51 | worker_info = get_worker_info()
52 | num_workers = worker_info.num_workers if worker_info is not None else 1
53 | worker_id = worker_info.id if worker_info is not None else 0
54 | num_shards = num_workers * self._num_processes
55 | shard_id = self._process_rank * num_workers + worker_id
56 |
57 | max_num_files = len(self._filenames) // num_shards * num_shards
58 | filenames = self._filenames[shard_id : max_num_files : num_shards]
59 |
60 | return PackedDatasetIterator(
61 | filenames=filenames,
62 | n_chunks=self._n_chunks,
63 | block_size=self._block_size,
64 | seed=self._seed,
65 | shuffle=self._shuffle,
66 | wrap=self._wrap,
67 | )
68 |
69 |
70 | class PackedDatasetBuilder(object):
71 | def __init__(
72 | self,
73 | outdir,
74 | prefix,
75 | chunk_size,
76 | sep_token,
77 | dtype="auto",
78 | vocab_size=None,
79 | ):
80 | if dtype == "auto":
81 | if vocab_size is None:
82 | raise ValueError("vocab_size cannot be None when dtype='auto'")
83 | if vocab_size is not None and vocab_size < 65500:
84 | self._dtype = np.uint16
85 | else:
86 | self._dtype = np.int32
87 | else:
88 | self._dtype = dtype
89 | self._counter = 0
90 | self._chunk_size = chunk_size
91 | self._outdir = outdir
92 | self._prefix = prefix
93 | self._sep_token = sep_token
94 | self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
95 | self._arr.fill(self._sep_token)
96 | self._idx = 0
97 | self._version = 1
98 | self._filenames = []
99 |
100 | def _write_chunk(self):
101 | filename = f"{self._prefix}_{self._counter:010d}.bin"
102 | filename = os.path.join(self._outdir, filename)
103 |
104 | with open(filename, "wb") as f:
105 | f.write(HDR_MAGIC)
106 | f.write(struct.pack(" self._chunk_size:
126 | part_len = self._chunk_size - self._idx
127 | self._arr[self._idx : self._idx + part_len] = arr[:part_len]
128 | self._write_chunk()
129 | arr = arr[part_len:]
130 |
131 | arr_len = arr.shape[0]
132 | self._arr[self._idx : self._idx + arr_len] = arr
133 | self._idx += arr_len
134 |
135 | def write_reminder(self):
136 | self._write_chunk()
137 |
138 |
139 | class PackedDatasetIterator:
140 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
141 | self._seed = seed
142 | self._shuffle = shuffle
143 | self._rng = np.random.default_rng(seed) if shuffle else None
144 | self._block_idxs = None
145 |
146 | self._wrap = wrap
147 |
148 | # TODO: instead of filenames, we could have a single text stream
149 | # (or text file) with the sequence of all files to be
150 | # fetched/loaded.
151 | self._filenames = filenames
152 | self._file_idx = 0
153 |
154 | self._n_chunks = n_chunks
155 |
156 | self._dtype = None
157 | self._block_size = block_size
158 | self._n_blocks = None
159 |
160 | self._mmaps = []
161 | self._buffers = []
162 |
163 | self._block_idxs = []
164 | self._curr_idx = 0
165 |
166 | self._load_n_chunks()
167 |
168 | def _read_header(self, path):
169 | with open(path, "rb") as f:
170 | magic = f.read(len(HDR_MAGIC))
171 | assert magic == HDR_MAGIC, "File doesn't match expected format."
172 | version = struct.unpack(" len(self._filenames[self._file_idx:]):
189 | if not self._wrap:
190 | raise StopIteration
191 | else:
192 | self._file_idx = 0
193 |
194 | for i in range(self._n_chunks):
195 | filename = self._filenames[self._file_idx + i]
196 | if self._dtype is None:
197 | self._dtype, self._chunk_size = self._read_header(
198 | filename
199 | )
200 | self._n_blocks = self._chunk_size // self._block_size
201 | # TODO: check header matches with previous files
202 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
203 | self._mmaps.append(mmap)
204 | self._buffers.append(memoryview(mmap))
205 |
206 | self._file_idx += self._n_chunks
207 | n_all_blocks = self._n_chunks * self._n_blocks
208 |
209 | self._block_idxs = (
210 | self._rng.permutation(n_all_blocks)
211 | if self._shuffle
212 | else range(n_all_blocks)
213 | )
214 |
215 | self._curr_idx = 0
216 |
217 | def __del__(self):
218 | self._close_mmaps()
219 | del self._mmaps
220 | del self._buffers
221 |
222 | def __iter__(self):
223 | return self
224 |
225 | def __next__(self):
226 | if self._curr_idx >= len(self._block_idxs):
227 | self._load_n_chunks()
228 | # TODO: trigger fetching next next n_chunks if remote
229 | block_idx = self._block_idxs[self._curr_idx]
230 | chunk_id = block_idx // self._n_blocks
231 | buffer = self._buffers[chunk_id]
232 | elem_id = (block_idx % self._n_blocks) * self._block_size
233 | offset = np.dtype(self._dtype).itemsize * elem_id
234 | arr = np.frombuffer(
235 | buffer, dtype=self._dtype, count=self._block_size, offset=offset
236 | )
237 | self._curr_idx += 1
238 | return torch.from_numpy(arr.astype(np.int64))
239 |
240 |
241 | class CombinedDataset(IterableDataset):
242 | def __init__(self, datasets, seed, weights=None):
243 | self._seed = seed
244 | self._datasets = datasets
245 | self._weights = weights
246 | n_datasets = len(datasets)
247 | if weights is None:
248 | self._weights = [1 / n_datasets] * n_datasets
249 |
250 | def __iter__(self):
251 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
252 |
253 |
254 | class CombinedDatasetIterator:
255 | def __init__(self, datasets, seed, weights):
256 | self._datasets = [iter(el) for el in datasets]
257 | self._weights = weights
258 | self._rng = random.Random(seed)
259 |
260 | def __next__(self):
261 | dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1)
262 | return next(dataset)
263 |
--------------------------------------------------------------------------------
/lit_llama/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import os
4 | from pathlib import Path
5 | from typing import Optional
6 |
7 | import torch
8 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
9 |
10 |
11 | class Tokenizer:
12 | """Tokenizer for LLaMA."""
13 |
14 | def __init__(self, model_path: Path) -> None:
15 | self.processor = SentencePieceProcessor(model_file=str(model_path))
16 | self.bos_id = self.processor.bos_id()
17 | self.eos_id = self.processor.eos_id()
18 | self.pad_id = self.processor.pad_id()
19 |
20 | @property
21 | def vocab_size(self) -> int:
22 | return self.processor.vocab_size()
23 |
24 | def encode(
25 | self,
26 | string: str,
27 | bos: bool = True,
28 | eos: bool = False,
29 | max_length: int = -1,
30 | pad: bool = False,
31 | device: Optional[torch.device] = None
32 | ) -> torch.Tensor:
33 | tokens = self.processor.encode(string)
34 | if bos:
35 | tokens = [self.bos_id] + tokens
36 | if eos:
37 | tokens = tokens + [self.eos_id]
38 | if max_length > 0:
39 | tokens = tokens[:max_length]
40 | if pad and len(tokens) < max_length:
41 | tokens += [self.pad_id] * (max_length - len(tokens))
42 |
43 | return torch.tensor(tokens, dtype=torch.int, device=device)
44 |
45 | def decode(self, tokens: torch.Tensor) -> str:
46 | return self.processor.decode(tokens.tolist())
47 |
48 | @staticmethod
49 | def train(input: str, destination: str, vocab_size=32000) -> None:
50 | model_prefix = os.path.join(destination, "tokenizer")
51 | SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
52 |
--------------------------------------------------------------------------------
/pretrain/shakespeare.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | """
4 | This script is a placeholder for training LLaMA from scratch.
5 | Currently, it just trains on the Shakespeare dataset.
6 | """
7 | from pathlib import Path
8 | import sys
9 | import os
10 | import time
11 | from functools import partial
12 | from typing import Tuple
13 |
14 | import lightning as L
15 | from lightning.fabric.strategies import FSDPStrategy
16 |
17 | import torch
18 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
19 |
20 | import numpy as np
21 |
22 | # support running without installing as a package
23 | wd = Path(__file__).parent.parent.resolve()
24 | sys.path.append(str(wd))
25 |
26 | from lit_llama.model import Block, LLaMA, LLaMAConfig
27 | from lit_llama.utils import save_model_checkpoint
28 |
29 |
30 | out_dir = "out/training"
31 | eval_interval = 2000
32 | eval_iters = 200
33 | log_interval = 1
34 | # compilation fails as it does not support torch.complex64 for RoPE
35 | # compile = False
36 |
37 | # Hyperparameters
38 | learning_rate = 6e-4
39 | batch_size = 2
40 | max_iters = 600000
41 | weight_decay = 1e-1
42 | beta1 = 0.9
43 | beta2 = 0.95
44 | grad_clip = 1.0
45 |
46 | # For shakespeare, choose smaller block size than vanilla LLaMA
47 | block_size = 1024
48 |
49 |
50 | def main() -> None:
51 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
52 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
53 |
54 | fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy)
55 | fabric.launch()
56 | fabric.seed_everything(1337 + fabric.global_rank)
57 |
58 | if fabric.global_rank == 0:
59 | os.makedirs(out_dir, exist_ok=True)
60 |
61 | train_data, val_data = load_datasets()
62 |
63 | config = LLaMAConfig.from_name("7B")
64 | config.block_size = block_size
65 | config.vocab_size = 100 # from prepare_shakespeare.py
66 |
67 | with fabric.device:
68 | model = LLaMA(config)
69 |
70 | # if compile:
71 | # model = torch.compile(model)
72 |
73 | model = fabric.setup_module(model)
74 |
75 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
76 | optimizer = fabric.setup_optimizers(optimizer)
77 |
78 | train(fabric, model, optimizer, train_data, val_data)
79 |
80 |
81 | def train(
82 | fabric: L.Fabric,
83 | model: torch.nn.Module,
84 | optimizer: torch.optim.Optimizer,
85 | train_data: np.ndarray,
86 | val_data: np.ndarray,
87 | ) -> None:
88 | """The training loop.
89 |
90 | Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
91 | """
92 |
93 | iter_num = 0
94 |
95 | while True:
96 | # TODO: add learning rate scheduling
97 |
98 | # evaluate the loss on train/val sets and write checkpoints
99 | if iter_num > 0 and iter_num % eval_interval == 0:
100 | val_loss = validate(fabric, model, val_data)
101 | fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
102 | fabric.print(f"Saving checkpoint to {out_dir}")
103 | save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
104 |
105 | t0 = time.time()
106 |
107 | input_ids, targets = get_batch(
108 | fabric,
109 | train_data,
110 | block_size=model.config.block_size, # type: ignore[union-attr,arg-type]
111 | )
112 | logits = model(input_ids)
113 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
114 |
115 | fabric.backward(loss)
116 |
117 | # TODO: Gradient clipping
118 | # if grad_clip != 0.0:
119 | # fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
120 |
121 | optimizer.step()
122 | optimizer.zero_grad()
123 |
124 | dt = time.time() - t0
125 | if iter_num % log_interval == 0:
126 | fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
127 | iter_num += 1
128 |
129 | if iter_num > max_iters:
130 | break
131 |
132 |
133 | @torch.no_grad()
134 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
135 | fabric.print("Validating ...")
136 | model.eval()
137 | losses = torch.zeros(eval_iters)
138 | for k in range(eval_iters):
139 | input_ids, targets = get_batch(
140 | fabric,
141 | val_data,
142 | block_size=model.config.block_size, # type: ignore[union-attr,arg-type]
143 | )
144 | logits = model(input_ids)
145 | loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
146 | losses[k] = loss.item()
147 | out = losses.mean()
148 | model.train()
149 | return out
150 |
151 |
152 | def get_batch(fabric: L.Fabric, data: np.ndarray, block_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
153 | ix = torch.randint(len(data) - block_size, (batch_size,))
154 | x = torch.stack([torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix])
155 | y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) for i in ix])
156 | x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
157 | return x, y
158 |
159 |
160 | def load_datasets(data_dir: str = "data/shakespeare") -> Tuple[np.ndarray, np.ndarray]:
161 | train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
162 | val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")
163 | return train_data, val_data
164 |
165 |
166 | if __name__ == "__main__":
167 | torch.set_float32_matmul_precision("high")
168 | main()
169 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "lit-llama"
7 | version = "0.1.0"
8 | description = "Implementation of the LLaMA language model"
9 | license = {text = "Apache-2.0"}
10 | authors = [
11 | { name = "Lightning AI", email = "community@lightning.ai" }
12 | ]
13 | readme = "README.md"
14 | requires-python = ">=3.10"
15 | dependencies = [
16 | "torch>=2.0.0",
17 | "lightning @ git+https://github.com/Lightning-AI/lightning@master",
18 | "sentencepiece",
19 | "bitsandbytes",
20 | ]
21 | classifiers = [
22 | "Topic :: Text Processing"
23 | ]
24 |
25 | [project.optional-dependencies]
26 | all = [
27 | "tqdm", # convert_checkpoint.py
28 | "numpy <2.0", # train.py dataset memmap
29 | "jsonargparse[signatures]", # generate.py, convert_checkpoint.py CLI
30 | "datasets", # evaluate.py
31 | "zstandard", # prepare_redpajama.py"
32 | ]
33 |
34 | [tool.setuptools.packages.find]
35 | where = ["."] # list of folders that contain the packages (["."] by default)
36 | include = ["lit_llama"] # package names should match these glob patterns (["*"] by default)
37 | exclude = [] # exclude packages matching these glob patterns (empty by default)
38 | namespaces = false # to disable scanning PEP 420 namespaces (true by default)
39 |
40 |
41 | [tool.pytest.ini_options]
42 | addopts = [
43 | "--strict-markers",
44 | "--color=yes",
45 | "--disable-pytest-warnings",
46 | ]
47 |
--------------------------------------------------------------------------------
/quantize/gptq.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # This adapts GPTQ's quantization process: https://github.com/IST-DASLab/gptq/
4 | # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
5 | # portions copyright by the authors licensed under the Apache License 2.0
6 | import gc
7 | import sys
8 | import time
9 | from pathlib import Path
10 | from typing import Optional
11 |
12 | import torch
13 | from datasets import load_dataset
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from lit_llama import LLaMA, Tokenizer
20 | from lit_llama.quantization import GPTQQuantizer
21 | from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup
22 |
23 |
24 | def get_sample_data():
25 | traindata = load_dataset(
26 | "allenai/c4",
27 | "allenai--c4",
28 | data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
29 | split="train",
30 | )
31 | # heuristic for the data size?
32 | txt = "\n".join(
33 | traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist()
34 | )
35 | return txt
36 |
37 |
38 | @torch.no_grad()
39 | def llama_blockwise_quantization(
40 | model, sample_inputs, working_device, *, bits=4, groupsize=-1
41 | ):
42 | """
43 | This is the classic post-training quantization of all linear layers.
44 | We quantize in order, i.e. when observing the inputs, we use the outputs of the previously quantized layers rather
45 | than doing them all at once.
46 | """
47 | print(model)
48 | print(model.config)
49 |
50 | print("Getting inputs for first block")
51 | model.transformer.wte.to(working_device)
52 | sample_inputs = sample_inputs.to(working_device)
53 | inps = model.transformer.wte(sample_inputs)
54 | model.transformer.wte.to("cpu")
55 | torch.cuda.empty_cache()
56 |
57 | rope_cache = model.build_rope_cache(sample_inputs)
58 | mask_cache = model.build_mask_cache(sample_inputs)
59 |
60 | print("Starting to quantize blocks")
61 | outs = torch.zeros_like(inps)
62 |
63 | # better than relying on enumeration? originally the code bundled
64 | # the two mlp fc layers
65 | # we could automate this with a lot of hooks and another iteration
66 | submodules_to_process = [
67 | "attn.c_attn",
68 | "attn.c_proj",
69 | "mlp.c_fc1",
70 | "mlp.c_fc2",
71 | "mlp.c_proj",
72 | ]
73 |
74 | for i, block in enumerate(model.transformer.h):
75 | block.to(working_device)
76 |
77 | for name in submodules_to_process:
78 | print(i, name, end=" ")
79 | t0 = time.perf_counter()
80 | print("collecting stats", end=" ")
81 | sys.stdout.flush()
82 | module = block.get_submodule(name)
83 |
84 | gptq = GPTQQuantizer(
85 | module,
86 | bits=bits,
87 | groupsize=groupsize,
88 | actorder=(groupsize == -1),
89 | )
90 | handle = module.register_forward_hook(gptq.collect_input_stats)
91 | for j in range(inps.size(0)):
92 | outs[j : j + 1], _ = block(
93 | inps[j : j + 1],
94 | rope=rope_cache,
95 | mask=mask_cache,
96 | max_seq_length=model.config.block_size
97 | )
98 |
99 | handle.remove()
100 |
101 | print("quantizing", end=" ")
102 | sys.stdout.flush()
103 | q_module, error = gptq.quantize()
104 |
105 | # replace the linear module with the quantized module
106 | pname, dname = name.rsplit(".", 1)
107 | setattr(block.get_submodule(pname), dname, q_module)
108 |
109 | # cleanup in an attempt to not run out of memory
110 | del gptq
111 | gc.collect()
112 | torch.cuda.empty_cache()
113 | t1 = time.perf_counter()
114 | print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}")
115 |
116 | for j in range(inps.size(0)):
117 | outs[j : j + 1], _ = block(
118 | inps[j : j + 1],
119 | rope=rope_cache,
120 | mask=mask_cache,
121 | max_seq_length=model.config.block_size
122 | )
123 |
124 | block.cpu()
125 | gc.collect()
126 | torch.cuda.empty_cache()
127 |
128 | # the outputs are the next block's inputs and we'll reuse the old inputs
129 | inps, outs = outs, inps
130 |
131 | model.transformer.ln_f.to(working_device)
132 | for j in range(inps.size(0)):
133 | outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1])
134 | model.transformer.ln_f.to("cpu")
135 | inps, outs = outs, inps
136 |
137 | model.lm_head.to(working_device)
138 | gptq = GPTQQuantizer(
139 | model.lm_head,
140 | bits=bits,
141 | groupsize=groupsize,
142 | actorder=(groupsize == -1),
143 | )
144 | handle = model.lm_head.register_forward_hook(gptq.collect_input_stats)
145 | for j in range(inps.size(0)):
146 | model.lm_head(inps[j : j + 1])
147 | handle.remove()
148 | q_module, error = gptq.quantize()
149 | model.lm_head = q_module
150 | model.lm_head.to("cpu")
151 |
152 |
153 | def main(
154 | *,
155 | checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
156 | output_path: Optional[Path] = None,
157 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
158 | n_samples: int = 128,
159 | dtype: str = "float32",
160 | quantize: Optional[str] = None,
161 | ) -> None:
162 | """Generates text samples based on a pre-trained LLaMA model and tokenizer.
163 |
164 | Args:
165 | checkpoint_path: The checkpoint path to load.
166 | output_path: Path to write the quantized model's state dict to.
167 | tokenizer_path: The tokenizer path to load.
168 | n_samples: Number of example inputs to use for statistics (default: 128)
169 | dtype: The dtype to use to load the model.
170 | quantize: Mode to quantize the model to:
171 | ``"gptq.int4"``: GPTQ 4-bit mode.
172 | Note that ``"llm.int8"```does not need a quantization step.
173 | """
174 | assert checkpoint_path.is_file()
175 | assert tokenizer_path.is_file()
176 | if output_path is None:
177 | output_path = checkpoint_path.parent / "llama-gptq.4bit.pth"
178 | assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file())
179 |
180 | device = "cuda"
181 |
182 | dt = getattr(torch, dtype, None)
183 | if not isinstance(dt, torch.dtype):
184 | raise ValueError(f"{dtype} is not a valid dtype.")
185 | dtype = dt
186 |
187 | if quantize == "gptq.int4":
188 | bits = 4
189 | elif quantize == "gptq.int8":
190 | bits = 8
191 | else:
192 | raise RuntimeError(f"unknown/unsupported quantization mode {quantize}")
193 |
194 | # we avoid loading the entire model on the GPU and do this block by block
195 | with EmptyInitOnDevice(
196 | device="cpu",
197 | dtype=dtype,
198 | ):
199 | print("Loading model ...", file=sys.stderr)
200 | t0 = time.time()
201 | checkpoint = torch.load(checkpoint_path)
202 | name = llama_model_lookup(checkpoint)
203 | model = LLaMA.from_name(name)
204 | model.load_state_dict(checkpoint)
205 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
206 |
207 | model.eval()
208 |
209 | tokenizer = Tokenizer(tokenizer_path)
210 |
211 | test_string = get_sample_data()
212 | encoded_text = tokenizer.encode(
213 | test_string,
214 | bos=True,
215 | eos=False,
216 | )
217 | block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
218 | encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size)
219 |
220 | t0 = time.perf_counter()
221 | llama_blockwise_quantization(model, encoded_text, device, bits=bits)
222 | t = time.perf_counter() - t0
223 |
224 | print(
225 | f"\n\nTime for quantization: {t:.02f} sec total",
226 | file=sys.stderr,
227 | )
228 | print(
229 | f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
230 | file=sys.stderr,
231 | )
232 |
233 | torch.save(model.state_dict(), output_path)
234 |
235 |
236 | if __name__ == "__main__":
237 | from jsonargparse import CLI
238 |
239 | torch.set_float32_matmul_precision("high")
240 | CLI(main)
241 |
--------------------------------------------------------------------------------
/scripts/convert_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import gc
4 | import shutil
5 | from pathlib import Path
6 | from typing import Dict
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | """
12 | Sample usage:
13 |
14 | ```bash
15 | python -m scripts.convert_checkpoint -h
16 |
17 | python -m scripts.convert_checkpoint converted
18 | ```
19 | """
20 |
21 |
22 | def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
23 | converted = {}
24 | converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype)
25 | converted["lm_head.weight"] = state_dict["output.weight"].to(dtype)
26 | converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype)
27 |
28 | for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
29 | # attention
30 | # the wq, wk, wv from the FB model are stacked in our model as c_attn
31 | converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
32 | (
33 | state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype),
34 | state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype),
35 | state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype),
36 | )
37 | )
38 | converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[
39 | f"layers.{layer_idx}.attention.wo.weight"
40 | ].to(dtype)
41 | # mlp
42 | converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[
43 | f"layers.{layer_idx}.feed_forward.w1.weight"
44 | ].to(dtype)
45 | converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[
46 | f"layers.{layer_idx}.feed_forward.w2.weight"
47 | ].to(dtype)
48 | converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[
49 | f"layers.{layer_idx}.feed_forward.w3.weight"
50 | ].to(dtype)
51 | # rms norm
52 | converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype)
53 | converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype)
54 | return converted
55 |
56 |
57 | shard_dims = {
58 | "lm_head.weight": 0,
59 | "wte.weight": 1,
60 | "attn.c_attn.weight": 0,
61 | "attn.c_proj.weight": 1,
62 | "mlp.c_fc1.weight": 0,
63 | "mlp.c_fc2.weight": 0,
64 | "mlp.c_proj.weight": 1
65 | }
66 |
67 |
68 | def meta_weights_for_nano_model(
69 | *,
70 | output_dir: Path = Path("checkpoints/lit-llama"),
71 | checkpoint_dir: Path = Path("checkpoints/llama/"),
72 | model_size: str = "7B",
73 | dtype: str = "float32",
74 | ) -> None:
75 | output_dir = output_dir / model_size
76 | checkpoint_dir = checkpoint_dir / model_size
77 | output_dir.mkdir(parents=True, exist_ok=True)
78 |
79 | # the tokenizer is the same for all model sizes, so we store it in the parent dir
80 | shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent)
81 |
82 | dt = getattr(torch, dtype, None)
83 | if not isinstance(dt, torch.dtype):
84 | raise ValueError(f"{dtype} is not a valid dtype.")
85 | dtype = dt
86 |
87 | checkpoint_files = sorted(checkpoint_dir.glob("*.pth"))
88 | checkpoint_files.sort()
89 | n_checkpoints = len(checkpoint_files)
90 |
91 | if n_checkpoints == 0:
92 | raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.")
93 |
94 | # for the bigger models, there are multiple model-parallel checkpoints
95 | # and we combine them into one single file
96 | combined = None
97 | for file in tqdm(checkpoint_files, total=n_checkpoints):
98 | checkpoint = torch.load(file, map_location="cpu")
99 | converted = convert_state_dict(checkpoint, dtype=dtype)
100 | if combined is None:
101 | combined = converted
102 | continue
103 | for name, param in converted.items():
104 | dim = None
105 | for k, d in shard_dims.items():
106 | if k in name:
107 | dim = d
108 | break
109 | if dim is None:
110 | # Extra check: assert that tensors are the same if not sharded
111 | # assert torch.allclose(combined[name], param)
112 | continue
113 | combined[name] = torch.cat((combined[name], param), dim=dim)
114 |
115 | del checkpoint
116 | del converted
117 | gc.collect()
118 |
119 | for name, param in combined.items():
120 | if "c_attn" not in name:
121 | continue
122 |
123 | # Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...]
124 |
125 | src_chunk_len = param.shape[0] // n_checkpoints
126 | mat_len = src_chunk_len // 3
127 | dst_chunk_len = mat_len * n_checkpoints
128 | attn = torch.clone(param)
129 | for i in range(n_checkpoints):
130 | for j in range(3):
131 | param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \
132 | attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len]
133 |
134 | del attn
135 | gc.collect()
136 |
137 | torch.save(combined, output_dir / "lit-llama.pth")
138 |
139 |
140 | if __name__ == "__main__":
141 | from jsonargparse import CLI
142 |
143 | CLI(meta_weights_for_nano_model)
144 |
--------------------------------------------------------------------------------
/scripts/convert_hf_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import collections
4 | import contextlib
5 | import gc
6 | import json
7 | import shutil
8 | import sys
9 | from pathlib import Path
10 |
11 | import torch
12 |
13 | # support running without installing as a package
14 | wd = Path(__file__).parent.parent.resolve()
15 | sys.path.append(str(wd))
16 |
17 | from lit_llama.model import LLaMA, LLaMAConfig
18 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, incremental_save
19 |
20 |
21 | @torch.no_grad()
22 | def convert_hf_checkpoint(
23 | *,
24 | output_dir: Path = Path("checkpoints/lit-llama/7B"),
25 | checkpoint_dir: Path = Path("checkpoints/hf-llama/7B"),
26 | model_size: str = "7B",
27 | dtype: str = "float32",
28 | verify: bool = False,
29 | ) -> None:
30 | """
31 | Perform the reverse operation of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
32 | """
33 | output_dir.mkdir(parents=True, exist_ok=True)
34 |
35 | # the tokenizer is the same for all model sizes, so we store it in the parent dir
36 | shutil.copy(checkpoint_dir / "tokenizer.model", output_dir.parent)
37 |
38 | dt = getattr(torch, dtype, None)
39 | if not isinstance(dt, torch.dtype):
40 | raise ValueError(f"{dtype} is not a valid dtype.")
41 | dtype = dt
42 |
43 | print("Initializing lit-llama")
44 | config = LLaMAConfig.from_name(model_size)
45 |
46 | with EmptyInitOnDevice(device="meta", dtype=dtype):
47 | model = LLaMA(config)
48 |
49 | qkv_size = model.transformer.h[0].attn.c_attn.weight.shape[0] // 3
50 |
51 | # initialize a new empty state dict to hold our new weights
52 | sd_meta = model.state_dict()
53 | sd = {}
54 |
55 | # Load the json file containing weight mapping
56 | pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
57 | with open(pytorch_bin_map_json_path) as json_map:
58 | bin_index = json.load(json_map)
59 | bin_files = set(checkpoint_dir / bin for bin in bin_index["weight_map"].values())
60 | if not bin_files:
61 | raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
62 |
63 | def permute(w):
64 | dim = config.n_embd
65 | w = w._load_tensor().to(dtype)
66 | return (
67 | w.view(config.n_head, 2, dim // config.n_head // 2, dim)
68 | .transpose(1, 2)
69 | .reshape(dim, dim)
70 | )
71 |
72 | weight_map = {
73 | "self_attn.o_proj.weight": "attn.c_proj.weight",
74 | "self_attn.q_proj.weight": "attn.c_attn.weight",
75 | "self_attn.k_proj.weight": "attn.c_attn.weight",
76 | "self_attn.v_proj.weight": "attn.c_attn.weight",
77 | "mlp.gate_proj.weight": "mlp.c_fc1.weight",
78 | "mlp.up_proj.weight": "mlp.c_fc2.weight",
79 | "mlp.down_proj.weight": "mlp.c_proj.weight",
80 | "input_layernorm.weight": "rms_1.scale",
81 | "post_attention_layernorm.weight": "rms_2.scale",
82 | "model.embed_tokens.weight": "transformer.wte.weight",
83 | "model.norm.weight": "transformer.ln_f.scale",
84 | "lm_head.weight": "lm_head.weight",
85 | }
86 |
87 | print(f"Saving to disk at {output_dir}")
88 | unprocessed_weights = collections.defaultdict(dict)
89 |
90 | with incremental_save(output_dir / "lit-llama.pth") as saver:
91 | # for checkpoints that split the QKV across several files, we need to keep all the bin files
92 | # open, so we use `ExitStack` to close them all together at the end
93 | with contextlib.ExitStack() as stack:
94 | for bin_file in bin_files:
95 | print("Processing", bin_file)
96 | hf_weights = stack.enter_context(lazy_load(bin_file))
97 | for name, param in hf_weights.items():
98 | skip = False
99 | if "rotary_emb.inv_freq" in name:
100 | continue
101 | if "model.layers" in name:
102 | block_id = int(name.split(".")[2])
103 | from_name = ".".join(name.split(".")[3:])
104 | to_name = weight_map[from_name]
105 | sd_key = f"transformer.h.{block_id}.{to_name}"
106 |
107 | if "q_proj" in name:
108 | unprocessed_weights[sd_key]["q_proj"] = param
109 | skip = True
110 | elif "k_proj" in name:
111 | unprocessed_weights[sd_key]["k_proj"] = param
112 | skip = True
113 | elif "v_proj" in name:
114 | unprocessed_weights[sd_key]["v_proj"] = param
115 | skip = True
116 | if skip and len(unprocessed_weights[sd_key]) == 3:
117 | w = torch.empty(
118 | sd_meta[sd_key].shape, dtype=sd_meta[sd_key].dtype
119 | )
120 | w[:qkv_size] = permute(unprocessed_weights[sd_key]["q_proj"])
121 | w[qkv_size:-qkv_size] = permute(
122 | unprocessed_weights[sd_key]["k_proj"]
123 | )
124 | w[-qkv_size:] = (
125 | unprocessed_weights[sd_key]["v_proj"]
126 | ._load_tensor()
127 | .to(dtype)
128 | )
129 | sd[sd_key] = w
130 | del unprocessed_weights[sd_key]
131 | skip = False
132 | else:
133 | sd[sd_key] = param._load_tensor().to(dtype)
134 | else:
135 | sd_key = weight_map[name]
136 | sd[sd_key] = param._load_tensor().to(dtype)
137 | if not skip:
138 | sd[sd_key] = saver.store_early(sd[sd_key])
139 | gc.collect()
140 | saver.save(sd)
141 |
142 | assert len(unprocessed_weights) == 0, f"unexpected partial weights {list(unprocessed_weights)}"
143 | if verify:
144 | try:
145 | from transformers import LlamaForCausalLM
146 | except ImportError:
147 | raise ImportError("verify=True requires transformers to be installed, please `pip install transformers`")
148 | print("Verifying...")
149 |
150 | token_sample = torch.randint(0, config.vocab_size, size=(1, config.block_size), dtype=torch.int64)
151 | out = model(token_sample)
152 | del model
153 | gc.collect()
154 |
155 | print("Loading original model for comparison")
156 | model_hf = LlamaForCausalLM.from_pretrained(checkpoint_dir)
157 | out_hf = model_hf(token_sample)["logits"]
158 |
159 | print("Comparing outputs")
160 | assert out.device.type == out_hf.device.type
161 | assert out.dtype == out_hf.dtype
162 | assert torch.testing.assert_close(out, out_hf)
163 |
164 |
165 | if __name__ == "__main__":
166 | from jsonargparse import CLI
167 |
168 | CLI(convert_hf_checkpoint)
169 |
170 |
--------------------------------------------------------------------------------
/scripts/convert_lora_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | import time
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | import lightning as L
9 | import torch
10 | import torch.nn as nn
11 |
12 | # support running without installing as a package
13 | wd = Path(__file__).parent.parent.resolve()
14 | sys.path.append(str(wd))
15 |
16 | from lit_llama import LLaMA
17 | from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
18 | from lit_llama.lora import lora
19 |
20 | def del_lora_state_dict(model: nn.Module):
21 | base_model_dict = model.state_dict()
22 | key_to_delete = [k for k in base_model_dict if "lora_" in k]
23 | for del_key in key_to_delete:
24 | del base_model_dict[del_key]
25 | return base_model_dict
26 |
27 |
28 | def lora_model_lookup(checkpoint: dict) -> int:
29 | """Returns the LoRA rank from the adapter checkpoint.
30 |
31 | """
32 | return checkpoint["transformer.h.0.attn.c_attn.lora_B"].shape[1]
33 |
34 |
35 | def main(
36 | accelerator: str = "auto",
37 | lora_path: Optional[Path] = None,
38 | checkpoint_path: Optional[Path] = None,
39 | dtype: str = "bfloat16",
40 | ) -> None:
41 | """Merges lora weights to base model.
42 |
43 | Args:
44 | accelerator: The hardware to run on. Possible choices are:
45 | ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
46 | lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
47 | `finetune_lora.py`.
48 | checkpoint_path: The checkpoint path to load.
49 | dtype: `torch.dtype` to work with
50 | """
51 | if not lora_path:
52 | lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth")
53 | if not checkpoint_path:
54 | checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
55 |
56 | assert lora_path.is_file()
57 | assert checkpoint_path.is_file()
58 |
59 | fabric = L.Fabric(accelerator=accelerator, devices=1)
60 |
61 | dt = getattr(torch, dtype, None)
62 | if not isinstance(dt, torch.dtype):
63 | raise ValueError(f"{dtype} is not a valid dtype.")
64 | dtype = dt
65 |
66 | print("Loading model ...", file=sys.stderr)
67 | t0 = time.time()
68 |
69 | with (lazy_load(checkpoint_path) as pretrained_checkpoint,
70 | lazy_load(lora_path) as lora_checkpoint):
71 | name = llama_model_lookup(pretrained_checkpoint)
72 | rank = lora_model_lookup(lora_checkpoint)
73 |
74 | with EmptyInitOnDevice(
75 | device=fabric.device, dtype=dtype
76 | ), lora(r=rank, alpha=16, dropout=0.05, enabled=True):
77 | model = LLaMA.from_name(name)
78 |
79 | # 1. Load the pretrained weights
80 | model.load_state_dict(pretrained_checkpoint, strict=False)
81 | # 2. Load the fine-tuned lora weights
82 | model.load_state_dict(lora_checkpoint, strict=False)
83 |
84 | print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
85 |
86 | model.eval()
87 | base_model_dict = del_lora_state_dict(model)
88 | save_path = lora_path.with_stem(f"{lora_path.stem}-lora-merged-weights")
89 | print("Saving LoRA to base model weights ...")
90 | torch.save(base_model_dict, save_path)
91 | print(f"Model saved at {save_path}")
92 |
93 |
94 | if __name__ == "__main__":
95 | from jsonargparse import CLI
96 |
97 | CLI(main)
--------------------------------------------------------------------------------
/scripts/download.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import os
4 | from typing import Optional
5 | from urllib.request import urlretrieve
6 |
7 | files = {
8 | "original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py",
9 | "original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py",
10 | }
11 |
12 |
13 | def download_original(wd: str) -> None:
14 | for file, url in files.items():
15 | filepath = os.path.join(wd, file)
16 | if not os.path.isfile(filepath):
17 | print(f"Downloading original implementation to {filepath!r}")
18 | urlretrieve(url=url, filename=file)
19 | print("Done")
20 | else:
21 | print("Original implementation found. Skipping download.")
22 |
23 |
24 | def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None:
25 | if repo_id is None:
26 | raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.")
27 |
28 | from huggingface_hub import snapshot_download
29 |
30 | snapshot_download(repo_id, local_dir=local_dir, local_dir_use_symlinks=False)
31 |
32 |
33 | if __name__ == "__main__":
34 | from jsonargparse import CLI
35 |
36 | CLI(download_from_hub)
37 |
--------------------------------------------------------------------------------
/scripts/prepare_alpaca.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | """Implementation derived from https://github.com/tloen/alpaca-lora"""
4 | import sys
5 | from pathlib import Path
6 |
7 | # support running without installing as a package
8 | wd = Path(__file__).parent.parent.resolve()
9 | sys.path.append(str(wd))
10 |
11 | import torch
12 | import requests
13 | import json
14 | from torch.utils.data import random_split
15 | from lit_llama.tokenizer import Tokenizer
16 | from tqdm import tqdm
17 |
18 |
19 | DATA_FILE = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"
20 | DATA_FILE_NAME = "alpaca_data_cleaned_archive.json"
21 | IGNORE_INDEX = -1
22 |
23 |
24 | def prepare(
25 | destination_path: Path = Path("data/alpaca"),
26 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
27 | test_split_size: int = 2000,
28 | max_seq_length: int = 256,
29 | seed: int = 42,
30 | mask_inputs: bool = False, # as in alpaca-lora
31 | data_file_name: str = DATA_FILE_NAME
32 | ) -> None:
33 | """Prepare the Alpaca dataset for instruction tuning.
34 |
35 | The output is a training and validation dataset saved as `train.pt` and `val.pt`,
36 | which stores the preprocessed and tokenized prompts and labels.
37 | """
38 |
39 | destination_path.mkdir(parents=True, exist_ok=True)
40 | file_path = destination_path / data_file_name
41 | download(file_path)
42 |
43 | # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
44 | tokenizer = Tokenizer(tokenizer_path)
45 |
46 | with open(file_path, "r") as file:
47 | data = json.load(file)
48 |
49 | # Partition the dataset into train and test
50 | train_split_size = len(data) - test_split_size
51 | train_set, test_set = random_split(
52 | data,
53 | lengths=(train_split_size, test_split_size),
54 | generator=torch.Generator().manual_seed(seed),
55 | )
56 | train_set, test_set = list(train_set), list(test_set)
57 |
58 | print(f"train has {len(train_set):,} samples")
59 | print(f"val has {len(test_set):,} samples")
60 |
61 | print("Processing train split ...")
62 | train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
63 | torch.save(train_set, file_path.parent / "train.pt")
64 |
65 | print("Processing test split ...")
66 | test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
67 | torch.save(test_set, file_path.parent / "test.pt")
68 |
69 |
70 | def download(file_path: Path):
71 | """Downloads the raw json data file and saves it in the given destination."""
72 | if file_path.exists():
73 | return
74 | with open(file_path, "w") as f:
75 | f.write(requests.get(DATA_FILE).text)
76 |
77 |
78 | def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
79 | """Processes a single sample.
80 |
81 | Each sample in the dataset consists of:
82 | - instruction: A string describing the task
83 | - input: A string holding a special input value for the instruction.
84 | This only applies to some samples, and in others this is empty.
85 | - output: The response string
86 |
87 | This function processes this data to produce a prompt text and a label for
88 | supervised training. The input text is formed as a single message including all
89 | the instruction, the input (optional) and the response.
90 | The label/target is the same message but can optionally have the instruction + input text
91 | masked out (mask_inputs=True).
92 |
93 | Finally, both the prompt and the label get tokenized. If desired, all tokens
94 | in the label that correspond to the original input prompt get masked out (default).
95 | """
96 | full_prompt = generate_prompt(example)
97 | full_prompt_and_response = full_prompt + example["output"]
98 | encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
99 | encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
100 |
101 | # The labels are the full prompt with response, but with the prompt masked out
102 | labels = encoded_full_prompt_and_response.clone()
103 | if mask_inputs:
104 | labels[:len(encoded_full_prompt)] = IGNORE_INDEX
105 |
106 | return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
107 |
108 |
109 | def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
110 | return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
111 |
112 |
113 | def generate_prompt(example):
114 | """Generates a standardized message to prompt the model with an instruction, optional input and a
115 | 'response' field."""
116 |
117 | if example["input"]:
118 | return (
119 | "Below is an instruction that describes a task, paired with an input that provides further context. "
120 | "Write a response that appropriately completes the request.\n\n"
121 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
122 | )
123 | return (
124 | "Below is an instruction that describes a task. "
125 | "Write a response that appropriately completes the request.\n\n"
126 | f"### Instruction:\n{example['instruction']}\n\n### Response:"
127 | )
128 |
129 |
130 | if __name__ == "__main__":
131 | from jsonargparse import CLI
132 |
133 | CLI(prepare)
134 |
--------------------------------------------------------------------------------
/scripts/prepare_any_text.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | """Implementation derived from https://github.com/tloen/alpaca-lora"""
4 | import sys
5 | from pathlib import Path
6 |
7 | # support running without installing as a package
8 | wd = Path(__file__).parent.parent.resolve()
9 | sys.path.append(str(wd))
10 |
11 | import torch
12 | import requests
13 | import json
14 | from torch.utils.data import random_split
15 | from lit_llama.tokenizer import Tokenizer
16 | from tqdm import tqdm
17 |
18 |
19 | IGNORE_INDEX = -1
20 |
21 | DATA_FILE_NAME = "input.txt"
22 |
23 |
24 | def prepare(
25 | destination_path: Path = Path("data/any"),
26 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
27 | test_split_ratio: float = 0.9, # default 90% train, 10% validation
28 | max_seq_length: int = 256,
29 | seed: int = 42,
30 | data_file_name: str = DATA_FILE_NAME,
31 | ) -> None:
32 | """Prepare any dataset for finetuning (akin to Shakespheare full tuning).
33 |
34 | The output is a training and validation dataset saved as `train.pt` and `val.pt`,
35 | which stores the preprocessed and tokenized prompts and labels.
36 | """
37 |
38 | destination_path.mkdir(parents=True, exist_ok=True)
39 | file_path = destination_path / data_file_name
40 | if not file_path.exists():
41 | raise AssertionError(f"{data_file_name} is provided by the user")
42 |
43 | # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
44 | tokenizer = Tokenizer(tokenizer_path)
45 |
46 | data = []
47 |
48 | with open(file_path, "r") as input_file:
49 | for line in input_file.readlines():
50 | data.append(line)
51 |
52 | # Partition the dataset into train and test
53 | train_split_size = int(len(data) * test_split_ratio)
54 | test_split_size = len(data) - train_split_size
55 | train_set, test_set = random_split(
56 | data,
57 | lengths=(train_split_size, test_split_size),
58 | generator=torch.Generator().manual_seed(seed),
59 | )
60 | train_set, test_set = list(train_set), list(test_set)
61 |
62 | print(f"train has {len(train_set):,} samples")
63 | print(f"val has {len(test_set):,} samples")
64 |
65 | print("Processing train split ...")
66 | train_set = [
67 | prepare_line(line, tokenizer, max_seq_length) for line in tqdm(train_set)
68 | ]
69 | torch.save(train_set, file_path.parent / "train.pt")
70 |
71 | print("Processing test split ...")
72 | test_set = [
73 | prepare_line(line, tokenizer, max_seq_length) for line in tqdm(test_set)
74 | ]
75 | torch.save(test_set, file_path.parent / "test.pt")
76 |
77 |
78 | def prepare_line(line: str, tokenizer: Tokenizer, max_length: int):
79 | """Processes a single sample.
80 |
81 | This function processes the line to produce the tokenized version of it.
82 | """
83 | encoded_full_prompt = tokenize(tokenizer, line, max_length=max_length, eos=False)
84 | return {
85 | "input_ids": encoded_full_prompt,
86 | "labels": encoded_full_prompt,
87 | }
88 |
89 |
90 | def tokenize(
91 | tokenizer: Tokenizer, string: str, max_length: int, eos=True
92 | ) -> torch.Tensor:
93 | return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
94 |
95 |
96 | if __name__ == "__main__":
97 | from jsonargparse import CLI
98 |
99 | CLI(prepare)
100 |
--------------------------------------------------------------------------------
/scripts/prepare_dolly.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | """Implementation derived from https://github.com/tloen/alpaca-lora"""
4 | import sys
5 | from pathlib import Path
6 |
7 | # support running without installing as a package
8 | wd = Path(__file__).parent.parent.resolve()
9 | sys.path.append(str(wd))
10 |
11 | import torch
12 | import requests
13 | import json
14 | from torch.utils.data import random_split
15 | from lit_llama.tokenizer import Tokenizer
16 | from tqdm import tqdm
17 |
18 |
19 | DATA_FILE = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
20 | DATA_FILE_NAME = "dolly_data_cleaned.json"
21 | IGNORE_INDEX = -1
22 |
23 |
24 | def prepare(
25 | destination_path: Path = Path("data/dolly"),
26 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
27 | test_split_size: int = 2000,
28 | max_seq_length: int = 1024,
29 | seed: int = 42,
30 | mask_inputs: bool = False, # as in alpaca-lora
31 | ) -> None:
32 | """Prepare the Dolly dataset for instruction tuning.
33 |
34 | The output is a training and validation dataset saved as `train.pt` and `val.pt`,
35 | which stores the preprocessed and tokenized prompts and labels.
36 | """
37 |
38 | destination_path.mkdir(parents=True, exist_ok=True)
39 | file_path = destination_path / DATA_FILE_NAME
40 | download(file_path)
41 |
42 | # TODO: If we don't have the Meta weights, where do we get the tokenizer from?
43 | tokenizer = Tokenizer(tokenizer_path)
44 |
45 | with open(file_path, "r") as file:
46 | data = file.readlines()
47 | data = [json.loads(line) for line in data]
48 | for item in data:
49 | item["input"] = item.pop("context")
50 | item["output"] = item.pop("response")
51 |
52 | # Partition the dataset into train and test
53 | train_split_size = len(data) - test_split_size
54 | train_set, test_set = random_split(
55 | data,
56 | lengths=(train_split_size, test_split_size),
57 | generator=torch.Generator().manual_seed(seed),
58 | )
59 | train_set, test_set = list(train_set), list(test_set)
60 |
61 | print(f"train has {len(train_set):,} samples")
62 | print(f"val has {len(test_set):,} samples")
63 |
64 | print("Processing train split ...")
65 | train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
66 | torch.save(train_set, file_path.parent / "train.pt")
67 |
68 | print("Processing test split ...")
69 | test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
70 | torch.save(test_set, file_path.parent / "test.pt")
71 |
72 |
73 | def download(file_path: Path):
74 | """Downloads the raw json data file and saves it in the given destination."""
75 | if file_path.exists():
76 | return
77 | with open(file_path, "w") as f:
78 | f.write(requests.get(DATA_FILE).text)
79 |
80 |
81 | def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
82 | """Processes a single sample.
83 |
84 | Each sample in the dataset consists of:
85 | - instruction: A string describing the task
86 | - input: A string holding a special input value for the instruction.
87 | This only applies to some samples, and in others this is empty.
88 | - output: The response string
89 |
90 | This function processes this data to produce a prompt text and a label for
91 | supervised training. The prompt text is formed as a single message including both
92 | the instruction and the input. The label/target is the same message but with the
93 | response attached.
94 |
95 | Finally, both the prompt and the label get tokenized. If desired, all tokens
96 | in the label that correspond to the original input prompt get masked out (default).
97 | """
98 | full_prompt = generate_prompt(example)
99 | full_prompt_and_response = full_prompt + example["output"]
100 | encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
101 | encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
102 |
103 | # The labels are the full prompt with response, but with the prompt masked out
104 | labels = encoded_full_prompt_and_response.clone()
105 | if mask_inputs:
106 | labels[:len(encoded_full_prompt)] = IGNORE_INDEX
107 |
108 | return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
109 |
110 |
111 | def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
112 | return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
113 |
114 |
115 | def generate_prompt(example):
116 | """Generates a standardized message to prompt the model with an instruction, optional input and a
117 | 'response' field."""
118 |
119 | if example["input"]:
120 | return (
121 | f"Below is an instruction that describes a task, paired with an input that provides further context. "
122 | "Write a response that appropriately completes the request.\n\n"
123 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
124 | )
125 | return (
126 | f"Below is an instruction that describes a task. "
127 | "Write a response that appropriately completes the request.\n\n"
128 | f"### Instruction:\n{example['instruction']}\n\n### Response:"
129 | )
130 |
131 |
132 | if __name__ == "__main__":
133 | from jsonargparse import CLI
134 |
135 | CLI(prepare)
136 |
--------------------------------------------------------------------------------
/scripts/prepare_redpajama.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import json
4 | import glob
5 | import os
6 | from pathlib import Path
7 | import sys
8 |
9 | # support running without installing as a package
10 | wd = Path(__file__).parent.parent.resolve()
11 | sys.path.append(str(wd))
12 |
13 | import numpy as np
14 | from tqdm import tqdm
15 |
16 | from lit_llama import Tokenizer
17 | import lit_llama.packed_dataset as packed_dataset
18 |
19 |
20 | filenames_sample = [
21 | "arxiv_sample.jsonl",
22 | "book_sample.jsonl",
23 | "c4_sample.jsonl",
24 | "cc_2019-30_sample.jsonl",
25 | "cc_2020-05_sample.jsonl",
26 | "cc_2021-04_sample.jsonl",
27 | "cc_2022-05_sample.jsonl",
28 | "cc_2023-06_sample.jsonl",
29 | "github_sample.jsonl",
30 | "stackexchange_sample.jsonl",
31 | "wikipedia_sample.jsonl",
32 | ]
33 |
34 | filename_sets = {
35 | "arxiv": "arxiv/arxiv*",
36 | "book": "book/book*",
37 | "c4": "c4/c4-train*",
38 | "common_crawl": "common_crawl/*",
39 | "github": "github/filtered*",
40 | "stackexchange": "stackexchange/stackexchange*",
41 | "wikipedia": "wikipedia/wiki*",
42 | }
43 |
44 |
45 | def prepare_sample(
46 | source_path: Path,
47 | tokenizer_path: Path,
48 | destination_path: Path,
49 | chunk_size: int,
50 | match = ""
51 | ) -> None:
52 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model)."""
53 | destination_path.mkdir(parents=True, exist_ok=True)
54 |
55 | tokenizer = Tokenizer(tokenizer_path)
56 |
57 | for name in filenames_sample:
58 | if match and match not in name:
59 | continue
60 |
61 | filepath = source_path / name
62 |
63 | if not filepath.is_file():
64 | raise RuntimeError(
65 | f"Input file not found at {filepath}. \n"
66 | "Make sure you download the data, e.g. wget -i https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through \n"
67 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T \n"
68 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n"
69 | )
70 |
71 | prefix, _ = os.path.splitext(name)
72 |
73 | builder = packed_dataset.PackedDatasetBuilder(
74 | outdir=destination_path,
75 | prefix=prefix,
76 | chunk_size=chunk_size,
77 | sep_token=tokenizer.bos_id,
78 | dtype="auto",
79 | vocab_size=tokenizer.vocab_size,
80 | )
81 |
82 | print(f"Processing {name}")
83 |
84 | with open(filepath, encoding="utf-8") as f:
85 | for row in tqdm(f):
86 | text = json.loads(row)["text"]
87 | text_ids = tokenizer.encode(text)
88 | builder.add_array(np.array(text_ids, dtype=builder.dtype))
89 |
90 | builder.write_reminder()
91 |
92 |
93 | def prepare_full(
94 | source_path: Path,
95 | tokenizer_path: Path,
96 | destination_path: Path,
97 | chunk_size: int,
98 | match: str = ""
99 | ) -> None:
100 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model)."""
101 | import zstandard as zstd
102 |
103 | destination_path.mkdir(parents=True, exist_ok=True)
104 |
105 | tokenizer = Tokenizer(tokenizer_path)
106 |
107 | for set_name, pattern in filename_sets.items():
108 | if match and match not in set_name:
109 | continue
110 |
111 | is_cc = set_name == "common_crawl"
112 |
113 | filenames = glob.glob(os.path.join(source_path, pattern), recursive=True)
114 |
115 | if not filenames:
116 | raise RuntimeError(
117 | f"No files matching {pattern} found at {source_path}. \n"
118 | "Make sure you download the data, e.g. wget -i https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through \n"
119 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T \n"
120 | "https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n"
121 | )
122 |
123 | builder = packed_dataset.PackedDatasetBuilder(
124 | outdir=destination_path,
125 | prefix=set_name,
126 | chunk_size=chunk_size,
127 | sep_token=tokenizer.bos_id,
128 | dtype="auto",
129 | vocab_size=tokenizer.vocab_size,
130 | )
131 |
132 | for name in filenames:
133 | filepath = source_path / name
134 |
135 | print(f"Processing {name}")
136 |
137 | if is_cc:
138 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
139 | for row in tqdm(f):
140 | text = json.loads(row)["text"]
141 | text_ids = tokenizer.encode(text)
142 | builder.add_array(np.array(text_ids, dtype=builder.dtype))
143 | else:
144 | with open(filepath, encoding="utf-8") as f:
145 | for row in tqdm(f):
146 | text = json.loads(row)["text"]
147 | text_ids = tokenizer.encode(text)
148 | builder.add_array(np.array(text_ids, dtype=builder.dtype))
149 |
150 | builder.write_reminder()
151 |
152 |
153 | def prepare(
154 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"),
155 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
156 | destination_path: Path = Path("data/red_pajama_sample"),
157 | chunk_size: int = 2049 * 1024, # 2048 block size + 1 for causal (from LLama), 1024 blocks
158 | sample: bool = False,
159 | match: str = "",
160 | ) -> None:
161 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained (i.e. we reuse LLaMA's tokenizer model)."""
162 | if sample:
163 | prepare_sample(
164 | source_path=source_path,
165 | tokenizer_path=tokenizer_path,
166 | destination_path=destination_path,
167 | chunk_size=chunk_size,
168 | match=match,
169 | )
170 | else:
171 | prepare_full(
172 | source_path=source_path,
173 | tokenizer_path=tokenizer_path,
174 | destination_path=destination_path,
175 | chunk_size=chunk_size,
176 | match=match,
177 | )
178 |
179 |
180 | if __name__ == "__main__":
181 | from jsonargparse import CLI
182 |
183 | CLI(prepare)
184 |
--------------------------------------------------------------------------------
/scripts/prepare_shakespeare.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | # MIT License
4 |
5 | # Copyright (c) 2022 Andrej Karpathy
6 |
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 |
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 |
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 | import sys
25 | from pathlib import Path
26 |
27 | # support running without installing as a package
28 | wd = Path(__file__).parent.parent.resolve()
29 | sys.path.append(str(wd))
30 |
31 | import numpy as np
32 | import requests
33 |
34 |
35 | def prepare(destination_path: Path = Path("data/shakespeare")) -> None:
36 | """Prepare the "Tiny Shakespeare" dataset."""
37 | destination_path.mkdir(parents=True, exist_ok=True)
38 |
39 | # download the tiny shakespeare dataset
40 | input_file_path = destination_path / "input.txt"
41 | if not input_file_path.exists():
42 | data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
43 | with open(input_file_path, "w") as f:
44 | f.write(requests.get(data_url).text)
45 |
46 | with open(input_file_path) as f:
47 | data = f.read()
48 | n = len(data)
49 | train_data = data[: int(n * 0.9)]
50 | val_data = data[int(n * 0.9) :]
51 |
52 | from lit_llama import Tokenizer
53 |
54 | Tokenizer.train(input=input_file_path, destination=destination_path, vocab_size=100)
55 | tokenizer = Tokenizer(destination_path / "tokenizer.model")
56 | train_ids = tokenizer.encode(train_data)
57 | val_ids = tokenizer.encode(val_data)
58 | print(f"train has {len(train_ids):,} tokens")
59 | print(f"val has {len(val_ids):,} tokens")
60 |
61 | # export to bin files
62 | train_ids = np.array(train_ids, dtype=np.uint16)
63 | val_ids = np.array(val_ids, dtype=np.uint16)
64 | train_ids.tofile(destination_path / "train.bin")
65 | val_ids.tofile(destination_path / "val.bin")
66 |
67 |
68 | if __name__ == "__main__":
69 | from jsonargparse import CLI
70 |
71 | CLI(prepare)
72 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | from setuptools import setup
4 |
5 | setup()
6 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import sys
4 | from pathlib import Path
5 |
6 | import pytest
7 |
8 | wd = Path(__file__).parent.parent.absolute()
9 |
10 |
11 | @pytest.fixture()
12 | def orig_llama():
13 | sys.path.append(str(wd))
14 |
15 | from scripts.download import download_original
16 |
17 | download_original(wd)
18 |
19 | import original_model
20 |
21 | return original_model
22 |
23 |
24 | @pytest.fixture()
25 | def orig_llama_adapter():
26 | sys.path.append(str(wd))
27 |
28 | from scripts.download import download_original
29 |
30 | download_original(wd)
31 |
32 | import original_adapter
33 |
34 | return original_adapter
35 |
36 |
37 | @pytest.fixture()
38 | def lit_llama():
39 | # this adds support for running tests without the package installed
40 | sys.path.append(str(wd))
41 |
42 | import lit_llama
43 |
44 | return lit_llama
45 |
--------------------------------------------------------------------------------
/tests/test_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | from dataclasses import asdict
4 | import pytest
5 | import sys
6 | import torch
7 |
8 |
9 | @pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.")
10 | @pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"])
11 | def test_config_identical(model_size, lit_llama):
12 | import lit_llama.adapter as llama_adapter
13 | import lit_llama.model as llama
14 | from lit_llama.utils import EmptyInitOnDevice
15 |
16 | llama_config = asdict(llama.LLaMAConfig.from_name(model_size))
17 | adapter_config = asdict(llama_adapter.LLaMAConfig.from_name(model_size))
18 |
19 | del adapter_config["adapter_prompt_length"]
20 | del adapter_config["adapter_start_layer"]
21 | assert adapter_config == llama_config
22 |
23 | with EmptyInitOnDevice():
24 | llama_model = llama.LLaMA.from_name(model_size)
25 | adapter_model = llama_adapter.LLaMA.from_name(model_size)
26 | assert llama_model.lm_head.weight.shape == adapter_model.lm_head.weight.shape
27 |
28 |
29 | def test_adapter_load_gating_factor(lit_llama):
30 | """Tests backward-compatible loading of checkpoints after the `gating_factor` was extended per-head
31 | in PR #297.
32 | """
33 | import lit_llama.adapter as llama_adapter
34 | from lit_llama.utils import lazy_load
35 |
36 | config = llama_adapter.LLaMAConfig(n_head=4, block_size=100, n_embd=16)
37 | attn = llama_adapter.CausalSelfAttention(config=config, block_idx=3)
38 |
39 | # Old checkpoint format
40 | state_dict={
41 | "gating_factor": torch.tensor(0.42), # in old checkpoints, this was a scalar
42 | "c_attn.weight": torch.zeros(3 * 16, 16),
43 | "c_proj.weight": torch.zeros(16, 16),
44 | "adapter_wte.weight": torch.zeros(10, 16),
45 | }
46 | attn.load_state_dict(state_dict=state_dict)
47 | assert torch.equal(attn.gating_factor, torch.full((1, 4, 1, 1), 0.42))
48 |
49 | # New checkpoint format
50 | state_dict={
51 | "gating_factor": torch.tensor([0.42, 0.42, 0.42, 0.42]).reshape(1, 4, 1, 1),
52 | "c_attn.weight": torch.zeros(3 * 16, 16),
53 | "c_proj.weight": torch.zeros(16, 16),
54 | "adapter_wte.weight": torch.zeros(10, 16),
55 | }
56 | attn.load_state_dict(state_dict=state_dict)
57 | assert torch.equal(attn.gating_factor, torch.full((1, 4, 1, 1), 0.42))
58 |
--------------------------------------------------------------------------------
/tests/test_adapter_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import pytest
4 | import sys
5 |
6 |
7 | @pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.")
8 | @pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"])
9 | def test_config_identical(model_size, lit_llama):
10 | import torch.nn as nn
11 | import lit_llama.adapter as llama_adapter
12 | from lit_llama.adapter_v2 import adapter_v2_linear_with_bias_and_scale
13 | import lit_llama.model as llama
14 | from lit_llama.utils import EmptyInitOnDevice
15 |
16 | with EmptyInitOnDevice():
17 | llama_model = llama.LLaMA.from_name(model_size)
18 | adapter_model = llama_adapter.LLaMA.from_name(model_size)
19 |
20 | for module in adapter_model.modules():
21 | if isinstance(module, nn.Linear):
22 | adapter_v2_linear_with_bias_and_scale(module)
23 |
24 | print(adapter_model.transformer.h[2].attn.c_attn.adapter_bias)
25 | assert not hasattr(llama_model.transformer.h[2].attn.c_attn, 'adapter_bias')
26 | assert not hasattr(llama_model.transformer.h[2].attn.c_attn, 'adapter_scale')
27 | assert hasattr(adapter_model.transformer.h[2].attn.c_attn, 'adapter_bias')
28 | assert hasattr(adapter_model.transformer.h[2].attn.c_attn, 'adapter_scale')
--------------------------------------------------------------------------------
/tests/test_generate.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import functools
4 | import subprocess
5 | import sys
6 | from contextlib import contextmanager, redirect_stdout
7 | from io import StringIO
8 | from pathlib import Path
9 | from unittest import mock
10 | from unittest.mock import Mock, call, ANY
11 |
12 | import torch
13 |
14 | wd = Path(__file__).parent.parent.absolute()
15 |
16 |
17 | @functools.lru_cache(maxsize=1)
18 | def load_generate_script():
19 | sys.path.append(str(wd))
20 |
21 | import generate as generate
22 |
23 | return generate
24 |
25 |
26 | def test_generate():
27 | generate = load_generate_script()
28 |
29 | from lit_llama.model import LLaMA, LLaMAConfig
30 |
31 | T, C = 5, 3
32 | logits = torch.randn(T, C)
33 | input_idx = torch.randint(10, size=(T,))
34 |
35 | config = LLaMAConfig(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)
36 | model = LLaMA(config)
37 | max_new_tokens = 20
38 |
39 | multinomial_results = []
40 | original_multinomial = torch.multinomial
41 |
42 | def multinomial(*args, **kwargs):
43 | out = original_multinomial(*args, **kwargs)
44 | multinomial_results.append(out)
45 | return out
46 |
47 | with mock.patch("torch.multinomial", multinomial):
48 | out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10, top_k=4)
49 |
50 | assert out.size(0) == T + max_new_tokens
51 | multinomial_results = torch.hstack(multinomial_results)
52 | expected = torch.cat((input_idx, multinomial_results))
53 | assert out.shape == expected.shape
54 | torch.testing.assert_close(out, expected)
55 |
56 |
57 | @mock.patch("torch.cuda.is_bf16_supported", return_value=False)
58 | def test_main(tmp_path, monkeypatch):
59 | generate = load_generate_script()
60 |
61 | checkpoint_path = tmp_path / "ckpt"
62 | checkpoint_path.touch()
63 | tokenizer_path = tmp_path / "tokenizer"
64 | tokenizer_path.touch()
65 |
66 | class FabricMock(Mock):
67 | @property
68 | def device(self):
69 | return torch.device("cpu")
70 |
71 | @contextmanager
72 | def init_module(self, empty_init):
73 | yield
74 |
75 | monkeypatch.setattr(generate.L, "Fabric", FabricMock)
76 | model_mock = Mock()
77 | monkeypatch.setattr(generate.LLaMA, "from_name", model_mock)
78 | lookup_mock = Mock(return_value="1T")
79 | monkeypatch.setattr(generate, "llama_model_lookup", lookup_mock)
80 | load_mock = Mock()
81 | load_mock.return_value = load_mock
82 | load_mock.__enter__ = Mock()
83 | load_mock.__exit__ = Mock()
84 | monkeypatch.setattr(generate.torch, "load", load_mock)
85 | monkeypatch.setattr(generate, "lazy_load", load_mock)
86 | tokenizer_mock = Mock()
87 | tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
88 | tokenizer_mock.return_value.decode.return_value = "foo bar baz"
89 | monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
90 | generate_mock = Mock()
91 | generate_mock.return_value = torch.tensor([[3, 2, 1]])
92 | monkeypatch.setattr(generate, "generate", generate_mock)
93 |
94 | num_samples = 2
95 | out = StringIO()
96 | with redirect_stdout(out):
97 | generate.main(
98 | checkpoint_path=checkpoint_path,
99 | tokenizer_path=tokenizer_path,
100 | temperature=2.0,
101 | top_k=2,
102 | num_samples=num_samples,
103 | )
104 |
105 | model_mock.assert_called_once_with("1T")
106 | load_mock.assert_called_once_with(checkpoint_path)
107 | tokenizer_mock.assert_called_once_with(tokenizer_path)
108 | assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
109 | assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
110 | assert generate_mock.mock_calls == [call(ANY, ANY, 50, temperature=2.0, top_k=2)] * num_samples
111 | # only the generated result is printed to stdout
112 | assert out.getvalue() == "foo bar baz\n" * num_samples
113 |
114 |
115 | def test_cli():
116 | cli_path = wd / "generate.py"
117 | output = subprocess.check_output([sys.executable, cli_path, "-h"])
118 | output = str(output.decode())
119 | assert "Generates text samples" in output
120 |
--------------------------------------------------------------------------------
/tests/test_lora.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import torch
4 |
5 |
6 | def test_lora_layer_replacement(lit_llama):
7 | from lit_llama.lora import lora, CausalSelfAttention as LoRACausalSelfAttention
8 | from lit_llama.model import LLaMA, LLaMAConfig
9 |
10 | config = LLaMAConfig()
11 | config.n_layer = 2
12 | config.n_head = 4
13 | config.n_embd = 8
14 | config.block_size = 8
15 | config.vocab_size = 8
16 |
17 | with lora(r=8, alpha=8, dropout=0.1):
18 | model = LLaMA(config)
19 |
20 | assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention)
21 | assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention)
22 |
23 |
24 | def test_lora_merge_unmerge(lit_llama):
25 | from lit_llama.lora import lora, mark_only_lora_as_trainable
26 | from lit_llama.model import LLaMA, LLaMAConfig
27 |
28 | config = LLaMAConfig(n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8)
29 |
30 | with lora(r=8, alpha=8, dropout=0.1):
31 | model = LLaMA(config)
32 |
33 | initial_weight = model.transformer.h[0].attn.c_attn.weight.clone()
34 | model.train()
35 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight)
36 |
37 | # perform an update to the LoRA weights
38 | mark_only_lora_as_trainable(model)
39 | optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
40 | model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64)).sum().backward()
41 | optimizer.step()
42 | optimizer.zero_grad()
43 | # the weight remains unchanged (only lora A and B change)
44 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, initial_weight)
45 |
46 | # 'merge' and then 'unmerge' should neutralize themselves
47 | weight_before = model.transformer.h[0].attn.c_attn.weight.clone()
48 | model.eval()
49 | assert not torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_before)
50 | model.train()
51 | # note: numerically, `W + (A * B) - (A * B) == W` does not hold exactly
52 | assert torch.allclose(model.transformer.h[0].attn.c_attn.weight, weight_before)
53 |
54 | # calling eval/train multiple times in a row should not merge/unmerge multiple times
55 | model.eval()
56 | assert model.transformer.h[0].attn.c_attn.merged
57 | weight_after = model.transformer.h[0].attn.c_attn.weight.clone()
58 | model.eval()
59 | model.eval()
60 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after)
61 | model.train()
62 | assert not model.transformer.h[0].attn.c_attn.merged
63 | weight_after = model.transformer.h[0].attn.c_attn.weight.clone()
64 | model.train()
65 | model.train()
66 | assert torch.equal(model.transformer.h[0].attn.c_attn.weight, weight_after)
67 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import torch
4 | import pytest
5 | import sys
6 |
7 |
8 | def copy_mlp(llama_mlp, orig_llama_mlp) -> None:
9 | orig_llama_mlp.w1.weight.copy_(llama_mlp.c_fc1.weight)
10 | orig_llama_mlp.w3.weight.copy_(llama_mlp.c_fc2.weight)
11 | orig_llama_mlp.w2.weight.copy_(llama_mlp.c_proj.weight)
12 |
13 |
14 | def copy_attention(llama_attn, orig_llama_attn) -> None:
15 | n_embd = llama_attn.c_attn.weight.shape[1]
16 | orig_llama_attn.wq.weight.copy_(llama_attn.c_attn.weight[:n_embd])
17 | orig_llama_attn.wk.weight.copy_(llama_attn.c_attn.weight[n_embd:-n_embd])
18 | orig_llama_attn.wv.weight.copy_(llama_attn.c_attn.weight[-n_embd:])
19 | orig_llama_attn.wo.weight.copy_(llama_attn.c_proj.weight)
20 |
21 |
22 | def copy_block(llama_block, orig_llama_block) -> None:
23 | orig_llama_block.attention_norm.weight.copy_(llama_block.rms_1.scale)
24 | copy_attention(llama_block.attn, orig_llama_block.attention)
25 | orig_llama_block.ffn_norm.weight.copy_(llama_block.rms_2.scale)
26 | copy_mlp(llama_block.mlp, orig_llama_block.feed_forward)
27 |
28 |
29 | def copy_weights(llama_model, orig_llama_model) -> None:
30 | orig_llama_model.tok_embeddings.weight.copy_(llama_model.transformer.wte.weight)
31 | for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers):
32 | copy_block(llama_block, orig_llama_block)
33 | orig_llama_model.norm.weight.copy_(llama_model.transformer.ln_f.scale)
34 | orig_llama_model.output.weight.copy_(llama_model.lm_head.weight)
35 |
36 |
37 | @torch.no_grad()
38 | @pytest.mark.parametrize("kv_cache", (False, True))
39 | def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None:
40 | block_size = 64
41 | vocab_size = 32000
42 | n_layer = 16
43 | n_head = 16
44 | n_embd = 32
45 | batch_size = 3
46 |
47 | llama_config = lit_llama.LLaMAConfig(
48 | block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd
49 | )
50 | orig_llama_config = orig_llama.ModelArgs(
51 | dim=n_embd,
52 | n_layers=n_layer,
53 | n_heads=n_head,
54 | vocab_size=vocab_size,
55 | norm_eps=1e-5,
56 | max_seq_len=block_size,
57 | max_batch_size=batch_size,
58 | )
59 |
60 | seq_len = orig_llama_config.max_seq_len
61 | token_sample = torch.randint(0, orig_llama_config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
62 |
63 | llama_model = lit_llama.LLaMA(llama_config)
64 | llama_model.apply(llama_model._init_weights)
65 | orig_llama_model = orig_llama.Transformer(orig_llama_config)
66 |
67 | copy_weights(llama_model, orig_llama_model)
68 |
69 | orig_llama_embed = orig_llama_model.tok_embeddings(token_sample)
70 | llama_embed = llama_model.transformer.wte(token_sample)
71 | assert torch.allclose(orig_llama_embed, llama_embed)
72 |
73 | llama_rope = llama_model.build_rope_cache(token_sample)
74 | llama_mask = llama_model.build_mask_cache(token_sample)
75 | orig_llama_mask = torch.full((1, 1, seq_len, seq_len), float("-inf"))
76 | orig_llama_mask = torch.triu(orig_llama_mask, diagonal=1)
77 | if kv_cache:
78 | orig_llama_block_out = orig_llama_model.layers[0](
79 | orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask
80 | )
81 | theirs_k_cache = orig_llama_model.layers[0].attention.cache_k
82 | theirs_v_cache = orig_llama_model.layers[0].attention.cache_v
83 | head_size = n_embd // n_head
84 | kv_cache_shape = (batch_size, n_head, block_size, head_size)
85 | ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape)
86 | (llama_block_out, ours_kv_cache) = llama_model.transformer.h[0](
87 | llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache
88 | )
89 | ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3)
90 | ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3)
91 | torch.testing.assert_close(ours_k_cache, theirs_k_cache)
92 | torch.testing.assert_close(ours_v_cache, theirs_v_cache)
93 | else:
94 | orig_llama_block_out = orig_llama_model.layers[0](
95 | orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask
96 | )
97 | (llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len)
98 | assert torch.allclose(orig_llama_block_out, llama_block_out)
99 |
100 | expected = orig_llama_model(token_sample, 0)
101 | out = llama_model(token_sample)
102 | assert torch.allclose(out, expected)
103 |
104 |
105 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
106 | @torch.no_grad()
107 | def test_bfloat16_llama_init(lit_llama, orig_llama) -> None:
108 | from lit_llama.utils import EmptyInitOnDevice
109 |
110 | block_size = 64
111 | vocab_size = 32000
112 | n_layer = 16
113 | n_head = 16
114 | n_embd = 32
115 |
116 | llama_config = lit_llama.LLaMAConfig(
117 | block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd
118 | )
119 | llama_model = lit_llama.LLaMA(llama_config)
120 | llama_model.apply(llama_model._init_weights)
121 |
122 | batch_size = 3
123 |
124 | token_sample = torch.randint(0, vocab_size, size=(batch_size, block_size), dtype=torch.int64)
125 |
126 | expected = llama_model(token_sample)
127 |
128 | with EmptyInitOnDevice(device="cuda", dtype=torch.bfloat16):
129 | llama_model2 = lit_llama.LLaMA(llama_config)
130 | llama_model2.load_state_dict(llama_model.state_dict(keep_vars=True))
131 |
132 | out = llama_model2(token_sample.cuda()).float().cpu()
133 | torch.testing.assert_close(out, expected, atol=5e-3, rtol=1e-3)
134 |
135 |
136 | def copy_adapter_weights(llama_model, orig_llama_model) -> None:
137 | # copy the gating parameter
138 | for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers):
139 | if hasattr(llama_block.attn, "gating_factor"):
140 | llama_block.attn.gating_factor.copy_(orig_llama_block.attention.gate)
141 |
142 | # In the original model, there is one embedding layer for all blocks combined
143 | orig_adapter_wte = orig_llama_model.adapter_query.weight.reshape(
144 | orig_llama_model.params.adapter_layer, orig_llama_model.params.adapter_len, orig_llama_model.params.dim
145 | )
146 |
147 | # In ours, the embedding layer is split across the individual attention layers
148 | index = 0
149 | for llama_block in llama_model.transformer.h:
150 | if hasattr(llama_block.attn, "adapter_wte"):
151 | llama_block.attn.adapter_wte.weight.copy_(orig_adapter_wte[index])
152 | index += 1
153 |
154 |
155 | def enable_gate(model):
156 | for name, param in model.named_parameters():
157 | if "gating_factor" in name or "gate" in name:
158 | param.fill_(1)
159 |
160 |
161 | @torch.no_grad()
162 | def test_adapter_parity(orig_llama_adapter):
163 | """Test parity between our implementation of LLaMA-Adapter and the reference code."""
164 | import lit_llama.adapter as lit_llama
165 |
166 | orig_llama = orig_llama_adapter
167 |
168 | block_size = 32
169 | vocab_size = 100
170 | n_layer = 2
171 | n_head = 4
172 | n_embd = 16
173 | adapter_prompt_length: int = 10
174 | adapter_start_layer: int = 0
175 |
176 | llama_config = lit_llama.LLaMAConfig(
177 | block_size=block_size,
178 | vocab_size=vocab_size,
179 | n_layer=n_layer,
180 | n_head=n_head,
181 | n_embd=n_embd,
182 | adapter_prompt_length=adapter_prompt_length,
183 | adapter_start_layer=adapter_start_layer,
184 | )
185 | orig_llama_config = orig_llama.ModelArgs(
186 | dim=n_embd,
187 | n_layers=n_layer,
188 | n_heads=n_head,
189 | vocab_size=vocab_size,
190 | norm_eps=1e-5,
191 | max_seq_len=block_size,
192 | adapter_len=adapter_prompt_length,
193 | adapter_layer=(n_layer - adapter_start_layer),
194 | )
195 |
196 | batch_size = 3
197 | token_sample = torch.randint(
198 | 0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64
199 | )
200 |
201 | llama_model = lit_llama.LLaMA(llama_config)
202 | llama_model.apply(llama_model._init_weights)
203 | orig_llama_model = orig_llama.Transformer(orig_llama_config)
204 |
205 | copy_weights(llama_model, orig_llama_model)
206 | copy_adapter_weights(llama_model, orig_llama_model)
207 |
208 | # make the gate non-zero, otherwise the adapter is disabled and the model
209 | # identical to regular LLaMA
210 | enable_gate(llama_model)
211 | enable_gate(orig_llama_model)
212 |
213 | expected = orig_llama_model(token_sample, 0)
214 | out = llama_model(token_sample)
215 | assert torch.allclose(out, expected, atol=1e-5, rtol=1e-5)
216 |
217 |
218 | @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="torch.compile not supported on this platform")
219 | def test_model_compile(lit_llama):
220 | llama_config = lit_llama.LLaMAConfig(block_size=8, vocab_size=8, n_layer=2, n_head=2, n_embd=4)
221 | model = lit_llama.LLaMA(llama_config)
222 | model.apply(model._init_weights)
223 |
224 | model = torch.compile(model)
225 |
226 | sample = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)
227 | for _ in range(3):
228 | _ = model(sample)
229 |
--------------------------------------------------------------------------------
/tests/test_packed_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import os
4 | from unittest.mock import MagicMock
5 | import requests
6 |
7 | from torch.utils.data import IterableDataset
8 |
9 |
10 | def train_tokenizer(destination_path):
11 | destination_path.mkdir(parents=True, exist_ok=True)
12 |
13 | # download the tiny shakespeare dataset
14 | input_file_path = destination_path / "input.txt"
15 | if not input_file_path.exists():
16 | data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
17 | with open(input_file_path, "w") as f:
18 | f.write(requests.get(data_url).text)
19 |
20 | from lit_llama import Tokenizer
21 | Tokenizer.train(
22 | input=input_file_path,
23 | destination=destination_path,
24 | vocab_size=100,
25 | )
26 |
27 | return destination_path / "tokenizer.model"
28 |
29 |
30 | def test_packed_dataset(tmp_path):
31 | tokenizer_path = train_tokenizer(tmp_path)
32 |
33 | from lit_llama import Tokenizer
34 | tokenizer = Tokenizer(tokenizer_path)
35 |
36 | texts = [
37 | "The moment of truth is upon us.",
38 | "Time to open the fridge."
39 | ]
40 |
41 | from lit_llama.packed_dataset import PackedDatasetBuilder, PackedDataset, HDR_SIZE
42 |
43 | block_size = 10
44 | n_blocks = 2
45 | chunk_size = block_size * n_blocks
46 |
47 | builder = PackedDatasetBuilder(
48 | outdir=tmp_path,
49 | prefix="packed_dataset",
50 | chunk_size=chunk_size,
51 | sep_token=tokenizer.bos_id,
52 | dtype="auto",
53 | vocab_size=100,
54 | )
55 |
56 | text_ids = []
57 |
58 | for text in texts:
59 | text_ids = tokenizer.encode(text)
60 | assert text_ids[0] == tokenizer.bos_id
61 | builder.add_array(text_ids)
62 |
63 | filenames = builder.filenames
64 |
65 | assert len(filenames) == 2
66 | assert os.path.basename(filenames[0]) == "packed_dataset_0000000000.bin"
67 | assert os.path.basename(filenames[1]) == "packed_dataset_0000000001.bin"
68 |
69 | import numpy as np
70 |
71 | ex_tokenized = [
72 | tokenizer.encode(text).numpy().astype(builder.dtype)
73 | for text in texts
74 | ]
75 | ex_tokenized = np.concatenate(ex_tokenized)
76 | ex_tokenized = ex_tokenized[:2 * chunk_size]
77 |
78 | for filename, el in zip(filenames, np.array_split(ex_tokenized, 2)):
79 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
80 | count = len(mmap) // np.dtype(builder.dtype).itemsize
81 | arr = np.frombuffer(
82 | mmap, dtype=builder.dtype, count=count, offset=0
83 | )
84 | where_bos = np.where(arr == tokenizer.bos_id)
85 | # we expect two BOS tokens, one per file
86 | assert len(where_bos) == 1
87 | assert np.array_equal(arr, el)
88 |
89 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, shuffle=False)
90 |
91 | ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size)
92 |
93 | for item, el in zip(dataset, ex_split):
94 | assert np.array_equal(item, el)
95 |
96 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345)
97 |
98 | for i, item in enumerate(dataset):
99 | block_idxs = iter(dataset)._block_idxs
100 | assert np.array_equal(item, ex_split[block_idxs[i]])
101 |
102 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345, wrap=True)
103 |
104 | for i, item in enumerate(dataset):
105 | if i > 24:
106 | break
107 |
108 | dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, seed=12345)
109 |
110 | for i, item in enumerate(dataset):
111 | block_idxs = iter(dataset)._block_idxs
112 | chunk_idx = i // n_blocks * n_blocks
113 | assert np.array_equal(item, ex_split[chunk_idx + block_idxs[i % n_blocks]])
114 |
115 | block_size_ = block_size // 2
116 | ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size_)
117 | dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size_, seed=12345)
118 |
119 | for i, item in enumerate(dataset):
120 | block_idxs = iter(dataset)._block_idxs
121 | assert np.array_equal(item, ex_split[block_idxs[i]])
122 |
123 | block_size_ = block_size // 3
124 | n_chunks = 2
125 | ex_chunks = np.split(ex_tokenized, n_chunks)
126 | n_splits = ex_tokenized.shape[0] // n_chunks // block_size_
127 | ex_splits = [np.split(el[:n_splits * block_size_], n_splits) for el in ex_chunks]
128 | ex_split = sum(ex_splits, [])
129 |
130 | dataset = PackedDataset(filenames=filenames, n_chunks=n_chunks, block_size=block_size_, seed=12345)
131 |
132 | for i, item in enumerate(dataset):
133 | block_idxs = iter(dataset)._block_idxs
134 | assert np.array_equal(item, ex_split[block_idxs[i]])
135 |
136 |
137 | class SimpleDataset(IterableDataset):
138 | def __init__(self, start, end):
139 | super().__init__()
140 | self._start = start
141 | self._end = end
142 |
143 | def __iter__(self):
144 | return iter(range(self._start, self._end))
145 |
146 |
147 | def test_combined_dataset(tmp_path):
148 | from lit_llama.packed_dataset import CombinedDataset
149 |
150 | dataset1 = SimpleDataset(0, 10)
151 | dataset2 = SimpleDataset(10, 20)
152 | dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345)
153 |
154 | res = [el for el in dataset]
155 | assert res == list(range(0, 10))
156 |
157 | dataset1 = SimpleDataset(0, 10)
158 | dataset2 = SimpleDataset(10, 20)
159 | dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345)
160 |
161 | res = [el for el in dataset]
162 | assert res == list(range(10, 20))
163 |
164 | dataset1 = SimpleDataset(0, 10)
165 | dataset2 = SimpleDataset(10, 20)
166 | dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
167 |
168 | res = [el for el in dataset]
169 | assert 9 in res or 19 in res
170 | if len(res) > 10:
171 | assert 0 in res and 10 in res
172 |
173 |
174 | def test_sharded_packed_dataset(monkeypatch):
175 | import lit_llama.packed_dataset
176 | from lit_llama.packed_dataset import PackedDataset
177 |
178 | dataset_iterator_mock = MagicMock()
179 | monkeypatch.setattr(lit_llama.packed_dataset, "PackedDatasetIterator", dataset_iterator_mock)
180 | filenames = [str(i) for i in range(10)]
181 |
182 | # world_size = 1, rank = 0
183 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2))
184 | assert dataset_iterator_mock.call_args[1]["filenames"] == filenames
185 | dataset_iterator_mock.reset_mock()
186 | # world_size = 2, rank = 0
187 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=0))
188 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "2", "4", "6", "8"]
189 | dataset_iterator_mock.reset_mock()
190 | # world_size = 2, rank = 1
191 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=1))
192 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "3", "5", "7", "9"]
193 | dataset_iterator_mock.reset_mock()
194 |
195 | # world_size = 3, rank = 0 (dataset size not cleanly divisible by world size)
196 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=0))
197 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "3", "6"]
198 | dataset_iterator_mock.reset_mock()
199 | # world_size = 3, rank = 1 (dataset size not cleanly divisible by world size)
200 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=1))
201 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "4", "7"]
202 | dataset_iterator_mock.reset_mock()
203 | # world_size = 3, rank = 2 (dataset size not cleanly divisible by world size)
204 | iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=2))
205 | assert dataset_iterator_mock.call_args[1]["filenames"] == ["2", "5", "8"]
206 |
--------------------------------------------------------------------------------
/tests/test_prepare_redpajama.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import json
4 | import os
5 | import subprocess
6 | import sys
7 | from pathlib import Path
8 | from unittest import mock
9 | from unittest.mock import Mock, call, ANY
10 |
11 | wd = (Path(__file__).parent.parent / "scripts").absolute()
12 |
13 | import requests
14 |
15 |
16 | def train_tokenizer(destination_path):
17 | destination_path.mkdir(parents=True, exist_ok=True)
18 |
19 | # download the tiny shakespeare dataset
20 | input_file_path = destination_path / "input.txt"
21 | if not input_file_path.exists():
22 | data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
23 | with open(input_file_path, "w") as f:
24 | f.write(requests.get(data_url).text)
25 |
26 | from lit_llama import Tokenizer
27 | Tokenizer.train(input=input_file_path, destination=destination_path, vocab_size=100)
28 |
29 | return destination_path / "tokenizer.model"
30 |
31 |
32 | def test_prepare_sample(tmp_path):
33 | sys.path.append(str(wd))
34 |
35 | tokenizer_path = train_tokenizer(tmp_path)
36 |
37 | sample_path = tmp_path / "sample"
38 | source_path = sample_path / "source"
39 | dest_path = sample_path / "dest"
40 |
41 | source_path.mkdir(parents=True, exist_ok=True)
42 |
43 | sample = {
44 | "meta": {"some": "info"},
45 | "text": "some text"
46 | }
47 |
48 | jsonl_sample = "\n".join([json.dumps(el) for el in [sample] * 2])
49 |
50 | import prepare_redpajama
51 |
52 | for filename in prepare_redpajama.filenames_sample:
53 | with open(source_path / filename, "w") as f:
54 | f.write(jsonl_sample)
55 |
56 | prepare_redpajama.prepare(source_path=source_path, tokenizer_path=tokenizer_path, destination_path=dest_path, sample=True)
57 |
58 | bin_files = [el.replace(".jsonl", "_0000000000.bin") for el in prepare_redpajama.filenames_sample]
59 |
60 | assert set(os.listdir(dest_path)) == set(bin_files)
61 |
62 | from lit_llama import Tokenizer
63 | from lit_llama.packed_dataset import PackedDataset
64 |
65 | tokenizer = Tokenizer(tokenizer_path)
66 |
67 | # artificially set block_size to fit the text
68 | block_size = len(tokenizer.encode("some text"))
69 |
70 | for filename in bin_files:
71 | filenames = [os.path.join(dest_path, filename)]
72 | dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, shuffle=False)
73 | dataset_iter = iter(dataset)
74 | assert tokenizer.decode(next(dataset_iter)) == "some text"
75 | assert tokenizer.decode(next(dataset_iter)) == "some text"
76 |
77 |
78 | def test_prepare_full(tmp_path):
79 | sys.path.append(str(wd))
80 |
81 | tokenizer_path = train_tokenizer(tmp_path)
82 |
83 | full_path = tmp_path / "full"
84 | source_path = full_path / "source"
85 | dest_path = full_path / "dest"
86 |
87 | source_path.mkdir(parents=True, exist_ok=True)
88 |
89 | sample = {
90 | "meta": {"some": "info"},
91 | "text": "some text"
92 | }
93 |
94 | jsonl_sample = "\n".join([json.dumps(el) for el in [sample] * 2])
95 |
96 | import prepare_redpajama
97 |
98 | arxiv_file = source_path / "arxiv" / "arxiv_0.jsonl"
99 | arxiv_file.parent.mkdir(parents=True, exist_ok=True)
100 | with open(arxiv_file, "w") as f:
101 | f.write(jsonl_sample)
102 |
103 | import zstandard as zstd
104 |
105 | cc_file = source_path / "common_crawl" / "cc_0.jsonl"
106 | cc_file.parent.mkdir(parents=True, exist_ok=True)
107 | with zstd.open(cc_file, "wt", encoding="utf-8") as f:
108 | f.write(jsonl_sample)
109 |
110 | filename_sets = {
111 | "arxiv": "arxiv/arxiv*",
112 | "common_crawl": "common_crawl/*",
113 | }
114 |
115 | with mock.patch("prepare_redpajama.filename_sets", filename_sets):
116 | prepare_redpajama.prepare(source_path=source_path, tokenizer_path=tokenizer_path, destination_path=dest_path, sample=False)
117 |
118 | all_names = prepare_redpajama.filename_sets.keys()
119 | bin_files = [el + "_0000000000.bin" for el in all_names]
120 |
121 | assert set(os.listdir(dest_path)) == set(bin_files)
122 |
123 | from lit_llama import Tokenizer
124 | from lit_llama.packed_dataset import PackedDataset
125 |
126 | tokenizer = Tokenizer(tokenizer_path)
127 |
128 | # artificially set block_size to fit the text
129 | block_size = len(tokenizer.encode("some text"))
130 |
131 | filenames = [os.path.join(dest_path, el) for el in bin_files]
132 |
133 | for filename in filenames:
134 | dataset = PackedDataset(filenames=[filename], n_chunks=1, block_size=block_size, shuffle=False)
135 | dataset_iter = iter(dataset)
136 | assert tokenizer.decode(next(dataset_iter)) == "some text"
137 | assert tokenizer.decode(next(dataset_iter)) == "some text"
138 |
139 |
140 | def test_cli():
141 | cli_path = wd / "prepare_redpajama.py"
142 | output = subprocess.check_output([sys.executable, cli_path, "-h"])
143 | output = str(output.decode())
144 | assert 'Prepare the "Red Pajama"' in output
145 |
--------------------------------------------------------------------------------
/tests/test_prepare_shakespeare.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import os
4 | import subprocess
5 | import sys
6 | from pathlib import Path
7 |
8 | wd = (Path(__file__).parent.parent / "scripts").absolute()
9 |
10 |
11 | def test_prepare(tmp_path):
12 | sys.path.append(str(wd))
13 |
14 | import prepare_shakespeare
15 |
16 | prepare_shakespeare.prepare(tmp_path)
17 |
18 | assert set(os.listdir(tmp_path)) == {"train.bin", "tokenizer.model", "tokenizer.vocab", "input.txt", "val.bin"}
19 |
20 |
21 | def test_cli():
22 | cli_path = wd / "prepare_shakespeare.py"
23 | output = subprocess.check_output([sys.executable, cli_path, "-h"])
24 | output = str(output.decode())
25 | assert 'Prepare the "Tiny Shakespeare"' in output
26 |
--------------------------------------------------------------------------------
/tests/test_rmsnorm.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import torch
4 |
5 |
6 | @torch.no_grad()
7 | def test_rmsnorm(lit_llama, orig_llama) -> None:
8 | block_size = 16
9 | vocab_size = 16
10 |
11 | sample = torch.rand(size=(2, block_size, vocab_size), dtype=torch.float32)
12 |
13 | eps = 1e-6
14 | orig_llama_rmsnorm = orig_llama.RMSNorm(vocab_size, eps=eps)(sample)
15 | llama_rmsnorm = lit_llama.RMSNorm(vocab_size, eps=eps)(sample)
16 |
17 | assert torch.allclose(orig_llama_rmsnorm, llama_rmsnorm)
18 |
--------------------------------------------------------------------------------
/tests/test_rope.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import torch
4 |
5 |
6 | @torch.no_grad()
7 | def test_rope(lit_llama, orig_llama) -> None:
8 | torch.manual_seed(1)
9 |
10 | bs, seq_len, n_head, n_embed = 1, 6, 2, 8
11 | x = torch.randint(0, 10000, size=(bs, seq_len, n_head, n_embed // n_head)).float()
12 |
13 | freqs_cis = orig_llama.precompute_freqs_cis(n_embed // n_head, seq_len)
14 | llama_rope_cache = lit_llama.build_rope_cache(seq_len, n_embed // n_head, dtype=x.dtype, device=x.device)
15 | torch.testing.assert_close(freqs_cis, torch.view_as_complex(llama_rope_cache))
16 |
17 | llama_x_rope = lit_llama.apply_rope(x, llama_rope_cache)
18 | orig_llama_x_rope, _ = orig_llama.apply_rotary_emb(x, x, freqs_cis)
19 | torch.testing.assert_close(llama_x_rope, orig_llama_x_rope)
20 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2 |
3 | import tempfile
4 | import pathlib
5 |
6 | import torch
7 |
8 |
9 | class ATensor(torch.Tensor):
10 | pass
11 |
12 |
13 | def test_lazy_load_basic(lit_llama):
14 | import lit_llama.utils
15 |
16 | with tempfile.TemporaryDirectory() as tmpdirname:
17 | m = torch.nn.Linear(5, 3)
18 | path = pathlib.Path(tmpdirname)
19 | fn = str(path / "test.pt")
20 | torch.save(m.state_dict(), fn)
21 | with lit_llama.utils.lazy_load(fn) as sd_lazy:
22 | assert "NotYetLoadedTensor" in str(next(iter(sd_lazy.values())))
23 | m2 = torch.nn.Linear(5, 3)
24 | m2.load_state_dict(sd_lazy)
25 |
26 | x = torch.randn(2, 5)
27 | actual = m2(x)
28 | expected = m(x)
29 | torch.testing.assert_close(actual, expected)
30 |
31 |
32 | def test_lazy_load_subclass(lit_llama):
33 | import lit_llama.utils
34 |
35 | with tempfile.TemporaryDirectory() as tmpdirname:
36 | path = pathlib.Path(tmpdirname)
37 | fn = str(path / "test.pt")
38 | t = torch.randn(2, 3)[:, 1:]
39 | sd = {
40 | 1: t,
41 | 2: torch.nn.Parameter(t),
42 | 3: torch.Tensor._make_subclass(ATensor, t),
43 | }
44 | torch.save(sd, fn)
45 | with lit_llama.utils.lazy_load(fn) as sd_lazy:
46 | for k in sd.keys():
47 | actual = sd_lazy[k]
48 | expected = sd[k]
49 | torch.testing.assert_close(actual._load_tensor(), expected)
50 |
51 |
52 | def test_incremental_write(tmp_path, lit_llama):
53 | import lit_llama.utils
54 |
55 | sd = {str(k): torch.randn(5, 10) for k in range(3)}
56 | sd_expected = {k: v.clone() for k, v in sd.items()}
57 | fn = str(tmp_path / "test.pt")
58 | with lit_llama.utils.incremental_save(fn) as f:
59 | sd["0"] = f.store_early(sd["0"])
60 | sd["2"] = f.store_early(sd["2"])
61 | f.save(sd)
62 | sd_actual = torch.load(fn)
63 | assert sd_actual.keys() == sd_expected.keys()
64 | for k, v_expected in sd_expected.items():
65 | v_actual = sd_actual[k]
66 | torch.testing.assert_close(v_expected, v_actual)
67 |
68 |
69 | def test_find_multiple(lit_llama):
70 | from lit_llama.utils import find_multiple
71 |
72 | assert find_multiple(17, 5) == 20
73 | assert find_multiple(30, 7) == 35
74 | assert find_multiple(10, 2) == 10
75 | assert find_multiple(5, 10) == 10
76 |
--------------------------------------------------------------------------------