├── .gitignore ├── README.md ├── config.ini ├── config_experts.ini ├── experts.json ├── experts ├── README.md ├── code.npy ├── creative.npy ├── function.npy ├── general.npy ├── qa.npy └── reasoning.npy ├── herd ├── __init__.py ├── embeddings.py ├── finetune.py ├── models.py ├── multilora.py ├── router.py ├── run_model.py └── segment_experts.py ├── jobscript.sh ├── jobscript_run.sh ├── main.py ├── requirements.txt └── scripts ├── app.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .vscode/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Herd 2 | A group of Llamas. 3 | 4 | Leverage Mixture of Expert (MoE) models for large language models (LLMs) and enhance their performance with advanced PEFT methods. 5 | 6 | ## Overview 7 | This project is part of an MSc thesis focused on exploring MoE models for LLMs. By leveraging PEFT methods like LoRA and QLoRA, it seeks to offer an effective solution to use a base model in combination with an extensive set of adapters and a method to identify the appropriate adapter. 8 | 9 | The basic idea behind this is to come up with a solution that allows for using a base model + a large set of adapters + a method to find the right adapter. 10 | 11 | ## Key Features 12 | - **Prompt-Expert Mapping**: Determines the right expert based on the input prompt's distance from expert centroids. 13 | - **Combination of Adapters**: Allows merging multiple adapters based on input prompt proximity to each expert's centroid. 14 | 15 | 16 | ## Inspiration & Credits 17 | The first iteration of this project is heavily based on [airoboros/lmoe](https://github.com/jondurbin/airoboros/tree/main/airoboros/lmoe). Additionally, 18 | I have also included the option to combine multiple adapters according to the distance between the input prompt and the centroid for each expert (and it seems like [@aicrum had a similar idea](https://twitter.com/aicrumb/status/1681846805959528448)). 19 | 20 | ## Experts & Segmentation 21 | The experts are fine-tuned using QLoRA on the [jondurbin/airoboros-2.1 dataset](https://huggingface.co/datasets/jondurbin/airoboros-2.1/viewer/default/train) using the same segmentation as in the original project. 22 | 23 | Expert Name | Categories | 24 | -------------|------------| 25 | qa | quiz, multiple_choice, contextual, counterfactual_contextual | 26 | creative | card, writing, experience, song, roleplay, gtkm, rp, detailed_writing, joke | 27 | code | coding | 28 | reasoning | cot, theory_of_mind, riddle, orca | 29 | function | agent, plan | 30 | general | wordgame, trivia, general | 31 | 32 | 33 | ### Fine-tuning Experts 34 | To fine-tune: 35 | 36 | ```sh 37 | python scripts/main.py finetune 38 | ``` 39 | 40 | ### Computing Expert-Prompt Distance 41 | To compute the distance between the input prompt and each expert we: 42 | 43 | 1. Sample a few instructions from each expert and compute the average embedding for each expert. 44 | 2. Save the average embedding for each expert as a numpy array in `experts/` 45 | 46 | When a new prompt is received, we: 47 | 1. Load the average embedding for each expert. 48 | 2. Use [faiss](https://github.com/facebookresearch/faiss) to compute the distance between the input prompt and each expert. 49 | 50 | ## API Documentation 51 | Herd provides a simple REST API. It is based on OpenAI's API. 52 | 53 | 54 | ```py 55 | python scripts/app.py 56 | ``` 57 | 58 | The following options can be used: 59 | - `--port (-p)`: The port to run the server on. Default: `8000` 60 | - `--host (-i)`: The host to run the server on. Default: `127.0.0.1` 61 | - `--config-file`: The config file to use. Default: `config_experts.ini` 62 | - `--only-base`: Only use the base model. Default: `False` 63 | 64 | 65 | ### Querying the Model 66 | 67 | To query the model we can run: 68 | ```sh 69 | curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{ 70 | "model": "herd", 71 | "messages": [ 72 | { 73 | "role": "system", 74 | "content": "A chat." 75 | }, 76 | { 77 | "role": "user", 78 | "content": "Lorem ipsum dolor sit amet" 79 | } 80 | ] 81 | }' 82 | ``` 83 | 84 | The following options can be passed: 85 | - `model` (str): The name of the model. 86 | - `messages` (List[Dict[str, str]]): The list of messages in the chat. 87 | - `temperature` (float, optional): The temperature for generating responses. Defaults to 0.5. 88 | - `top_k` (int, optional): The number of top-k tokens to consider. Defaults to 50. 89 | - `top_p` (float, optional): The cumulative probability for generating responses. Defaults to 1.0. 90 | - `repetition_penalty` (float, optional): The repetition penalty for generating responses. Defaults to 1.0. 91 | - `stop` (List[str], optional): The list of stop words. Defaults to DEFAULT_STOPS. 92 | - `max_tokens` (int, optional): The maximum number of tokens in the response. Defaults to None. 93 | - `top_experts` (int, optional): The number of top experts to consider. Defaults to 1. 94 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [LoraConfig] 2 | lora_alpha=16 3 | lora_dropout=0.1 4 | r=64 5 | bias=none 6 | task_type=CAUSAL_LM 7 | 8 | [TrainingArguments] 9 | num_train_epochs=3 10 | per_device_train_batch_size=4 11 | gradient_accumulation_steps=2 12 | gradient_checkpointing=True 13 | optim=paged_adamw_32bit 14 | logging_steps=10 15 | save_strategy=epoch 16 | learning_rate=2e-4 17 | bf16=True 18 | tf32=True 19 | max_grad_norm=0.3 20 | warmup_ratio=0.03 21 | lr_scheduler_type=constant 22 | output_dir=${Paths:output_dir} 23 | 24 | [Models] 25 | model=meta-llama/Llama-2-7b-hf 26 | dataset=databricks/databricks-dolly-15k 27 | embeddings_model=thenlper/gte-small 28 | embeddings_max_length=512 29 | ; # Max tokens for our embedding model. This code is really designed for the gte-* 30 | ; series, e.g.: https://huggingface.co/thenlper/gte-small 31 | ; but could in theory be generated to work with other models I suspect. 32 | 33 | [Paths] 34 | base_dir=/work3/s212722/herd 35 | dataset_dir=${base_dir}/datasets 36 | cache_dir=${base_dir}/cache 37 | output_dir=${base_dir}/${Models:model} 38 | experts_dir=experts/ 39 | experts_file=experts.json 40 | -------------------------------------------------------------------------------- /config_experts.ini: -------------------------------------------------------------------------------- 1 | [LoraConfig] 2 | lora_alpha=16 3 | lora_dropout=0.1 4 | r=64 5 | bias=none 6 | task_type=CAUSAL_LM 7 | 8 | [TrainingArguments] 9 | num_train_epochs=3 10 | per_device_train_batch_size=4 11 | gradient_accumulation_steps=2 12 | gradient_checkpointing=True 13 | optim=paged_adamw_32bit 14 | logging_steps=10 15 | save_strategy=epoch 16 | learning_rate=2e-4 17 | bf16=True 18 | tf32=True 19 | max_grad_norm=0.3 20 | warmup_ratio=0.03 21 | lr_scheduler_type=constant 22 | output_dir=${Paths:output_dir} 23 | report_to=wandb 24 | 25 | [Models] 26 | model=meta-llama/Llama-2-7b-hf 27 | dataset=jondurbin/airoboros-2.1 28 | embeddings_model=thenlper/gte-small 29 | embeddings_max_length=512 30 | ; # Max tokens for our embedding model. This code is really designed for the gte-* 31 | ; series, e.g.: https://huggingface.co/thenlper/gte-small 32 | ; but could in theory be generated to work with other models I suspect. 33 | 34 | [Paths] 35 | base_dir=/work3/s212722/herd 36 | dataset_dir=${base_dir}/datasets 37 | cache_dir=${base_dir}/cache 38 | output_dir=${base_dir}/${Models:model} 39 | experts_dir=experts/ 40 | experts_file=experts.json 41 | -------------------------------------------------------------------------------- /experts.json: -------------------------------------------------------------------------------- 1 | { 2 | "qa": { 3 | "categories": [ 4 | "quiz", 5 | "multiple_choice", 6 | "contextual", 7 | "counterfactual_contextual" 8 | ] 9 | }, 10 | "creative": { 11 | "categories": [ 12 | "card", 13 | "writing", 14 | "experience", 15 | "song", 16 | "roleplay", 17 | "gtkm", 18 | "rp", 19 | "detailed_writing", 20 | "joke" 21 | ] 22 | }, 23 | "code": { 24 | "categories": [ 25 | "coding" 26 | ] 27 | }, 28 | "reasoning": { 29 | "categories": [ 30 | "cot", 31 | "theory_of_mind", 32 | "riddle", 33 | "orca" 34 | ] 35 | }, 36 | "function": { 37 | "categories": [ 38 | "agent", 39 | "plan" 40 | ] 41 | }, 42 | "general": { 43 | "categories": [ 44 | "wordgame", 45 | "trivia", 46 | "general" 47 | ] 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /experts/README.md: -------------------------------------------------------------------------------- 1 | # Experts 2 | This folder contins `npy` files named following the format `{expert_name}.npy` (ie `coding.npy`). 3 | 4 | These files contain the embedding vector that represents an expert. -------------------------------------------------------------------------------- /experts/code.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/experts/code.npy -------------------------------------------------------------------------------- /experts/creative.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/experts/creative.npy -------------------------------------------------------------------------------- /experts/function.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/experts/function.npy -------------------------------------------------------------------------------- /experts/general.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/experts/general.npy -------------------------------------------------------------------------------- /experts/qa.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/experts/qa.npy -------------------------------------------------------------------------------- /experts/reasoning.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/experts/reasoning.npy -------------------------------------------------------------------------------- /herd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrs/herd/a199c8fee8e26a70e1eceb4794b73f561ff77955/herd/__init__.py -------------------------------------------------------------------------------- /herd/embeddings.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/jondurbin/airoboros/blob/4cf457eaf541d6025a165f27e8596b6a1980bdab/airoboros/embeddings.py 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from sentence_transformers import SentenceTransformer 7 | from transformers import AutoTokenizer 8 | 9 | 10 | class Embeddings: 11 | def __init__(self, model: SentenceTransformer, tokenizer: AutoTokenizer, max_length: int): 12 | self.max_length = max_length 13 | self.model = model 14 | self.tokenizer = tokenizer 15 | 16 | def calculate_fragment_embeddings(self, fragment: str) -> List[float]: 17 | """Calculate vector embeddings for a single input fragment, which is smaller than the 18 | max model length. 19 | """ 20 | with torch.no_grad(): 21 | return self.model.encode(fragment, normalize_embeddings=True) 22 | 23 | def calculate_embeddings(self, input_text: str) -> List[float]: 24 | """Calculate the vector embeddings for the specified input text. 25 | 26 | 1. split the text based on the model's max sequence length 27 | 2. calculate the embeddings for each chunk 28 | 3. calculate the average embedding across all chunks 29 | """ 30 | 31 | # Tokenize the input, and convert tokens into chunks based on max model size. 32 | inputs = self.tokenizer(input_text, padding=False, truncation=False, return_tensors="pt") 33 | chunks = [ 34 | torch.Tensor(inputs["input_ids"][0][i : i + self.max_length].tolist()).int() 35 | for i in range(0, len(inputs["input_ids"][0]), self.max_length) 36 | ] 37 | fragments = [self.tokenizer.decode(chunk) for chunk in chunks] 38 | 39 | # Now, calculate embeddings for each fragment. 40 | all_embeddings = [] 41 | lengths = [] 42 | for fragment in fragments: 43 | lengths.append(len(fragment)) 44 | all_embeddings.append(self.calculate_fragment_embeddings(fragment)) 45 | 46 | # Finally, calculate the average across all fragments. 47 | embeddings = np.average(all_embeddings, axis=0, weights=lengths) 48 | return embeddings / np.linalg.norm(embeddings) 49 | 50 | 51 | # For local testing 52 | if __name__ == "__main__": 53 | model = SentenceTransformer("thenlper/gte-small", device="cuda") 54 | tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-small") 55 | max_length = 512 56 | 57 | embeddings = Embeddings(model, tokenizer, max_length) 58 | e = embeddings.calculate_embeddings("Hello world!") 59 | print(e) 60 | -------------------------------------------------------------------------------- /herd/finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | from configparser import ConfigParser 3 | from dataclasses import asdict, dataclass 4 | from typing import Dict 5 | 6 | import datasets 7 | import torch 8 | from loguru import logger 9 | from herd.models import ModelValues, PathValues 10 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 11 | from transformers import ( 12 | AutoModelForCausalLM, 13 | AutoTokenizer, 14 | BitsAndBytesConfig, 15 | TrainingArguments, 16 | ) 17 | from trl import SFTTrainer 18 | 19 | 20 | def format_instruction(sample: dict) -> str: 21 | return f"""{sample['system']} 22 | 23 | ### Input: 24 | {sample['instruction']} 25 | 26 | ### Response: 27 | {sample['response']} 28 | """ 29 | 30 | 31 | def finetune( 32 | model_values: ModelValues, 33 | path_values: PathValues, 34 | config: ConfigParser, 35 | experts: Dict, 36 | ) -> None: 37 | dataset = datasets.load_dataset(model_values.dataset, split="train") 38 | 39 | logger.info( 40 | f"model_id: {model_values.model}, base_dir: {path_values.base_dir}, dataset_dir: {path_values.dataset_dir}, output_dir: {path_values.output_dir}, dataset: {model_values.dataset}", 41 | ) 42 | 43 | # BitsAndBytesConfig int-4 config 44 | bnb_config = BitsAndBytesConfig( 45 | load_in_4bit=True, 46 | bnb_4bit_use_double_quant=True, 47 | bnb_4bit_quant_type="nf4", 48 | bnb_4bit_compute_dtype=torch.bfloat16, 49 | ) 50 | 51 | # Load model and tokenizer 52 | model = AutoModelForCausalLM.from_pretrained( 53 | model_values.model, 54 | quantization_config=bnb_config, 55 | use_cache=False, 56 | device_map="auto", 57 | cache_dir=path_values.cache_dir, 58 | ) 59 | model.config.pretraining_tp = 1 60 | 61 | tokenizer = AutoTokenizer.from_pretrained(model_values.model, cache_dir=path_values.cache_dir) 62 | tokenizer.pad_token = tokenizer.eos_token 63 | tokenizer.padding_side = "right" 64 | 65 | # LoRA config based on QLoRA paper 66 | peft_config = LoraConfig(**asdict(LoraConfigValues(**dict(config.items("LoraConfig"))))) 67 | 68 | # prepare model for training 69 | model = prepare_model_for_kbit_training(model) 70 | model = get_peft_model(model, peft_config) 71 | 72 | for expert_name, expert_data in experts.items(): 73 | os.environ["WANDB_NAME"] = expert_name 74 | 75 | expert_dataset = dataset.filter(lambda row: row["category"] in expert_data["categories"]) 76 | 77 | training_args = TrainingArguments( 78 | **asdict(TrainingArgumentsValues(**dict(config.items("TrainingArguments")))) 79 | ) 80 | 81 | # set output dir to contain expert name 82 | training_args.output_dir = os.path.join(training_args.output_dir, expert_name) 83 | 84 | trainer = SFTTrainer( 85 | model=model, 86 | train_dataset=expert_dataset, 87 | peft_config=peft_config, 88 | max_seq_length=2048, 89 | tokenizer=tokenizer, 90 | packing=True, 91 | formatting_func=format_instruction, 92 | args=training_args, 93 | ) 94 | 95 | logger.info(f"Training expert: {expert_name}, output_dir: {training_args.output_dir}") 96 | 97 | # train 98 | trainer.train() 99 | 100 | # save model 101 | trainer.save_model(training_args.output_dir) 102 | 103 | 104 | @dataclass 105 | class LoraConfigValues: 106 | lora_alpha: int 107 | lora_dropout: float 108 | r: int 109 | bias: str 110 | task_type: str 111 | 112 | def __post_init__(self): 113 | self.lora_alpha = int(self.lora_alpha) 114 | self.lora_dropout = float(self.lora_dropout) 115 | self.r = int(self.r) 116 | 117 | 118 | @dataclass 119 | class TrainingArgumentsValues: 120 | num_train_epochs: int 121 | per_device_train_batch_size: int 122 | gradient_accumulation_steps: int 123 | gradient_checkpointing: bool 124 | optim: str 125 | logging_steps: int 126 | save_strategy: str 127 | learning_rate: float 128 | bf16: bool 129 | tf32: bool 130 | max_grad_norm: float 131 | warmup_ratio: float 132 | lr_scheduler_type: str 133 | output_dir: str 134 | report_to: str 135 | 136 | def __post_init__(self): 137 | self.num_train_epochs = int(self.num_train_epochs) 138 | self.per_device_train_batch_size = int(self.per_device_train_batch_size) 139 | self.gradient_accumulation_steps = int(self.gradient_accumulation_steps) 140 | self.logging_steps = int(self.logging_steps) 141 | self.bf16 = bool(self.bf16) 142 | self.tf32 = bool(self.tf32) 143 | self.learning_rate = float(self.learning_rate) 144 | self.max_grad_norm = float(self.max_grad_norm) 145 | self.warmup_ratio = float(self.warmup_ratio) 146 | -------------------------------------------------------------------------------- /herd/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class ModelValues: 6 | model: str 7 | dataset: str 8 | embeddings_model: str 9 | embeddings_max_length: int 10 | 11 | def __post_init__(self): 12 | self.embeddings_max_length = int(self.embeddings_max_length) 13 | 14 | 15 | @dataclass 16 | class PathValues: 17 | base_dir: str 18 | dataset_dir: str 19 | cache_dir: str 20 | output_dir: str 21 | experts_dir: str 22 | experts_file: str 23 | -------------------------------------------------------------------------------- /herd/multilora.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from loguru import logger 7 | from peft import LoraModel, PeftConfig 8 | 9 | 10 | class MultiloraModel(LoraModel): 11 | """ 12 | Based on https://github.com/huggingface/peft/tree/93d0c03d5ba6b2a6b16b7ca887e740a67bc680f3/src/peft/tuners/lora 13 | 14 | Creates a Multi LoRA model from a pretrained transformer model. 15 | This model uses `add_weighted_adapter` to add adapters to the model with some weights. 16 | 17 | Args: 18 | model ([`~transformers.PreTrainedModel`]): The model to be adapted. 19 | adapters (str): The base path to the adapters. 20 | adapter_names (list): The names of the adapters to be loaded. 21 | router (herd.Router): The router to be used for routing the adapters. 22 | 23 | Returns: 24 | `torch.nn.Module`: The Multilora model. 25 | """ 26 | 27 | def __init__(self, model, adapters, adapter_names, router): 28 | # Build PerftConfig from the first adapter 29 | config = PeftConfig.from_pretrained(os.path.join(adapters, adapter_names[0])) 30 | super().__init__(model, config, adapter_names[0]) 31 | 32 | # Load the other adapters 33 | for adapter_name in adapter_names[1:]: 34 | self.load_adapter(os.path.join(adapters, adapter_name), adapter_name) 35 | 36 | self.router = router 37 | 38 | def generate(self, prompt: str, top: int = 1, **kwargs): 39 | self.route_to_experts(prompt, top) 40 | return self.model.generate(**kwargs) 41 | 42 | def route_to_experts(self, instruction: str, top: int = 1): 43 | # Experts is a list of tuples (expert_name, score). 44 | experts = self.router.route(instruction, top) 45 | if top == 1: 46 | # If we only want the top expert, set it as an adapter 47 | self.model.set_adapter(experts[0][0]) 48 | else: 49 | # Otherwise, we compute a new adapter as a combination of the top experts. 50 | # We generate a unique name for the adapter because even if the same experts are used 51 | # the weights may be different. 52 | adapter_name = str(hash(datetime.datetime.now())) 53 | 54 | weights = np.array([expert[1] for expert in experts]) 55 | inverted_weights = 1 / weights 56 | # w = inverted_weights / np.sum(inverted_weights) 57 | # use softmax for the weights 58 | w = torch.nn.functional.softmax(torch.tensor(inverted_weights), dim=0).numpy() 59 | e = [expert[0] for expert in experts] 60 | logger.debug(f"Creating adapter for: {list(zip(e, w))}") 61 | 62 | # TODO: Experiment with other routing methods 63 | self.add_weighted_adapter( 64 | e, 65 | w, 66 | combination_type="linear", 67 | adapter_name=adapter_name, 68 | ) 69 | 70 | self.model.set_adapter(adapter_name) 71 | 72 | return experts 73 | -------------------------------------------------------------------------------- /herd/router.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, List, Tuple 3 | 4 | import faiss 5 | import numpy as np 6 | from loguru import logger 7 | 8 | from herd.embeddings import Embeddings 9 | 10 | 11 | class Router: 12 | def __init__( 13 | self, 14 | embeddings: Embeddings, 15 | experts, 16 | k: int = 50, 17 | ): 18 | """ 19 | Initializes the router 20 | """ 21 | self.experts = experts 22 | self.embeddings = embeddings 23 | self.k = k 24 | 25 | self.indices = {} 26 | for expert_name, _ in experts.items(): 27 | # TODO: Should we make the experts folder configurable? 28 | expert_path = os.path.join("experts", expert_name + ".npy") 29 | self.indices[expert_name] = self.create_index(expert_path) 30 | 31 | def route(self, prompt: str, top: int = 1) -> List[Tuple[str, float]]: 32 | """ 33 | Selects the best expert to answer to prompt 34 | """ 35 | if top > len(self.indices): 36 | logger.warning( 37 | f"Requested more experts than available. Setting top to {len(self.indices)}" 38 | ) 39 | top = len(self.indices) 40 | 41 | query_emb = self.embeddings.calculate_embeddings(prompt) 42 | query_emb = query_emb.reshape(1, query_emb.shape[0]) 43 | expert_distances = [] 44 | for expert, index in self.indices.items(): 45 | distances, _ = index.search(query_emb, k=min(index.ntotal, self.k)) 46 | distances = distances[0].tolist() 47 | average_distance = sum(distances) / len(distances) 48 | logger.debug(f"Average distance [{expert}]: {average_distance}") 49 | expert_distances.append((expert, average_distance)) 50 | sorted_experts = sorted(expert_distances, key=lambda x: x[1]) 51 | logger.success(f"Routing to {[expert[0] for expert in sorted_experts[:top]]}") 52 | return sorted_experts[:top] 53 | 54 | def create_index(self, input_path: str) -> Any: 55 | """Create a faiss index from the routing data for a given expert.""" 56 | logger.info(f"Creating routing faiss index: {input_path}") 57 | index = faiss.IndexFlatL2( 58 | self.embeddings.model.get_sentence_embedding_dimension() 59 | ) 60 | em = np.load(input_path) 61 | em = em.reshape(1, em.shape[0]) 62 | index.add(em) 63 | 64 | return index 65 | -------------------------------------------------------------------------------- /herd/run_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import randrange 3 | from typing import Dict 4 | 5 | import datasets 6 | import torch 7 | from datasets import load_dataset 8 | from herd.embeddings import Embeddings 9 | from peft import PeftModel 10 | from herd.models import ModelValues, PathValues 11 | from herd.router import Router 12 | from sentence_transformers import SentenceTransformer 13 | from transformers import ( 14 | AutoModelForCausalLM, 15 | AutoTokenizer, 16 | BitsAndBytesConfig, 17 | ) 18 | 19 | 20 | def run_model( 21 | model_values: ModelValues, 22 | path_values: PathValues, 23 | experts: Dict, 24 | only_base: bool = False, 25 | interactive: bool = False, 26 | ): 27 | tokenizer = AutoTokenizer.from_pretrained( 28 | model_values.model, cache_dir=path_values.cache_dir 29 | ) 30 | 31 | quantization_config = BitsAndBytesConfig( 32 | load_in_4bit=True, 33 | bnb_4bit_compute_dtype=torch.bfloat16, 34 | bnb_4bit_use_double_quant=True, 35 | bnb_4bit_quant_type="nf4", 36 | ) 37 | 38 | model = AutoModelForCausalLM.from_pretrained( 39 | model_values.model, 40 | load_in_4bit=True, 41 | torch_dtype=torch.bfloat16, 42 | quantization_config=quantization_config, 43 | device_map="auto", 44 | cache_dir=path_values.cache_dir, 45 | ) 46 | 47 | if not only_base: 48 | model = PeftModel.from_pretrained( 49 | model, 50 | os.path.join(path_values.output_dir, "general"), 51 | adapter_name="general", 52 | ) 53 | 54 | # Load adapters 55 | for expert_name in experts.keys(): 56 | model.load_adapter( 57 | os.path.join(path_values.output_dir, expert_name), expert_name 58 | ) 59 | 60 | embeddings_model = SentenceTransformer(model_values.embeddings_model, device="cuda") 61 | embeddings_tokenizer = AutoTokenizer.from_pretrained(model_values.embeddings_model) 62 | embeddings = Embeddings( 63 | embeddings_model, embeddings_tokenizer, model_values.embeddings_max_length 64 | ) 65 | router = Router(embeddings, experts) 66 | 67 | # Load dataset from the hub 68 | datasets.config.DOWNLOADED_DATASETS_PATH = path_values.dataset_dir 69 | os.environ["HF_DATASETS_CACHE"] = path_values.cache_dir 70 | 71 | if interactive: 72 | while True: 73 | instruction = input("Enter instruction: ") 74 | if instruction == "exit": 75 | break 76 | 77 | if instruction == "": 78 | continue 79 | 80 | prompt = f""" 81 | ### Input: 82 | {instruction} 83 | 84 | ### Response: 85 | """ 86 | 87 | if not only_base: 88 | rounter_expert = router.route(instruction) 89 | print(f"---- Routing to {rounter_expert}") 90 | model.set_adapter(rounter_expert) 91 | 92 | input_ids = tokenizer( 93 | prompt, return_tensors="pt", truncation=True 94 | ).input_ids.cuda() 95 | 96 | outputs = model.generate( 97 | input_ids=input_ids, 98 | max_new_tokens=500, 99 | do_sample=True, 100 | top_p=0.9, 101 | temperature=0.9, 102 | ) 103 | 104 | print( 105 | f"Response: \n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}" 106 | ) 107 | else: 108 | # Load dataset from the hub and get a sample 109 | dataset = load_dataset(model_values.dataset, split="train") 110 | for expert_name, expert_data in experts.items(): 111 | expert_dataset = dataset.filter( 112 | lambda row: row["category"] in expert_data["categories"] 113 | ) 114 | sample = expert_dataset[randrange(len(expert_dataset))] 115 | 116 | prompt = f"""{sample['system']} 117 | 118 | ### Input: 119 | {sample['instruction']} 120 | 121 | ### Response: 122 | """ 123 | 124 | rounter_expert = router.route(sample["instruction"]) 125 | print(f"---- Routing to {rounter_expert}. Ground truth: {expert_name}") 126 | 127 | input_ids = tokenizer( 128 | prompt, return_tensors="pt", truncation=True 129 | ).input_ids.cuda() 130 | 131 | model.set_adapter(expert_name) 132 | 133 | outputs = model.generate( 134 | input_ids=input_ids, 135 | max_new_tokens=500, 136 | do_sample=True, 137 | top_p=0.9, 138 | temperature=0.9, 139 | ) 140 | 141 | print(f"Prompt:\n{sample['instruction']}\n") 142 | print( 143 | f"Generated response (expert):\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}" 144 | ) 145 | print(f"Ground truth:\n{sample['response']}") 146 | print("----------------------------------------\n\n") 147 | -------------------------------------------------------------------------------- /herd/segment_experts.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import numpy as np 3 | from herd.embeddings import Embeddings 4 | from loguru import logger 5 | from sentence_transformers import SentenceTransformer 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def segment_experts(model_values, path_values, experts): 10 | dataset = datasets.load_dataset(model_values.dataset, split="train") 11 | model = SentenceTransformer(model_values.embeddings_model, device="cuda") 12 | tokenizer = AutoTokenizer.from_pretrained(model_values.embeddings_model) 13 | embeddings = Embeddings(model, tokenizer, model_values.embeddings_max_length) 14 | 15 | for expert_name, expert_data in experts.items(): 16 | logger.info(f"Calculating embedding for {expert_name}...") 17 | 18 | expert_dataset = dataset.filter( 19 | lambda row: row["category"] in expert_data["categories"] 20 | ) 21 | 22 | # get 100 random samples from the expert dataset. TODO: Is 100 a good number? 23 | expert_dataset = expert_dataset.select( 24 | np.random.choice(expert_dataset.shape[0], 100) 25 | ) 26 | 27 | # create an empty numpy array to store embeddings 28 | es = np.empty((0, embeddings.model.get_sentence_embedding_dimension())) 29 | # compute embeddings for each sample 30 | for sample in expert_dataset: 31 | e = embeddings.calculate_embeddings(sample["instruction"]) 32 | es = np.vstack((es, e)) 33 | 34 | # compute average embedding. TODO: Is average a good choice? What about max_pooling? 35 | avg_embedding = np.average(es, axis=0) 36 | 37 | # save embedding to file 38 | np.save(f"{path_values.experts_dir}/{expert_name}.npy", avg_embedding) 39 | logger.info( 40 | f" Saved embedding for {expert_name} at experts/{expert_name}.npy." 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | segment_experts() 46 | -------------------------------------------------------------------------------- /jobscript.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #BSUB -q gpua100 3 | #BSUB -J fine-tune-llama 4 | #BSUB -W 23:00 5 | #BSUB -B 6 | #BSUB -N 7 | ### request the number of GPUs 8 | #BSUB -gpu "num=1::mode=exclusive_process" 9 | ### request the number of CPU cores (at least 4x the number of GPUs) 10 | #BSUB -n 4 11 | ### we want to have this on a single node 12 | #BSUB -R "span[hosts=1]" 13 | ### we need to request CPU memory, too (note: this is per CPU core) 14 | #BSUB -R "rusage[mem=8GB]" 15 | #BSUB -o logs/%J.out 16 | #BSUB -e logs/%J.err 17 | 18 | module load cuda/11.6 19 | module load python3/3.11.4 20 | 21 | source .venv/bin/activate 22 | 23 | export HF_DATASETS_CACHE="/work3/s212722/herd/cache" 24 | 25 | export WANDB_API_KEY=[redacted] 26 | 27 | # set the wandb project where this run will be logged 28 | export WANDB_PROJECT="herd-llama" 29 | # save your trained model checkpoint to wandb 30 | export WANDB_LOG_MODEL="true" 31 | # turn off watch to log faster 32 | export WANDB_WATCH="false" 33 | 34 | python herd/herd.py finetune_experts --config_file config_experts.ini 35 | -------------------------------------------------------------------------------- /jobscript_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #BSUB -q gpua100 3 | #BSUB -J fine-tune-llama 4 | #BSUB -W 23:00 5 | #BSUB -B 6 | #BSUB -N 7 | ### request the number of GPUs 8 | #BSUB -gpu "num=1::mode=exclusive_process" 9 | ### request the number of CPU cores (at least 4x the number of GPUs) 10 | #BSUB -n 4 11 | ### we want to have this on a single node 12 | #BSUB -R "span[hosts=1]" 13 | ### we need to request CPU memory, too (note: this is per CPU core) 14 | #BSUB -R "rusage[mem=8GB]" 15 | #BSUB -o logs/%J.out 16 | #BSUB -e logs/%J.err 17 | 18 | module load cuda/11.6 19 | module load python3/3.11.4 20 | 21 | source .venv/bin/activate 22 | 23 | export HF_DATASETS_CACHE="/work3/s212722/herd/cache" 24 | python herd/herd.py run_model --config_file config_experts.ini 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import datetime 4 | import json 5 | import os 6 | import time 7 | import uuid 8 | from configparser import ConfigParser, ExtendedInterpolation 9 | from contextlib import asynccontextmanager 10 | from functools import wraps 11 | from typing import Dict, List 12 | 13 | import torch 14 | import uvicorn 15 | from dotenv import load_dotenv 16 | from fastapi import FastAPI, Request 17 | from loguru import logger 18 | from pydantic import BaseModel 19 | from sentence_transformers import SentenceTransformer 20 | from transformers import ( 21 | AutoModelForCausalLM, 22 | AutoTokenizer, 23 | BitsAndBytesConfig, 24 | StoppingCriteria, 25 | StoppingCriteriaList, 26 | ) 27 | 28 | from herd.embeddings import Embeddings 29 | from herd.models import ModelValues, PathValues 30 | from herd.multilora import MultiloraModel 31 | from herd.router import Router 32 | 33 | load_dotenv() # take environment variables from .env. 34 | 35 | MODEL_LOCK = asyncio.Lock() 36 | 37 | DEFAULT_STOPS = [ 38 | "USER:", 39 | "ASSISTANT:", 40 | "### Instruction", 41 | "### Response", 42 | # These are often used as refusals, warnings, etc, but may also remove useful info. 43 | # "\nRemember," 44 | # "\nPlease note," 45 | ] 46 | 47 | # TODO: What is this??? 48 | USER_STOP_TOKENS = [ 49 | torch.tensor([3148, 1001, 29901], device="cuda"), 50 | torch.tensor([11889, 29901], device="cuda"), 51 | ] 52 | 53 | 54 | class StoppingCriteriaSub(StoppingCriteria): 55 | def __init__(self, stops=None, encounters=1): 56 | if stops is None: 57 | stops = [] 58 | super().__init__() 59 | self.stops = list(stops + USER_STOP_TOKENS) 60 | 61 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 62 | for stop in self.stops: 63 | if torch.all((stop == input_ids[0][-len(stop) :])).item(): 64 | return True 65 | return False 66 | 67 | 68 | # TODO: Do not use global variables. 69 | 70 | app_data = {} 71 | 72 | 73 | @asynccontextmanager 74 | async def lifespan(app: FastAPI): 75 | # Load model and adapters 76 | config = ConfigParser(interpolation=ExtendedInterpolation()) 77 | config.read(app.args.config_file) 78 | 79 | model_values = ModelValues(**dict(config.items("Models"))) 80 | path_values = PathValues(**dict(config.items("Paths"))) 81 | 82 | # Create base_dir if it does not exists 83 | if not os.path.exists(path_values.base_dir): 84 | os.makedirs(path_values.base_dir) 85 | 86 | # Load tokenizer 87 | app_data["tokenizer"] = AutoTokenizer.from_pretrained( 88 | model_values.model, cache_dir=path_values.cache_dir 89 | ) 90 | 91 | logger.debug(f"Loading model {model_values.model}") 92 | logger.debug(f"Tokenizer {app_data['tokenizer']}") 93 | 94 | quantization_config = BitsAndBytesConfig( 95 | load_in_4bit=True, 96 | bnb_4bit_compute_dtype=torch.bfloat16, 97 | bnb_4bit_use_double_quant=True, 98 | bnb_4bit_quant_type="nf4", 99 | ) 100 | 101 | app_data["model"] = AutoModelForCausalLM.from_pretrained( 102 | model_values.model, 103 | load_in_4bit=True, 104 | torch_dtype=torch.bfloat16, 105 | quantization_config=quantization_config, 106 | device_map="auto", 107 | cache_dir=path_values.cache_dir, 108 | ) 109 | 110 | if not app.args.only_base: 111 | embeddings_model = SentenceTransformer(model_values.embeddings_model, device="cuda") 112 | embeddings_tokenizer = AutoTokenizer.from_pretrained(model_values.embeddings_model) 113 | embeddings = Embeddings( 114 | embeddings_model, embeddings_tokenizer, model_values.embeddings_max_length 115 | ) 116 | 117 | # Read experts.json file 118 | with open(path_values.experts_file, "r") as json_file: 119 | experts = json.loads(json_file.read()) 120 | # Create router 121 | 122 | app_data["model"] = MultiloraModel( 123 | app_data["model"], 124 | path_values.output_dir, 125 | list(experts.keys()), 126 | Router(embeddings, experts), 127 | ) 128 | 129 | yield 130 | 131 | app_data.clear() 132 | 133 | 134 | app = FastAPI(lifespan=lifespan) 135 | 136 | 137 | class ChatRequest(BaseModel): 138 | model: str 139 | experts: List[str] = None 140 | messages: List[Dict[str, str]] 141 | temperature: float = 0.5 142 | top_k: int = 50 143 | top_p: float = 1.0 144 | repetition_penalty: float = 1.0 145 | stop: List[str] = DEFAULT_STOPS 146 | max_tokens: int = None 147 | top_experts: int = 1 148 | 149 | 150 | @app.get("/") 151 | async def root(): 152 | return {"message": "Hello World"} 153 | 154 | 155 | @app.post("/v1/chat/completions") 156 | async def chat_completions(raw_request: Request): 157 | """Simulate the OpenAI /v1/chat/completions endpoint. 158 | 159 | NOTE: Parameters supported in request include: 160 | - model: str. Ignored for now. Present for compatibility with OpenAI API. 161 | - messages: list[dict[str, str]] 162 | - temperature: float 163 | - repetition_penalty: float 164 | - top_p: float 165 | - top_k: int 166 | - stop: list[str] 167 | - max_tokens: int 168 | - top_experts: int. This parameter is not present in the OpenAI API. 169 | 170 | Example request: 171 | curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{ 172 | "model": "", 173 | "messages": [ 174 | { 175 | "role": "system", 176 | "content": "A chat.", 177 | }, 178 | { 179 | "role": "user", 180 | "content": "Lorem ipsum dolor sit amet" 181 | } 182 | ] 183 | }' 184 | """ 185 | request = ChatRequest(**await raw_request.json()) 186 | async with MODEL_LOCK: 187 | return complete_request(request) 188 | 189 | 190 | def complete_request(request: ChatRequest): 191 | request_id = f"cmpl-{uuid.uuid4()}" 192 | 193 | stop_words_ids = get_stop_words_ids(request.stop) 194 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 195 | 196 | logger.debug(f"Request {request}") 197 | prompt = get_prompt(request.messages) 198 | input_ids = get_input_ids(prompt) 199 | response, duration = generate_response(input_ids, prompt, request, stopping_criteria) 200 | 201 | logger.debug(f"Response {response}") 202 | logger.debug(f"Duration {duration}") 203 | 204 | return create_completion_response(request, request_id, response, duration, input_ids) 205 | 206 | 207 | def get_stop_words_ids(stop_words): 208 | return [ 209 | app_data["tokenizer"](stop_word, return_tensors="pt").input_ids.to("cuda")[0][1:] 210 | for stop_word in stop_words 211 | ] 212 | 213 | 214 | def get_prompt(messages): 215 | system_message = messages[0]["content"] 216 | instruction_message = messages[1]["content"] 217 | return f"{system_message}\n### Input:\n{instruction_message}\n\n### Response:" 218 | 219 | 220 | def get_input_ids(prompt): 221 | return app_data["tokenizer"](prompt, return_tensors="pt", truncation=True).input_ids.cuda() 222 | 223 | 224 | def create_completion_response(request, request_id, response, duration, input_ids): 225 | return { 226 | "id": request_id, 227 | "object": "chat.completion", 228 | "created": int(time.time()), 229 | "duration": duration, 230 | "routing_duration": "TODO", 231 | "model": request.model, 232 | "expert": "TODO", 233 | "choices": [ 234 | { 235 | "index": 0, 236 | "message": { 237 | "role": "assistant", 238 | "content": response.strip(), 239 | }, 240 | "finish_reason": "stop", 241 | } 242 | ], 243 | "usage": { 244 | "prompt_tokens": len(input_ids[0]), 245 | "completion_tokens": len(response[0]), 246 | "total_tokens": len(input_ids[0]) + len(response[0]), 247 | }, 248 | } 249 | 250 | 251 | def measure_time(func): 252 | @wraps(func) 253 | def wrapper(*args, **kwargs): 254 | started_at = datetime.datetime.utcnow() 255 | result = func(*args, **kwargs) 256 | duration = (datetime.datetime.utcnow() - started_at).total_seconds() 257 | return result, duration 258 | 259 | return wrapper 260 | 261 | 262 | @measure_time 263 | def generate_response( 264 | input_ids: torch.Tensor, 265 | prompt: str, 266 | request: ChatRequest, 267 | stopping_criteria: StoppingCriteriaList, 268 | ): 269 | max_tokens = app_data["model"].config.max_position_embeddings - len(input_ids[0]) - 1 270 | 271 | output = app_data["model"].generate( 272 | prompt=prompt, 273 | top=request.top_experts, 274 | input_ids=input_ids, 275 | stopping_criteria=stopping_criteria, 276 | repetition_penalty=request.repetition_penalty, 277 | top_p=request.top_p, 278 | top_k=request.top_k, 279 | temperature=request.temperature, 280 | max_new_tokens=max_tokens, 281 | min_new_tokens=1, 282 | do_sample=True, 283 | use_cache=False, 284 | ) 285 | 286 | logger.debug("Decoding response") 287 | 288 | return app_data["tokenizer"].batch_decode( 289 | output.detach().cpu().numpy(), skip_special_tokens=True 290 | )[0][len(prompt) :] 291 | 292 | 293 | def prompt_template(system: str, instruction: str): 294 | prompt = f"""{system} 295 | ### Input: 296 | {instruction} 297 | 298 | ### Response: 299 | """ 300 | return prompt 301 | 302 | 303 | def main(): 304 | parser = argparse.ArgumentParser( 305 | description="LMoE API server, somewhat similar to OpenAI API.", 306 | ) 307 | parser.add_argument("-i", "--host", type=str, default="127.0.0.1", help="host name") 308 | parser.add_argument("-p", "--port", type=int, default=8000, help="port number") 309 | parser.add_argument("--config-file", default="config_experts.ini") 310 | parser.add_argument("--only-base", default=False, type=bool) 311 | 312 | args = parser.parse_args() 313 | app.args = args 314 | 315 | # Start the API server. 316 | uvicorn.run( 317 | app, 318 | host=args.host, 319 | port=args.port, 320 | log_level="info", 321 | timeout_keep_alive=5, 322 | ) 323 | 324 | 325 | if __name__ == "__main__": 326 | main() 327 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | aiohttp==3.8.5 3 | aiosignal==1.3.1 4 | annotated-types==0.5.0 5 | anyio==3.7.1 6 | appdirs==1.4.4 7 | async-timeout==4.0.3 8 | attrs==23.1.0 9 | azure-common==1.1.28 10 | azure-core==1.29.3 11 | azure-storage-blob==12.17.0 12 | bcrypt==4.0.1 13 | bitsandbytes==0.41.1 14 | black==23.7.0 15 | boto3==1.28.36 16 | botocore==1.31.36 17 | cachetools==5.3.1 18 | certifi==2023.7.22 19 | cffi==1.15.1 20 | charset-normalizer==3.2.0 21 | click==8.1.7 22 | cmake==3.27.2 23 | cryptography==41.0.3 24 | datasets==2.13.0 25 | dill==0.3.6 26 | docker-pycreds==0.4.0 27 | faiss-cpu==1.7.4 28 | fastapi==0.103.1 29 | filelock==3.9.0 30 | frozenlist==1.4.0 31 | fsspec==2023.4.0 32 | gitdb==4.0.10 33 | GitPython==3.1.35 34 | google-api-core==2.11.1 35 | google-auth==2.22.0 36 | google-cloud-core==2.3.3 37 | google-cloud-storage==2.10.0 38 | google-crc32c==1.5.0 39 | google-resumable-media==2.5.0 40 | googleapis-common-protos==1.60.0 41 | h11==0.14.0 42 | httptools==0.6.0 43 | huggingface-hub==0.16.4 44 | idna==3.4 45 | isodate==0.6.1 46 | Jinja2==3.1.2 47 | jmespath==1.0.1 48 | joblib==1.3.2 49 | lit==16.0.6 50 | loguru==0.7.0 51 | MarkupSafe==2.1.2 52 | mpmath==1.2.1 53 | multidict==6.0.4 54 | multiprocess==0.70.14 55 | mypy-extensions==1.0.0 56 | networkx==3.0rc1 57 | nltk==3.8.1 58 | numpy==1.24.1 59 | nvidia-cublas-cu11==11.10.3.66 60 | nvidia-cuda-cupti-cu11==11.7.101 61 | nvidia-cuda-nvrtc-cu11==11.7.99 62 | nvidia-cuda-runtime-cu11==11.7.99 63 | nvidia-cudnn-cu11==8.5.0.96 64 | nvidia-cufft-cu11==10.9.0.58 65 | nvidia-curand-cu11==10.2.10.91 66 | nvidia-cusolver-cu11==11.4.0.1 67 | nvidia-cusparse-cu11==11.7.4.91 68 | nvidia-nccl-cu11==2.14.3 69 | nvidia-nvtx-cu11==11.7.91 70 | packaging==23.1 71 | pandas==2.0.3 72 | paramiko==3.3.1 73 | pathspec==0.11.2 74 | pathtools==0.1.2 75 | peft==0.4.0 76 | Pillow==9.3.0 77 | platformdirs==3.10.0 78 | protobuf==4.24.2 79 | psutil==5.9.5 80 | pyarrow==13.0.0 81 | pyasn1==0.5.0 82 | pyasn1-modules==0.3.0 83 | pycparser==2.21 84 | pydantic==2.3.0 85 | pydantic_core==2.6.3 86 | PyNaCl==1.5.0 87 | python-dateutil==2.8.2 88 | python-dotenv==1.0.0 89 | pytz==2023.3 90 | PyYAML==6.0.1 91 | regex==2023.8.8 92 | requests==2.31.0 93 | rsa==4.9 94 | s3transfer==0.6.2 95 | safetensors==0.3.3 96 | scikit-learn==1.3.0 97 | scipy==1.11.2 98 | sentence-transformers==2.2.2 99 | sentencepiece==0.1.99 100 | sentry-sdk==1.30.0 101 | setproctitle==1.3.2 102 | six==1.16.0 103 | smart-open==6.3.0 104 | smmap==5.0.0 105 | sniffio==1.3.0 106 | starlette==0.27.0 107 | sympy==1.11.1 108 | threadpoolctl==3.2.0 109 | tokenizers==0.13.3 110 | tomli==2.0.1 111 | torch==2.0.1 112 | torchvision==0.15.2 113 | tqdm==4.66.1 114 | transformers==4.31.0 115 | triton==2.0.0 116 | trl==0.4.7 117 | typing_extensions==4.7.1 118 | tzdata==2023.3 119 | urllib3==1.26.16 120 | uvicorn==0.23.2 121 | uvloop==0.17.0 122 | wandb==0.15.10 123 | watchfiles==0.20.0 124 | websockets==11.0.3 125 | xxhash==3.3.0 126 | yarl==1.9.2 127 | -------------------------------------------------------------------------------- /scripts/app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import datetime 4 | import json 5 | import os 6 | import time 7 | import uuid 8 | from configparser import ConfigParser, ExtendedInterpolation 9 | from contextlib import asynccontextmanager 10 | from functools import wraps 11 | from typing import Dict, List 12 | 13 | import torch 14 | import uvicorn 15 | from dotenv import load_dotenv 16 | from fastapi import FastAPI, Request 17 | from loguru import logger 18 | from pydantic import BaseModel 19 | from sentence_transformers import SentenceTransformer 20 | from transformers import ( 21 | AutoModelForCausalLM, 22 | AutoTokenizer, 23 | BitsAndBytesConfig, 24 | StoppingCriteria, 25 | StoppingCriteriaList, 26 | ) 27 | 28 | from herd.embeddings import Embeddings 29 | from herd.models import ModelValues, PathValues 30 | from herd.multilora import MultiloraModel 31 | from herd.router import Router 32 | 33 | load_dotenv() # take environment variables from .env. 34 | 35 | MODEL_LOCK = asyncio.Lock() 36 | 37 | DEFAULT_STOPS = [ 38 | "USER:", 39 | "ASSISTANT:", 40 | "### Instruction", 41 | "### Response", 42 | # These are often used as refusals, warnings, etc, but may also remove useful info. 43 | # "\nRemember," 44 | # "\nPlease note," 45 | ] 46 | 47 | # TODO: What is this??? 48 | USER_STOP_TOKENS = [ 49 | torch.tensor([3148, 1001, 29901], device="cuda"), 50 | torch.tensor([11889, 29901], device="cuda"), 51 | ] 52 | 53 | 54 | class StoppingCriteriaSub(StoppingCriteria): 55 | def __init__(self, stops=None, encounters=1): 56 | if stops is None: 57 | stops = [] 58 | super().__init__() 59 | self.stops = list(stops + USER_STOP_TOKENS) 60 | 61 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 62 | for stop in self.stops: 63 | if torch.all((stop == input_ids[0][-len(stop) :])).item(): 64 | return True 65 | return False 66 | 67 | 68 | # TODO: Do not use global variables. 69 | 70 | app_data = {} 71 | 72 | 73 | @asynccontextmanager 74 | async def lifespan(app: FastAPI): 75 | # Load model and adapters 76 | config = ConfigParser(interpolation=ExtendedInterpolation()) 77 | config.read(app.args.config_file) 78 | 79 | model_values = ModelValues(**dict(config.items("Models"))) 80 | path_values = PathValues(**dict(config.items("Paths"))) 81 | 82 | # Create base_dir if it does not exists 83 | if not os.path.exists(path_values.base_dir): 84 | os.makedirs(path_values.base_dir) 85 | 86 | # Load tokenizer 87 | app_data["tokenizer"] = AutoTokenizer.from_pretrained( 88 | model_values.model, cache_dir=path_values.cache_dir 89 | ) 90 | 91 | logger.debug(f"Loading model {model_values.model}") 92 | logger.debug(f"Tokenizer {app_data['tokenizer']}") 93 | 94 | quantization_config = BitsAndBytesConfig( 95 | load_in_4bit=True, 96 | bnb_4bit_compute_dtype=torch.bfloat16, 97 | bnb_4bit_use_double_quant=True, 98 | bnb_4bit_quant_type="nf4", 99 | ) 100 | 101 | app_data["model"] = AutoModelForCausalLM.from_pretrained( 102 | model_values.model, 103 | load_in_4bit=True, 104 | torch_dtype=torch.bfloat16, 105 | quantization_config=quantization_config, 106 | device_map="auto", 107 | cache_dir=path_values.cache_dir, 108 | ) 109 | 110 | if not app.args.only_base: 111 | embeddings_model = SentenceTransformer(model_values.embeddings_model, device="cuda") 112 | embeddings_tokenizer = AutoTokenizer.from_pretrained(model_values.embeddings_model) 113 | embeddings = Embeddings( 114 | embeddings_model, embeddings_tokenizer, model_values.embeddings_max_length 115 | ) 116 | 117 | # Read experts.json file 118 | with open(path_values.experts_file, "r") as json_file: 119 | experts = json.loads(json_file.read()) 120 | # Create router 121 | 122 | app_data["model"] = MultiloraModel( 123 | app_data["model"], 124 | path_values.output_dir, 125 | list(experts.keys()), 126 | Router(embeddings, experts), 127 | ) 128 | 129 | yield 130 | 131 | app_data.clear() 132 | 133 | 134 | app = FastAPI(lifespan=lifespan) 135 | 136 | 137 | class ChatRequest(BaseModel): 138 | model: str 139 | experts: List[str] = None 140 | messages: List[Dict[str, str]] 141 | temperature: float = 0.5 142 | top_k: int = 50 143 | top_p: float = 1.0 144 | repetition_penalty: float = 1.0 145 | stop: List[str] = DEFAULT_STOPS 146 | max_tokens: int = None 147 | top_experts: int = 1 148 | 149 | 150 | @app.get("/") 151 | async def root(): 152 | return {"message": "Hello World"} 153 | 154 | 155 | @app.post("/v1/chat/completions") 156 | async def chat_completions(raw_request: Request): 157 | """Simulate the OpenAI /v1/chat/completions endpoint. 158 | 159 | NOTE: Parameters supported in request include: 160 | - model: str. Ignored for now. Present for compatibility with OpenAI API. 161 | - messages: list[dict[str, str]] 162 | - temperature: float 163 | - repetition_penalty: float 164 | - top_p: float 165 | - top_k: int 166 | - stop: list[str] 167 | - max_tokens: int 168 | - top_experts: int. This parameter is not present in the OpenAI API. 169 | 170 | Example request: 171 | curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{ 172 | "model": "", 173 | "messages": [ 174 | { 175 | "role": "system", 176 | "content": "A chat.", 177 | }, 178 | { 179 | "role": "user", 180 | "content": "Lorem ipsum dolor sit amet" 181 | } 182 | ] 183 | }' 184 | """ 185 | request = ChatRequest(**await raw_request.json()) 186 | async with MODEL_LOCK: 187 | return complete_request(request) 188 | 189 | 190 | def complete_request(request: ChatRequest): 191 | request_id = f"cmpl-{uuid.uuid4()}" 192 | 193 | stop_words_ids = get_stop_words_ids(request.stop) 194 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 195 | 196 | logger.debug(f"Request {request}") 197 | prompt = get_prompt(request.messages) 198 | input_ids = get_input_ids(prompt) 199 | response, duration = generate_response(input_ids, prompt, request, stopping_criteria) 200 | 201 | logger.debug(f"Response {response}") 202 | logger.debug(f"Duration {duration}") 203 | 204 | return create_completion_response(request, request_id, response, duration, input_ids) 205 | 206 | 207 | def get_stop_words_ids(stop_words): 208 | return [ 209 | app_data["tokenizer"](stop_word, return_tensors="pt").input_ids.to("cuda")[0][1:] 210 | for stop_word in stop_words 211 | ] 212 | 213 | 214 | def get_prompt(messages): 215 | system_message = messages[0]["content"] 216 | instruction_message = messages[1]["content"] 217 | return f"{system_message}\n### Input:\n{instruction_message}\n\n### Response:" 218 | 219 | 220 | def get_input_ids(prompt): 221 | return app_data["tokenizer"](prompt, return_tensors="pt", truncation=True).input_ids.cuda() 222 | 223 | 224 | def create_completion_response(request, request_id, response, duration, input_ids): 225 | return { 226 | "id": request_id, 227 | "object": "chat.completion", 228 | "created": int(time.time()), 229 | "duration": duration, 230 | "routing_duration": "TODO", 231 | "model": request.model, 232 | "expert": "TODO", 233 | "choices": [ 234 | { 235 | "index": 0, 236 | "message": { 237 | "role": "assistant", 238 | "content": response.strip(), 239 | }, 240 | "finish_reason": "stop", 241 | } 242 | ], 243 | "usage": { 244 | "prompt_tokens": len(input_ids[0]), 245 | "completion_tokens": len(response[0]), 246 | "total_tokens": len(input_ids[0]) + len(response[0]), 247 | }, 248 | } 249 | 250 | 251 | def measure_time(func): 252 | @wraps(func) 253 | def wrapper(*args, **kwargs): 254 | started_at = datetime.datetime.utcnow() 255 | result = func(*args, **kwargs) 256 | duration = (datetime.datetime.utcnow() - started_at).total_seconds() 257 | return result, duration 258 | 259 | return wrapper 260 | 261 | 262 | @measure_time 263 | def generate_response( 264 | input_ids: torch.Tensor, 265 | prompt: str, 266 | request: ChatRequest, 267 | stopping_criteria: StoppingCriteriaList, 268 | ): 269 | max_tokens = app_data["model"].config.max_position_embeddings - len(input_ids[0]) - 1 270 | 271 | output = app_data["model"].generate( 272 | prompt=prompt, 273 | top=request.top_experts, 274 | input_ids=input_ids, 275 | stopping_criteria=stopping_criteria, 276 | repetition_penalty=request.repetition_penalty, 277 | top_p=request.top_p, 278 | top_k=request.top_k, 279 | temperature=request.temperature, 280 | max_new_tokens=max_tokens, 281 | min_new_tokens=1, 282 | do_sample=True, 283 | use_cache=False, 284 | ) 285 | 286 | logger.debug("Decoding response") 287 | 288 | return app_data["tokenizer"].batch_decode( 289 | output.detach().cpu().numpy(), skip_special_tokens=True 290 | )[0][len(prompt) :] 291 | 292 | 293 | def prompt_template(system: str, instruction: str): 294 | prompt = f"""{system} 295 | ### Input: 296 | {instruction} 297 | 298 | ### Response: 299 | """ 300 | return prompt 301 | 302 | 303 | def main(): 304 | parser = argparse.ArgumentParser( 305 | description="LMoE API server, somewhat similar to OpenAI API.", 306 | ) 307 | parser.add_argument("-i", "--host", type=str, default="127.0.0.1", help="host name") 308 | parser.add_argument("-p", "--port", type=int, default=8000, help="port number") 309 | parser.add_argument("--config-file", default="config_experts.ini") 310 | parser.add_argument("--only-base", default=False, type=bool) 311 | 312 | args = parser.parse_args() 313 | app.args = args 314 | 315 | # Start the API server. 316 | uvicorn.run( 317 | app, 318 | host=args.host, 319 | port=args.port, 320 | log_level="info", 321 | timeout_keep_alive=5, 322 | ) 323 | 324 | 325 | if __name__ == "__main__": 326 | main() 327 | -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from configparser import ConfigParser, ExtendedInterpolation 5 | 6 | from herd.finetune import finetune 7 | from herd.models import ModelValues, PathValues 8 | from herd.run_model import run_model 9 | from herd.segment_experts import segment_experts 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description="Fine-Tune Llama 2 models") 14 | parser.add_argument( 15 | "action", 16 | choices=["finetune", "run_model", "segment_experts"], 17 | ) 18 | parser.add_argument("--config-file", default="config_experts.ini") 19 | parser.add_argument("--only-base", default=False, type=bool) 20 | parser.add_argument("--interactive", default=False, type=bool) 21 | args = parser.parse_args() 22 | 23 | config = ConfigParser(interpolation=ExtendedInterpolation()) 24 | config.read(args.config_file) 25 | 26 | model_values = ModelValues(**dict(config.items("Models"))) 27 | path_values = PathValues(**dict(config.items("Paths"))) 28 | 29 | # Create base_dir if it does not exists 30 | if not os.path.exists(path_values.base_dir): 31 | os.makedirs(path_values.base_dir) 32 | 33 | # Read experts.json file 34 | with open(path_values.experts_file, "r") as json_file: 35 | experts = json.loads(json_file.read()) 36 | # Process based on action 37 | match args.action: 38 | case "finetune": 39 | finetune(model_values, path_values, config, experts) 40 | case "run_model": 41 | run_model( 42 | model_values, path_values, experts, args.only_base, args.interactive 43 | ) 44 | case "segment_experts": 45 | segment_experts(model_values, path_values, experts) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | --------------------------------------------------------------------------------