├── .gitignore ├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE.txt ├── Makefile ├── README.md ├── __init__.py ├── base-schema.json ├── chat-schema.json ├── cog.yaml ├── config.py ├── examples └── alpaca │ ├── README.md │ ├── process_data.py │ └── replicate_alpaca_data.json ├── llama_recipes ├── LICENSE ├── __init__.py ├── configs │ ├── __init__.py │ ├── datasets.py │ ├── fsdp.py │ ├── peft.py │ └── training.py ├── ft_datasets │ ├── __init__.py │ ├── alpaca_dataset.py │ ├── completion_dataset.py │ ├── grammar_dataset │ │ ├── __init__.py │ │ ├── grammar_dataset.py │ │ └── grammar_dataset_process.ipynb │ ├── samsum_dataset.py │ └── utils.py ├── llama_finetuning.py ├── model_checkpointing │ ├── __init__.py │ └── checkpoint_handler.py ├── multi_node.slurm ├── policies │ ├── __init__.py │ ├── activation_checkpointing_functions.py │ ├── anyprecision_optimizer.py │ ├── mixed_precision.py │ └── wrapping.py ├── quickstart.ipynb ├── requirements.txt ├── scripts │ ├── markdown_link_check_config.json │ ├── spellcheck.sh │ └── spellcheck_conf │ │ ├── spellcheck.yaml │ │ └── wordlist.txt └── utils │ ├── __init__.py │ ├── config_utils.py │ ├── dataset_utils.py │ ├── fsdp_utils.py │ ├── memory_utils.py │ └── train_utils.py ├── mistral-schema.json ├── model_templates ├── .dockerignore └── config.py ├── models ├── dockerignore ├── llama-2-13b-chat-hf-mlc │ ├── .env │ └── config.py ├── llama-2-13b-chat │ ├── .dockerignore │ ├── .env │ └── config.py ├── llama-2-13b-mlc │ ├── .env │ └── config.py ├── llama-2-13b-transformers │ └── .env ├── llama-2-13b │ ├── .dockerignore │ ├── .env │ └── config.py ├── llama-2-70b-chat-hf-mlc │ ├── .env │ └── config.py ├── llama-2-70b-chat │ ├── .dockerignore │ ├── .env │ └── config.py ├── llama-2-70b-mlc │ ├── .env │ └── config.py ├── llama-2-70b │ ├── .dockerignore │ ├── .env │ ├── config.py │ └── model_artifacts │ │ └── tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer.model │ │ ├── tokenizer_checklist.chk │ │ └── tokenizer_config.json ├── llama-2-7b-chat-hf-mlc │ ├── .dockerignore │ ├── .env │ └── config.py ├── llama-2-7b-chat │ ├── .dockerignore │ ├── .env │ └── config.py ├── llama-2-7b-mlc │ ├── .dockerignore │ ├── .env │ └── config.py ├── llama-2-7b-transformers │ ├── .env │ ├── config.py │ └── model_artifacts │ │ └── tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer.model │ │ ├── tokenizer_checklist.chk │ │ └── tokenizer_config.json ├── llama-2-7b-vllm │ ├── .env │ └── config.py ├── llama-2-7b │ ├── .dockerignore │ ├── .env │ └── config.py ├── mistral-7b-instruct-v0.1-mlc │ ├── .env │ └── config.py └── mistral-7b-v0.1-mlc │ ├── .env │ └── config.py ├── notes └── new_model_notes.md ├── predict.py ├── pyproject.toml ├── requirements-dev.txt ├── scripts ├── .DS_Store ├── benchmark_token_latency.py ├── load_secrets.sh ├── test_fast_llama.py ├── test_load_unload_lora.py ├── train_multi_gpu.sh └── train_single_gpu.sh ├── src ├── __init__.py ├── config_utils.py ├── download.py ├── inference_engines │ ├── __init__.py │ ├── engine.py │ ├── exllama.py │ ├── mlc_engine.py │ ├── mlc_vllm_engine.py │ ├── transformers_engine.py │ ├── vllm_engine.py │ ├── vllm_exllama_engine.py │ └── vllm_transformers.py ├── more_utils.py └── utils.py ├── tests ├── __init__.py ├── assets │ └── llama_tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer.model │ │ ├── tokenizer_checklist.chk │ │ └── tokenizer_config.json ├── conftest.py ├── data │ └── 200_samples.jsonl ├── run_local_tests.sh ├── test_e2e.py ├── test_predict.py ├── test_predict_with_trained_weights.py ├── test_remote_predict.py ├── test_remote_train.py ├── test_train.py ├── test_train_predict.py ├── test_utils.py ├── timing.py └── unit_tests │ ├── test_completion_dataset.py │ └── test_utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | flan-t5** 3 | checkpoints/** 4 | tmp/** 5 | unconverted-weights 6 | unconverted-weights/ 7 | weights 8 | weights/ 9 | .DS_STORE 10 | *.safetensors 11 | .cog/ 12 | llama_weights/ 13 | .env 14 | exllama/ 15 | llama-recipes/ 16 | orig-llama-recipes/ 17 | vllm/ 18 | .pytest_cache 19 | .dockerignore 20 | *.egg-info/ 21 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "exllama"] 2 | path = exllama 3 | url = https://github.com/technillogue/exllama.git 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for taking the time to contribute to this project! 4 | 5 | ## Releases 6 | 7 | This section documents the process used internally at Replicate to deploy the many variant Llama models. 8 | 9 | Model variants live in the [models](models) directory, and deployment is managed by a [Makefile](Makefile). 10 | 11 | To release a new model: 12 | 13 | 1. Run `make select `, where model name corresponds to the name of a folder in the [models](models) directory, like `model-llama-2-7b`. This will copy stuff around and jigger the local state of the repo to say "use this model". 14 | 1. Run `make test-local` to test locally (assuming you're on a machine with a GPU). 15 | 1. Run `make stage test-stage ` to push to staging. If this passes, the model is ready to be promoted to production. 16 | 1. Run `REPLICATE_USER=replicate && make push test-prod `. This runs the same tests as staging. 17 | 18 | After releasing to production: 19 | 20 | 1. Search for old instances of the previous version's Docker image id in documentation and replace them with the new version. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: init 2 | .PHONY: select 3 | .PHONY: test-local 4 | .PHONY: push 5 | .PHONY: push-and-test 6 | .PHONY: clean 7 | 8 | # this is required to build sentencepiece for py3.11 9 | # requires cog > 0.9.0-beta1 10 | # get it at https://github.com/replicate/cog/releases/download/v0.9.0-beta1/cog_linux_x86_64 11 | export COG_EXPERIMENTAL_BUILD_STAGE_DEPS = apt update && apt install -yy cmake google-perftools 12 | export FAKE_COG_VERSION = 0.8.1 13 | 14 | CURRENT_DIR := $(shell basename $(PWD)) 15 | 16 | ifeq ($(findstring cog,$(CURRENT_DIR)),cog) 17 | IMAGE_NAME := $(CURRENT_DIR) 18 | else 19 | IMAGE_NAME := cog-$(CURRENT_DIR) 20 | endif 21 | 22 | REPLICATE_USER ?= replicate-internal 23 | 24 | model ?= $(SELECTED_MODEL) 25 | 26 | PROD_MODEL ?= $(model) 27 | 28 | ifeq ($(findstring chat,$(model)),chat) 29 | schema := chat-schema.json 30 | else ifeq ($(model),mistral-7b-instruct-v0.1-mlc) 31 | schema := mistral-schema.json 32 | else 33 | schema := base-schema.json 34 | endif 35 | 36 | base-schema.json: 37 | $(MAKE) select model=llama-2-7b-mlc 38 | cog run --use-cuda-base-image=false python3 -m cog.command.openapi_schema | jq > base-schema.json 39 | chat-schema.json: 40 | $(MAKE) select model=llama-2-7b-chat-hf-mlc 41 | cog run --use-cuda-base-image=false python3 -m cog.command.openapi_schema | jq > chat-schema.json 42 | mistral-schema.json: 43 | $(MAKE) select model=mistral-7b-instruct-v0.1-mlc 44 | cog run --use-cuda-base-image=false python3 -m cog.command.openapi_schema | jq > mistral-schema.json 45 | 46 | 47 | init: 48 | @if [ -z "$(model)" ]; then \ 49 | echo "Error: 'model' argument must be specified or 'MODEL_ENV' environment variable must be set. E.g., make select model=your_model_name or export MODEL_ENV=your_model_name"; \ 50 | exit 1; \ 51 | fi 52 | # Initialize directory for model 53 | mkdir -p models/$(model) 54 | cp -r model_templates/* models/$(model) 55 | if [ -e model_templates/.env ]; then cp model_templates/.env models/$(model) ; fi 56 | if [ -e model_templates/.dockerignore ]; then \ 57 | cp model_templates/.dockerignore models/$(model); \ 58 | else \ 59 | touch models/$(model)/.dockerignore; \ 60 | fi 61 | printf "\n# Generated by 'make init'\n" >> models/$(model)/.dockerignore 62 | printf "/models/*/\n" >> models/$(model)/.dockerignore 63 | printf "!/models/$(model)/\n" >> models/$(model)/.dockerignore 64 | printf "/models/$(model)/model_artifacts/**\n" >> models/$(model)/.dockerignore 65 | printf "!/models/$(model)/model_artifacts/tokenizer/\n" >> models/$(model)/.dockerignore 66 | 67 | mkdir -p models/$(model)/model_artifacts/tokenizer 68 | cp -r llama_weights/tokenizer/* models/$(model)/model_artifacts/tokenizer 69 | 70 | update: 71 | @if [ -z "$(model)" ]; then \ 72 | echo "Error: 'model' argument must be specified or 'MODEL_ENV' environment variable must be set. E.g., make select model=your_model_name or export MODEL_ENV=your_model_name"; \ 73 | exit 1; \ 74 | fi 75 | cp -r model_templates/* models/$(model) 76 | 77 | model_dir=models/$(model) 78 | 79 | select: 80 | @if [ -z "$(model)" ]; then \ 81 | echo "Error: 'model' argument must be specified or 'MODEL_ENV' environment variable must be set. E.g., make select model=your_model_name or export MODEL_ENV=your_model_name"; \ 82 | exit 1; \ 83 | fi 84 | # this approach makes copies 85 | # rsync -av --exclude 'model_artifacts/' models/$(model)/ . 86 | 87 | # this approach behaves the same way but makes symlinks 88 | # # if we also wanted to copy directory structure we could do this, but we only need one dir deep 89 | # rsync -av --exclude 'model_artifacts/' --include '*/' --exclude '*' $(model_dir)/ . 90 | # For symlinking files 91 | find $(model_dir) -type f ! -path "$(model_dir)/model_artifacts/*" -exec ln -sf {} . \; 92 | # For specific files like .env and .dockerignore, we link them if they exist 93 | [ -e $(model_dir)/.env ] && ln -sf $(model_dir)/.env .env || true 94 | # rm .dockerignore || true 95 | cp models/dockerignore .dockerignore 96 | echo "!$(model_dir)" >> .dockerignore 97 | # [ -e $(model_dir)/dockerignore ] && cat $(model_dir)/dockerignore > .dockerignore 98 | #cog build 99 | @echo "#########Selected model: $(model)########" 100 | 101 | clean: select 102 | if [ -e models/$(model)/model_artifacts/default_inference_weights]; then sudo rm -rf models/$(model)/model_artifacts/default_inference_weights; fi 103 | if [ -e models/$(model)/model_artifacts/training_weights]; then sudo rm -rf models/$(model)/model_artifacts/training_weights; fi 104 | if [ -e training_output.zip]; then sudo rm -rf training_output.zip; fi 105 | 106 | build-local: select 107 | cog build --openapi-schema=$(schema) --use-cuda-base-image=false --progress plain 108 | 109 | serve: select 110 | docker run \ 111 | -ti \ 112 | -p 5000:5000 \ 113 | --gpus=all \ 114 | -e COG_WEIGHTS=http://$(HOST_NAME):8000/training_output.zip \ 115 | -v `pwd`/training_output.zip:/src/local_weights.zip \ 116 | $(IMAGE_NAME) 117 | 118 | test-local-predict: build-local 119 | @if [ "$(verbose)" = "true" ]; then \ 120 | pytest ./tests/test_predict.py -s; \ 121 | else \ 122 | pytest ./tests/test_predict.py; \ 123 | fi 124 | 125 | test-local-train: build-local 126 | rm -rf training_output.zip 127 | @if [ "$(verbose)" = "true" ]; then \ 128 | pytest ./tests/test_train.py -s; \ 129 | else \ 130 | pytest ./tests/test_train.py; \ 131 | fi 132 | 133 | test-local-train-predict: build-local 134 | @if [ "$(verbose)" = "true" ]; then \ 135 | pytest ./tests/test_train_predict.py -s; \ 136 | else \ 137 | pytest ./tests/test_train_predict.py; \ 138 | fi 139 | 140 | test-local: select test-local-predict test-local-train test-local-train-predict 141 | 142 | stage: select 143 | @echo "Pushing $(model) to r8.im/$(REPLICATE_USER)/staging-$(model)..." 144 | cog push --openapi-schema=$(schema) --use-cuda-base-image=false --progress plain r8.im/$(REPLICATE_USER)/staging-$(model) 145 | 146 | test-stage-predict: 147 | @if [ "$(verbose)" = "true" ]; then \ 148 | pytest tests/test_remote_predict.py -s --model $(REPLICATE_USER)/staging-$(model); \ 149 | else \ 150 | pytest tests/test_remote_predict.py --model $(REPLICATE_USER)/staging-$(model); \ 151 | fi 152 | 153 | test-stage-train-predict: 154 | @if [ "$(verbose)" = "true" ]; then \ 155 | pytest tests/test_remote_train.py -s --model $(REPLICATE_USER)/staging-$(model); \ 156 | else \ 157 | pytest tests/test_remote_train.py --model $(REPLICATE_USER)/staging-$(model); \ 158 | fi 159 | 160 | test-stage: test-stage-predict test-stage-train-predict 161 | 162 | 163 | stage-and-test-models: 164 | $(foreach model, $(subst ,, $(models)), \ 165 | $(MAKE) select model=$(model); \ 166 | $(MAKE) stage model=$(model); \ 167 | $(MAKE) test-stage model=$(model); \ 168 | ) 169 | 170 | push: select 171 | cog push --openapi-schema=$(schema) --use-cuda-base-image=false --progress plain r8.im/$(REPLICATE_USER)/$(PROD_MODEL) 172 | 173 | test-prod-predict: 174 | @if [ "$(verbose)" = "true" ]; then \ 175 | pytest tests/test_remote_predict.py -s --model $(REPLICATE_USER)/$(PROD_MODEL); \ 176 | else \ 177 | pytest tests/test_remote_predict.py --model $(REPLICATE_USER)/$(PROD_MODEL); \ 178 | fi 179 | 180 | test-prod-train-predict: 181 | @if [ "$(verbose)" = "true" ]; then \ 182 | pytest tests/test_remote_train.py -s --model $(REPLICATE_USER)/$(PROD_MODEL); \ 183 | else \ 184 | pytest tests/test_remote_train.py --model $(REPLICATE_USER)/$(PROD_MODEL); \ 185 | fi 186 | 187 | test-prod: test-prod-predict test-prod-train-predict 188 | 189 | format: 190 | python3 -m ruff format . 191 | 192 | lint: 193 | python3 -m ruff . 194 | python3 -m ruff format --check . 195 | 196 | help: 197 | @echo "Available targets:\n\n" 198 | @echo "init: Create the model directory." 199 | @echo " e.g., \`make init dir=\`" 200 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/__init__.py -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | cuda: "11.8" 8 | 9 | # python version in the form '3.8' or '3.8.12' 10 | python_version: "3.11" 11 | 12 | # a list of packages in the format == 13 | python_packages: 14 | - "numpy==1.24.2" 15 | - "sentencepiece==0.1.99" 16 | - "jinja2==3.1.2" 17 | - "scipy==1.11.1" 18 | - "safetensors>=0.3.1" 19 | - "python-dotenv" 20 | - "fire" 21 | - "datasets" 22 | - "transformers==4.33.2" 23 | - "peft==0.4.0" 24 | - "accelerate" 25 | - "bitsandbytes" 26 | - "trl==0.5.0" 27 | - "aiohttp[speedups]" 28 | - "triton" # hm 29 | - "fastapi<0.99.0" 30 | # uncomment these when we go back to 12.1 31 | # - "https://r2.drysys.workers.dev/torch/torch-2.1.0-cp311-cp311-linux_x86_64.whl" 32 | # - "https://weights.replicate.delivery/default/wheels/vllm-0.2a0-cp311-cp311-linux_x86_64.whl" 33 | 34 | - "https://r2.drysys.workers.dev/torch/11.8/torch-2.1.0-cp311-cp311-linux_x86_64.whl" 35 | # This wheel can be built by running `TORCH_CUDA_ARCH_LIST="8.0;8.6" pip wheel .` in https://github.com/replicate/vllm-with-loras 36 | - "https://r2.drysys.workers.dev/vllm/11.8/vllm-0.2a0-cp311-cp311-linux_x86_64.whl" 37 | - "https://r2.drysys.workers.dev/xformers/11.8/xformers-0.0.23+b4c853d.d20231107-cp311-cp311-linux_x86_64.whl" 38 | 39 | - "--pre -f https://mlc.ai/wheels" 40 | - "mlc-chat-nightly-cu118" 41 | - "mlc-ai-nightly-cu118" 42 | # - "mlc-chat-nightly-cu121" 43 | # - "mlc-ai-nightly-cu121" 44 | run: 45 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.1.1/pget" && chmod +x /usr/local/bin/pget 46 | # since we can't do LD_LIBRARY_PATH=torch/lib, use this to make sure mlc can access the cuda libs bundled with torch 47 | - bash -c 'ln -s /usr/local/lib/python3.11/site-packages/torch/lib/lib{nv,cu}* /usr/lib' 48 | # predict.py defines how predictions are run on your model 49 | predict: "predict.py:Predictor" 50 | train: "train.py:train" 51 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | models/llama-2-7b-mlc/config.py -------------------------------------------------------------------------------- /examples/alpaca/README.md: -------------------------------------------------------------------------------- 1 | Example code for parsing the dataset needed to train [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca). 2 | 3 | This contains both a function, `process_data.py`, which shows how to transform the [given alpaca data](https://github.com/gururise/AlpacaDataCleaned) into the format expected by `cog train`. It also contains an example parsed dataset as a reference for that `{'prompt': ..., 'completion':...}` format. -------------------------------------------------------------------------------- /examples/alpaca/process_data.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | import json 3 | 4 | PROMPT_DICT = { 5 | "prompt_input": ( 6 | "Below is an instruction that describes a task, paired with an input that provides further context. " 7 | "Write a response that appropriately completes the request.\n\n" 8 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 9 | ), 10 | "prompt_no_input": ( 11 | "Below is an instruction that describes a task. " 12 | "Write a response that appropriately completes the request.\n\n" 13 | "### Instruction:\n{instruction}\n\n### Response:" 14 | ), 15 | } 16 | 17 | 18 | class Preprocessor: 19 | """Simple class to parse alpaca data into format expected by trainer. Run this offline to build your dataset.""" 20 | 21 | def __init__(self, tokenizer): 22 | self.prompt_dict = PROMPT_DICT 23 | self.tokenizer = tokenizer 24 | 25 | def batch_tokenize(self, texts): 26 | """Tokenizes text. Presently doesn't pad inputs, just returns input ids.""" 27 | tokenized = [ 28 | self.tokenizer( 29 | prompt, 30 | return_tensors="pt", 31 | padding="longest", 32 | ).input_ids 33 | for prompt in texts 34 | ] 35 | return tokenized 36 | 37 | def make_prompt(self, input_row): 38 | if len(input_row["input"]) > 1: 39 | return self.prompt_dict["prompt_input"].format_map(input_row) 40 | return self.prompt_dict["prompt_no_input"].format_map(input_row) 41 | 42 | def make_short_prompt(self, input_row): 43 | if len(input_row["input"]) > 1: 44 | return f"""{input_row['instruction']}\n{input_row['input']}""" 45 | return input_row["instruction"] 46 | 47 | def construct_dataset(self, input_data): 48 | prompts = [self.make_short_prompt(val) for val in input_data] 49 | return [ 50 | {"prompt": val[0], "completion": val[1]} 51 | for val in zip(prompts, [val["output"] for val in input_data]) 52 | ] 53 | 54 | 55 | if __name__ == "__main__": 56 | proc = Preprocessor(T5Tokenizer.from_pretrained("google/flan-t5-xl")) 57 | with open("alpaca_data.json", "r") as f: 58 | data = json.load(f) 59 | 60 | data_out = proc.construct_dataset(data) 61 | 62 | with open("short_alpaca_data.json", "w") as f: 63 | json.dump(data_out, f, indent=2) 64 | -------------------------------------------------------------------------------- /llama_recipes/LICENSE: -------------------------------------------------------------------------------- 1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT 2 | Llama 2 Version Release Date: July 18, 2023 3 | 4 | "Agreement" means the terms and conditions for use, reproduction, distribution and 5 | modification of the Llama Materials set forth herein. 6 | 7 | "Documentation" means the specifications, manuals and documentation 8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- 9 | libraries/llama-downloads/. 10 | 11 | "Licensee" or "you" means you, or your employer or any other person or entity (if 12 | you are entering into this Agreement on such person or entity's behalf), of the age 13 | required under applicable laws, rules or regulations to provide legal consent and that 14 | has legal authority to bind your employer or such other person or entity if you are 15 | entering in this Agreement on their behalf. 16 | 17 | "Llama 2" means the foundational large language models and software and 18 | algorithms, including machine-learning model code, trained model weights, 19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other 20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- 21 | libraries/llama-downloads/. 22 | 23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and 24 | Documentation (and any portion thereof) made available under this Agreement. 25 | 26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you 27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta 28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland). 29 | 30 | By clicking "I Accept" below or by using or distributing any portion or element of the 31 | Llama Materials, you agree to be bound by this Agreement. 32 | 33 | 1. License Rights and Redistribution. 34 | 35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non- 36 | transferable and royalty-free limited license under Meta's intellectual property or 37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce, 38 | distribute, copy, create derivative works of, and make modifications to the Llama 39 | Materials. 40 | 41 | b. Redistribution and Use. 42 | 43 | i. If you distribute or make the Llama Materials, or any derivative works 44 | thereof, available to a third party, you shall provide a copy of this Agreement to such 45 | third party. 46 | ii. If you receive Llama Materials, or any derivative works thereof, from 47 | a Licensee as part of an integrated end user product, then Section 2 of this 48 | Agreement will not apply to you. 49 | 50 | iii. You must retain in all copies of the Llama Materials that you 51 | distribute the following attribution notice within a "Notice" text file distributed as a 52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, 53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved." 54 | 55 | iv. Your use of the Llama Materials must comply with applicable laws 56 | and regulations (including trade compliance laws and regulations) and adhere to the 57 | Acceptable Use Policy for the Llama Materials (available at 58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into 59 | this Agreement. 60 | 61 | v. You will not use the Llama Materials or any output or results of the 62 | Llama Materials to improve any other large language model (excluding Llama 2 or 63 | derivative works thereof). 64 | 65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the 66 | monthly active users of the products or services made available by or for Licensee, 67 | or Licensee's affiliates, is greater than 700 million monthly active users in the 68 | preceding calendar month, you must request a license from Meta, which Meta may 69 | grant to you in its sole discretion, and you are not authorized to exercise any of the 70 | rights under this Agreement unless or until Meta otherwise expressly grants you 71 | such rights. 72 | 73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE 74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE 75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY 77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR 78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE 79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING 80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR 81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. 82 | 83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE 84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, 85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS 86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, 87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN 88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF 89 | ANY OF THE FOREGOING. 90 | 91 | 5. Intellectual Property. 92 | 93 | a. No trademark licenses are granted under this Agreement, and in 94 | connection with the Llama Materials, neither Meta nor Licensee may use any name 95 | or mark owned by or associated with the other or any of its affiliates, except as 96 | required for reasonable and customary use in describing and redistributing the 97 | Llama Materials. 98 | 99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or 100 | for Meta, with respect to any derivative works and modifications of the Llama 101 | Materials that are made by you, as between you and Meta, you are and will be the 102 | owner of such derivative works and modifications. 103 | 104 | c. If you institute litigation or other proceedings against Meta or any entity 105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama 106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing, 107 | constitutes infringement of intellectual property or other rights owned or licensable 108 | by you, then any licenses granted to you under this Agreement shall terminate as of 109 | the date such litigation or claim is filed or instituted. You will indemnify and hold 110 | harmless Meta from and against any claim by any third party arising out of or related 111 | to your use or distribution of the Llama Materials. 112 | 113 | 6. Term and Termination. The term of this Agreement will commence upon your 114 | acceptance of this Agreement or access to the Llama Materials and will continue in 115 | full force and effect until terminated in accordance with the terms and conditions 116 | herein. Meta may terminate this Agreement if you are in breach of any term or 117 | condition of this Agreement. Upon termination of this Agreement, you shall delete 118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the 119 | termination of this Agreement. 120 | 121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and 122 | construed under the laws of the State of California without regard to choice of law 123 | principles, and the UN Convention on Contracts for the International Sale of Goods 124 | does not apply to this Agreement. The courts of California shall have exclusive 125 | jurisdiction of any dispute arising out of this Agreement. 126 | -------------------------------------------------------------------------------- /llama_recipes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/llama_recipes/__init__.py -------------------------------------------------------------------------------- /llama_recipes/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .peft import ( 5 | lora_config, 6 | llama_adapter_config, 7 | prefix_config, 8 | qlora_config, 9 | bitsandbytes_config, 10 | ) 11 | from .fsdp import fsdp_config 12 | from .training import train_config 13 | -------------------------------------------------------------------------------- /llama_recipes/configs/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class samsum_dataset: 9 | dataset: str = "samsum_dataset" 10 | train_split: str = "train" 11 | test_split: str = "validation" 12 | input_length: int = 2048 13 | 14 | 15 | @dataclass 16 | class grammar_dataset: 17 | dataset: str = "grammar_dataset" 18 | train_split: str = "ft_datasets/grammar_dataset/gtrain_10k.csv" 19 | test_split: str = "ft_datasets/grammar_dataset/grammar_validation.csv" 20 | input_length: int = 2048 21 | 22 | 23 | @dataclass 24 | class alpaca_dataset: 25 | dataset: str = "alpaca_dataset" 26 | train_split: str = "train" 27 | test_split: str = "val" 28 | data_path: str = "ft_datasets/alpaca_data.json" 29 | 30 | 31 | @dataclass 32 | class completion: 33 | """ 34 | A generic class for completion format datasets. Format is expected 35 | to be JSONL like: 36 | ``` 37 | {"text": "..."} 38 | ``` 39 | or 40 | ``` 41 | {"text": "prompt ...", "completion": "..."} 42 | ``` 43 | """ 44 | 45 | dataset: str = "completion" 46 | train_split: str = "train" 47 | test_split: str = "val" 48 | data_path: str = None 49 | num_validation_samples: int = 100 50 | run_validation: bool = True 51 | validation_data_path: str = None 52 | pack_sequences: bool = True 53 | wrap_packed_sequences: bool = False 54 | chunk_size: int = 2048 55 | max_seq_length: int = 4096 56 | -------------------------------------------------------------------------------- /llama_recipes/configs/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | from torch.distributed.fsdp import ShardingStrategy 6 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 7 | 8 | 9 | @dataclass 10 | class fsdp_config: 11 | mixed_precision: bool = True 12 | use_fp16: bool = False 13 | sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD 14 | checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. 15 | fsdp_activation_checkpointing: bool = True 16 | pure_bf16: bool = False 17 | optimizer: str = "AdamW" 18 | -------------------------------------------------------------------------------- /llama_recipes/configs/peft.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from dataclasses import dataclass 5 | from typing import ClassVar, List 6 | import torch 7 | 8 | 9 | @dataclass 10 | class lora_config: 11 | r: int = 8 12 | lora_alpha: int = 16 13 | target_modules: ClassVar[List[str]] = ["q_proj", "v_proj"] 14 | bias = "none" 15 | task_type: str = "CAUSAL_LM" 16 | lora_dropout: float = 0.05 17 | inference_mode: bool = False 18 | 19 | 20 | @dataclass 21 | class llama_adapter_config: 22 | adapter_len: int = 10 23 | adapter_layers: int = 30 24 | task_type: str = "CAUSAL_LM" 25 | 26 | 27 | @dataclass 28 | class prefix_config: 29 | num_virtual_tokens: int = 30 30 | task_type: str = "CAUSAL_LM" 31 | 32 | 33 | @dataclass 34 | class bitsandbytes_config: 35 | load_in_4bit: bool = True 36 | bnb_4bit_quant_type: str = "nf4" 37 | bnb_4bit_use_double_quant: bool = True 38 | bnb_4bit_compute_dtype: torch.dtype = torch.bfloat16 39 | 40 | 41 | @dataclass 42 | class qlora_config: 43 | r: int = 8 44 | lora_alpha: int = 32 45 | target_modules: ClassVar[List[str]] = ["q_proj", "v_proj"] 46 | bias = "none" 47 | task_type: str = "CAUSAL_LM" 48 | lora_dropout: float = 0.05 49 | inference_mode: bool = False 50 | -------------------------------------------------------------------------------- /llama_recipes/configs/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | from dataclasses import dataclass 4 | 5 | 6 | @dataclass 7 | class train_config: 8 | model_name: str = "llama_weights/llama-2-7b" 9 | enable_fsdp: bool = False 10 | run_validation: bool = True 11 | batch_size_training: int = 4 12 | num_epochs: int = 3 13 | num_workers_dataloader: int = 1 14 | gradient_accumulation_steps: int = 1 15 | lr: float = 1e-4 16 | weight_decay: float = 0.0 17 | gamma: float = 0.85 18 | seed: int = 42 19 | use_fp16: bool = False 20 | mixed_precision: bool = True 21 | val_batch_size: int = 1 22 | dataset = "completion" 23 | peft_method: str = "lora" # None , llama_adapter, prefix 24 | use_peft: bool = False 25 | output_dir: str = "PATH/to/save/PEFT/model" 26 | freeze_layers: bool = False 27 | num_freeze_layers: int = 1 28 | quantization: bool = False 29 | one_gpu: bool = False 30 | save_model: bool = True 31 | dist_checkpoint_root_folder: str = ( 32 | "PATH/to/save/FSDP/model" # will be used if using FSDP 33 | ) 34 | dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP 35 | save_optimizer: bool = False # will be used if using FSDP 36 | data_path: str = None 37 | num_validation_samples: int = 100 38 | validation_data_path: str = None 39 | validation_prompt: str = None 40 | wrap_packed_sequences: bool = False 41 | pack_sequences: bool = True 42 | chunk_size: int = 2048 43 | 44 | # optim: Optional[str] = field( 45 | # default="paged_adamw_32bit", 46 | # metadata={"help": "The optimizer to use."}, 47 | # ) 48 | # lr_scheduler_type: str = field( 49 | # default="constant", 50 | # metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"}, 51 | # ) 52 | # max_steps: int = field(default=10000, metadata={"help": "How many optimizer update steps to take"}) 53 | # warmup_ratio 54 | 55 | # save_steps: int = field(default=100, metadata={"help": "Save checkpoint every X updates steps."}) 56 | # logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."}) 57 | # eval_steps: int = field(default=None, metadata={"help": "Run evaluation every X steps"}) 58 | # evaluation_strateg 59 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .grammar_dataset import get_dataset as get_grammar_dataset 5 | from .alpaca_dataset import InstructionDataset as get_alpaca_dataset 6 | from .samsum_dataset import get_preprocessed_samsum as get_samsum_dataset 7 | from .completion_dataset import get_completion_dataset as get_completion_dataset 8 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/alpaca_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html 5 | 6 | import copy 7 | import json 8 | import torch 9 | 10 | from torch.utils.data import Dataset 11 | 12 | PROMPT_DICT = { 13 | "prompt_input": ( 14 | "Below is an instruction that describes a task, paired with an input that provides further context. " 15 | "Write a response that appropriately completes the request.\n\n" 16 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 17 | ), 18 | "prompt_no_input": ( 19 | "Below is an instruction that describes a task. " 20 | "Write a response that appropriately completes the request.\n\n" 21 | "### Instruction:\n{instruction}\n\n### Response:" 22 | ), 23 | } 24 | 25 | 26 | class InstructionDataset(Dataset): 27 | def __init__(self, dataset_config, tokenizer, partition="train", max_words=30): 28 | self.ann = json.load(open(dataset_config.data_path)) 29 | if partition == "train": 30 | self.ann = self.ann 31 | else: 32 | self.ann = self.ann[:200] 33 | 34 | self.max_words = max_words 35 | # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model") 36 | self.tokenizer = tokenizer 37 | # self.tokenizer1 = tokenizer 38 | 39 | def __len__(self): 40 | return len(self.ann) 41 | 42 | def __getitem__(self, index): 43 | ann = self.ann[index] 44 | if ann.get("input", "") == "": 45 | prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) 46 | else: 47 | prompt = PROMPT_DICT["prompt_input"].format_map(ann) 48 | example = prompt + ann["output"] 49 | prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64) 50 | example = self.tokenizer.encode(example) 51 | example.append(self.tokenizer.eos_token_id) 52 | example = torch.tensor(example, dtype=torch.int64) 53 | padding = self.max_words - example.shape[0] 54 | if padding > 0: 55 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 56 | elif padding < 0: 57 | example = example[: self.max_words] 58 | labels = copy.deepcopy(example) 59 | labels[: len(prompt)] = -1 60 | example_mask = example.ge(0) 61 | label_mask = labels.ge(0) 62 | example[~example_mask] = 0 63 | labels[~label_mask] = 0 64 | example_mask = example_mask.float() 65 | label_mask = label_mask.float() 66 | 67 | return { 68 | "input_ids": example, 69 | "labels": labels, 70 | "attention_mask": example_mask, 71 | } 72 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/completion_dataset.py: -------------------------------------------------------------------------------- 1 | from .utils import Concatenator 2 | import json 3 | from datasets import Dataset 4 | 5 | 6 | def load_data( 7 | dataset_config, 8 | split, 9 | ): 10 | data_path = dataset_config.data_path 11 | num_validation_samples = int(dataset_config.num_validation_samples) 12 | run_validation = dataset_config.run_validation 13 | validation_data_path = dataset_config.validation_data_path 14 | 15 | def _load_data(path): 16 | data = [] 17 | with open(path, "r") as file: 18 | for line in file: 19 | data.append(json.loads(line)) 20 | 21 | dataset = Dataset.from_dict( 22 | {key: [item[key] for item in data] for key in data[0]}, 23 | ) 24 | 25 | return dataset 26 | 27 | if not validation_data_path: 28 | dataset = _load_data(data_path) 29 | 30 | if run_validation and split == "train": 31 | print( 32 | f"Selecting observations 0 through {len(dataset)-num_validation_samples} from data for training..." 33 | ) 34 | end_index = len(dataset) - num_validation_samples 35 | indices = list(range(end_index)) 36 | dataset = dataset.select(indices) 37 | 38 | elif run_validation and split == "val": 39 | print( 40 | f"Selecting observations {len(dataset)-num_validation_samples} through {len(dataset)} from data for validation..." 41 | ) 42 | start_index = len(dataset) - num_validation_samples 43 | indices = list(range(start_index, len(dataset))) 44 | dataset = dataset.select(indices) 45 | else: 46 | if split == "train": 47 | dataset = _load_data(data_path) 48 | elif split == "val": 49 | dataset = _load_data(validation_data_path) 50 | 51 | return dataset 52 | 53 | 54 | def format_data(dataset, tokenizer, config=None): 55 | def apply_text_template(sample): 56 | return {"text": sample["text"] + tokenizer.eos_token} 57 | 58 | def apply_prompt_template(sample): 59 | return { 60 | "text": sample["prompt"] + "\n" + sample["completion"] + tokenizer.eos_token 61 | } 62 | 63 | # Assume - all "text" or all "prompt/completion" 64 | if "text" in dataset[0]: 65 | dataset = dataset.map( 66 | apply_text_template, remove_columns=list(dataset.features) 67 | ) 68 | elif "prompt" in dataset[0] and "completion" in dataset[0]: 69 | dataset = dataset.map( 70 | apply_prompt_template, remove_columns=list(dataset.features) 71 | ) 72 | else: 73 | raise Exception( 74 | "Dataset did not contain `text` or `prompt` and `completion` inputs. Example row:", 75 | dataset[0], 76 | ) 77 | 78 | return dataset 79 | 80 | 81 | def tokenize_data(dataset, tokenizer, config=None): 82 | try: 83 | max_length = config.max_seq_length 84 | except: 85 | max_length = tokenizer.model_max_length 86 | 87 | dataset = dataset.map( 88 | lambda sample: tokenizer( 89 | sample["text"], max_length=max_length, truncation=True 90 | ), 91 | batched=True, 92 | remove_columns=list(dataset.features), 93 | ).map(lambda sample: {"labels": sample["input_ids"]}, batched=True) 94 | 95 | if config.pack_sequences: 96 | dataset = dataset.map( 97 | Concatenator( 98 | chunk_size=config.chunk_size, 99 | wrap_packed_sequences=config.wrap_packed_sequences, 100 | ), 101 | batched=True, 102 | ) 103 | 104 | return dataset 105 | 106 | 107 | def get_completion_dataset(config: str, tokenizer, split: str = "train"): 108 | dataset = load_data(config, split) 109 | dataset = format_data(dataset, tokenizer, config) 110 | dataset = tokenize_data(dataset, tokenizer, config) 111 | 112 | return dataset 113 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/grammar_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .grammar_dataset import get_dataset 5 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/grammar_dataset/grammar_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # For dataset details visit: https://huggingface.co/datasets/jfleg 5 | # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb 6 | 7 | 8 | 9 | from torch.utils.data import Dataset 10 | 11 | from datasets import load_dataset 12 | from pathlib import Path 13 | 14 | from ..utils import ConcatDataset 15 | 16 | 17 | class grammar(Dataset): 18 | def __init__( 19 | self, 20 | tokenizer, 21 | csv_name=None, 22 | ): 23 | try: 24 | self.dataset = load_dataset( 25 | "csv", 26 | data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"}, 27 | delimiter=",", 28 | ) 29 | except Exception as e: 30 | print( 31 | "Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset." 32 | ) 33 | raise e 34 | 35 | # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path) 36 | # if num_samples: 37 | # self.dataset = self.dataset.select(list(range(0, num_samples))) 38 | self.tokenizer = tokenizer 39 | self.print_text = False # print_text 40 | 41 | def __len__(self): 42 | return self.dataset["train"].shape[0] 43 | 44 | def convert_to_features(self, example_batch): 45 | # Create prompt and tokenize contexts and questions 46 | 47 | if self.print_text: 48 | print("Input Text: ", self.clean_text(example_batch["text"])) 49 | 50 | input_ = example_batch["input"] 51 | target_ = example_batch["target"] 52 | 53 | prompt = ( 54 | f"Correct this to standard English: {input_}\n---\nCorrected: {target_}" 55 | ) 56 | sample = self.tokenizer(prompt) 57 | 58 | return sample 59 | 60 | def __getitem__(self, index): 61 | sample = self.convert_to_features(self.dataset["train"][index]) 62 | source_ids = sample["input_ids"] 63 | 64 | src_mask = sample["attention_mask"] 65 | 66 | return { 67 | "input_ids": source_ids, 68 | "attention_mask": src_mask, 69 | "labels": source_ids.copy(), 70 | } 71 | 72 | 73 | def get_dataset(dataset_config, tokenizer, csv_name=None): 74 | """cover function for handling loading the working dataset""" 75 | """dataset loading""" 76 | if csv_name is None: 77 | currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv" 78 | print(f"Loading dataset {currPath}") 79 | csv_name = str(currPath) 80 | dataset = grammar( 81 | tokenizer=tokenizer, 82 | csv_name=csv_name, 83 | ) 84 | 85 | return ConcatDataset(dataset, chunk_size=dataset_config.input_length) 86 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/samsum_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # For dataset details visit: https://huggingface.co/datasets/samsum 5 | 6 | import datasets 7 | from .utils import Concatenator 8 | 9 | 10 | def get_preprocessed_samsum(dataset_config, tokenizer, split): 11 | dataset = datasets.load_dataset("samsum", split=split) 12 | 13 | prompt = ( 14 | "Summarize this dialog:\n{dialog}\n---\nSummary:\n{summary}{eos_token}" 15 | ) 16 | 17 | def apply_prompt_template(sample): 18 | return { 19 | "text": prompt.format( 20 | dialog=sample["dialogue"], 21 | summary=sample["summary"], 22 | eos_token=tokenizer.eos_token, 23 | ) 24 | } 25 | 26 | dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) 27 | 28 | dataset = dataset.map( 29 | lambda sample: tokenizer(sample["text"]), 30 | batched=True, 31 | remove_columns=list(dataset.features), 32 | ).map(Concatenator(), batched=True) 33 | return dataset 34 | -------------------------------------------------------------------------------- /llama_recipes/ft_datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from tqdm import tqdm 5 | from itertools import chain 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class Concatenator(object): 10 | def __init__(self, chunk_size=2048, wrap_packed_sequences=False): 11 | self.chunk_size = chunk_size 12 | self.residual = {"input_ids": [], "attention_mask": []} 13 | self.wrap_packed_sequences = wrap_packed_sequences 14 | 15 | def _wrap_concat(self, batch): 16 | """ 17 | When we pack samples into a single sequence, it's possible that the final 18 | sample's sequence will exceed `chunk_size`. In this case, the `_wrap_concat` 19 | method will wrap the final sample around to the beginning of the next sequence. 20 | This breaks the sample into two parts and may introduce samples that violate prompt formats. 21 | However, it allows us to strictly enforce chunk size. 22 | """ 23 | concatenated_samples = { 24 | k: v + list(chain(*batch[k])) for k, v in self.residual.items() 25 | } 26 | 27 | total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]]) 28 | 29 | if total_length >= self.chunk_size: 30 | chunk_num = total_length // self.chunk_size 31 | result = { 32 | k: [ 33 | v[i : i + self.chunk_size] 34 | for i in range(0, chunk_num * self.chunk_size, self.chunk_size) 35 | ] 36 | for k, v in concatenated_samples.items() 37 | } 38 | self.residual = { 39 | k: v[(chunk_num * self.chunk_size) :] 40 | for k, v in concatenated_samples.items() 41 | } 42 | else: 43 | result = concatenated_samples 44 | self.residual = {k: [] for k in concatenated_samples.keys()} 45 | 46 | # result["labels"] = result["input_ids"].copy() 47 | 48 | return result 49 | 50 | def _concat(self, batch): 51 | """ 52 | When we pack samples into a single sequence, it's possible that the final 53 | sample's sequence will exceed `chunk_size`. In this case, the `_concat` method 54 | will simply promote the final sample to the next sequence. This may introduce 55 | sequences with variable lengths, e.g. some that are below `chunk_size`, 56 | but it allows us to pack sequences while strictly respecting formatting. 57 | """ 58 | 59 | # Initialize current sequences from residual or empty if none exists 60 | keys = batch.keys() 61 | current_sequences = {key: self.residual.get(key, []) for key in keys} 62 | 63 | # # We'll store packed sequences in results 64 | results = {key: [] for key in keys} 65 | 66 | # len_of_new_seq = len(batch[list(batch.keys())[0]]) 67 | # len_of_current_seq = len(current_sequences[list(current_sequences.keys())[0]]) 68 | 69 | num_samples = len(batch[next(iter(keys))]) 70 | 71 | for idx in range(num_samples): 72 | # Check if adding next sample will exceed the chunk size for any key 73 | len_current_sequences = len(current_sequences[list(keys)[0]]) 74 | len_batch_sequence = len(batch[list(keys)[0]][idx]) 75 | 76 | will_exceed = len_current_sequences + len_batch_sequence > self.chunk_size 77 | 78 | if will_exceed: 79 | if len_current_sequences > 0: 80 | for key in keys: 81 | results[key].append(current_sequences[key]) 82 | current_sequences[key] = [] 83 | 84 | # After appending to results, extend current_sequences with the sample for all keys 85 | for key in keys: 86 | current_sequences[key].extend(batch[key][idx]) 87 | else: 88 | for key in keys: 89 | current_sequences[key].extend(batch[key][idx]) 90 | 91 | # Store unappended sequences as residual 92 | self.residual = current_sequences 93 | 94 | # results["labels"] = results["input_ids"].copy() 95 | 96 | return results 97 | 98 | def __call__(self, batch): 99 | if self.wrap_packed_sequences: 100 | return self._wrap_concat(batch) 101 | else: 102 | return self._concat(batch) 103 | 104 | 105 | class ConcatDataset(Dataset): 106 | def __init__(self, dataset, chunk_size=4096): 107 | self.dataset = dataset 108 | self.chunk_size = chunk_size 109 | 110 | self.samples = [] 111 | 112 | buffer = { 113 | "input_ids": [], 114 | "attention_mask": [], 115 | "labels": [], 116 | } 117 | 118 | for sample in tqdm(self.dataset, desc="Preprocessing dataset"): 119 | buffer = {k: v + sample[k] for k, v in buffer.items()} 120 | 121 | while len(next(iter(buffer.values()))) > self.chunk_size: 122 | self.samples.append( 123 | {k: v[: self.chunk_size] for k, v in buffer.items()} 124 | ) 125 | buffer = {k: v[self.chunk_size :] for k, v in buffer.items()} 126 | 127 | def __getitem__(self, idx): 128 | return self.samples[idx] 129 | 130 | def __len__(self): 131 | return len(self.samples) 132 | -------------------------------------------------------------------------------- /llama_recipes/model_checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .checkpoint_handler import ( 5 | load_model_checkpoint, 6 | save_model_checkpoint, 7 | load_optimizer_checkpoint, 8 | save_optimizer_checkpoint, 9 | save_model_and_optimizer_sharded, 10 | load_model_sharded, 11 | load_sharded_model_single_gpu, 12 | ) 13 | -------------------------------------------------------------------------------- /llama_recipes/multi_node.slurm: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | 5 | #!/bin/bash 6 | 7 | #SBATCH --job-name=Nano-2d-trainer-20b-8nodes 8 | 9 | #SBATCH --ntasks=2 10 | #SBATCH --nodes=2 11 | #SBATCH --gpus-per-task=4 12 | #SBATCH --partition=train 13 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 14 | nodes_array=($nodes) 15 | head_node=${nodes_array[0]} 16 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 17 | # Enable for A100 18 | export FI_PROVIDER="efa" 19 | 20 | echo Node IP: $head_node_ip 21 | export LOGLEVEL=INFO 22 | # debugging flags (optional) 23 | export NCCL_DEBUG=WARN 24 | export NCCL_DEBUG_SUBSYS=WARN 25 | export PYTHONFAULTHANDLER=1 26 | export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH 27 | export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH 28 | export CUDA_LAUNCH_BLOCKING=0 29 | 30 | # on your cluster you might need these: 31 | # set the network interface 32 | export NCCL_SOCKET_IFNAME="ens" 33 | export FI_EFA_USE_DEVICE_RDMA=1 34 | 35 | srun torchrun --nproc_per_node 4 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $head_node_ip:29500 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora 36 | 37 | -------------------------------------------------------------------------------- /llama_recipes/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .mixed_precision import * 5 | from .wrapping import * 6 | from .activation_checkpointing_functions import apply_fsdp_checkpointing 7 | from .anyprecision_optimizer import AnyPrecisionAdamW 8 | -------------------------------------------------------------------------------- /llama_recipes/policies/activation_checkpointing_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 5 | checkpoint_wrapper, 6 | CheckpointImpl, 7 | apply_activation_checkpointing, 8 | ) 9 | 10 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 11 | from functools import partial 12 | 13 | non_reentrant_wrapper = partial( 14 | checkpoint_wrapper, 15 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 16 | ) 17 | 18 | check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) 19 | 20 | 21 | def apply_fsdp_checkpointing(model): 22 | """apply activation checkpointing to model 23 | returns None as model is updated directly 24 | """ 25 | print("--> applying fsdp activation checkpointing...") 26 | 27 | apply_activation_checkpointing( 28 | model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 29 | ) 30 | -------------------------------------------------------------------------------- /llama_recipes/policies/anyprecision_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # AnyPrecisionAdamW: a flexible precision AdamW optimizer 5 | # with optional Kahan summation for high precision weight updates. 6 | # Allows direct control over momentum, variance and auxiliary compensation 7 | # buffer dtypes. 8 | # Optional Kahan summation is used to offset precision reduction for 9 | # the weight updates. This allows full training in BFloat16 (equal or 10 | # better than FP32 results in many cases) due to high precision weight upates. 11 | 12 | import torch 13 | from torch.optim.optimizer import Optimizer 14 | 15 | 16 | class AnyPrecisionAdamW(Optimizer): 17 | def __init__( 18 | self, 19 | params, 20 | lr=1e-3, 21 | betas=(0.9, 0.999), 22 | eps=1e-8, 23 | weight_decay=0.0, 24 | use_kahan_summation=False, 25 | momentum_dtype=torch.bfloat16, 26 | variance_dtype=torch.bfloat16, 27 | compensation_buffer_dtype=torch.bfloat16, 28 | ): 29 | """ 30 | Args: 31 | params (iterable): iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr (float, optional): learning rate (default: 1e-3) 34 | betas (Tuple[float, float], optional): coefficients used for computing 35 | running averages of gradient and its square (default: (0.9, 0.999)) 36 | eps (float, optional): term added to the denominator to improve 37 | numerical stability (default: 1e-8) 38 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 39 | 40 | # Any Precision specific 41 | use_kahan_summation = creates auxiliary buffer to ensure high precision 42 | model param updates (default: False) 43 | momentum_dtype = dtype for momentum (default: BFloat32) 44 | variance_dtype = dtype for uncentered variance (default: BFloat16) 45 | compensation_buffer_dtype = dtype for Kahan summation 46 | buffer (default: BFloat16) 47 | 48 | # Usage 49 | This optimizer implements optimizer states, and Kahan summation 50 | for high precision updates, all in user controlled dtypes. 51 | Defaults are variance in BF16, Momentum in FP32. 52 | This can be run in FSDP mixed precision, amp, or full precision, 53 | depending on what training pipeline you wish to work with. 54 | 55 | Setting to use_kahan_summation = False, and changing momentum and 56 | variance dtypes to FP32, reverts this to a standard AdamW optimizer. 57 | 58 | """ 59 | defaults = dict( 60 | lr=lr, 61 | betas=betas, 62 | eps=eps, 63 | weight_decay=weight_decay, 64 | use_kahan_summation=use_kahan_summation, 65 | momentum_dtype=momentum_dtype, 66 | variance_dtype=variance_dtype, 67 | compensation_buffer_dtype=compensation_buffer_dtype, 68 | ) 69 | 70 | super().__init__(params, defaults) 71 | 72 | @torch.no_grad() 73 | def step(self, closure=None): 74 | """Performs a single optimization step. 75 | Args: 76 | closure (callable, optional): A closure that reevaluates the model 77 | and returns the loss. 78 | """ 79 | 80 | if closure is not None: 81 | with torch.enable_grad(): 82 | # to fix linter, we do not keep the returned loss for use atm. 83 | closure() 84 | 85 | for group in self.param_groups: 86 | beta1, beta2 = group["betas"] 87 | lr = group["lr"] 88 | weight_decay = group["weight_decay"] 89 | eps = group["eps"] 90 | use_kahan_summation = group["use_kahan_summation"] 91 | 92 | momentum_dtype = group["momentum_dtype"] 93 | variance_dtype = group["variance_dtype"] 94 | compensation_buffer_dtype = group["compensation_buffer_dtype"] 95 | 96 | for p in group["params"]: 97 | if p.grad is None: 98 | continue 99 | 100 | if p.grad.is_sparse: 101 | raise RuntimeError( 102 | "AnyPrecisionAdamW does not support sparse gradients" 103 | ) 104 | 105 | state = self.state[p] 106 | 107 | # State initialization 108 | if len(state) == 0: 109 | state["step"] = torch.tensor(0.0) 110 | 111 | # momentum - EMA of gradient values 112 | state["exp_avg"] = torch.zeros_like( 113 | p, 114 | dtype=momentum_dtype, 115 | ) 116 | 117 | # variance uncentered - EMA of squared gradient values 118 | state["exp_avg_sq"] = torch.zeros_like( 119 | p, 120 | dtype=variance_dtype, 121 | ) 122 | 123 | # optional Kahan summation - accumulated error tracker 124 | if use_kahan_summation: 125 | state["compensation"] = torch.zeros_like( 126 | p, 127 | dtype=compensation_buffer_dtype, 128 | ) 129 | 130 | # main processing ------------------------- 131 | 132 | # update the steps for each param group update 133 | state["step"] += 1 134 | step = state["step"] 135 | 136 | exp_avg = state["exp_avg"] 137 | exp_avg_sq = state["exp_avg_sq"] 138 | 139 | grad = p.grad 140 | 141 | # weight decay, AdamW style 142 | if weight_decay: 143 | p.data.mul_(1 - lr * weight_decay) 144 | 145 | # update momentum 146 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 147 | 148 | # update uncentered variance 149 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 150 | 151 | # adjust using bias1 152 | bias_correction1 = 1 - beta1**step 153 | 154 | step_size = lr / bias_correction1 155 | 156 | # adjust using bias2 157 | denom_correction = (1 - beta2**step) ** 0.5 # avoids math import 158 | 159 | centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( 160 | eps, alpha=1 161 | ) 162 | 163 | # lr update to compensation 164 | if use_kahan_summation: 165 | compensation = state["compensation"] 166 | 167 | compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) 168 | 169 | # update weights with compensation (Kahan summation) 170 | # save error back to compensation for next iteration 171 | temp_buffer = p.detach().clone() 172 | p.data.add_(compensation) 173 | compensation.add_(temp_buffer.sub_(p.data)) 174 | 175 | else: 176 | # usual AdamW updates 177 | p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) 178 | -------------------------------------------------------------------------------- /llama_recipes/policies/mixed_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import torch 5 | 6 | from torch.distributed.fsdp import ( 7 | # FullyShardedDataParallel as FSDP, 8 | # CPUOffload, 9 | MixedPrecision, 10 | # BackwardPrefetch, 11 | # ShardingStrategy, 12 | ) 13 | 14 | # requires grad scaler in main loop 15 | fpSixteen = MixedPrecision( 16 | param_dtype=torch.float16, 17 | # Gradient communication precision. 18 | reduce_dtype=torch.float16, 19 | # Buffer precision. 20 | buffer_dtype=torch.float16, 21 | ) 22 | 23 | bfSixteen = MixedPrecision( 24 | param_dtype=torch.bfloat16, 25 | # Gradient communication precision. 26 | reduce_dtype=torch.bfloat16, 27 | # Buffer precision. 28 | buffer_dtype=torch.bfloat16, 29 | cast_forward_inputs=True, 30 | ) 31 | 32 | bfSixteen_mixed = MixedPrecision( 33 | param_dtype=torch.float32, 34 | reduce_dtype=torch.bfloat16, 35 | buffer_dtype=torch.bfloat16, 36 | ) 37 | 38 | fp32_policy = MixedPrecision( 39 | param_dtype=torch.float32, 40 | reduce_dtype=torch.float32, 41 | buffer_dtype=torch.float32, 42 | ) 43 | -------------------------------------------------------------------------------- /llama_recipes/policies/wrapping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | 5 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 6 | 7 | from torch.distributed.fsdp.wrap import ( 8 | transformer_auto_wrap_policy, 9 | size_based_auto_wrap_policy, 10 | ) 11 | 12 | import functools 13 | 14 | 15 | def get_size_policy(min_params=1e8): 16 | num_wrap_policy = functools.partial( 17 | size_based_auto_wrap_policy, min_num_params=min_params 18 | ) 19 | return num_wrap_policy 20 | 21 | 22 | def get_llama_wrapper(): 23 | """we register our main layer class and use the fsdp transformer wrapping policy 24 | ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers 25 | """ 26 | # ==== use new transformer wrapper 27 | 28 | llama_auto_wrap_policy = functools.partial( 29 | transformer_auto_wrap_policy, 30 | transformer_layer_cls={ 31 | LlamaDecoderLayer, 32 | }, 33 | ) 34 | 35 | return llama_auto_wrap_policy 36 | -------------------------------------------------------------------------------- /llama_recipes/requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | torch==2.0.1+cu118 3 | accelerate 4 | appdirs 5 | loralib 6 | bitsandbytes==0.39.1 7 | black 8 | black[jupyter] 9 | datasets 10 | fire 11 | git+https://github.com/huggingface/peft.git 12 | transformers>=4.31.0 13 | sentencepiece 14 | py7zr 15 | scipy 16 | 17 | -------------------------------------------------------------------------------- /llama_recipes/scripts/markdown_link_check_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "retryOn429": true, 3 | "retryCount": 5, 4 | "fallbackRetryDelay": "10s", 5 | "httpHeaders": [ 6 | { 7 | "urls": [ 8 | "https://docs.github.com/", 9 | "https://help.github.com/" 10 | ], 11 | "headers": { 12 | "Accept-Encoding": "zstd, br, gzip, deflate" 13 | } 14 | } 15 | ], 16 | "ignorePatterns": [ 17 | { 18 | "pattern": "^http(s)?://127.0.0.1.*" 19 | }, 20 | { 21 | "pattern": "^http(s)?://localhost.*" 22 | }, 23 | { 24 | "pattern": "https://www.intel.com/content/www/us/en/developer/articles/news/llama2.html" 25 | } 26 | ] 27 | } 28 | -------------------------------------------------------------------------------- /llama_recipes/scripts/spellcheck.sh: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 4 | # Source: https://github.com/pytorch/torchx/blob/main/scripts/spellcheck.sh 5 | set -ex 6 | sudo apt-get install aspell 7 | 8 | if [[ -z "$@" ]]; then 9 | sources=$(find -name '*.md') 10 | else 11 | sources=$@ 12 | fi 13 | 14 | sources_arg="" 15 | for src in $sources; do 16 | sources_arg="${sources_arg} -S $src" 17 | done 18 | 19 | if [ ! "$sources_arg" ]; then 20 | echo "No files to spellcheck" 21 | else 22 | pyspelling -c scripts/spellcheck_conf/spellcheck.yaml --name Markdown $sources_arg 23 | fi 24 | -------------------------------------------------------------------------------- /llama_recipes/scripts/spellcheck_conf/spellcheck.yaml: -------------------------------------------------------------------------------- 1 | matrix: 2 | - name: Markdown 3 | apsell: 4 | lang: en 5 | d: en_US 6 | dictionary: 7 | wordlists: 8 | - scripts/spellcheck_conf/wordlist.txt 9 | output: scripts/spellcheck_conf/wordlist.dic 10 | encoding: utf-8 11 | pipeline: 12 | - pyspelling.filters.context: 13 | context_visible_first: true 14 | delimiters: 15 | - open: '(?s)^ *(?P`{3,})[a-z0-9]*?$' 16 | close: '^(?P=open)$' 17 | - open: '' 18 | content: 'https?://[-a-zA-Z0-9.]+?\.[a-z]{2,6}[-?=&%.0-9a-zA-Z/_#]*' 19 | close: '' 20 | - pyspelling.filters.markdown: 21 | markdown_extensions: 22 | - markdown.extensions.extra: 23 | -------------------------------------------------------------------------------- /llama_recipes/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .memory_utils import MemoryTrace 5 | from .dataset_utils import * 6 | from .fsdp_utils import fsdp_auto_wrap_policy 7 | from .train_utils import * 8 | -------------------------------------------------------------------------------- /llama_recipes/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import inspect 5 | from dataclasses import fields 6 | from peft import ( 7 | LoraConfig, 8 | AdaptionPromptConfig, 9 | PrefixTuningConfig, 10 | ) 11 | 12 | from transformers import BitsAndBytesConfig 13 | 14 | import configs.datasets as datasets 15 | from configs import ( 16 | lora_config, 17 | llama_adapter_config, 18 | prefix_config, 19 | train_config, 20 | qlora_config, 21 | bitsandbytes_config, 22 | ) 23 | from .dataset_utils import DATASET_PREPROC 24 | 25 | 26 | def update_config(config, **kwargs): 27 | if isinstance(config, (tuple, list)): 28 | for c in config: 29 | update_config(c, **kwargs) 30 | else: 31 | for k, v in kwargs.items(): 32 | if hasattr(config, k): 33 | setattr(config, k, v) 34 | elif "." in k: 35 | # allow --some_config.some_param=True 36 | config_name, param_name = k.split(".") 37 | if type(config).__name__ == config_name: 38 | if hasattr(config, param_name): 39 | setattr(config, param_name, v) 40 | else: 41 | # In case of specialized config we can warm user 42 | print(f"Warning: {config_name} does not accept parameter: {k}") 43 | elif isinstance(config, train_config): 44 | print(f"Warning: unknown parameter {k}") 45 | 46 | 47 | def generate_peft_config(peft_method, kwargs): 48 | # Config mapping for train_config.peft_method to its corresponding config class 49 | config_mapping = { 50 | "lora": lora_config, 51 | "llama_adapter": llama_adapter_config, 52 | "prefix": prefix_config, 53 | "bitsandbytes_config": bitsandbytes_config, 54 | "qlora": qlora_config, 55 | # Add other mappings as needed 56 | } 57 | 58 | # Mapping from config class to its corresponding PEFT config 59 | peft_config_mapping = { 60 | lora_config: LoraConfig, 61 | llama_adapter_config: AdaptionPromptConfig, 62 | prefix_config: PrefixTuningConfig, 63 | bitsandbytes_config: BitsAndBytesConfig, 64 | qlora_config: LoraConfig, 65 | # Add other mappings as needed 66 | } 67 | 68 | # Step 2: Updated assertion 69 | assert peft_method in config_mapping.keys(), f"Peft config not found: {peft_method}" 70 | 71 | # Step 3: Fetch the correct configuration class based on train_config.peft_method 72 | config = config_mapping[peft_method] 73 | update_config(config, **kwargs) 74 | params = {k.name: getattr(config, k.name) for k in fields(config)} 75 | 76 | # Step 5: Fetch the correct PEFT config based on the configuration class 77 | peft_config_class = peft_config_mapping[config] 78 | peft_config = peft_config_class(**params) 79 | 80 | return peft_config 81 | 82 | 83 | # def generate_peft_config(train_config, kwargs): 84 | # configs = (lora_config, llama_adapter_config, prefix_config, qlora_config) 85 | # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) 86 | # names = tuple(c.__name__.rstrip("_config") for c in configs) 87 | 88 | # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" 89 | 90 | # config = configs[names.index(train_config.peft_method)] 91 | # update_config(config, **kwargs) 92 | # params = {k.name: getattr(config, k.name) for k in fields(config)} 93 | # peft_config = peft_configs[names.index(train_config.peft_method)](**params) 94 | 95 | # return peft_config 96 | 97 | 98 | def generate_dataset_config(train_config, kwargs): 99 | names = tuple(DATASET_PREPROC.keys()) 100 | 101 | assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" 102 | 103 | dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[ 104 | train_config.dataset 105 | ] 106 | update_config(dataset_config, **kwargs) 107 | 108 | return dataset_config 109 | -------------------------------------------------------------------------------- /llama_recipes/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import torch 5 | 6 | from functools import partial 7 | 8 | 9 | from ft_datasets import ( 10 | get_grammar_dataset, 11 | get_alpaca_dataset, 12 | get_samsum_dataset, 13 | get_completion_dataset, 14 | ) 15 | 16 | 17 | DATASET_PREPROC = { 18 | "alpaca_dataset": partial(get_alpaca_dataset, max_words=224), 19 | "grammar_dataset": get_grammar_dataset, 20 | "samsum_dataset": get_samsum_dataset, 21 | "completion": get_completion_dataset, 22 | } 23 | 24 | 25 | def get_preprocessed_dataset( 26 | tokenizer, dataset_config, split: str = "train" 27 | ) -> torch.utils.data.Dataset: 28 | if dataset_config.dataset not in DATASET_PREPROC: 29 | raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented") 30 | 31 | def get_split(): 32 | return ( 33 | dataset_config.train_split 34 | if split == "train" 35 | else dataset_config.test_split 36 | ) 37 | 38 | return DATASET_PREPROC[dataset_config.dataset]( 39 | dataset_config, 40 | tokenizer, 41 | get_split(), 42 | ) 43 | -------------------------------------------------------------------------------- /llama_recipes/utils/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | 5 | def fsdp_auto_wrap_policy(model, transformer_layer_name): 6 | import functools 7 | 8 | from torch.distributed.fsdp.wrap import ( 9 | _or_policy, 10 | lambda_auto_wrap_policy, 11 | transformer_auto_wrap_policy, 12 | ) 13 | 14 | from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder 15 | 16 | def lambda_policy_fn(module): 17 | if ( 18 | len(list(module.named_children())) == 0 19 | and getattr(module, "weight", None) is not None 20 | and module.weight.requires_grad 21 | ): 22 | return True 23 | return False 24 | 25 | lambda_policy = functools.partial( 26 | lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn 27 | ) 28 | transformer_wrap_policy = functools.partial( 29 | transformer_auto_wrap_policy, 30 | transformer_layer_cls=( 31 | PrefixEncoder, 32 | PromptEncoder, 33 | PromptEmbedding, 34 | transformer_layer_name, 35 | # FullyShardedDataParallelPlugin.get_module_class_from_name( 36 | # model, transformer_layer_name 37 | # ), 38 | ), 39 | ) 40 | 41 | auto_wrap_policy = functools.partial( 42 | _or_policy, policies=[lambda_policy, transformer_wrap_policy] 43 | ) 44 | return auto_wrap_policy 45 | -------------------------------------------------------------------------------- /llama_recipes/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | import gc 4 | import threading 5 | 6 | import psutil 7 | import torch 8 | 9 | 10 | def byte2gb(x): 11 | return int(x / 2**30) 12 | 13 | 14 | # This context manager is used to track the peak memory usage of the process 15 | class MemoryTrace: 16 | def __enter__(self): 17 | gc.collect() 18 | torch.cuda.empty_cache() 19 | torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero 20 | self.begin = byte2gb(torch.cuda.memory_allocated()) 21 | self.process = psutil.Process() 22 | self.cpu_begin = byte2gb(self.cpu_mem_used()) 23 | self.peak_monitoring = True 24 | peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) 25 | peak_monitor_thread.daemon = True 26 | peak_monitor_thread.start() 27 | return self 28 | 29 | def cpu_mem_used(self): 30 | """get resident set size memory for the current process""" 31 | return self.process.memory_info().rss 32 | 33 | def peak_monitor_func(self): 34 | self.cpu_peak = -1 35 | 36 | while True: 37 | self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) 38 | 39 | # can't sleep or will not catch the peak right (this comment is here on purpose) 40 | # time.sleep(0.001) # 1msec 41 | 42 | if not self.peak_monitoring: 43 | break 44 | 45 | def __exit__(self, *exc): 46 | self.peak_monitoring = False 47 | 48 | gc.collect() 49 | torch.cuda.empty_cache() 50 | self.end = byte2gb(torch.cuda.memory_allocated()) 51 | self.peak = byte2gb(torch.cuda.max_memory_allocated()) 52 | cuda_info = torch.cuda.memory_stats() 53 | self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) 54 | self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) 55 | self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) 56 | self.m_cuda_ooms = cuda_info.get("num_ooms", 0) 57 | self.used = byte2gb(self.end - self.begin) 58 | self.peaked = byte2gb(self.peak - self.begin) 59 | self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) 60 | 61 | self.cpu_end = self.cpu_mem_used() 62 | self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) 63 | self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) 64 | # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") 65 | -------------------------------------------------------------------------------- /model_templates/.dockerignore: -------------------------------------------------------------------------------- 1 | *pdf 2 | *docx 3 | flan-t5** 4 | checkpoints/** 5 | examples/** 6 | weights_13/** 7 | tmp/** 8 | **.jsonl 9 | unconverted-weights 10 | unconverted-weights/ 11 | weights 12 | weights/ 13 | llama_weights/ 14 | llama_weights 15 | */**/*.safetensors 16 | */**/*.tensors 17 | **/.git/lfs/objects/** 18 | *.tensors 19 | default_base_weights/ 20 | llama.tensors 21 | code 22 | tests 23 | **/*ipynb 24 | 25 | # generated by replicate/cog 26 | __pycache__ 27 | *.pyc 28 | *.pyo 29 | *.pyd 30 | .Python 31 | env 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | .tox 35 | .coverage 36 | .coverage.* 37 | .cache 38 | nosetests.xml 39 | coverage.xml 40 | *.cover 41 | *.log 42 | .git 43 | **/.git 44 | .mypy_cache 45 | **/.mypy_cache 46 | .pytest_cache 47 | .hypothesis 48 | -------------------------------------------------------------------------------- /model_templates/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.utils import get_env_var_or_default 3 | 4 | load_dotenv() 5 | 6 | MODEL_NAME = 7 | # INFERENCE CONFIGURATION 8 | ####################################################################### 9 | # --------------------Notes-------------------------------------------- 10 | # We are trying our very best to no longer have different inference code paths 11 | # for trained and untrained weights :) 12 | # 13 | # INFERENCE CONFIGURATION 14 | # ------------------------------- 15 | # This section defines the general inference configuration, 16 | # which is used for both trained and untrained models. 17 | # ------------------------------- 18 | 19 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/tokenizer" 20 | USE_SYSTEM_PROMPT = 21 | 22 | 23 | # ENGINE CONFIGURATION 24 | # ------------------------------- 25 | # Here we define the specific inference engine we intend to use for inference, and all appropriate kwargs. 26 | # ------------------------------- 27 | 28 | 29 | ENGINE = 30 | ENGINE_KWARGS = {} 31 | 32 | # DEFAULT INFERENCE CONFIGURATION 33 | # ------------------------------- 34 | # This section defines the default inference configuration, which may differ from 35 | # how we implement inference for a trained model. 36 | # ------------------------------- 37 | 38 | 39 | LOCAL_DEFAULT_INFERENCE_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 40 | 41 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH = get_env_var_or_default( 42 | "REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", 43 | "remote/path/to/your/weights/here", 44 | 45 | ) 46 | 47 | # N_SHARDS = 2 48 | # REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 49 | # f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 50 | # for i in range(N_SHARDS) 51 | # ] 52 | 53 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD = #["gptq_model-4bit-128g.safetensors"] 54 | 55 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD += [ 56 | "config.json", 57 | "generation_config.json", 58 | "special_tokens_map.json", 59 | "tokenizer_config.json", 60 | "tokenizer.json", 61 | "tokenizer.model", 62 | "quantize_config.json", 63 | ] 64 | 65 | # TRAINED INFERENCE CONFIGURATION 66 | # ------------------------------- 67 | # This section defines the inference configuration for fine-tuned models 68 | # ------------------------------- 69 | 70 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 71 | 72 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 73 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 74 | default_value="remote/path/to/your/weights/here" 75 | ) 76 | 77 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 78 | 79 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 80 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 81 | default_value="remote/path/to/your/weights/here" 82 | ) 83 | 84 | N_SHARDS = 2 85 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 86 | f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 87 | for i in range(N_SHARDS) 88 | ] 89 | 90 | REMOTE_TRAINING_FILES_TO_DOWNLOAD += [ 91 | "config.json", 92 | "generation_config.json", 93 | "model.safetensors.index.json", 94 | "special_tokens_map.json", 95 | "tokenizer_config.json", 96 | "tokenizer.json", 97 | "tokenizer.model", 98 | ] 99 | -------------------------------------------------------------------------------- /models/dockerignore: -------------------------------------------------------------------------------- 1 | *pdf 2 | *docx 3 | flan-t5** 4 | checkpoints/** 5 | examples/** 6 | weights_13/** 7 | tmp/** 8 | **.jsonl 9 | unconverted-weights 10 | unconverted-weights/ 11 | weights 12 | weights/ 13 | llama_weights/ 14 | llama_weights 15 | */**/*.safetensors 16 | */**/*.tensors 17 | **/.git/lfs/objects/** 18 | *.tensors 19 | default_base_weights/ 20 | llama.tensors 21 | code 22 | tests 23 | **/*ipynb 24 | .ruff/** 25 | .mypy_cache 26 | tests 27 | 28 | # generated by replicate/cog 29 | __pycache__ 30 | *.pyc 31 | *.pyo 32 | *.pyd 33 | .Python 34 | env 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | .tox 38 | .coverage 39 | .coverage.* 40 | .cache 41 | nosetests.xml 42 | coverage.xml 43 | *.cover 44 | *.log 45 | .git 46 | **/.git 47 | .mypy_cache 48 | **/.mypy_cache 49 | .pytest_cache 50 | .hypothesis 51 | 52 | models/*/ 53 | -------------------------------------------------------------------------------- /models/llama-2-13b-chat-hf-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat-hf-mlc 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-13b-chat-hf-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "llama-2-13b-chat-hf-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list( 18 | model_name="Llama-2-13b-chat-hf-q4f16_1", n_shards=163 19 | ) 20 | 21 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 22 | 23 | mlc_weights = Weights( 24 | local_path=LOCAL_PATH, 25 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 26 | remote_files=mlc_file_list, 27 | ) 28 | 29 | num_vllm_shards = 3 30 | vllm_weights = Weights( 31 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 32 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 33 | remote_files=get_fp16_file_list(num_vllm_shards), 34 | ) 35 | 36 | # Inference config 37 | USE_SYSTEM_PROMPT = True 38 | 39 | ENGINE = MLCvLLMEngine 40 | ENGINE_KWARGS = { 41 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False), 42 | "vllm_args": vllm_kwargs(vllm_weights), 43 | } 44 | 45 | # Training config 46 | LOAD_IN_4BIT = False 47 | 48 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 49 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 50 | "REMOTE_TRAINING_WEIGHTS_PATH", 51 | None, 52 | ) 53 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 54 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 55 | ) 56 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 57 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 58 | default_value=None, 59 | ) 60 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(num_vllm_shards) 61 | -------------------------------------------------------------------------------- /models/llama-2-13b-chat/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-13b-chat/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat-gptq/Llama-2-13B-chat-GPTQ 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-13b-chat/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/llama-2-13b-chat-gptq/Llama-2-13B-chat-GPTQ/LICENSE -------------------------------------------------------------------------------- /models/llama-2-13b-chat/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.utils import get_env_var_or_default 3 | 4 | load_dotenv() 5 | 6 | MODEL_NAME = "llama-2-13b-chat" 7 | # INFERENCE CONFIGURATION 8 | ####################################################################### 9 | # --------------------Notes-------------------------------------------- 10 | # We sometimes implement inference differently for models that have not 11 | # been trained/fine-tuned vs. those that have been trained/fine-tuned. We refer to the 12 | # former as "default" and the latter as "trained". Below, you can 13 | # set your "default inference configuration" and your "trained 14 | # inference configuration". 15 | # 16 | # GENERAL INFERENCE CONFIGURATION 17 | # ------------------------------- 18 | # This section defines the general inference configuration, 19 | # which is used for both trained and untrained models. 20 | # ------------------------------- 21 | 22 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/tokenizer" 23 | USE_SYSTEM_PROMPT = True 24 | 25 | 26 | # ENGINE CONFIGURATION 27 | # ------------------------------- 28 | # Here we define the specific inference engine we intend to use for inference, and all appropriate kwargs. 29 | # ------------------------------- 30 | 31 | from src.inference_engines.exllama import ExllamaEngine 32 | 33 | ENGINE = ExllamaEngine 34 | ENGINE_KWARGS = { 35 | "fused_attn": True, 36 | } 37 | 38 | # WEIGHTS CONFIGURATION 39 | # ------------------------------- 40 | # Which base weights do we use for inference with this model? 41 | # ------------------------------- 42 | 43 | 44 | LOCAL_DEFAULT_INFERENCE_WEIGHTS_PATH = ( 45 | f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 46 | ) 47 | 48 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH = get_env_var_or_default( 49 | "REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", 50 | "remote/path/to/your/weights/here", 51 | ) 52 | 53 | # N_SHARDS = 2 54 | # REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 55 | # f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 56 | # for i in range(N_SHARDS) 57 | # ] 58 | 59 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD = ["gptq_model-4bit-128g.safetensors"] 60 | 61 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD += [ 62 | "config.json", 63 | "generation_config.json", 64 | "special_tokens_map.json", 65 | "tokenizer_config.json", 66 | "tokenizer.json", 67 | "tokenizer.model", 68 | "quantize_config.json", 69 | ] 70 | 71 | # TRAINED INFERENCE CONFIGURATION 72 | # ------------------------------- 73 | # This section defines the inference configuration for fine-tuned models 74 | # ------------------------------- 75 | 76 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 77 | 78 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 79 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 80 | default_value="remote/path/to/your/weights/here", 81 | ) 82 | 83 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 84 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 85 | ) 86 | 87 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 88 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 89 | default_value="remote/path/to/your/weights/here", 90 | ) 91 | 92 | N_SHARDS = 3 93 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 94 | f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 95 | for i in range(N_SHARDS) 96 | ] 97 | 98 | REMOTE_TRAINING_FILES_TO_DOWNLOAD += [ 99 | "config.json", 100 | "generation_config.json", 101 | "model.safetensors.index.json", 102 | "special_tokens_map.json", 103 | "tokenizer_config.json", 104 | "tokenizer.json", 105 | "tokenizer.model", 106 | ] 107 | -------------------------------------------------------------------------------- /models/llama-2-13b-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-mlc-fp16 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-13b/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-13b-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "llama-2-13b-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list(model_name="llama-2-13b-hf-q0f16", n_shards=163) 18 | 19 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 20 | 21 | mlc_weights = Weights( 22 | local_path=LOCAL_PATH, 23 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 24 | remote_files=mlc_file_list, 25 | ) 26 | 27 | num_vllm_shards = 3 28 | vllm_weights = Weights( 29 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 30 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 31 | remote_files=get_fp16_file_list(num_vllm_shards), 32 | ) 33 | 34 | # Inference config 35 | USE_SYSTEM_PROMPT = False 36 | 37 | ENGINE = MLCvLLMEngine 38 | ENGINE_KWARGS = { 39 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False), 40 | "vllm_args": vllm_kwargs(vllm_weights), 41 | } 42 | 43 | # Training config 44 | LOAD_IN_4BIT = False 45 | 46 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 47 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 48 | "REMOTE_TRAINING_WEIGHTS_PATH", 49 | None, 50 | ) 51 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 52 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 53 | ) 54 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 55 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 56 | default_value=None, 57 | ) 58 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(num_vllm_shards) 59 | -------------------------------------------------------------------------------- /models/llama-2-13b-transformers/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b 2 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b 3 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-13b/config.json -------------------------------------------------------------------------------- /models/llama-2-13b/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-13b/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b-gptq 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-13b 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-13b/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/llama-2-13b-gptq/LICENSE.txt -------------------------------------------------------------------------------- /models/llama-2-13b/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.utils import get_env_var_or_default 3 | 4 | load_dotenv() 5 | 6 | MODEL_NAME = "llama-2-13b" 7 | # INFERENCE CONFIGURATION 8 | ####################################################################### 9 | # --------------------Notes-------------------------------------------- 10 | # We are trying our very best to no longer have different inference code paths 11 | # for trained and untrained weights :) 12 | 13 | # 14 | # INFERENCE CONFIGURATION 15 | # ------------------------------- 16 | # This section defines the general inference configuration, 17 | # which is used for both trained and untrained models. 18 | # ------------------------------- 19 | 20 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/tokenizer" 21 | USE_SYSTEM_PROMPT = False 22 | 23 | # ENGINE CONFIGURATION 24 | # ------------------------------- 25 | # Here we define the specific inference engine we intend to use for inference, and all appropriate kwargs. 26 | # ------------------------------- 27 | 28 | from src.inference_engines.exllama import ExllamaEngine 29 | 30 | ENGINE = ExllamaEngine 31 | ENGINE_KWARGS = { 32 | "fused_attn": True, 33 | } 34 | 35 | # WEIGHTS CONFIGURATION 36 | # ------------------------------- 37 | # Which base weights do we use for inference with this model? 38 | # ------------------------------- 39 | 40 | LOCAL_DEFAULT_INFERENCE_WEIGHTS_PATH = ( 41 | f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 42 | ) 43 | 44 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH = get_env_var_or_default( 45 | "REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", 46 | "remote/path/to/your/weights/here", 47 | ) 48 | 49 | # N_SHARDS = 2 50 | # REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 51 | # f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 52 | # for i in range(N_SHARDS) 53 | # ] 54 | 55 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD = ["gptq_model-4bit-32g.safetensors"] 56 | 57 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD += [ 58 | "config.json", 59 | "generation_config.json", 60 | "special_tokens_map.json", 61 | "tokenizer_config.json", 62 | "tokenizer.json", 63 | "tokenizer.model", 64 | "quantize_config.json", 65 | ] 66 | 67 | # TRAINED INFERENCE CONFIGURATION 68 | # ------------------------------- 69 | # This section defines the inference configuration for fine-tuned models 70 | # ------------------------------- 71 | 72 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 73 | 74 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 75 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 76 | default_value="remote/path/to/your/weights/here", 77 | ) 78 | 79 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 80 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 81 | ) 82 | 83 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 84 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 85 | default_value="remote/path/to/your/weights/here", 86 | ) 87 | 88 | N_SHARDS = 3 89 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 90 | f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 91 | for i in range(N_SHARDS) 92 | ] 93 | 94 | REMOTE_TRAINING_FILES_TO_DOWNLOAD += [ 95 | "config.json", 96 | "generation_config.json", 97 | "model.safetensors.index.json", 98 | "special_tokens_map.json", 99 | "tokenizer_config.json", 100 | "tokenizer.json", 101 | "tokenizer.model", 102 | ] 103 | -------------------------------------------------------------------------------- /models/llama-2-70b-chat-hf-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b-chat-hf-mlc 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b-chat 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b-chat 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-70b-chat/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-70b-chat-hf-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "llama-2-70b-chat-hf-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list( 18 | model_name="Llama-2-70b-chat-hf-q4f16_1", n_shards=483 19 | ) 20 | 21 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 22 | 23 | mlc_weights = Weights( 24 | local_path=LOCAL_PATH, 25 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 26 | remote_files=mlc_file_list, 27 | ) 28 | 29 | vllm_weights = Weights( 30 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 31 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 32 | remote_files=get_fp16_file_list(15), 33 | ) 34 | 35 | # Inference config 36 | USE_SYSTEM_PROMPT = True 37 | 38 | ENGINE = MLCvLLMEngine 39 | ENGINE_KWARGS = { 40 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False, num_shards=4), 41 | "vllm_args": vllm_kwargs(vllm_weights), 42 | } 43 | 44 | # Training config 45 | LOAD_IN_4BIT = False 46 | 47 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 48 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 49 | "REMOTE_TRAINING_WEIGHTS_PATH", 50 | None, 51 | ) 52 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 53 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 54 | ) 55 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 56 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 57 | default_value=None, 58 | ) 59 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(15) 60 | -------------------------------------------------------------------------------- /models/llama-2-70b-chat/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-70b-chat/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/Llama-2-70B-chat-GPTQ 2 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b-chat 3 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-70b-chat/config.json 4 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-70B-chat-GPTQ/LICENSE.txt -------------------------------------------------------------------------------- /models/llama-2-70b-chat/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.utils import get_env_var_or_default 3 | 4 | load_dotenv() 5 | 6 | MODEL_NAME = "llama-2-70b-chat" 7 | # INFERENCE CONFIGURATION 8 | ####################################################################### 9 | # --------------------Notes-------------------------------------------- 10 | # We are trying our very best to no longer have different inference code paths 11 | # for trained and untrained weights :) 12 | 13 | # 14 | # INFERENCE CONFIGURATION 15 | # ------------------------------- 16 | # This section defines the general inference configuration, 17 | # which is used for both trained and untrained models. 18 | # ------------------------------- 19 | 20 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/tokenizer" 21 | USE_SYSTEM_PROMPT = True 22 | 23 | 24 | # ENGINE CONFIGURATION 25 | # ------------------------------- 26 | # Here we define the specific inference engine we intend to use for inference, and all appropriate kwargs. 27 | # ------------------------------- 28 | 29 | from src.inference_engines.exllama import ExllamaEngine 30 | 31 | ENGINE = ExllamaEngine 32 | ENGINE_KWARGS = { 33 | "fused_attn": True, 34 | } 35 | 36 | # WEIGHTS CONFIGURATION 37 | # ------------------------------- 38 | # Which base weights do we use for inference with this model? 39 | # ------------------------------- 40 | 41 | 42 | LOCAL_DEFAULT_INFERENCE_WEIGHTS_PATH = ( 43 | f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 44 | ) 45 | 46 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH = get_env_var_or_default( 47 | "REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", 48 | "remote/path/to/your/weights/here", 49 | ) 50 | 51 | # N_SHARDS = 2 52 | # REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 53 | # f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 54 | # for i in range(N_SHARDS) 55 | # ] 56 | 57 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD = ["gptq_model-4bit--1g.safetensors"] 58 | 59 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD += [ 60 | "config.json", 61 | "generation_config.json", 62 | "special_tokens_map.json", 63 | "tokenizer_config.json", 64 | "tokenizer.json", 65 | "tokenizer.model", 66 | "quantize_config.json", 67 | ] 68 | 69 | # TRAINED INFERENCE CONFIGURATION 70 | # ------------------------------- 71 | # This section defines the inference configuration for fine-tuned models 72 | # ------------------------------- 73 | 74 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 75 | 76 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 77 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 78 | default_value="remote/path/to/your/weights/here", 79 | ) 80 | 81 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 82 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 83 | ) 84 | 85 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 86 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 87 | default_value="remote/path/to/your/weights/here", 88 | ) 89 | 90 | N_SHARDS = 15 91 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 92 | f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 93 | for i in range(N_SHARDS) 94 | ] 95 | 96 | REMOTE_TRAINING_FILES_TO_DOWNLOAD += [ 97 | "config.json", 98 | "generation_config.json", 99 | "model.safetensors.index.json", 100 | "special_tokens_map.json", 101 | "tokenizer_config.json", 102 | "tokenizer.json", 103 | "tokenizer.model", 104 | ] 105 | -------------------------------------------------------------------------------- /models/llama-2-70b-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b-mlc-fp16 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-70b/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-70b-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "llama-2-70b-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list(model_name="llama-2-70b-q0f16", n_shards=323) 18 | 19 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 20 | 21 | mlc_weights = Weights( 22 | local_path=LOCAL_PATH, 23 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 24 | remote_files=mlc_file_list, 25 | ) 26 | 27 | vllm_weights = Weights( 28 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 29 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 30 | remote_files=get_fp16_file_list(15), 31 | ) 32 | 33 | # Inference config 34 | USE_SYSTEM_PROMPT = False 35 | 36 | ENGINE = MLCvLLMEngine 37 | ENGINE_KWARGS = { 38 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False, num_shards=4), 39 | "vllm_args": vllm_kwargs(vllm_weights), 40 | } 41 | 42 | # Training config 43 | LOAD_IN_4BIT = False 44 | 45 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 46 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 47 | "REMOTE_TRAINING_WEIGHTS_PATH", 48 | None, 49 | ) 50 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 51 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 52 | ) 53 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 54 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 55 | default_value=None, 56 | ) 57 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(15) 58 | -------------------------------------------------------------------------------- /models/llama-2-70b/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-70b/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b-gptq 2 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-70b 3 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-70b/config.json 4 | LICENSE_URL=https://weights.replicate.delivery/default/llama-2-70b-gptq/LICENSE.txt -------------------------------------------------------------------------------- /models/llama-2-70b/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.utils import get_env_var_or_default 3 | 4 | load_dotenv() 5 | 6 | MODEL_NAME = "llama-2-70b" 7 | # INFERENCE CONFIGURATION 8 | ####################################################################### 9 | # --------------------Notes-------------------------------------------- 10 | # We sometimes implement inference differently for models that have not 11 | # been trained/fine-tuned vs. those that have been trained/fine-tuned. We refer to the 12 | # former as "default" and the latter as "trained". Below, you can 13 | # set your "default inference configuration" and your "trained 14 | # inference configuration". 15 | # 16 | # GENERAL INFERENCE CONFIGURATION 17 | # ------------------------------- 18 | # This section defines the general inference configuration, 19 | # which is used for both trained and untrained models. 20 | # ------------------------------- 21 | 22 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/tokenizer" 23 | USE_SYSTEM_PROMPT = False 24 | 25 | 26 | # ENGINE CONFIGURATION 27 | # ------------------------------- 28 | # Here we define the specific inference engine we intend to use for inference, and all appropriate kwargs. 29 | # ------------------------------- 30 | 31 | from src.inference_engines.exllama import ExllamaEngine 32 | 33 | ENGINE = ExllamaEngine 34 | ENGINE_KWARGS = { 35 | "fused_attn": True, 36 | } 37 | 38 | # WEIGHTS CONFIGURATION 39 | # ------------------------------- 40 | # Which base weights do we use for inference with this model? 41 | # ------------------------------- 42 | 43 | 44 | LOCAL_DEFAULT_INFERENCE_WEIGHTS_PATH = ( 45 | f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 46 | ) 47 | 48 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH = get_env_var_or_default( 49 | "REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", 50 | "remote/path/to/your/weights/here", 51 | ) 52 | 53 | # N_SHARDS = 2 54 | # REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 55 | # f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 56 | # for i in range(N_SHARDS) 57 | # ] 58 | 59 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD = ["gptq_model-4bit-32g.safetensors"] 60 | 61 | REMOTE_DEFAULT_INFERENCE_FILES_TO_DOWNLOAD += [ 62 | "config.json", 63 | "generation_config.json", 64 | "special_tokens_map.json", 65 | "tokenizer_config.json", 66 | "tokenizer.json", 67 | "tokenizer.model", 68 | "quantize_config.json", 69 | ] 70 | 71 | # TRAINED INFERENCE CONFIGURATION 72 | # ------------------------------- 73 | # This section defines the inference configuration for fine-tuned models 74 | # ------------------------------- 75 | 76 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 77 | 78 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 79 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 80 | default_value="remote/path/to/your/weights/here", 81 | ) 82 | 83 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 84 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 85 | ) 86 | 87 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 88 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 89 | default_value="remote/path/to/your/weights/here", 90 | ) 91 | 92 | N_SHARDS = 15 93 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = [ 94 | f"model-{str(i+1).zfill(5)}-of-{str(N_SHARDS).zfill(5)}.safetensors" 95 | for i in range(N_SHARDS) 96 | ] 97 | 98 | REMOTE_TRAINING_FILES_TO_DOWNLOAD += [ 99 | "config.json", 100 | "generation_config.json", 101 | "model.safetensors.index.json", 102 | "special_tokens_map.json", 103 | "tokenizer_config.json", 104 | "tokenizer.json", 105 | "tokenizer.model", 106 | ] 107 | -------------------------------------------------------------------------------- /models/llama-2-70b/model_artifacts/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /models/llama-2-70b/model_artifacts/tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/models/llama-2-70b/model_artifacts/tokenizer/tokenizer.model -------------------------------------------------------------------------------- /models/llama-2-70b/model_artifacts/tokenizer/tokenizer_checklist.chk: -------------------------------------------------------------------------------- 1 | eeec4125e9c7560836b4873b6f8e3025 tokenizer.model 2 | -------------------------------------------------------------------------------- /models/llama-2-70b/model_artifacts/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "model_max_length": 4096, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} -------------------------------------------------------------------------------- /models/llama-2-7b-chat-hf-mlc/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-7b-chat-hf-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat-hf-mlc 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-7b-chat-hf-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "llama-2-7b-chat-hf-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list(model_name="Llama-2-7b-chat-hf-q4f16_1", n_shards=115) 18 | 19 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 20 | 21 | mlc_weights = Weights( 22 | local_path=LOCAL_PATH, 23 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 24 | remote_files=mlc_file_list, 25 | ) 26 | 27 | vllm_weights = Weights( 28 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 29 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 30 | remote_files=get_fp16_file_list(2), 31 | ) 32 | 33 | 34 | # Inference config 35 | USE_SYSTEM_PROMPT = True 36 | 37 | ENGINE = MLCvLLMEngine 38 | ENGINE_KWARGS = { 39 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False), 40 | "vllm_args": vllm_kwargs(vllm_weights), 41 | } 42 | 43 | # Training config 44 | LOAD_IN_4BIT = False 45 | 46 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 47 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 48 | "REMOTE_TRAINING_WEIGHTS_PATH", 49 | None, 50 | ) 51 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 52 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 53 | ) 54 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 55 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 56 | default_value=None, 57 | ) 58 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 59 | -------------------------------------------------------------------------------- /models/llama-2-7b-chat/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-7b-chat/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/Llama-2-7b-Chat-GPTQ 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat/Llama-2-7b-chat 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat/Llama-2-7b-chat 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-7b-chat/Llama-2-7b-chat/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7b-Chat-GPTQ/LICENSE -------------------------------------------------------------------------------- /models/llama-2-7b-chat/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | exllama_kwargs, 4 | get_fp16_file_list, 5 | get_gptq_file_list, 6 | vllm_kwargs, 7 | Weights, 8 | ) 9 | from src.utils import get_env_var_or_default 10 | 11 | from src.inference_engines.vllm_exllama_engine import ExllamaVllmEngine 12 | 13 | load_dotenv() 14 | 15 | MODEL_NAME = "llama-2-7b-chat" 16 | 17 | # Inference weights 18 | 19 | exllama_weights = Weights( 20 | local_path=f"models/{MODEL_NAME}/model_artifacts/default_inference_weights", 21 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 22 | remote_files=get_gptq_file_list("gptq_model-4bit-32g.safetensors"), 23 | ) 24 | 25 | vllm_weights = Weights( 26 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 27 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 28 | remote_files=get_fp16_file_list(2), 29 | ) 30 | 31 | 32 | # Inference config 33 | 34 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 35 | USE_SYSTEM_PROMPT = True 36 | 37 | ENGINE = ExllamaVllmEngine 38 | exllama_kw = exllama_kwargs(exllama_weights) 39 | vllm_kw = vllm_kwargs(vllm_weights) 40 | 41 | ENGINE_KWARGS = { 42 | "exllama_args": exllama_kw, 43 | "vllm_args": vllm_kw, 44 | } 45 | 46 | # Training config 47 | 48 | LOAD_IN_4BIT = False 49 | 50 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 51 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 52 | "REMOTE_TRAINING_WEIGHTS_PATH", 53 | None, 54 | ) 55 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 56 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 57 | ) 58 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 59 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 60 | default_value=None, 61 | ) 62 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 63 | -------------------------------------------------------------------------------- /models/llama-2-7b-mlc/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-7b-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b-mlc-fp16 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-7b/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-7b-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "llama-2-7b-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list(model_name="llama-2-7b-hf-q0f16", n_shards=131) 18 | 19 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 20 | 21 | mlc_weights = Weights( 22 | local_path=LOCAL_PATH, 23 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 24 | remote_files=mlc_file_list, 25 | ) 26 | 27 | vllm_weights = Weights( 28 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 29 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 30 | remote_files=get_fp16_file_list(2), 31 | ) 32 | 33 | # Inference config 34 | USE_SYSTEM_PROMPT = False 35 | 36 | ENGINE = MLCvLLMEngine 37 | ENGINE_KWARGS = { 38 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False), 39 | "vllm_args": vllm_kwargs(vllm_weights), 40 | } 41 | 42 | # Training config 43 | LOAD_IN_4BIT = False 44 | 45 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 46 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 47 | "REMOTE_TRAINING_WEIGHTS_PATH", 48 | None, 49 | ) 50 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 51 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 52 | ) 53 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 54 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 55 | default_value=None, 56 | ) 57 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 58 | -------------------------------------------------------------------------------- /models/llama-2-7b-transformers/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 2 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 3 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-7b/config.json -------------------------------------------------------------------------------- /models/llama-2-7b-transformers/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import Weights, get_fp16_file_list 3 | from src.utils import get_env_var_or_default 4 | 5 | load_dotenv() 6 | 7 | MODEL_NAME = "llama-2-7b-transformers" 8 | # INFERENCE CONFIGURATION 9 | ####################################################################### 10 | # --------------------Notes-------------------------------------------- 11 | # We are trying our very best to no longer have different inference code paths 12 | # for trained and untrained weights :) 13 | # 14 | # INFERENCE CONFIGURATION 15 | # ------------------------------- 16 | # This section defines the general inference configuration, 17 | # which is used for both trained and untrained models. 18 | # ------------------------------- 19 | 20 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/tokenizer" 21 | USE_SYSTEM_PROMPT = False 22 | 23 | 24 | # ENGINE CONFIGURATION 25 | # ------------------------------- 26 | # Here we define the specific inference engine we intend to use for inference, and all appropriate kwargs. 27 | # ------------------------------- 28 | 29 | from src.inference_engines.transformers_engine import TransformersEngine 30 | 31 | # todo - this is probably wrong - now that different engines have different tokenizers, should we eliminate load_tokenizer & handle it all within the engine? I ...think so 32 | from functools import partial 33 | from src.more_utils import load_tokenizer 34 | 35 | weights = Weights( 36 | local_path=f"models/{MODEL_NAME}/model_artifacts/default_inference_weights", 37 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 38 | remote_files=get_fp16_file_list(2), 39 | ) 40 | 41 | ENGINE = TransformersEngine 42 | ENGINE_KWARGS = { 43 | "weights": weights, 44 | "tokenizer_func": partial(load_tokenizer, TOKENIZER_PATH), 45 | } 46 | 47 | 48 | # TRAINED INFERENCE CONFIGURATION 49 | # ------------------------------- 50 | # This section defines the inference configuration for fine-tuned models 51 | # ------------------------------- 52 | 53 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 54 | 55 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 56 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 57 | default_value="remote/path/to/your/weights/here", 58 | ) 59 | 60 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 61 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 62 | ) 63 | 64 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 65 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 66 | default_value="remote/path/to/your/weights/here", 67 | ) 68 | 69 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 70 | -------------------------------------------------------------------------------- /models/llama-2-7b-transformers/model_artifacts/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /models/llama-2-7b-transformers/model_artifacts/tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/models/llama-2-7b-transformers/model_artifacts/tokenizer/tokenizer.model -------------------------------------------------------------------------------- /models/llama-2-7b-transformers/model_artifacts/tokenizer/tokenizer_checklist.chk: -------------------------------------------------------------------------------- 1 | eeec4125e9c7560836b4873b6f8e3025 tokenizer.model 2 | -------------------------------------------------------------------------------- /models/llama-2-7b-transformers/model_artifacts/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "model_max_length": 4096, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} -------------------------------------------------------------------------------- /models/llama-2-7b-vllm/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-7b/config.json -------------------------------------------------------------------------------- /models/llama-2-7b-vllm/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import Weights, get_fp16_file_list, vllm_kwargs 3 | 4 | 5 | from src.utils import get_env_var_or_default 6 | 7 | load_dotenv() 8 | 9 | MODEL_NAME = "llama-2-7b-vllm" 10 | 11 | # Inference config 12 | 13 | weights = Weights( 14 | local_path=f"models/{MODEL_NAME}/model_artifacts/default_inference_weights", 15 | remote_path=get_env_var_or_default( 16 | "REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", 17 | "remote/path/to/your/weights/here", 18 | ), 19 | remote_files=get_fp16_file_list(2), 20 | ) 21 | 22 | LOAD_IN_4BIT = False 23 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 24 | USE_SYSTEM_PROMPT = False 25 | USE_EXLLAMA_FOR_UNTRAINED_WEIGHTS = False 26 | 27 | # Engine config 28 | 29 | from src.inference_engines.vllm_engine import vLLMEngine 30 | 31 | 32 | ENGINE = vLLMEngine 33 | ENGINE_KWARGS = vllm_kwargs(weights) 34 | 35 | 36 | # TRAINED INFERENCE CONFIGURATION 37 | # ------------------------------- 38 | # This section defines the inference configuration for fine-tuned models 39 | # ------------------------------- 40 | 41 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 42 | 43 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 44 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 45 | default_value="remote/path/to/your/weights/here", 46 | ) 47 | 48 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 49 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 50 | ) 51 | 52 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 53 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 54 | default_value="remote/path/to/your/weights/here", 55 | ) 56 | 57 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 58 | 59 | 60 | # ------------------------------- 61 | 62 | DEFAULT_PAD_TOKEN = "[PAD]" 63 | DEFAULT_EOS_TOKEN = "" 64 | DEFAULT_BOS_TOKEN = "" 65 | DEFAULT_UNK_TOKEN = "" 66 | -------------------------------------------------------------------------------- /models/llama-2-7b/.dockerignore: -------------------------------------------------------------------------------- 1 | ../dockerignore -------------------------------------------------------------------------------- /models/llama-2-7b/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/llama-2-7b 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/llama-2-7b/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/llama-2-7b/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | exllama_kwargs, 5 | get_fp16_file_list, 6 | get_gptq_file_list, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.vllm_exllama_engine import ExllamaVllmEngine 10 | 11 | from src.utils import get_env_var_or_default 12 | 13 | load_dotenv() 14 | 15 | MODEL_NAME = "llama-2-7b" 16 | 17 | 18 | # Inference weights 19 | 20 | exllama_weights = Weights( 21 | local_path=f"models/{MODEL_NAME}/model_artifacts/default_inference_weights", 22 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 23 | remote_files=get_gptq_file_list("gptq_model-4bit-128g.safetensors"), 24 | ) 25 | 26 | vllm_weights = Weights( 27 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 28 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 29 | remote_files=get_fp16_file_list(2), 30 | ) 31 | 32 | # Inference config 33 | 34 | TOKENIZER_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 35 | USE_SYSTEM_PROMPT = False 36 | 37 | ENGINE = ExllamaVllmEngine 38 | exllama_kw = exllama_kwargs(exllama_weights) 39 | vllm_kw = vllm_kwargs(vllm_weights) 40 | 41 | ENGINE_KWARGS = { 42 | "exllama_args": exllama_kw, 43 | "vllm_args": vllm_kw, 44 | } 45 | 46 | 47 | # Training config 48 | 49 | LOAD_IN_4BIT = False 50 | 51 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 52 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 53 | var_name="REMOTE_TRAINING_WEIGHTS_PATH", 54 | default_value="remote/path/to/your/weights/here", 55 | ) 56 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 57 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 58 | ) 59 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 60 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 61 | default_value="remote/path/to/your/weights/here", 62 | ) 63 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 64 | 65 | # ------------------------------- 66 | 67 | DEFAULT_PAD_TOKEN = "[PAD]" 68 | DEFAULT_EOS_TOKEN = "" 69 | DEFAULT_BOS_TOKEN = "" 70 | DEFAULT_UNK_TOKEN = "" 71 | -------------------------------------------------------------------------------- /models/mistral-7b-instruct-v0.1-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/mistral-7b-v0.1-instruct-mlc 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/mistral-7b-instruct-0.1 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/mistral-7b-instruct-0.1 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/mistral-7b-instruct-0.1/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/mistral-7b-instruct-v0.1-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "mistral-7b-instruct-v0.1-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list(model_name="Mistral-7B-Instruct-v0.1-q4f16_1", n_shards=107) 18 | 19 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 20 | 21 | mlc_weights = Weights( 22 | local_path=LOCAL_PATH, 23 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 24 | remote_files=mlc_file_list, 25 | ) 26 | 27 | vllm_weights = Weights( 28 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 29 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 30 | remote_files=get_fp16_file_list(2), 31 | ) 32 | 33 | # Inference config 34 | USE_SYSTEM_PROMPT = True 35 | 36 | # from mistral: "[INST] + Instruction [/INST] Model answer[INST] Follow-up instruction [/INST]" 37 | PROMPT_TEMPLATE = "[INST] {system_prompt}{prompt} [/INST]" 38 | DEFAULT_SYSTEM_PROMPT = "Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity. " 39 | 40 | ENGINE = MLCvLLMEngine 41 | ENGINE_KWARGS = { 42 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False), 43 | "vllm_args": vllm_kwargs(vllm_weights), 44 | } 45 | 46 | # Training config 47 | LOAD_IN_4BIT = False 48 | 49 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 50 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 51 | "REMOTE_TRAINING_WEIGHTS_PATH", 52 | None, 53 | ) 54 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 55 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 56 | ) 57 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 58 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 59 | default_value=None, 60 | ) 61 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 62 | -------------------------------------------------------------------------------- /models/mistral-7b-v0.1-mlc/.env: -------------------------------------------------------------------------------- 1 | REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/mistral-7b-v0.1-base-mlc 2 | REMOTE_VLLM_INFERENCE_WEIGHTS_PATH=https://weights.replicate.delivery/default/mistral-7b-0.1 3 | REMOTE_TRAINING_WEIGHTS_PATH=https://weights.replicate.delivery/default/mistral-7b-0.1 4 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH=https://weights.replicate.delivery/default/mistral-7b-0.1/config.json 5 | LICENSE_URL=https://weights.replicate.delivery/default/Llama-2-7B-GPTQ/LICENSE 6 | -------------------------------------------------------------------------------- /models/mistral-7b-v0.1-mlc/config.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from src.config_utils import ( 3 | Weights, 4 | get_fp16_file_list, 5 | get_mlc_file_list, 6 | mlc_kwargs, 7 | vllm_kwargs, 8 | ) 9 | from src.inference_engines.mlc_vllm_engine import MLCvLLMEngine 10 | from src.utils import get_env_var_or_default 11 | 12 | load_dotenv() 13 | 14 | MODEL_NAME = "mistral-7b-v0.1-mlc" 15 | 16 | # Inference weights 17 | mlc_file_list = get_mlc_file_list(model_name="Mistral-7B-v0.1-q4f16_1", n_shards=107) 18 | 19 | LOCAL_PATH = f"models/{MODEL_NAME}/model_artifacts/default_inference_weights" 20 | 21 | mlc_weights = Weights( 22 | local_path=LOCAL_PATH, 23 | remote_path=get_env_var_or_default("REMOTE_DEFAULT_INFERENCE_WEIGHTS_PATH", None), 24 | remote_files=mlc_file_list, 25 | ) 26 | 27 | vllm_weights = Weights( 28 | local_path=f"models/{MODEL_NAME}/model_artifacts/lora_inference_weights", 29 | remote_path=get_env_var_or_default("REMOTE_VLLM_INFERENCE_WEIGHTS_PATH", None), 30 | remote_files=get_fp16_file_list(2), 31 | ) 32 | 33 | # Inference config 34 | USE_SYSTEM_PROMPT = False 35 | 36 | ENGINE = MLCvLLMEngine 37 | ENGINE_KWARGS = { 38 | "mlc_args": mlc_kwargs(mlc_weights, is_chat=False), 39 | "vllm_args": vllm_kwargs(vllm_weights), 40 | } 41 | 42 | # Training config 43 | LOAD_IN_4BIT = False 44 | 45 | LOCAL_TRAINING_WEIGHTS_PATH = f"models/{MODEL_NAME}/model_artifacts/training_weights" 46 | REMOTE_TRAINING_WEIGHTS_PATH = get_env_var_or_default( 47 | "REMOTE_TRAINING_WEIGHTS_PATH", 48 | None, 49 | ) 50 | LOCAL_TRAINING_WEIGHTS_CONFIG_PATH = ( 51 | f"models/{MODEL_NAME}/model_artifacts/training_weights/config.json" 52 | ) 53 | REMOTE_TRAINING_WEIGHTS_CONFIG_PATH = get_env_var_or_default( 54 | var_name="REMOTE_TRAINING_WEIGHTS_CONFIG_PATH", 55 | default_value=None, 56 | ) 57 | REMOTE_TRAINING_FILES_TO_DOWNLOAD = get_fp16_file_list(2) 58 | -------------------------------------------------------------------------------- /notes/new_model_notes.md: -------------------------------------------------------------------------------- 1 | # `cog-llama-template` Model Management 2 | 3 | The `cog-llama-template` repo decomposes model management into four constructs: 4 | 5 | * **Templates.** We store templates in the `./model_templates/` directory. For our purposes, a template includes the following model specific artifacts: `cog.yaml`, `config.py`, `predict.py`. 6 | 7 | * **Models.** We store artifacts for initialized models in the `./models/` directory. These artifacts are copied from a template and then updated with model specific information. 8 | 9 | * **Shared code.** Models defined in `cog-llama-template` share code, e.g. implementations of training and inference methods. Shared code is maintained in the `./src/` directory. 10 | 11 | * **Active model.** To build, run, or push a specific model, it's artifacts must be copied from its associated `./models/` directory to the root of this project. We do this so that `./src/` code is available at build time. We refer to this copying process as model *selection*. 12 | 13 | To help users manage and interact with these constructs, we provide a `Makefile` with commands that streamline the model development process. Below, is a step-by-step demonstration of how you can use the `Makefile` to develop a model. 14 | 15 | **1. Initialize a new model.** 16 | 17 | You can initialize a new model by setting the environment variable `SELECTED_MODEL` to the name of the model you want to initialize. The name is arbitrary and there are no forced naming conventions, however our inhouse style is lowered dash-case. 18 | 19 | The `SELECTED_MODEL` environment variable will be referenced for all subsequent make commands. However, you can also specify the argument `name=` instead of setting an environment variable. 20 | 21 | Finally, `make init` will copy a model template from `model_templates` to `./models//`. 22 | 23 | ``` 24 | export SELECTED_MODEL=llama-2-70b-chat 25 | make init 26 | ``` 27 | 28 | **2. Update model details.** 29 | 30 | Currently, you need to manually update model details in `config.py`, as well as possibly in `predict.py`. Specifically, you need to provide variables for global config variables that determine inference logic and file's that should be downloaded. 31 | 32 | We assume that model artifacts are stored in an accessible and external location. During `setup` or training intialization, model artifacts specified in `config.py` will be downloaded. 33 | 34 | However, in some cases, it is preferable to not expose the locations of model artifacts in `config.py`. In such cases, you can store information in a `.env` file in your model's directory. At runtime, those environment variables will be initialized and their values will be used by `config.py`. 35 | 36 | For example, we store paths to model artifacts in `.env` and load this at runtime. 37 | 38 | **3. Select model.** 39 | 40 | To interact with a model, its artifacts need to be copied to root of `cog-llama-templates`. You can do this like: 41 | 42 | ```make select``` 43 | 44 | or 45 | 46 | ```make select model=``` 47 | 48 | This will copy the model artifacts to root and run `cog build`. 49 | 50 | **Local testing.** 51 | 52 | Our `Makefile` provides easy access to a rudimentary test suite that supports local and staged testing. 53 | 54 | Assuming you've set the `SELECTED_MODEL` environment variable, you can just call: 55 | 56 | `make test-local` 57 | 58 | Appending `verbose=true` will run tests with `-s` so that output will be printed. 59 | 60 | **Staging.** 61 | 62 | We also provide a staging workflow via `make stage` and `make test-stage-<...>`. To use the staging commands, you must specify your Replicate user account (we default to `replicate-internal`) and create a Replicate model in the specified account with the naming convention `staging-<$SELECTED_MODEL>`. Accordingly, if your selected model is `llama-2-7b`, you would create a model called `staging-llama-2-7b`. 63 | 64 | You also need to log in via cog login and set the `REPLICATE_API_TOKEN` environment variable to your accounts API token. 65 | 66 | Calling `make stage` will push the selected model to the associated staging model. Then you can call `make test-stage`. 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cog-llama-template" 7 | version = "0.0.0" 8 | optional-dependencies = { dev = ["ruff>=0.1.3"] } 9 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.11 3 | # by the following command: 4 | # 5 | # pip-compile --extra=dev --output-file=requirements-dev.txt --resolver=backtracking pyproject.toml 6 | # 7 | ruff==0.1.3 8 | # via cog-llama-template (pyproject.toml) 9 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/benchmark_token_latency.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import random 4 | import torch 5 | import argparse 6 | from abc import ABC, abstractmethod 7 | 8 | # Number of runs for each combination of model, prompt length, and output length. 9 | num_runs = 5 10 | 11 | 12 | class AbstractInferenceModel(ABC): 13 | @abstractmethod 14 | def __init__(self, model_name_or_path, tokenizer_name_or_path): 15 | self.model_name_or_path = model_name_or_path 16 | self.tokenizer_name_or_path = tokenizer_name_or_path 17 | self.model = self._load_model() 18 | self.tokenizer = self._load_tokenizer() 19 | 20 | @abstractmethod 21 | def _load_model(self): 22 | pass 23 | 24 | @abstractmethod 25 | def _load_tokenizer(self): 26 | pass 27 | 28 | @abstractmethod 29 | def generate_tokens(self, input_ids, prompt_length, output_length): 30 | pass 31 | 32 | 33 | class LlamaBnB4Bit(AbstractInferenceModel): 34 | def __init__(self, model_name_or_path, tokenizer_name_or_path, some_other_arg): 35 | super().__init__(model_name_or_path, tokenizer_name_or_path) 36 | 37 | def _load_model(self): 38 | from transformers import LlamaForCausalLM 39 | 40 | model = LlamaForCausalLM.from_pretrained( 41 | self.model_name_or_path, 42 | cache_dir="pretrained_weights", 43 | device_map={"": 0}, 44 | load_in_4bit=True, 45 | ) 46 | 47 | return model 48 | 49 | def _load_tokenizer(self): 50 | from transformers import LlamaTokenizer 51 | 52 | DEFAULT_PAD_TOKEN = "[PAD]" 53 | DEFAULT_EOS_TOKEN = "" 54 | DEFAULT_BOS_TOKEN = "" 55 | DEFAULT_UNK_TOKEN = "" 56 | 57 | tok = LlamaTokenizer.from_pretrained(self.tokenizer_name_or_path, legacy=False) 58 | tok.add_special_tokens( 59 | { 60 | "eos_token": DEFAULT_EOS_TOKEN, 61 | "bos_token": DEFAULT_BOS_TOKEN, 62 | "unk_token": DEFAULT_UNK_TOKEN, 63 | "pad_token": DEFAULT_PAD_TOKEN, 64 | } 65 | ) 66 | return tok 67 | 68 | def generate_tokens(self, input_ids, prompt_length, output_length): 69 | generated = self.model.generate( 70 | input_ids, max_length=prompt_length + output_length, do_sample=False 71 | ) 72 | return generated 73 | 74 | 75 | def measure_latency(inference_model, prompt_length, output_length): 76 | # Generate a random prompt 77 | prompt = " ".join([random.choice("a") for _ in range(prompt_length)]) 78 | 79 | # Tokenize the prompt 80 | input_ids = inference_model.tokenizer.encode(prompt, return_tensors="pt") 81 | 82 | # Set the random seed for reproducibility 83 | torch.manual_seed(0) 84 | 85 | # Maximum number of attempts to generate the correct number of tokens. 86 | max_attempts = 10 87 | 88 | # Generate response and ensure the response length is as expected 89 | for _ in range(max_attempts): 90 | # Time the model's response 91 | start_time = time.time() 92 | 93 | output = inference_model.generate_tokens( 94 | input_ids, prompt_length, output_length 95 | ) 96 | 97 | end_time = time.time() 98 | elapsed_time = end_time - start_time 99 | 100 | if len(output[0]) == prompt_length + output_length: 101 | break 102 | else: 103 | raise RuntimeError( 104 | f"Failed to generate output with correct length after {max_attempts} attempts." 105 | ) 106 | 107 | tokens_per_second = output_length / elapsed_time 108 | 109 | return tokens_per_second 110 | 111 | 112 | def benchmark_model(model_name, inference_model, prompt_lengths, output_lengths): 113 | results = {} 114 | results[model_name] = {} 115 | 116 | for prompt_length in prompt_lengths: 117 | for output_length in output_lengths: 118 | latencies = [] 119 | 120 | print( 121 | f"\n--- Benchmarking Model: {model_name}, Prompt Length: {prompt_length}, Output Length: {output_length} ---" 122 | ) 123 | for i in range(num_runs): 124 | tokens_per_second = measure_latency( 125 | inference_model, prompt_length, output_length 126 | ) 127 | latencies.append(tokens_per_second) 128 | 129 | print(f"Run {i+1} - Tokens/sec: {tokens_per_second}") 130 | 131 | avg_tokens_per_second = sum(latencies) / num_runs 132 | 133 | results[model_name][ 134 | f"{prompt_length}_{output_length}" 135 | ] = avg_tokens_per_second 136 | 137 | print(f"Average tokens/sec over {num_runs} runs: {avg_tokens_per_second}") 138 | 139 | # Write results to a JSON file 140 | with open(f"{model_name}_benchmark_results.json", "w") as f: 141 | json.dump(results, f) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser(description="Benchmark a Language Model.") 146 | parser.add_argument( 147 | "--model_name", type=str, help="The name of the model to benchmark." 148 | ) 149 | parser.add_argument( 150 | "--model_name_or_path", 151 | type=str, 152 | help="Path to weights or info needed to trigger downloads.", 153 | ) 154 | parser.add_argument( 155 | "--tokenizer_name_or_path", 156 | type=str, 157 | default=None, 158 | help="The name or path of the tokenizer to use. If not provided, uses the same as the model.", 159 | ) 160 | parser.add_argument( 161 | "--prompt_lengths", 162 | nargs="+", 163 | type=int, 164 | default=[25, 50, 100, 250, 500, 1000], 165 | help="The lengths of the prompts to be used.", 166 | ) 167 | parser.add_argument( 168 | "--output_lengths", 169 | nargs="+", 170 | type=int, 171 | default=[25, 50, 100], 172 | help="The lengths of the output sequences to be generated.", 173 | ) 174 | 175 | args = parser.parse_args() 176 | 177 | tokenizer_name_or_path = args.tokenizer_name_or_path or args.model_name_or_path 178 | inference_model = LlamaBnB4Bit( 179 | args.model_name_or_path, tokenizer_name_or_path, None 180 | ) 181 | 182 | benchmark_model( 183 | args.model_name, inference_model, args.prompt_lengths, args.output_lengths 184 | ) 185 | -------------------------------------------------------------------------------- /scripts/load_secrets.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "../official-models" ]; then 2 | pushd .. 3 | git clone git@github.com:replicate/official-models 4 | popd 5 | fi 6 | 7 | cp ../official-models/model_secrets/llama-2-13b/.env models/llama-2-13b/ 8 | cp ../official-models/model_secrets/llama-2-13b-chat/.env models/llama-2-13b-chat/ 9 | cp ../official-models/model_secrets/llama-2-70b/.env models/llama-2-70b/ 10 | cp ../official-models/model_secrets/llama-2-70b-chat/.env models/llama-2-70b-chat/ 11 | cp ../official-models/model_secrets/llama-2-7b/.env models/llama-2-7b/ 12 | cp ../official-models/model_secrets/llama-2-7b-chat/.env models/llama-2-7b-chat/ 13 | -------------------------------------------------------------------------------- /scripts/test_load_unload_lora.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | from io import BytesIO 3 | 4 | import replicate 5 | from termcolor import cprint 6 | 7 | from src.download import Downloader 8 | from src.inference_engines.vllm_engine import vLLMEngine 9 | 10 | 11 | class vLLMLoraTest: 12 | def __init__(self): 13 | # setup 14 | self.downloader = Downloader() 15 | self.sql_lora_path = ( 16 | "https://pub-df34620a84bb4c0683fae07a260df1ea.r2.dev/sql.zip" 17 | ) 18 | self.summary_lora_path = ( 19 | "https://storage.googleapis.com/dan-scratch-public/tmp/samsum-lora.zip" 20 | ) 21 | 22 | self.engine_kwargs = { 23 | "max_new_tokens": 128, 24 | "temperature": 1.0, 25 | "top_p": 0.9, 26 | "top_k": 50, 27 | } 28 | MODEL_PATH = "models/llama-2-7b-vllm/model_artifacts/default_inference_weights" 29 | self.engine = vLLMEngine( 30 | model_path=MODEL_PATH, tokenizer_path=MODEL_PATH, dtype="auto" 31 | ) 32 | self.sql_lora = self.get_lora(self.sql_lora_path) 33 | self.summary_lora = self.get_lora(self.summary_lora_path) 34 | 35 | def get_lora(self, lora_path): 36 | buffer = self.downloader.sync_download_file(lora_path) 37 | with zipfile.ZipFile(buffer, "r") as zip_ref: 38 | data = {name: zip_ref.read(name) for name in zip_ref.namelist()} 39 | adapter_config, adapter_model = ( 40 | data["adapter_config.json"], 41 | BytesIO(data["adapter_model.bin"]), 42 | ) 43 | return self.engine.load_lora( 44 | adapter_config=adapter_config, adapter_model=adapter_model 45 | ) 46 | 47 | def generate_replicate(self, prompt, lora_path): 48 | output = replicate.run( 49 | "moinnadeem/vllm-engine-llama-7b:15ec772e3ae45cf5afd629a766774ad7cc2a80894d23848e840f926e8b5868c4", 50 | input={"prompt": prompt, "replicate_weights": lora_path}, 51 | ) 52 | generated_text = "" 53 | for item in output: 54 | generated_text += item 55 | return generated_text 56 | 57 | def generate(self, prompt, lora): 58 | self.engine_kwargs["prompt"] = prompt 59 | base_generation = "" 60 | if self.engine.is_lora_active(): 61 | self.engine.delete_lora() 62 | if lora: 63 | self.engine.set_lora(lora) 64 | 65 | generation = "".join(list(self.engine(**self.engine_kwargs))) 66 | return generation 67 | 68 | def run_base(self): 69 | # generate vanilla output that should be screwed up by a lora 70 | sql_prompt = "What is the meaning of life?" 71 | base_generation = self.generate_replicate(sql_prompt, "") 72 | 73 | sql_generation = self.generate_replicate(sql_prompt, self.sql_lora_path) 74 | lora_expected_generation = "What is the meaning of life?" 75 | cprint("Philosophy output:", "blue") 76 | cprint(f"Base model output: {base_generation}", "blue") 77 | cprint(f"LoRA output: {sql_generation}", "blue") 78 | # assert base_generation != lora_expected_generation 79 | # assert sql_generation == lora_expected_generation 80 | 81 | def run_sql(self): 82 | # generate SQL 83 | sql_prompt = """You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. 84 | 85 | You must output the SQL query that answers the question. 86 | 87 | ### Input: 88 | What is the total number of decile for the redwood school locality? 89 | 90 | ### Context: 91 | CREATE TABLE table_name_34 (decile VARCHAR, name VARCHAR) 92 | 93 | ### Response:""" 94 | 95 | base_generation = self.generate_replicate(sql_prompt, "") 96 | sql_generation = self.generate_replicate(sql_prompt, self.sql_lora_path) 97 | base_generation = base_generation.strip() 98 | sql_generation = sql_generation.strip() 99 | lora_expected_generation = ( 100 | 'SELECT COUNT(decile) FROM table_name_34 WHERE name = "redwood school"' 101 | ) 102 | cprint("SQL output:", "green") 103 | cprint(f"Base model output: {base_generation}", "green") 104 | cprint(f"LoRA output: {sql_generation}", "green") 105 | # assert base_generation != lora_expected_generation 106 | # assert sql_generation == lora_expected_generation 107 | 108 | def run_summary(self): 109 | # generate summaries 110 | summary_prompt = """[INST] <> 111 | Use the Input to provide a summary of a conversation. 112 | <> 113 | Input: 114 | Liam: did you see that new movie that just came out? 115 | Liam: "Starry Skies" I think it's called 116 | Ava: oh yeah, I heard about it 117 | Liam: it's about this astronaut who gets lost in space 118 | Liam: and he has to find his way back to earth 119 | Ava: sounds intense 120 | Liam: it was! there were so many moments where I thought he wouldn't make it 121 | Ava: i need to watch it then, been looking for a good movie 122 | Liam: highly recommend it! 123 | Ava: thanks for the suggestion Liam! 124 | Liam: anytime, always happy to share good movies 125 | Ava: let's plan to watch it together sometime 126 | Liam: sounds like a plan! [/INST]""" 127 | 128 | base_generation = self.generate_replicate(summary_prompt, "") 129 | summary_generation = self.generate_replicate( 130 | summary_prompt, self.summary_lora_path 131 | ) 132 | lora_expected_generation = ( 133 | '\nSummary: Liam recommends the movie "Starry Skies" to Ava.' 134 | ) 135 | cprint("Summary output:", "blue") 136 | cprint(f"Base model output: {base_generation}", "blue") 137 | cprint(f"LoRA output: {summary_generation}", "blue") 138 | # assert base_generation != lora_expected_generation 139 | # assert summary_generation == lora_expected_generation 140 | 141 | 142 | if __name__ == "__main__": 143 | tester = vLLMLoraTest() 144 | # tester.run_base() 145 | # tester.run_summary() 146 | for idx in range(10): 147 | print(f"SQL Test #{idx}:") 148 | tester.run_sql() 149 | print("-" * 10) 150 | print(f"Summary Test #{idx}:") 151 | tester.run_summary() 152 | print("=" * 20) 153 | -------------------------------------------------------------------------------- /scripts/train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --train_data 70k_samples_prompt.jsonl \ 5 | --num_train_epochs 1 \ 6 | --learning_rate 2e-5 \ 7 | --train_batch_size 2 \ 8 | --gradient_accumulation_steps 4 \ 9 | --logging_steps 2 \ 10 | --warmup_ratio 0.03 \ 11 | --weights /src/weights_13 -------------------------------------------------------------------------------- /scripts/train_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --model_name_or_path google/flan-t5-base \ 5 | --data_path ./replicate_alpaca_data.json \ 6 | --num_train_epochs 3 \ 7 | --learning_rate 3e-4 \ 8 | --train_batch_size 8 \ 9 | --warmup_ratio 0.03 \ 10 | --max_steps 10 # number of steps before returning, mostly useful for testing performance 11 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/src/__init__.py -------------------------------------------------------------------------------- /src/config_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | An entirely self-contained config parsing util that should, if all goes well, dramatically simplify our configuration. 3 | """ 4 | from typing import List, Optional 5 | 6 | from pydantic import BaseModel 7 | 8 | 9 | class Weights(BaseModel): 10 | local_path: str 11 | remote_path: str 12 | remote_files: List[str] 13 | 14 | 15 | def get_fp16_file_list(n_shards: int): 16 | """ 17 | Assumes safetensors 18 | """ 19 | base_files = [ 20 | f"model-{str(val).zfill(5)}-of-{str(n_shards).zfill(5)}.safetensors" 21 | for val in range(1, n_shards + 1) 22 | ] 23 | base_files += [ 24 | "config.json", 25 | "generation_config.json", 26 | "special_tokens_map.json", 27 | "tokenizer_config.json", 28 | "tokenizer.json", 29 | "tokenizer.model", 30 | "model.safetensors.index.json", 31 | ] 32 | return base_files 33 | 34 | 35 | def get_gptq_file_list(base_model_name: str): 36 | """ 37 | name of .safetensors varies 38 | """ 39 | base_files = [base_model_name] 40 | base_files += [ 41 | "config.json", 42 | "generation_config.json", 43 | "special_tokens_map.json", 44 | "tokenizer_config.json", 45 | "tokenizer.json", 46 | "tokenizer.model", 47 | "quantize_config.json", 48 | ] 49 | return base_files 50 | 51 | 52 | def get_mlc_file_list(model_name: str, n_shards: int): 53 | files_to_download = [ 54 | f"params/params_shard_{shard_idx}.bin" for shard_idx in range(n_shards) 55 | ] 56 | 57 | files_to_download += [ 58 | f"{model_name}-cuda.so", 59 | "mod_cache_before_build.pkl", 60 | "params/mlc-chat-config.json", 61 | "params/ndarray-cache.json", 62 | "params/tokenizer.json", 63 | "params/tokenizer_config.json", 64 | "params/tokenizer.model", 65 | "params/config.json", 66 | ] 67 | return files_to_download 68 | 69 | 70 | def exllama_kwargs(weights: Weights, config_overrides: Optional[dict] = None): 71 | exllama_default = {"weights": weights, "fused_attn": True} 72 | if config_overrides: 73 | exllama_default.update(config_overrides) 74 | return exllama_default 75 | 76 | 77 | def vllm_kwargs(weights: Weights, config_overrides: Optional[dict] = None): 78 | vllm_default = { 79 | "weights": weights, 80 | "dtype": "auto", 81 | } 82 | if config_overrides: 83 | vllm_default.update(config_overrides) 84 | return vllm_default 85 | 86 | 87 | def mlc_kwargs( 88 | weights: Weights, 89 | is_chat: bool, 90 | num_shards: int = 1, 91 | tokenizer_path: str = None, 92 | config_overrides: Optional[dict] = None, 93 | ): 94 | mlc_default = { 95 | "weights": weights, 96 | "tokenizer_path": tokenizer_path, 97 | "is_chat": is_chat, 98 | "num_shards": num_shards, 99 | } 100 | if config_overrides: 101 | mlc_default.update(config_overrides) 102 | return mlc_default 103 | -------------------------------------------------------------------------------- /src/inference_engines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/src/inference_engines/__init__.py -------------------------------------------------------------------------------- /src/inference_engines/engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABC, abstractmethod 3 | from typing import Any 4 | 5 | from src.config_utils import Weights 6 | from src.utils import maybe_download_with_pget 7 | 8 | 9 | class Engine(ABC): 10 | """ 11 | WIP - this is what the engine looks like at the moment, outlining this just as an exercise to see what our ABC looks like. It will change. 12 | """ 13 | 14 | def load_weights(self, weights: Weights): 15 | start = time.time() 16 | maybe_download_with_pget( 17 | weights.local_path, weights.remote_path, weights.remote_files 18 | ) 19 | print(f"downloading weights took {time.time() - start:.3f}s") 20 | return weights.local_path 21 | 22 | @abstractmethod 23 | def load_lora(self, lora_data: dict): 24 | """ 25 | loads a lora from files into the format that this particular engine expects. DOES NOT prepare the engine for inference. 26 | lora_data is a dictionary of file names & references from the zip file 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def set_lora(self, lora: Any): 32 | """ 33 | given a loaded lora (created w/load_lora), configures the engine to use that lora in combination with the loaded base weights. 34 | """ 35 | pass 36 | 37 | @abstractmethod 38 | def is_lora_active(self) -> bool: 39 | """ 40 | Checks whether a LoRA has currently been loaded onto the engine. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def delete_lora(self): 46 | """ 47 | Deletes a LoRA. 48 | """ 49 | pass 50 | 51 | @abstractmethod 52 | def __call__(self, prompt, **kwargs): 53 | """ 54 | generation! 55 | """ 56 | pass 57 | -------------------------------------------------------------------------------- /src/inference_engines/exllama.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import sys 4 | import glob 5 | 6 | import torch 7 | import time 8 | import typing as tp 9 | 10 | from src.config_utils import Weights 11 | 12 | exllama_path = os.path.abspath("exllama") 13 | sys.path.insert(0, exllama_path) 14 | 15 | from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig 16 | from exllama.lora import ExLlamaLora 17 | from exllama.tokenizer import ExLlamaTokenizer 18 | from exllama.generator import ExLlamaGenerator 19 | 20 | from src.inference_engines.engine import Engine 21 | from ..utils import StreamingTextStopSequenceHandler 22 | 23 | torch.cuda._lazy_init() 24 | torch.set_printoptions(precision=10) 25 | 26 | 27 | def next_logits( 28 | generator, input_ids, apply_lora=None, last_id_only=True, input_mask=None 29 | ): 30 | n_logits = generator.model.forward( 31 | input_ids, generator.cache, last_id_only, lora=apply_lora, input_mask=input_mask 32 | ) 33 | return n_logits 34 | 35 | 36 | def begin(generator): 37 | if generator.cache is None: 38 | generator.cache = ExLlamaCache(generator.model) 39 | else: 40 | generator.cache.current_seq_len = 0 41 | return generator 42 | 43 | 44 | def timer(name, func): 45 | t = time.time() 46 | ret = func() 47 | t = time.time() - t 48 | print(f" ** Time, {name}: {t:.2f} seconds") 49 | return ret 50 | 51 | 52 | class ExllamaEngine(Engine): 53 | def __init__(self, weights: Weights, fused_attn=True): 54 | model_directory = self.load_weights(weights) 55 | tokenizer_path = os.path.join(model_directory, "tokenizer.model") 56 | model_config_path = os.path.join(model_directory, "config.json") 57 | st_pattern = os.path.join(model_directory, "*.safetensors") 58 | model_path = glob.glob(st_pattern)[0] 59 | 60 | config = ExLlamaConfig(model_config_path) # create config from config.json 61 | config.model_path = model_path # supply path to model weights file 62 | 63 | # Override exllam's default settings to use full llama v2 context 64 | config.max_seq_len = 2 * 2048 65 | config.max_input_len = 2 * 2048 66 | config.max_attention_size = 2 * 2048**2 67 | config.fused_attn = fused_attn 68 | 69 | self.model = model = ExLlama( 70 | config 71 | ) # create ExLlama instance and load the weights 72 | tokenizer = ExLlamaTokenizer( 73 | tokenizer_path 74 | ) # create tokenizer from tokenizer model file 75 | 76 | cache = ExLlamaCache(model) # create cache for inference 77 | generator = ExLlamaGenerator(model, tokenizer, cache) # create generator 78 | 79 | # warmup kernels 80 | 81 | warmup_ids = torch.randint(0, 31999, (1, 50)).cuda() 82 | print("warming up exllama kernels...") 83 | for i in range(1, 3): 84 | print(f" -- Warmup pass {i}...") 85 | begin(generator) 86 | logits = timer("Warmup", lambda: next_logits(generator, warmup_ids, None)) 87 | 88 | self.generator = begin(generator) 89 | 90 | def delete_lora(self): 91 | self.generator.lora = None 92 | return 93 | 94 | def is_lora_active(self) -> bool: 95 | return self.generator.lora is None 96 | 97 | def load_lora(self, data_ref: dict) -> ExLlamaLora: 98 | return ExLlamaLora( 99 | self.model, 100 | data_ref["adapter_config.json"], 101 | io.BytesIO(data_ref["adapter_model.bin"]), 102 | ) 103 | 104 | def set_lora(self, lora: ExLlamaLora | None) -> None: 105 | self.generator.lora = lora 106 | 107 | def __call__( 108 | self, 109 | prompt: str, 110 | repetition_penalty: float = 1.15, 111 | repetition_penalty_sustain: int = 256, 112 | token_repetition_penalty_decay: float = 128, 113 | temperature: float = 0.95, 114 | top_p: float = 0.65, 115 | top_k: int = 20, 116 | max_new_tokens: int = 128, 117 | min_new_tokens: int = 0, 118 | beams: int = 1, 119 | beam_length: int = 1, 120 | stop_sequences: tp.List[str] = None, 121 | ): 122 | if top_k <= 0: 123 | top_k = 20 124 | generator = begin(self.generator) 125 | generator.settings.token_repetition_penalty_max = repetition_penalty 126 | generator.settings.token_repetition_penalty_sustain = repetition_penalty_sustain 127 | generator.settings.token_repetition_penalty_decay = ( 128 | token_repetition_penalty_decay 129 | ) 130 | generator.settings.temperature = temperature 131 | generator.settings.top_p = top_p 132 | generator.settings.top_k = top_k 133 | generator.settings.beams = beams 134 | generator.settings.beam_length = beam_length 135 | 136 | in_tokens = generator.tokenizer.encode(prompt) 137 | n_in_tokens = in_tokens.shape[-1] 138 | if n_in_tokens >= generator.model.config.max_input_len: 139 | raise ValueError( 140 | f"Your input is too long. Max input length is {generator.model.config.max_input_len} tokens, but you supplied {n_in_tokens} tokens." 141 | ) 142 | 143 | max_new_tokens = min( 144 | max_new_tokens, generator.model.config.max_seq_len - n_in_tokens 145 | ) 146 | 147 | num_res_tokens = in_tokens.shape[-1] # Decode from here 148 | 149 | generator.gen_begin(in_tokens) 150 | generator.begin_beam_search() 151 | 152 | stop_sequence_handler = StreamingTextStopSequenceHandler( 153 | stop_sequences=stop_sequences, 154 | eos_token=generator.tokenizer.eos_token, 155 | ) 156 | 157 | for i in range(max_new_tokens): 158 | if i < min_new_tokens: 159 | generator.disallow_tokens( 160 | [ 161 | generator.tokenizer.newline_token_id, 162 | generator.tokenizer.eos_token_id, 163 | ] 164 | ) 165 | else: 166 | generator.disallow_tokens(None) 167 | 168 | gen_token = generator.beam_search() 169 | if gen_token.item() == generator.tokenizer.eos_token_id: 170 | break 171 | 172 | if gen_token.item() == generator.tokenizer.eos_token_id: 173 | generator.replace_last_token(generator.tokenizer.newline_token_id) 174 | 175 | num_res_tokens += 1 176 | text = generator.tokenizer.decode( 177 | generator.sequence_actual[:, -num_res_tokens:][0] 178 | ) 179 | new_text = text[len(prompt):] 180 | 181 | if len(new_text.replace("�", "")) == 0: 182 | # if we're getting �, then we're halfway through an emoji; ignore it til it's fully generated. 183 | continue 184 | skip_space = prompt.endswith(("\n", "[/INST]")) and new_text.startswith( 185 | " " 186 | ) # Bit prettier console output 187 | prompt += new_text 188 | if skip_space: 189 | new_text = new_text[1:] 190 | 191 | yielded_text = None 192 | for yielded_text in stop_sequence_handler(new_text): 193 | if yielded_text == stop_sequence_handler.eos_token: 194 | break 195 | yield yielded_text 196 | 197 | if yielded_text == stop_sequence_handler.eos_token: 198 | break 199 | 200 | for yielded_text in stop_sequence_handler.finalize(): 201 | yield yielded_text 202 | -------------------------------------------------------------------------------- /src/inference_engines/mlc_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cog import ConcatenateIterator 4 | from mlc_chat import ChatConfig, ChatModule, ConvConfig, GenerationConfig 5 | from transformers import AutoTokenizer 6 | 7 | from src.config_utils import Weights 8 | 9 | from .engine import Engine 10 | 11 | class MLCEngine(Engine): 12 | """ 13 | An inference engine that runs inference w/ vLLM 14 | """ 15 | 16 | def __init__( 17 | self, weights: Weights, is_chat: bool, num_shards: int = 1, tokenizer_path: os.PathLike = None 18 | ) -> None: 19 | weights_path = self.load_weights(weights) 20 | self.is_chat = is_chat 21 | self.num_shards = num_shards 22 | 23 | if self.is_chat: 24 | self.conv_template = "llama-2" 25 | self.stop_str = "" 26 | self.stop_tokens = [] 27 | self.add_bos = None 28 | else: 29 | self.conv_template = "LM" 30 | self.stop_str = "[INST]" 31 | self.stop_tokens = [ 32 | 2, 33 | ] 34 | self.add_bos = True 35 | 36 | conv_config = ConvConfig( 37 | stop_tokens=self.stop_tokens, add_bos=self.add_bos, stop_str=self.stop_str 38 | ) 39 | chat_config = ChatConfig( 40 | conv_config=conv_config, conv_template=self.conv_template, num_shards=self.num_shards 41 | ) 42 | 43 | model_path = os.path.join(weights_path, "params") 44 | self.cm = ChatModule(model=model_path, chat_config=chat_config) 45 | 46 | # this isn't used! 47 | tokenizer_path = os.path.join(weights_path, "params") 48 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 49 | 50 | def load_weights(self, weights: Weights) -> str: 51 | """ 52 | Downloads the weights from the given Weights object and returns the path to the downloaded weights. 53 | 54 | Args: 55 | - weights (Weights): the weights to download. 56 | 57 | Returns: 58 | - weights_path (str): the path to the downloaded weights. 59 | """ 60 | # ensure directories exist 61 | for path in weights.remote_files: 62 | path_directory = os.path.dirname(path) 63 | if path_directory: 64 | path_directory = os.path.join(weights.local_path, path_directory) 65 | os.makedirs(path_directory, exist_ok=True) 66 | 67 | return super().load_weights(weights) 68 | 69 | def get_logits(self): 70 | """ 71 | Given a prompt, returns the logits from the language model. 72 | """ 73 | raise NotImplementedError("MLC currently does not support logits.") 74 | 75 | def load_lora(self): 76 | """ 77 | loads a lora from files into the format that this particular engine expects. DOES NOT prepare the engine for inference. 78 | lora_data is a dictionary of file names & references from the zip file 79 | """ 80 | raise NotImplementedError("MLC currently does not support LoRAs.") 81 | 82 | def is_lora_active(self): 83 | """ 84 | Returns True if the engine is currently configured to use a lora, False otherwise. 85 | """ 86 | raise NotImplementedError("MLC currently does not support LoRAs.") 87 | 88 | def set_lora(self): 89 | """ 90 | Given a loaded lora (created w/ load_lora), configures the engine to use that lora in combination with the loaded base weights. 91 | """ 92 | raise NotImplementedError("MLC currently does not support LoRAs.") 93 | 94 | def delete_lora(self): 95 | print("MLC is currently not using any LoRAs.") 96 | 97 | def __call__( 98 | self, 99 | prompt: str, 100 | max_new_tokens: int, 101 | temperature: float, 102 | top_p: float, 103 | top_k: int, 104 | stop_sequences: str | list[str] = None, 105 | stop_token_ids: list[int] = [], 106 | repetition_penalty: float = 1.0, 107 | incremental_generation: bool = True, 108 | *args, 109 | **kwargs, 110 | ) -> ConcatenateIterator[str]: 111 | """ 112 | Given a prompt, runs generation on the language model with vLLM. 113 | 114 | Args: 115 | - prompt (str): the prompt to give the model. 116 | - max_new_tokens (int): the maximum number of new tokens to generate. 117 | - temperature (float): the parameter to anneal the sampling distribution with. 118 | - top_p (float): the amount to truncate the sampling distribution by. 119 | - top_k (int): the number of tokens to truncate the sampling distribution by. 120 | - stop_sequences (str | list[str]): the string to stop generation at. 121 | - stop_token_ids (list[str]): a list of token ids to stop generation at. 122 | - frequency_penalty (float): the amount to penalize tokens that have already been generated, higher values penalize more. 123 | - incremental_generation: whether to yield the entire generated sequence or the next generated token at each step. 124 | 125 | Yields: 126 | - generated_text (str): the generated text, or next token, depending on the value of `incremental_generation`. 127 | """ 128 | 129 | if top_k is not None and top_k > 0: 130 | raise ValueError( 131 | "top_k is currently not supported by our generation engine." 132 | ) 133 | 134 | stop_token_ids += self.stop_tokens 135 | # stop_sequences = [self.stop_str] + stop_sequences 136 | 137 | # TODO (Moin): add support for the system prompt on chat models 138 | conv_config = ConvConfig( 139 | stop_tokens=stop_token_ids, add_bos=self.add_bos, stop_str=stop_sequences 140 | ) 141 | chat_config = ChatConfig( 142 | temperature=temperature, 143 | repetition_penalty=repetition_penalty, 144 | top_p=top_p, 145 | max_gen_len=max_new_tokens, 146 | mean_gen_len=max_new_tokens, 147 | conv_config=conv_config, 148 | conv_template=self.conv_template, 149 | num_shards=self.num_shards 150 | ) 151 | self.cm.reset_chat(chat_config) 152 | 153 | generation_config = GenerationConfig( 154 | temperature=temperature, 155 | repetition_penalty=repetition_penalty, 156 | top_p=top_p, 157 | max_gen_len=max_new_tokens, 158 | ) 159 | self.cm._prefill(input=prompt, generation_config=generation_config) 160 | 161 | min_new_tokens = kwargs.pop("min_new_tokens", None) 162 | if min_new_tokens is not None and min_new_tokens > -1: 163 | raise ValueError( 164 | "min_new_tokens is currently not supported by MLC's engine." 165 | ) 166 | 167 | if len(kwargs) > 0: 168 | raise ValueError(f"Unknown keyword arguments: {', '.join(kwargs.keys())}") 169 | 170 | generation_length = 0 171 | while True: 172 | if self.cm._stopped(): 173 | break 174 | self.cm._decode(generation_config=generation_config) 175 | out = self.cm._get_message() 176 | # stops us from yielding half an emoji, which breaks 177 | out = out.replace("\N{Replacement Character}", "") 178 | if len(out) == generation_length: 179 | # don't yield an empty string 180 | continue 181 | yield out[generation_length:] 182 | generation_length = len(out) 183 | -------------------------------------------------------------------------------- /src/inference_engines/mlc_vllm_engine.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, List 2 | import os 3 | 4 | from .engine import Engine 5 | from .vllm_engine import vLLMEngine 6 | 7 | 8 | class MLCvLLMEngine(Engine): 9 | """ 10 | MLC for base models, vllm for fine-tunes. 11 | """ 12 | 13 | def __init__(self, mlc_args: dict, vllm_args: dict) -> None: 14 | # checks for old style loras & if this is booted as a fine-tuneable hotswap 15 | if os.getenv("COG_WEIGHTS") or os.getenv("REPLICATE_HOTSWAP") == "1": 16 | self.engine = vLLMEngine(**vllm_args) 17 | else: 18 | # can't run vllm if MLC is imported 19 | from .mlc_engine import MLCEngine 20 | 21 | self.engine = MLCEngine(**mlc_args) 22 | self.vllm_args = vllm_args 23 | 24 | def load_lora(self, lora_data: dict) -> Any: 25 | """ 26 | loads a lora from files into the format that this particular engine expects. DOES NOT prepare the engine for inference. 27 | lora_data is a dictionary of file names & references from the zip file 28 | """ 29 | if not isinstance(self.engine, vLLMEngine): 30 | # Really we should never need to do this. 31 | # print("Transitioning from MLC to vLLM") 32 | # del self.engine.cm 33 | # del self.engine.tokenizer 34 | # del self.engine 35 | 36 | # gc.collect() 37 | # torch.cuda.empty_cache() 38 | # self.engine = vLLMEngine(**self.vllm_args) 39 | raise Exception("Loras not supported with MLCEngine") 40 | 41 | return self.engine.load_lora(lora_data) 42 | 43 | def is_lora_active(self) -> bool: 44 | """ 45 | Returns True if the engine is currently configured to use a lora, False otherwise. 46 | """ 47 | if isinstance(self.engine, vLLMEngine): 48 | return self.engine.is_lora_active() 49 | return False 50 | 51 | def set_lora(self, lora: Any) -> None: 52 | """ 53 | Given a loaded lora (created w/ load_lora), configures the engine to use that lora in combination with the loaded base weights. 54 | """ 55 | if not isinstance(self.engine, vLLMEngine): 56 | raise Exception( 57 | "Loras not supported with MLC Engine! Invalid state reached." 58 | ) 59 | self.engine.set_lora(lora) 60 | 61 | def delete_lora(self) -> None: 62 | self.engine.delete_lora() 63 | 64 | def __call__( 65 | self, 66 | prompt, 67 | max_new_tokens: int = 128, 68 | min_new_tokens: int = -1, 69 | temperature: float = 0.75, 70 | top_p: float = 0.9, 71 | top_k: int = 50, 72 | stop_sequences: Optional[List[str]] = None, 73 | **kwargs, 74 | ): 75 | print(f"MLC: {not isinstance(self.engine, vLLMEngine)}") 76 | gen = self.engine( 77 | prompt, 78 | max_new_tokens=max_new_tokens, 79 | min_new_tokens=min_new_tokens, 80 | temperature=temperature, 81 | top_p=top_p, 82 | top_k=top_k, 83 | stop_sequences=stop_sequences, 84 | **kwargs, 85 | ) 86 | for val in gen: 87 | yield val 88 | -------------------------------------------------------------------------------- /src/inference_engines/transformers_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from transformers import AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria 4 | from typing import Optional, List, Tuple, Any 5 | from threading import Thread 6 | from peft import PeftModel, LoraConfig 7 | from peft.utils.save_and_load import set_peft_model_state_dict 8 | 9 | import torch.nn.init 10 | 11 | from src.config_utils import Weights 12 | 13 | torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x 14 | torch.nn.init.uniform_ = lambda x, *args, **kwargs: x 15 | 16 | import torch 17 | 18 | from .engine import Engine 19 | 20 | ADAPTER_NAME = "default" 21 | 22 | 23 | class ExtraStopSequence(StoppingCriteria): 24 | """ 25 | Adds in an extra stop sequence. Assuming 1-D generation, not batch. 26 | """ 27 | 28 | # TODO: there's something silly to debug here. 29 | def __init__(self, stop_sequence: torch.Tensor, device: str): 30 | self.stop_sequence = stop_sequence.to(device) 31 | 32 | def __call__( 33 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 34 | ): 35 | return torch.equal( 36 | self.stop_sequence, input_ids[:, self.stop_sequence.shape[-1]] 37 | ) 38 | 39 | 40 | class TransformersEngine(Engine): 41 | """ 42 | An inference engine that runs in vanilla transformers. 43 | Vanilla is, at times, fantastic. 44 | """ 45 | 46 | def __init__(self, weights: Weights, tokenizer_func=None, device="cuda"): 47 | model_path = self.load_weights(weights) 48 | self.model = AutoModelForCausalLM.from_pretrained( 49 | model_path, torch_dtype=torch.bfloat16 50 | ).to(device) 51 | self.tokenizer = tokenizer_func() 52 | self.device = device 53 | print("Transformers engine initialized.") 54 | 55 | def load_lora(self, lora_weights: dict) -> Tuple[LoraConfig, Any]: 56 | """ 57 | Given a dict of {filename:bytes}, returns a tuple of (LoraConfig, Torch model) 58 | This relies on external but poorly documented peft methods, when we upgrade peft past 0.4.0 we may need to (briefly) revisit 59 | """ 60 | 61 | # serializing the dictionary of files and such - hf doesn't have quick and easy ways to load loras from file references, 62 | # and this implementation isn't built for speed anyway 63 | model_dir = "tmp/model" 64 | os.makedirs(model_dir) 65 | for handle in lora_weights: 66 | fpath = os.path.join(model_dir, handle) 67 | with open(fpath, "wb") as f: 68 | f.write(lora_weights[handle]) 69 | 70 | config = LoraConfig.from_pretrained(model_dir) 71 | weights = torch.load( 72 | os.path.join(model_dir, "adapter_model.bin"), 73 | map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 74 | ) 75 | shutil.rmtree(model_dir) 76 | return (config, weights) 77 | 78 | def is_lora_active(self) -> bool: 79 | return isinstance(self.model, PeftModel) 80 | 81 | def delete_lora(self): 82 | if hasattr(self.model, "disable_adapter_layers") and callable( 83 | self.model.disable_adapter_layers 84 | ): 85 | self.model.disable_adapter_layers() 86 | else: 87 | print("No loras were ever loaded, nothing to disable.") 88 | return 89 | 90 | def set_lora(self, lora): 91 | """ 92 | Sets a new lora if needed. 93 | """ 94 | config, weights = lora 95 | 96 | # Note that right now we're just overwriting the "default" adapter w/ADAPTER_NAME 97 | # we can try managing multiple adapters w/lru eviction logic, didn't seem necessary 98 | if not isinstance(self.model, PeftModel): 99 | # is not a peft model 100 | self.model = PeftModel(self.model, config, ADAPTER_NAME) 101 | set_peft_model_state_dict(self.model, weights, ADAPTER_NAME) 102 | self.model.eval() 103 | print("added lora for the first time") 104 | else: 105 | self.model.enable_adapter_layers() 106 | self.model.add_adapter(ADAPTER_NAME, config) 107 | set_peft_model_state_dict(self.model, weights, ADAPTER_NAME) 108 | print("set new lora") 109 | print(self.model.peft_config) 110 | self.model.eval() 111 | 112 | return 113 | 114 | def get_logits(self, prompt): 115 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 116 | self.device 117 | ) 118 | inputs = self.model.prepare_inputs_for_generation(input_ids) 119 | with torch.no_grad(): 120 | output = self.model( 121 | **inputs, 122 | return_dict=True, 123 | output_attentions=False, 124 | output_hidden_states=False, 125 | ) 126 | logits = output.logits[:, -1, :] 127 | return logits 128 | 129 | def __call__( 130 | self, 131 | prompt, 132 | max_new_tokens: int = 128, 133 | min_new_tokens: int = -1, 134 | temperature: float = 0.75, 135 | top_p: float = 0.9, 136 | top_k: int = 50, 137 | stop_sequences: Optional[List[str]] = None, 138 | **kwargs, 139 | ): 140 | tokens_in = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 141 | self.device 142 | ) 143 | streamer = TextIteratorStreamer( 144 | self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True 145 | ) 146 | 147 | stopping_criteria_list = None 148 | if stop_sequences is not None: 149 | # stop sequences! 150 | stopping_criteria_list = [] 151 | for seq in stop_sequences: 152 | stop_ids = self.tokenizer( 153 | seq, return_tensors="pt", add_special_tokens=False 154 | ).input_ids[0] 155 | stopping_criteria_list.append(ExtraStopSequence(stop_ids, self.device)) 156 | 157 | generate_kwargs = dict( 158 | input_ids=tokens_in, 159 | streamer=streamer, 160 | do_sample=True, 161 | max_new_tokens=max_new_tokens, 162 | min_new_tokens=min_new_tokens, 163 | temperature=temperature, 164 | top_p=top_p, 165 | top_k=top_k, 166 | stopping_criteria=stopping_criteria_list, 167 | ) 168 | 169 | t = Thread(target=self.model.generate, kwargs=generate_kwargs) 170 | t.start() 171 | 172 | for out in streamer: 173 | yield out 174 | -------------------------------------------------------------------------------- /src/inference_engines/vllm_exllama_engine.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Any, Optional, List 3 | 4 | import torch 5 | import os 6 | 7 | from .engine import Engine 8 | from .vllm_engine import vLLMEngine 9 | from .exllama import ExllamaEngine 10 | 11 | 12 | class ExllamaVllmEngine(Engine): 13 | """ 14 | It's exllama until fine-tuning hits, and then it's vllm. 15 | """ 16 | 17 | def __init__(self, vllm_args: dict, exllama_args: dict) -> None: 18 | # for old-style loras, should they happen 19 | if "COG_WEIGHTS" in os.environ or ( 20 | "REPLICATE_HOTSWAP" in os.environ and os.environ["REPLICATE_HOTSWAP"] == "1" 21 | ): 22 | self.engine = vLLMEngine(**vllm_args) 23 | else: 24 | self.engine = ExllamaEngine(**exllama_args) 25 | self.vllm_args = vllm_args 26 | 27 | def load_lora(self, lora_data: dict) -> Any: 28 | """ 29 | loads a lora from files into the format that this particular engine expects. DOES NOT prepare the engine for inference. 30 | lora_data is a dictionary of file names & references from the zip file 31 | """ 32 | if isinstance(self.engine, ExllamaEngine): 33 | # Really we should never need to do this. 34 | print("Transitioning from Exllama to vLLM") 35 | del self.engine.model 36 | del self.engine.generator 37 | del self.engine 38 | 39 | gc.collect() 40 | torch.cuda.empty_cache() 41 | self.engine = vLLMEngine(**self.vllm_args) 42 | 43 | return self.engine.load_lora(lora_data) 44 | 45 | def is_lora_active(self) -> bool: 46 | """ 47 | Returns True if the engine is currently configured to use a lora, False otherwise. 48 | """ 49 | if isinstance(self.engine, vLLMEngine): 50 | return self.engine.is_lora_active() 51 | return False 52 | 53 | def set_lora(self, lora: Any) -> None: 54 | """ 55 | Given a loaded lora (created w/ load_lora), configures the engine to use that lora in combination with the loaded base weights. 56 | """ 57 | if isinstance(self.engine, ExllamaEngine): 58 | raise Exception( 59 | "Loras not supported with Exllama Engine! Invalid state reached." 60 | ) 61 | self.engine.set_lora(lora) 62 | 63 | def delete_lora(self) -> None: 64 | self.engine.delete_lora() 65 | 66 | def __call__( 67 | self, 68 | prompt, 69 | max_new_tokens: int = 128, 70 | min_new_tokens: int = -1, 71 | temperature: float = 0.75, 72 | top_p: float = 0.9, 73 | top_k: int = 50, 74 | stop_sequences: Optional[List[str]] = None, 75 | **kwargs, 76 | ): 77 | if top_k <=0: 78 | top_k = 50 79 | print(f"Exllama: {isinstance(self.engine, ExllamaEngine)}") 80 | gen = self.engine( 81 | prompt, 82 | max_new_tokens=max_new_tokens, 83 | min_new_tokens=min_new_tokens, 84 | temperature=temperature, 85 | top_p=top_p, 86 | top_k=top_k, 87 | stop_sequences=stop_sequences, 88 | **kwargs, 89 | ) 90 | for val in gen: 91 | yield val 92 | -------------------------------------------------------------------------------- /src/inference_engines/vllm_transformers.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Any, Optional, List 3 | 4 | import torch 5 | 6 | from .engine import Engine 7 | from .vllm_engine import vLLMEngine 8 | from .transformers_engine import TransformersEngine 9 | 10 | 11 | class vLLMTransformersEngine(Engine): 12 | """ 13 | It's vLLM until fine-tuning hits, and then it's transformers. 14 | """ 15 | 16 | def __init__( 17 | self, model_path: str, vllm_args: dict, transformers_args: dict 18 | ) -> None: 19 | self.engine = vLLMEngine(model_path, **vllm_args) 20 | self.model_path = model_path 21 | self.transformers_args = transformers_args 22 | 23 | def load_lora(self, lora_data: dict) -> Any: 24 | """ 25 | loads a lora from files into the format that this particular engine expects. DOES NOT prepare the engine for inference. 26 | lora_data is a dictionary of file names & references from the zip file 27 | """ 28 | if isinstance(self.engine, vLLMEngine): 29 | print("Transitioning from vLLM to Transformers") 30 | for worker in self.engine.engine.engine.workers: # needs more engine 31 | del worker.cache_engine.gpu_cache 32 | del worker.cache_engine.cpu_cache 33 | del worker.gpu_cache 34 | del worker.model 35 | 36 | del self.engine 37 | gc.collect() 38 | torch.cuda.empty_cache() 39 | self.engine = TransformersEngine(self.model_path, **self.transformers_args) 40 | 41 | return self.engine.load_lora(lora_data) 42 | 43 | def is_lora_active(self) -> bool: 44 | """ 45 | Returns True if the engine is currently configured to use a lora, False otherwise. 46 | """ 47 | if isinstance(self.engine, TransformersEngine): 48 | return self.engine.is_lora_active() 49 | return False 50 | 51 | def set_lora(self, lora: Any) -> None: 52 | """ 53 | Given a loaded lora (created w/ load_lora), configures the engine to use that lora in combination with the loaded base weights. 54 | """ 55 | if isinstance(self.engine, vLLMEngine): 56 | raise Exception( 57 | "Loras not supported with vLLM Engine! Invalid state reached." 58 | ) 59 | self.engine.set_lora(lora) 60 | 61 | def delete_lora(self) -> None: 62 | self.engine.delete_lora() 63 | 64 | def __call__( 65 | self, 66 | prompt, 67 | max_new_tokens: int = 128, 68 | min_new_tokens: int = -1, 69 | temperature: float = 0.75, 70 | top_p: float = 0.9, 71 | top_k: int = 50, 72 | stop_sequences: Optional[List[str]] = None, 73 | **kwargs, 74 | ): 75 | gen = self.engine( 76 | prompt, 77 | max_new_tokens=max_new_tokens, 78 | min_new_tokens=min_new_tokens, 79 | temperature=temperature, 80 | top_p=top_p, 81 | top_k=top_k, 82 | stop_sequences=stop_sequences, 83 | **kwargs, 84 | ) 85 | for val in gen: 86 | yield val 87 | -------------------------------------------------------------------------------- /src/more_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | DEFAULT_PAD_TOKEN = "[PAD]" 5 | DEFAULT_EOS_TOKEN = "" 6 | DEFAULT_BOS_TOKEN = "" 7 | DEFAULT_UNK_TOKEN = "" 8 | 9 | 10 | def log_memory_stuff(prompt=None): 11 | """One method to barf out everything we'd ever want to know about memory""" 12 | import torch 13 | 14 | if prompt is not None: 15 | print(prompt) 16 | os.system("nvidia-smi") 17 | print(torch.cuda.memory_summary()) 18 | 19 | 20 | def load_tokenizer(tokenizer_path): 21 | """Same tokenizer, agnostic from tensorized weights/etc""" 22 | from transformers import LlamaTokenizer 23 | 24 | tok = LlamaTokenizer.from_pretrained( 25 | tokenizer_path, cache_dir="pretrained_weights", legacy=False 26 | ) 27 | tok.add_special_tokens( 28 | { 29 | "eos_token": DEFAULT_EOS_TOKEN, 30 | "bos_token": DEFAULT_BOS_TOKEN, 31 | "unk_token": DEFAULT_UNK_TOKEN, 32 | "pad_token": DEFAULT_PAD_TOKEN, 33 | } 34 | ) 35 | return tok 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/llama_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /tests/assets/llama_tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/tests/assets/llama_tokenizer/tokenizer.model -------------------------------------------------------------------------------- /tests/assets/llama_tokenizer/tokenizer_checklist.chk: -------------------------------------------------------------------------------- 1 | eeec4125e9c7560836b4873b6f8e3025 tokenizer.model 2 | -------------------------------------------------------------------------------- /tests/assets/llama_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "model_max_length": 4096, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_addoption(parser): 2 | parser.addoption("--model", action="store", default=None, help="Model name to test") 3 | -------------------------------------------------------------------------------- /tests/run_local_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO - rework this to spin up cog servers locally for prediction & training 4 | # this gives us the ability to test out post-training results (w/docker "env" vars) 5 | # I think that'll actually do it. 6 | 7 | cog predict -i prompt="Hey! How are you doing?" 8 | cog train -i train_data="https://storage.googleapis.com/dan-scratch-public/fine-tuning/1k_samples_prompt.jsonl" -i max_steps=10 9 | -------------------------------------------------------------------------------- /tests/test_e2e.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | import subprocess 4 | import time 5 | 6 | # Constants 7 | SERVER_URL = "http://localhost:5000/predictions" 8 | HEALTH_CHECK_URL = "http://localhost:5000/health-check" 9 | 10 | IMAGE_NAME = "your_image_name" # replace with your image name 11 | HOST_NAME = "your_host_name" # replace with your host name 12 | 13 | 14 | def wait_for_server_to_be_ready(url, timeout=300): 15 | """ 16 | Waits for the server to be ready. 17 | 18 | Args: 19 | - url: The health check URL to poll. 20 | - timeout: Maximum time (in seconds) to wait for the server to be ready. 21 | """ 22 | start_time = time.time() 23 | while True: 24 | try: 25 | response = requests.get(url) 26 | data = response.json() 27 | 28 | if data["status"] == "READY": 29 | return 30 | elif data["status"] == "SETUP_FAILED": 31 | raise RuntimeError( 32 | "Server initialization failed with status: SETUP_FAILED" 33 | ) 34 | 35 | except requests.RequestException: 36 | pass 37 | 38 | if time.time() - start_time > timeout: 39 | raise TimeoutError("Server did not become ready in the expected time.") 40 | 41 | time.sleep(5) # Poll every 5 seconds 42 | 43 | 44 | # Starting and stopping the server as part of the setup and teardown 45 | @pytest.fixture(scope="session") 46 | def server(): 47 | # Start the server 48 | command = [ 49 | "docker", 50 | "run", 51 | "-ti", 52 | "-p", 53 | "5000:5000", 54 | "--gpus=all", 55 | "-e", 56 | f"COG_WEIGHTS=http://{HOST_NAME}:8000/training_output.zip", 57 | "-v", 58 | "`pwd`/training_output.zip:/src/local_weights.zip", 59 | IMAGE_NAME, 60 | ] 61 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 62 | 63 | # Giving some time for the server to properly start 64 | time.sleep(10) 65 | 66 | yield process # This is where the test will execute 67 | 68 | # Stop the server 69 | process.terminate() 70 | process.wait() 71 | 72 | 73 | def test_health_check(): 74 | response = requests.get(HEALTH_CHECK_URL) 75 | assert ( 76 | response.status_code == 200 77 | ), f"Unexpected status code: {response.status_code}" 78 | 79 | 80 | def test_prediction(): 81 | data = { 82 | "input": { 83 | "prompt": "...", 84 | "max_length": "...", 85 | # Add other parameters here 86 | } 87 | } 88 | response = requests.post(SERVER_URL, json=data) 89 | assert ( 90 | response.status_code == 200 91 | ), f"Unexpected status code: {response.status_code}" 92 | # Add other assertions based on expected response 93 | 94 | 95 | # You can add more tests as per your requirements 96 | 97 | if __name__ == "__main__": 98 | pytest.main() 99 | -------------------------------------------------------------------------------- /tests/test_predict.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | import subprocess 4 | from threading import Thread, Lock 5 | 6 | from tests.test_utils import ( 7 | get_image_name, 8 | capture_output, 9 | wait_for_server_to_be_ready, 10 | ) 11 | 12 | # Constants 13 | SERVER_URL = "http://localhost:5000/predictions" 14 | HEALTH_CHECK_URL = "http://localhost:5000/health-check" 15 | 16 | IMAGE_NAME = "your_image_name" # replace with your image name 17 | HOST_NAME = "your_host_name" # replace with your host name 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def server(): 22 | image_name = get_image_name() 23 | 24 | command = [ 25 | "docker", 26 | "run", 27 | # "-ti", 28 | "-p", 29 | "5000:5000", 30 | "--gpus=all", 31 | image_name, 32 | ] 33 | print("\n**********************STARTING SERVER**********************") 34 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 35 | 36 | print_lock = Lock() 37 | 38 | stdout_thread = Thread(target=capture_output, args=(process.stdout, print_lock)) 39 | stdout_thread.start() 40 | 41 | stderr_thread = Thread(target=capture_output, args=(process.stderr, print_lock)) 42 | stderr_thread.start() 43 | 44 | wait_for_server_to_be_ready(HEALTH_CHECK_URL) 45 | 46 | yield process 47 | 48 | process.terminate() 49 | process.wait() 50 | 51 | 52 | def test_health_check(server): 53 | response = requests.get(HEALTH_CHECK_URL) 54 | assert ( 55 | response.status_code == 200 56 | ), f"Unexpected status code: {response.status_code}" 57 | 58 | 59 | def test_simple_prediction(server): 60 | data = { 61 | "input": { 62 | "prompt": "It was a dark and stormy night.", 63 | "max_new_tokens": 25, 64 | # Add other parameters here 65 | } 66 | } 67 | response = requests.post(SERVER_URL, json=data) 68 | assert ( 69 | response.status_code == 200 70 | ), f"Unexpected status code: {response.status_code}" 71 | print("\n**********************RESPONSE**********************") 72 | print("".join(response.json()["output"])) 73 | print("******************************************************\n") 74 | # Add other assertions based on expected response 75 | 76 | 77 | def test_input_too_long(server): 78 | # This is a placeholder. You need to provide an input that is expected to be too long. 79 | data = { 80 | "input": { 81 | "prompt": " a" 82 | * 6000, # Assuming this string will produce more than 4096 tokens. 83 | "max_new_tokens": 25, 84 | # Add other parameters here 85 | } 86 | } 87 | 88 | response = requests.post(SERVER_URL, json=data) 89 | 90 | response_data = response.json() 91 | 92 | assert "error" in response_data, "Expected an 'error' field in the response" 93 | 94 | error_msg_prefix = "Your input is too long. Max input length is" 95 | assert response_data["error"].startswith( 96 | error_msg_prefix 97 | ), f"Expected the error message to start with '{error_msg_prefix}'" 98 | assert response_data["status"] == "failed", "Expected the status to be 'failed'" 99 | 100 | print("\n**********************RESPONSE**********************") 101 | print(response.text) 102 | print("******************************************************\n") 103 | 104 | 105 | if __name__ == "__main__": 106 | pytest.main() 107 | -------------------------------------------------------------------------------- /tests/test_predict_with_trained_weights.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-llama-template/845e24f626bba67da586e8cb4c5793ea930e5a38/tests/test_predict_with_trained_weights.py -------------------------------------------------------------------------------- /tests/test_remote_predict.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import replicate 3 | 4 | 5 | @pytest.fixture(scope="module") 6 | def model_name(request): 7 | return request.config.getoption("--model") 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def model(model_name): 12 | return replicate.models.get(model_name) 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def version(model): 17 | versions = model.versions.list() 18 | return versions[0] 19 | 20 | 21 | @pytest.fixture(scope="module") 22 | def prediction_tests(): 23 | return [ 24 | {"prompt": "How are you doing today?"}, 25 | ] 26 | 27 | 28 | def test_initial_predictions(version, prediction_tests): 29 | predictions = [ 30 | replicate.predictions.create(version=version, input=val) 31 | for val in prediction_tests 32 | ] 33 | for val in predictions: 34 | val.wait() 35 | assert val.status == "succeeded" 36 | -------------------------------------------------------------------------------- /tests/test_remote_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pytest 3 | import replicate 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def model_name(request): 8 | return request.config.getoption("--model") 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def model(model_name): 13 | return replicate.models.get(model_name) 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def version(model): 18 | versions = model.versions.list() 19 | return versions[0] 20 | 21 | 22 | @pytest.fixture(scope="module") 23 | def training(model_name, version): 24 | training_input = { 25 | # "train_data": "https://storage.googleapis.com/replicate-weights/training-deadlock/1k_samples.jsonl", 26 | "train_data": "https://pub-3054bb37389944ca9c8e5ada8572840e.r2.dev/samsum.jsonl", 27 | } 28 | return replicate.trainings.create( 29 | version=model_name + ":" + version.id, 30 | input=training_input, 31 | destination="replicate-internal/training-scratch", 32 | ) 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def prediction_tests(): 37 | return [ 38 | {"prompt": "How are you doing today?"}, 39 | { 40 | "prompt": """[INST] <> 41 | Use the Input to provide a summary of a conversation. 42 | <> 43 | Input: 44 | Liam: did you see that new movie that just came out? 45 | Liam: "Starry Skies" I think it's called 46 | Ava: oh yeah, I heard about it 47 | Liam: it's about this astronaut who gets lost in space 48 | Liam: and he has to find his way back to earth 49 | Ava: sounds intense 50 | Liam: it was! there were so many moments where I thought he wouldn't make it 51 | Ava: i need to watch it then, been looking for a good movie 52 | Liam: highly recommend it! 53 | Ava: thanks for the suggestion Liam! 54 | Liam: anytime, always happy to share good movies 55 | Ava: let's plan to watch it together sometime 56 | Liam: sounds like a plan! [/INST] 57 | """ 58 | }, 59 | ] 60 | 61 | 62 | def test_training(training): 63 | while training.completed_at is None: 64 | time.sleep(60) 65 | training.reload() 66 | assert training.status == "succeeded" 67 | 68 | 69 | @pytest.fixture(scope="module") 70 | def trained_model_and_version(training): 71 | trained_model, trained_version = training.output["version"].split(":") 72 | return trained_model, trained_version 73 | 74 | 75 | def test_post_training_predictions(trained_model_and_version, prediction_tests): 76 | trained_model, trained_version = trained_model_and_version 77 | model = replicate.models.get(trained_model) 78 | version = model.versions.get(trained_version) 79 | predictions = [ 80 | replicate.predictions.create(version=version, input=val) 81 | for val in prediction_tests 82 | ] 83 | 84 | for ind, val in enumerate(predictions): 85 | val.wait() 86 | assert val.status == "succeeded" 87 | out = "".join(val.output) 88 | print("Output: ", out) 89 | if ind == 1: 90 | assert "Summary" in out 91 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import re 4 | 5 | from tests.test_utils import run_training_subprocess 6 | 7 | ERROR_PATTERN = re.compile(r"ERROR:|Exception", re.IGNORECASE) 8 | 9 | # Constants 10 | SERVER_URL = "http://localhost:5000/predictions" 11 | HEALTH_CHECK_URL = "http://localhost:5000/health-check" 12 | 13 | IMAGE_NAME = "your_image_name" # replace with your image name 14 | HOST_NAME = "your_host_name" # replace with your host name 15 | 16 | 17 | # def run_training_subprocess(command): 18 | # # Start the subprocess with pipes for stdout and stderr 19 | # process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 20 | 21 | # # Create a lock for printing to avoid potential race conditions between the two print processes 22 | # print_lock = multiprocessing.Lock() 23 | 24 | # # Start two separate processes to handle stdout and stderr 25 | # stdout_processor = multiprocessing.Process(target=capture_output, args=(process.stdout, print_lock)) 26 | # stderr_processor = multiprocessing.Process(target=capture_output, args=(process.stderr, print_lock)) 27 | 28 | # # Start the log processors 29 | # stdout_processor.start() 30 | # stderr_processor.start() 31 | 32 | # # Wait for the subprocess to finish 33 | # return_code = process.wait() 34 | 35 | # # Wait for the log processors to finish 36 | # stdout_processor.join() 37 | # stderr_processor.join() 38 | 39 | # return return_code 40 | 41 | 42 | def test_train(): 43 | command = [ 44 | "cog", 45 | "train", 46 | "-i", 47 | "train_data=https://storage.googleapis.com/dan-scratch-public/fine-tuning/1k_samples_prompt.jsonl", 48 | "-i", 49 | "train_batch_size=4", 50 | "-i", 51 | "max_steps=5", 52 | "-i", 53 | "gradient_accumulation_steps=2", 54 | ] 55 | 56 | # result = subprocess.run(command, capture_output=False, text=True)#, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 57 | try: 58 | logs = run_training_subprocess(command) 59 | except Exception as e: 60 | pytest.fail(f"Error detected in training logs! Exception: {str(e)}") 61 | 62 | # Additional assertions can be added here, e.g.: 63 | assert not any( 64 | ERROR_PATTERN.search(log) for log in logs 65 | ), "Error pattern detected in logs!" 66 | 67 | # Check the return code 68 | # assert exit_code == 0, "Subprocess failed with return code {}".format(exit_code) 69 | 70 | # # Check if the log indicates successful completion for all processes 71 | # success_logs = result.stdout.count("exits successfully.") 72 | # # Assuming 4 processes should exit successfully based on the logs provided 73 | # assert success_logs == 4, "Not all processes exited successfully. Expected 4 but got {}".format(success_logs) 74 | 75 | # # Optionally, you can also check for other indicators 76 | # assert "Written output to weights" in result.stdout, "Output weights were not successfully written." 77 | 78 | assert os.path.exists("training_output.zip") 79 | # print_lock = Lock() 80 | 81 | # stdout_thread = Thread(target=capture_output, args=(process.stdout, print_lock)) 82 | # stdout_thread.start() 83 | 84 | # stderr_thread = Thread(target=capture_output, args=(process.stderr, print_lock)) 85 | # stderr_thread.start() 86 | 87 | # process.terminate() 88 | # process.wait() 89 | -------------------------------------------------------------------------------- /tests/test_train_predict.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | import subprocess 4 | import os 5 | from threading import Thread, Lock 6 | 7 | from tests.test_utils import ( 8 | get_image_name, 9 | capture_output, 10 | wait_for_server_to_be_ready, 11 | ) 12 | 13 | # Constants 14 | SERVER_URL = "http://localhost:5000/predictions" 15 | HEALTH_CHECK_URL = "http://localhost:5000/health-check" 16 | 17 | IMAGE_NAME = "your_image_name" # replace with your image name 18 | HOST_NAME = "your_host_name" # replace with your host name 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def server(): 23 | image_name = get_image_name() 24 | current_directory = os.getcwd() 25 | volume_mount = f"{current_directory}/training_output.zip:/src/local_weights.zip" 26 | 27 | command = [ 28 | "docker", 29 | "run", 30 | "-p", 31 | "5000:5000", 32 | "--gpus=all", 33 | "-e", 34 | f"COG_WEIGHTS=http://{HOST_NAME}:8000/training_output.zip", 35 | "-v", 36 | volume_mount, 37 | image_name, 38 | ] 39 | print("\n**********************STARTING SERVER**********************") 40 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 41 | 42 | print_lock = Lock() 43 | 44 | stdout_thread = Thread(target=capture_output, args=(process.stdout, print_lock)) 45 | stdout_thread.start() 46 | 47 | stderr_thread = Thread(target=capture_output, args=(process.stderr, print_lock)) 48 | stderr_thread.start() 49 | 50 | wait_for_server_to_be_ready(HEALTH_CHECK_URL) 51 | 52 | yield process 53 | 54 | process.terminate() 55 | process.wait() 56 | 57 | 58 | def test_health_check(server): 59 | response = requests.get(HEALTH_CHECK_URL) 60 | assert ( 61 | response.status_code == 200 62 | ), f"Unexpected status code: {response.status_code}" 63 | 64 | 65 | def test_prediction(server): 66 | data = { 67 | "input": { 68 | "prompt": "It was a dark and stormy night.", 69 | "max_new_tokens": 25, 70 | # Add other parameters here 71 | } 72 | } 73 | response = requests.post(SERVER_URL, json=data) 74 | assert ( 75 | response.status_code == 200 76 | ), f"Unexpected status code: {response.status_code}" 77 | print("\n**********************RESPONSE**********************") 78 | print("".join(response.json()["output"])) 79 | print("******************************************************\n") 80 | # Add other assertions based on expected response 81 | 82 | 83 | def test_input_too_long(server): 84 | # This is a placeholder. You need to provide an input that is expected to be too long. 85 | data = { 86 | "input": { 87 | "prompt": " a" 88 | * 6000, # Assuming this string will produce more than 4096 tokens. 89 | "max_new_tokens": 25, 90 | # Add other parameters here 91 | } 92 | } 93 | 94 | response = requests.post(SERVER_URL, json=data) 95 | 96 | response_data = response.json() 97 | assert "error" in response_data, "Expected an 'error' field in the response" 98 | 99 | error_msg_prefix = "Your input is too long. Max input length is" 100 | assert response_data["error"].startswith( 101 | error_msg_prefix 102 | ), f"Expected the error message to start with '{error_msg_prefix}'" 103 | assert response_data["status"] == "failed", "Expected the status to be 'failed'" 104 | 105 | print("\n**********************RESPONSE**********************") 106 | print(response.text) 107 | print("******************************************************\n") 108 | 109 | 110 | if __name__ == "__main__": 111 | pytest.main() 112 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import time 5 | import re 6 | import multiprocessing 7 | import subprocess 8 | 9 | ERROR_PATTERN = re.compile(r"ERROR:") 10 | 11 | 12 | def get_image_name(): 13 | current_dir = os.path.basename(os.getcwd()) 14 | 15 | if "cog" in current_dir: 16 | return current_dir 17 | else: 18 | return f"cog-{current_dir}" 19 | 20 | 21 | def process_log_line(line): 22 | line = line.decode("utf-8").strip() 23 | try: 24 | log_data = json.loads(line) 25 | return json.dumps(log_data, indent=2) 26 | except json.JSONDecodeError: 27 | return line 28 | 29 | 30 | # def capture_output(pipe, print_lock): 31 | # for line in iter(pipe.readline, b''): 32 | # formatted_line = process_log_line(line) 33 | # with print_lock: 34 | # print(formatted_line) 35 | 36 | 37 | def capture_output(pipe, print_lock, logs=None, error_detected=None): 38 | for line in iter(pipe.readline, b""): 39 | formatted_line = process_log_line(line) 40 | with print_lock: 41 | print(formatted_line) 42 | if logs is not None: 43 | logs.append(formatted_line) 44 | if error_detected is not None: 45 | if ERROR_PATTERN.search(formatted_line): 46 | error_detected[0] = True 47 | 48 | 49 | def wait_for_server_to_be_ready(url, timeout=300): 50 | """ 51 | Waits for the server to be ready. 52 | 53 | Args: 54 | - url: The health check URL to poll. 55 | - timeout: Maximum time (in seconds) to wait for the server to be ready. 56 | """ 57 | start_time = time.time() 58 | while True: 59 | try: 60 | response = requests.get(url) 61 | data = response.json() 62 | 63 | if data["status"] == "READY": 64 | return 65 | elif data["status"] == "SETUP_FAILED": 66 | raise RuntimeError( 67 | "Server initialization failed with status: SETUP_FAILED" 68 | ) 69 | 70 | except requests.RequestException: 71 | pass 72 | 73 | if time.time() - start_time > timeout: 74 | raise TimeoutError("Server did not become ready in the expected time.") 75 | 76 | time.sleep(5) # Poll every 5 seconds 77 | 78 | 79 | def run_training_subprocess(command): 80 | # Start the subprocess with pipes for stdout and stderr 81 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 82 | 83 | # Create a lock for printing and a list to accumulate logs 84 | print_lock = multiprocessing.Lock() 85 | logs = multiprocessing.Manager().list() 86 | error_detected = multiprocessing.Manager().list([False]) 87 | 88 | # Start two separate processes to handle stdout and stderr 89 | stdout_processor = multiprocessing.Process( 90 | target=capture_output, args=(process.stdout, print_lock, logs, error_detected) 91 | ) 92 | stderr_processor = multiprocessing.Process( 93 | target=capture_output, args=(process.stderr, print_lock, logs, error_detected) 94 | ) 95 | 96 | # Start the log processors 97 | stdout_processor.start() 98 | stderr_processor.start() 99 | 100 | # Wait for the subprocess to finish 101 | process.wait() 102 | 103 | # Wait for the log processors to finish 104 | stdout_processor.join() 105 | stderr_processor.join() 106 | 107 | # Check if an error pattern was detected 108 | if error_detected[0]: 109 | raise Exception("Error detected in training logs! Check logs for details") 110 | 111 | return list(logs) 112 | -------------------------------------------------------------------------------- /tests/timing.py: -------------------------------------------------------------------------------- 1 | import time 2 | import replicate 3 | import os 4 | 5 | base = "replicate-internal/staging-llama-2-7b:8ba7b9478e1cbdde020f79f0838cd94465dfc6fc0207e01d2e59c00422f65148" 6 | 7 | v1 = "a42037aa39fc7cdc9138d61a0a94172107906ed8be7c8b0568cc5766d633f0fe" 8 | v2 = "ca0a7d930eed4f330d7f187a18052842f35087fc15b93b741a554753591cb366" 9 | 10 | model = replicate.models.get("technillogue/llama2-summarizer") 11 | ver1 = model.versions.get(v1) 12 | ver2 = model.versions.get(v2) 13 | 14 | os.system("kubectl delete pod -l replicate/version_short_id=8ba7b947") 15 | 16 | 17 | def run(v): 18 | t0 = time.time() 19 | # gen = replicate.run(v1, input={"prompt": "a"}) 20 | global last 21 | last = pred = replicate.predictions.create(v, input={"prompt": "a"}) 22 | t1 = time.time() 23 | print(f"got result after {t1 - t0:.4f}") 24 | gen = pred.output_iterator() 25 | next(gen) 26 | t2 = time.time() 27 | print(f"got first token {t2 - t1:.4f}") 28 | try: 29 | print(re.search("previous weights were (.*)\n", pred.logs).group().strip()) 30 | except: 31 | pass 32 | try: 33 | print(re.search("Downloaded peft weights in (\d+.\d+)", pred.logs).group()) 34 | except: 35 | pass 36 | try: 37 | print(re.search("initialize_peft took (\d+.\d+)", pred.logs).group()) 38 | except: 39 | pass 40 | print(f"prediciton created to first token: {t2 - t0:.4f}") 41 | pred.wait() 42 | t3 = time.time() 43 | print(re.search("hostname: (.*)\n", pred.logs).group().strip()) 44 | print(f"prediction took {t3 - t2:.4f} from first to last token") 45 | 46 | 47 | run(ver1) 48 | run(ver2) 49 | -------------------------------------------------------------------------------- /tests/unit_tests/test_completion_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import sys 4 | 5 | sys.path.append(".") 6 | 7 | from llama_recipes.ft_datasets.completion_dataset import ( 8 | load_data, 9 | format_data, 10 | tokenize_data, 11 | ) 12 | 13 | from dataclasses import dataclass 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def dataset_config(): 18 | @dataclass 19 | class completion: 20 | dataset: str = "completion" 21 | train_split: str = "train" 22 | test_split: str = "val" 23 | data_path: str = "tests/data/200_samples.jsonl" 24 | num_validation_samples: int = 100 25 | run_validation: bool = True 26 | validation_data_path: str = None 27 | pack_sequences: bool = True 28 | wrap_packed_sequences: bool = True 29 | chunk_size: int = 100 30 | 31 | return completion 32 | 33 | 34 | @pytest.fixture(scope="session") 35 | def tokenizer(): 36 | from transformers import LlamaTokenizer 37 | 38 | tokenizer = LlamaTokenizer.from_pretrained( 39 | "tests/assets/llama_tokenizer", legacy=False 40 | ) 41 | tokenizer.add_special_tokens( 42 | { 43 | "pad_token": "", 44 | "eos_token": "", 45 | "bos_token": "", 46 | } 47 | ) 48 | return tokenizer 49 | 50 | 51 | def test__load_data_train(dataset_config): 52 | dataset_config.run_validation = False 53 | dataset = load_data(dataset_config, split="train") 54 | assert len(dataset) == 200 55 | for example in dataset: 56 | assert example["text"].startswith("Write a response to the following message") 57 | 58 | 59 | def test__load_data_train_with_val_split(dataset_config): 60 | dataset_config.run_validation = True 61 | train_dataset = load_data(dataset_config, split="train") 62 | 63 | train_texts = [example["text"] for example in train_dataset] 64 | 65 | val_dataset = load_data(dataset_config, split="val") 66 | assert len(val_dataset) == 100 67 | for example in val_dataset: 68 | assert example["text"].startswith("Write a response to the following message") 69 | assert example["text"] not in train_texts 70 | 71 | 72 | @pytest.fixture(scope="session") 73 | def dataset(dataset_config): 74 | dataset_config.run_validation = False 75 | dataset = load_data(dataset_config, split="train") 76 | return dataset 77 | 78 | 79 | def test_format_data(dataset, tokenizer): 80 | formatted_data = format_data(dataset, tokenizer, dataset_config) 81 | for example in formatted_data: 82 | assert example["text"].startswith("Write a response to the following message") 83 | assert example["text"].endswith(tokenizer.eos_token) 84 | 85 | 86 | @pytest.fixture(scope="session") 87 | def formatted_dataset(dataset, tokenizer): 88 | return format_data(dataset, tokenizer, dataset_config) 89 | 90 | 91 | def test_tokenize_data_with_wrapped_packing( 92 | formatted_dataset, tokenizer, dataset_config 93 | ): 94 | dataset_config.pack_sequences = True 95 | dataset_config.wrap_packed_sequences = True 96 | 97 | tokenized_data = tokenize_data(formatted_dataset, tokenizer, dataset_config) 98 | 99 | for tokenized_example in tokenized_data: 100 | assert "labels" in tokenized_example 101 | 102 | decoded_data = tokenizer.batch_decode( 103 | tokenized_data["input_ids"], skip_special_tokens=False 104 | ) 105 | 106 | decoded_data = tokenizer.batch_decode( 107 | tokenized_data["input_ids"], skip_special_tokens=True 108 | ) 109 | 110 | at_least_one_wrapped = False 111 | for example in decoded_data: 112 | if not example.startswith("Write a response to the following message"): 113 | at_least_one_wrapped = True 114 | 115 | assert at_least_one_wrapped 116 | 117 | for tokenized_example in tokenized_data["input_ids"]: 118 | assert len(tokenized_example) == dataset_config.chunk_size 119 | 120 | 121 | def test_tokenize_data_without_wrapped_packing_small_chunk( 122 | formatted_dataset, tokenizer, dataset_config 123 | ): 124 | dataset_config.pack_sequences = True 125 | dataset_config.wrap_packed_sequences = False 126 | dataset_config.chunk_size: int = 100 127 | 128 | tokenized_data = tokenize_data(formatted_dataset, tokenizer, dataset_config) 129 | 130 | for tokenized_example in tokenized_data: 131 | assert tokenized_example["input_ids"][-1] == tokenizer.eos_token_id 132 | assert "labels" in tokenized_example 133 | 134 | decoded_data = tokenizer.batch_decode( 135 | tokenized_data["input_ids"], skip_special_tokens=False 136 | ) 137 | 138 | for example in decoded_data: 139 | prefix = " ".join( 140 | [tokenizer.bos_token, "Write a response to the following message"] 141 | ) 142 | assert example.startswith(prefix) 143 | 144 | recovered_data = [] 145 | for decoded_sequence in decoded_data: 146 | for decoded_example in decoded_sequence.split(tokenizer.eos_token)[:-1]: 147 | decoded_example = decoded_example.removeprefix(tokenizer.bos_token + " ") 148 | decoded_example += tokenizer.eos_token 149 | recovered_data.append(decoded_example) 150 | 151 | for i in range(len(recovered_data)): 152 | assert recovered_data[i] == formatted_dataset[i]["text"] 153 | 154 | 155 | def test_tokenize_data_without_wrapped_packing_large_chunk( 156 | formatted_dataset, tokenizer, dataset_config 157 | ): 158 | dataset_config.pack_sequences = True 159 | dataset_config.wrap_packed_sequences = False 160 | dataset_config.chunk_size: int = 2048 161 | 162 | tokenized_data = tokenize_data(formatted_dataset, tokenizer, dataset_config) 163 | 164 | for tokenized_example in tokenized_data: 165 | assert tokenized_example["input_ids"][-1] == tokenizer.eos_token_id 166 | assert "labels" in tokenized_example 167 | 168 | decoded_data = tokenizer.batch_decode( 169 | tokenized_data["input_ids"], skip_special_tokens=False 170 | ) 171 | 172 | for example in decoded_data: 173 | prefix = " ".join( 174 | [tokenizer.bos_token, "Write a response to the following message"] 175 | ) 176 | assert example.startswith(prefix) 177 | 178 | recovered_data = [] 179 | for decoded_sequence in decoded_data: 180 | for decoded_example in decoded_sequence.split(tokenizer.eos_token)[:-1]: 181 | decoded_example = decoded_example.removeprefix(tokenizer.bos_token + " ") 182 | decoded_example += tokenizer.eos_token 183 | recovered_data.append(decoded_example) 184 | 185 | for i in range(len(recovered_data)): 186 | assert recovered_data[i] == formatted_dataset[i]["text"] 187 | 188 | 189 | def test_tokenize_data_without_packing(formatted_dataset, tokenizer, dataset_config): 190 | dataset_config.pack_sequences = False 191 | tokenized_data = tokenize_data(formatted_dataset, tokenizer, dataset_config) 192 | 193 | for tokenized_example in tokenized_data["input_ids"]: 194 | assert tokenized_example[-1] == tokenizer.eos_token_id 195 | 196 | decoded_data = tokenizer.batch_decode( 197 | tokenized_data["input_ids"], skip_special_tokens=True 198 | ) 199 | for i, example in enumerate(decoded_data): 200 | assert example.startswith("Write a response to the following message") 201 | assert example + tokenizer.eos_token == formatted_dataset[i]["text"] 202 | --------------------------------------------------------------------------------