├── training └── axolotl │ ├── src │ ├── axolotl │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── const.py │ │ │ └── cli.py │ │ ├── core │ │ │ └── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── phi │ │ │ │ ├── __init__.py │ │ │ │ └── configuration_mixformer_sequential.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── dict.py │ │ │ ├── wandb_.py │ │ │ ├── tokenization.py │ │ │ ├── bench.py │ │ │ ├── schedulers.py │ │ │ ├── collators.py │ │ │ ├── distributed.py │ │ │ ├── trainer.py │ │ │ └── dataloader.py │ │ ├── prompt_strategies │ │ │ ├── alpaca_instruct.py │ │ │ ├── __init__.py │ │ │ ├── sharegpt_jokes.py │ │ │ ├── orcamini.py │ │ │ ├── metharme.py │ │ │ ├── completion.py │ │ │ ├── user_defined.py │ │ │ ├── sharegpt.py │ │ │ ├── pygmalion.py │ │ │ ├── alpaca_chat.py │ │ │ ├── context_qa.py │ │ │ ├── alpaca_w_system.py │ │ │ ├── creative_acr.py │ │ │ └── llama2_chat.py │ │ ├── cli │ │ │ ├── merge_lora.py │ │ │ ├── inference.py │ │ │ ├── shard.py │ │ │ ├── train.py │ │ │ └── __init__.py │ │ ├── monkeypatch │ │ │ ├── llama_embeddings_hijack.py │ │ │ ├── mistral_embeddings_hijack.py │ │ │ ├── llama_expand_mask.py │ │ │ ├── btlm_attn_hijack_flash.py │ │ │ ├── xpos_rope_llama_monkey_patch.py │ │ │ ├── utils.py │ │ │ ├── llama_attn_hijack_sdp.py │ │ │ ├── llama_attn_hijack_xformers.py │ │ │ └── fastchat_conversation_turns.py │ │ ├── logging_config.py │ │ ├── convert.py │ │ ├── train.py │ │ └── datasets.py │ └── axolotl.egg-info │ │ ├── top_level.txt │ │ ├── dependency_links.txt │ │ ├── SOURCES.txt │ │ ├── PKG-INFO │ │ └── requires.txt │ ├── dist │ └── axolotl-0.3.0-py3.9.egg │ ├── docs │ ├── faq.md │ ├── multipack.md │ ├── multi-node.md │ └── nccl.md │ ├── scripts │ ├── runpod-entrypoint.sh │ └── finetune.py │ ├── docker │ ├── Dockerfile-runpod │ ├── Dockerfile │ └── Dockerfile-base │ ├── requirements.txt │ ├── deepspeed │ ├── zero1.json │ ├── zero2.json │ └── zero3.json │ ├── setup.py │ └── examples │ └── mistral │ └── nips │ └── nips_02.yml ├── inference └── submission_2 │ ├── fast_api_requirements.txt │ ├── Dockerfile │ ├── api.py │ └── main.py └── data_prep ├── prepare_math_reasoning_dataset.py ├── prepare_exact_match_tasks_dataset.py ├── prepare_generation_tasks_dataset.py └── combine_datasets.py /training/axolotl/src/axolotl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /inference/submission_2/fast_api_requirements.txt: -------------------------------------------------------------------------------- 1 | # FAST API 2 | fastapi>=0.68.0,<0.69.0 3 | pydantic>=1.8.0,<2.0.0 4 | uvicorn>=0.15.0,<0.16.0 5 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/common/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various shared constants 3 | """ 4 | 5 | DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" 6 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | https://download.pytorch.org/whl/cu118 2 | https://huggingface.github.io/autogptq-index/whl/cu118/ 3 | -------------------------------------------------------------------------------- /training/axolotl/dist/axolotl-0.3.0-py3.9.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Upaya07/NeurIPS-llm-efficiency-challenge/HEAD/training/axolotl/dist/axolotl-0.3.0-py3.9.egg -------------------------------------------------------------------------------- /training/axolotl/src/axolotl.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | src/axolotl.egg-info/PKG-INFO 3 | src/axolotl.egg-info/SOURCES.txt 4 | src/axolotl.egg-info/dependency_links.txt 5 | src/axolotl.egg-info/requires.txt 6 | src/axolotl.egg-info/top_level.txt -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/models/phi/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MixFormers model architecture used for phi models 3 | """ 4 | 5 | from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa 6 | from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa 7 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: axolotl 3 | Version: 0.3.0 4 | Summary: LLM Trainer 5 | Provides-Extra: flash-attn 6 | Provides-Extra: deepspeed 7 | 8 | Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. 9 | -------------------------------------------------------------------------------- /training/axolotl/docs/faq.md: -------------------------------------------------------------------------------- 1 | # Axolotl FAQ's 2 | 3 | 4 | > The trainer stopped and hasn't progressed in several minutes. 5 | 6 | Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md) 7 | 8 | > Exitcode -9 9 | 10 | This usually happens when you run out of system RAM. 11 | 12 | > Exitcode -7 while using deepspeed 13 | 14 | Try upgrading deepspeed w: `pip install -U deepspeed` 15 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/dict.py: -------------------------------------------------------------------------------- 1 | """Module containing the DictDefault class""" 2 | 3 | from addict import Dict 4 | 5 | 6 | class DictDefault(Dict): 7 | """ 8 | A Dict that returns None instead of returning empty Dict for missing keys. 9 | """ 10 | 11 | def __missing__(self, key): 12 | return None 13 | 14 | def __or__(self, other): 15 | return DictDefault(super().__or__(other)) 16 | -------------------------------------------------------------------------------- /training/axolotl/scripts/runpod-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Export specific ENV variables to /etc/rp_environment 4 | echo "Exporting environment variables..." 5 | printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment 6 | echo 'source /etc/rp_environment' >> ~/.bashrc 7 | 8 | if [[ $PUBLIC_KEY ]] 9 | then 10 | mkdir -p ~/.ssh 11 | chmod 700 ~/.ssh 12 | echo $PUBLIC_KEY >> ~/.ssh/authorized_keys 13 | chmod 700 -R ~/.ssh 14 | # Start the SSH service in the background 15 | service ssh start 16 | else 17 | echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon" 18 | fi 19 | 20 | # Execute the passed arguments (CMD) 21 | exec "$@" 22 | -------------------------------------------------------------------------------- /inference/submission_2/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel 2 | 3 | RUN apt-get update && apt-get install -y git python3-virtualenv wget 4 | 5 | RUN pip install transformers 6 | RUN pip install torch>=2.0.1 huggingface_hub accelerate sentencepiece optimum py7zr scipy appdirs peft 7 | 8 | WORKDIR /workspace 9 | # Setup server requriements 10 | COPY ./fast_api_requirements.txt fast_api_requirements.txt 11 | RUN pip install --no-cache-dir --upgrade -r fast_api_requirements.txt 12 | 13 | ENV HUGGINGFACE_REPO="upaya07/Birbal-7B-V1" 14 | 15 | # Copy over single file server 16 | COPY ./main.py main.py 17 | COPY ./api.py api.py 18 | # Run the server 19 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"] 20 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | auto-gptq==0.4.2 3 | packaging 4 | peft==0.6.0 5 | transformers@ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697 6 | bitsandbytes>=0.41.1 7 | accelerate@ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 8 | addict 9 | fire 10 | PyYAML>=6.0 11 | datasets 12 | sentencepiece 13 | wandb 14 | einops 15 | xformers>=0.0.22 16 | optimum==1.13.2 17 | hf_transfer 18 | colorama 19 | numba 20 | numpy>=1.24.4 21 | bert-score==0.3.13 22 | evaluate==0.4.0 23 | rouge-score==0.1.2 24 | scipy 25 | scikit-learn==1.2.2 26 | pynvml 27 | art 28 | fschat==0.2.29 29 | gradio 30 | 31 | [deepspeed] 32 | deepspeed 33 | 34 | [flash-attn] 35 | flash-attn>=2.3.0 36 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/alpaca_instruct.py: -------------------------------------------------------------------------------- 1 | """Module loading the AlpacaInstructPromptTokenizingStrategy class""" 2 | 3 | from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy 4 | from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter 5 | 6 | 7 | def load(tokenizer, cfg): 8 | return AlpacaPromptTokenizingStrategy( 9 | AlpacaPrompter(PromptStyle.INSTRUCT.value), 10 | tokenizer, 11 | cfg.train_on_inputs, 12 | cfg.sequence_len, 13 | ) 14 | 15 | 16 | def load_no_prompt(tokenizer, cfg): 17 | return AlpacaPromptTokenizingStrategy( 18 | UnpromptedPrompter(PromptStyle.INSTRUCT.value), 19 | tokenizer, 20 | cfg.train_on_inputs, 21 | cfg.sequence_len, 22 | ) 23 | -------------------------------------------------------------------------------- /training/axolotl/docker/Dockerfile-runpod: -------------------------------------------------------------------------------- 1 | ARG BASE_TAG=main 2 | FROM winglian/axolotl:$BASE_TAG 3 | 4 | ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" 5 | ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" 6 | ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" 7 | 8 | COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh 9 | 10 | RUN apt install --yes --no-install-recommends openssh-server tmux && \ 11 | mkdir -p ~/.ssh && \ 12 | chmod 700 ~/.ssh && \ 13 | printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ 14 | chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh && \ 15 | chmod +x /root/runpod-entrypoint.sh 16 | 17 | ENTRYPOINT ["/root/runpod-entrypoint.sh"] 18 | CMD ["sleep", "infinity"] 19 | -------------------------------------------------------------------------------- /training/axolotl/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ 3 | torch==2.0.1 4 | auto-gptq==0.4.2 5 | packaging 6 | peft==0.6.0 7 | transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697 8 | bitsandbytes>=0.41.1 9 | accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 10 | deepspeed 11 | addict 12 | fire 13 | PyYAML>=6.0 14 | datasets 15 | flash-attn>=2.3.0 16 | sentencepiece 17 | wandb 18 | einops 19 | xformers>=0.0.22 20 | optimum==1.13.2 21 | hf_transfer 22 | colorama 23 | numba 24 | numpy>=1.24.4 25 | # qlora things 26 | bert-score==0.3.13 27 | evaluate==0.4.0 28 | rouge-score==0.1.2 29 | scipy 30 | scikit-learn==1.2.2 31 | pynvml 32 | art 33 | fschat==0.2.29 34 | gradio -------------------------------------------------------------------------------- /inference/submission_2/api.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from typing import List, Dict, Optional 4 | 5 | 6 | class ProcessRequest(BaseModel): 7 | prompt: str 8 | num_samples: int = 1 9 | max_new_tokens: int = 50 10 | top_k: int = 200 11 | temperature: float = 1e-7 12 | seed: Optional[int] = None 13 | echo_prompt: Optional[bool] 14 | 15 | 16 | class Token(BaseModel): 17 | text: str 18 | logprob: float 19 | top_logprob: Dict[str, float] 20 | 21 | 22 | class ProcessResponse(BaseModel): 23 | text: str 24 | tokens: List[Token] 25 | logprob: float 26 | request_time: float 27 | 28 | 29 | class TokenizeRequest(BaseModel): 30 | text: str 31 | truncation: bool = True 32 | max_length: int = 2048 33 | 34 | 35 | class TokenizeResponse(BaseModel): 36 | tokens: List[int] 37 | request_time: float 38 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/cli/merge_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | CLI to run merge a trained LoRA into a base model 3 | """ 4 | from pathlib import Path 5 | 6 | import fire 7 | import transformers 8 | 9 | from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art 10 | from axolotl.common.cli import TrainerCliArgs 11 | 12 | 13 | def do_cli(config: Path = Path("examples/"), **kwargs): 14 | # pylint: disable=duplicate-code 15 | print_axolotl_text_art() 16 | parser = transformers.HfArgumentParser((TrainerCliArgs)) 17 | parsed_cli_args, _ = parser.parse_args_into_dataclasses( 18 | return_remaining_strings=True 19 | ) 20 | parsed_cli_args.merge_lora = True 21 | parsed_cfg = load_cfg(config, merge_lora=True, **kwargs) 22 | 23 | do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) 24 | 25 | 26 | if __name__ == "__main__": 27 | fire.Fire(do_cli) 28 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/cli/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | CLI to run inference on a trained model 3 | """ 4 | from pathlib import Path 5 | 6 | import fire 7 | import transformers 8 | 9 | from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art 10 | from axolotl.common.cli import TrainerCliArgs 11 | 12 | 13 | def do_cli(config: Path = Path("examples/"), **kwargs): 14 | # pylint: disable=duplicate-code 15 | print_axolotl_text_art() 16 | parsed_cfg = load_cfg(config, **kwargs) 17 | parsed_cfg.sample_packing = False 18 | parser = transformers.HfArgumentParser((TrainerCliArgs)) 19 | parsed_cli_args, _ = parser.parse_args_into_dataclasses( 20 | return_remaining_strings=True 21 | ) 22 | parsed_cli_args.inference = True 23 | 24 | do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) 25 | 26 | 27 | if __name__ == "__main__": 28 | fire.Fire(do_cli) 29 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/wandb_.py: -------------------------------------------------------------------------------- 1 | """Module for wandb utilities""" 2 | 3 | import os 4 | 5 | 6 | def setup_wandb_env_vars(cfg): 7 | if cfg.wandb_mode and cfg.wandb_mode == "offline": 8 | os.environ["WANDB_MODE"] = cfg.wandb_mode 9 | elif cfg.wandb_project and len(cfg.wandb_project) > 0: 10 | os.environ["WANDB_PROJECT"] = cfg.wandb_project 11 | cfg.use_wandb = True 12 | if cfg.wandb_entity and len(cfg.wandb_entity) > 0: 13 | os.environ["WANDB_ENTITY"] = cfg.wandb_entity 14 | if cfg.wandb_watch and len(cfg.wandb_watch) > 0: 15 | os.environ["WANDB_WATCH"] = cfg.wandb_watch 16 | if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: 17 | os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model 18 | if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: 19 | os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id 20 | else: 21 | os.environ["WANDB_DISABLED"] = "true" 22 | -------------------------------------------------------------------------------- /training/axolotl/deepspeed/zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "overlap_comm": true 5 | }, 6 | "bf16": { 7 | "enabled": "auto" 8 | }, 9 | "fp16": { 10 | "enabled": "auto", 11 | "auto_cast": false, 12 | "loss_scale": 0, 13 | "initial_scale_power": 32, 14 | "loss_scale_window": 1000, 15 | "hysteresis": 2, 16 | "min_loss_scale": 1 17 | }, 18 | "optimizer": { 19 | "type": "AdamW", 20 | "params": { 21 | "lr": "auto", 22 | "betas": "auto", 23 | "eps": "auto", 24 | "weight_decay": "auto" 25 | } 26 | }, 27 | "scheduler": { 28 | "type": "WarmupDecayLR", 29 | "params": { 30 | "warmup_min_lr": "auto", 31 | "warmup_max_lr": "auto", 32 | "warmup_num_steps": "auto", 33 | "warmup_type": "linear", 34 | "total_num_steps": "auto" 35 | } 36 | }, 37 | "gradient_accumulation_steps": "auto", 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } 42 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/__init__.py: -------------------------------------------------------------------------------- 1 | """Module to load prompt strategies.""" 2 | 3 | import importlib 4 | import inspect 5 | 6 | from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig 7 | 8 | 9 | def load(strategy, tokenizer, cfg, ds_cfg): 10 | try: 11 | load_fn = "load" 12 | if strategy.split(".")[-1].startswith("load_"): 13 | load_fn = strategy.split(".")[-1] 14 | strategy = ".".join(strategy.split(".")[:-1]) 15 | mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") 16 | func = getattr(mod, load_fn) 17 | load_kwargs = {} 18 | if strategy == "user_defined": 19 | load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) 20 | else: 21 | sig = inspect.signature(func) 22 | if "ds_cfg" in sig.parameters: 23 | load_kwargs["ds_cfg"] = ds_cfg 24 | return func(tokenizer, cfg, **load_kwargs) 25 | except Exception: # pylint: disable=broad-exception-caught 26 | return None 27 | -------------------------------------------------------------------------------- /training/axolotl/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_TAG=main-base 2 | FROM winglian/axolotl-base:$BASE_TAG 3 | 4 | ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" 5 | ARG AXOLOTL_EXTRAS="" 6 | ARG CUDA="118" 7 | ENV BNB_CUDA_VERSION=$CUDA 8 | ARG PYTORCH_VERSION="2.0.1" 9 | 10 | ENV PYTORCH_VERSION=$PYTORCH_VERSION 11 | 12 | RUN apt-get update && \ 13 | apt-get install -y vim curl 14 | 15 | WORKDIR /workspace 16 | 17 | RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git 18 | 19 | WORKDIR /workspace/axolotl 20 | 21 | # If AXOLOTL_EXTRAS is set, append it in brackets 22 | RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt 23 | RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ 24 | pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ 25 | else \ 26 | pip install -e .[flash-attn]; \ 27 | fi 28 | 29 | # fix so that git fetch/pull from remote works 30 | RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ 31 | git config --get remote.origin.fetch 32 | 33 | # helper for huggingface-login cli 34 | RUN git config --global credential.helper store 35 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/sharegpt_jokes.py: -------------------------------------------------------------------------------- 1 | """Module for Jokes prompts using sharegpt style """ 2 | from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy 3 | from axolotl.prompters import ShareGPTPrompterV2 4 | 5 | 6 | def load(tokenizer, cfg): 7 | return SimpleJokesShareGPTPromptTokenizingStrategy( 8 | ShareGPTPrompterV2(), 9 | tokenizer, 10 | cfg.train_on_inputs, 11 | cfg.sequence_len, 12 | ) 13 | 14 | 15 | class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): 16 | """ 17 | Tokenization strategy for asking bot to tell a joke and then explain why its funny 18 | """ 19 | 20 | # title, text, explanation 21 | def get_conversation_thread(self, prompt): 22 | title = "" if not prompt["title"] else prompt["title"] + " " 23 | return [ 24 | {"from": "human", "value": "Tell me a joke."}, 25 | {"from": "gpt", "value": title + prompt["text"]}, 26 | {"from": "human", "value": "Why is that joke funny?"}, 27 | {"from": "gpt", "value": prompt["explanation"]}, 28 | ] 29 | -------------------------------------------------------------------------------- /training/axolotl/deepspeed/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu" 6 | }, 7 | "contiguous_gradients": true, 8 | "overlap_comm": true 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "fp16": { 14 | "enabled": "auto", 15 | "auto_cast": false, 16 | "loss_scale": 0, 17 | "initial_scale_power": 32, 18 | "loss_scale_window": 1000, 19 | "hysteresis": 2, 20 | "min_loss_scale": 1 21 | }, 22 | "optimizer": { 23 | "type": "AdamW", 24 | "params": { 25 | "lr": "auto", 26 | "betas": "auto", 27 | "eps": "auto", 28 | "weight_decay": "auto" 29 | } 30 | }, 31 | "scheduler": { 32 | "type": "WarmupDecayLR", 33 | "params": { 34 | "warmup_min_lr": "auto", 35 | "warmup_max_lr": "auto", 36 | "warmup_num_steps": "auto", 37 | "warmup_type": "linear", 38 | "total_num_steps": "auto" 39 | } 40 | }, 41 | "gradient_accumulation_steps": "auto", 42 | "train_batch_size": "auto", 43 | "train_micro_batch_size_per_gpu": "auto", 44 | "wall_clock_breakdown": false 45 | } 46 | -------------------------------------------------------------------------------- /training/axolotl/docs/multipack.md: -------------------------------------------------------------------------------- 1 | # Multipack 2 | 3 | 4k context, bsz =4, 4 | each character represents 256 tokens 5 | X represents a padding token 6 | 7 | ``` 8 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 9 | [[ A A A A A A A A A A A ] 10 | B B B B B B ] 11 | C C C C C C C ] 12 | D D D D ]] 13 | 14 | [[ E E E E E E E E ] 15 | [ F F F F ] 16 | [ G G G ] 17 | [ H H H H ]] 18 | 19 | [[ I I I ] 20 | [ J J J ] 21 | [ K K K K K] 22 | [ L L L ]] 23 | ``` 24 | 25 | after padding to longest input in each step 26 | ``` 27 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 28 | [[ A A A A A A A A A A A ] 29 | B B B B B B X X X X X X ] 30 | C C C C C C C X X X X ] 31 | D D D D X X X X X X X ]] 32 | 33 | [[ E E E E E E E E ] 34 | [ F F F F X X X X ] 35 | [ G G G X X X X X ] 36 | [ H H H H X X X X ]] 37 | 38 | [[ I I I X X ] 39 | [ J J J X X ] 40 | [ K K K K K ] 41 | [ L L L X X ]] 42 | ``` 43 | 44 | w packing ( note it's the same effective number of tokens per step, but a true bsz of 1) 45 | ``` 46 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 47 | [[ A A A A A A A A A A A B B B B B 48 | B C C C C C C C D D D D E E E E 49 | E E E E F F F F F G G G H H H H 50 | I I I J J J J K K K K K L L L X ]] 51 | ``` 52 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/cli/shard.py: -------------------------------------------------------------------------------- 1 | """ 2 | CLI to shard a trained model into 10GiB chunks 3 | """ 4 | import logging 5 | from pathlib import Path 6 | 7 | import fire 8 | import transformers 9 | 10 | from axolotl.cli import load_cfg, print_axolotl_text_art 11 | from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer 12 | from axolotl.utils.dict import DictDefault 13 | 14 | LOG = logging.getLogger("axolotl.scripts") 15 | 16 | 17 | def shard( 18 | *, 19 | cfg: DictDefault, 20 | cli_args: TrainerCliArgs, 21 | ): 22 | model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) 23 | safe_serialization = cfg.save_safetensors is True 24 | LOG.debug("Re-saving model w/ sharding") 25 | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) 26 | 27 | 28 | def do_cli(config: Path = Path("examples/"), **kwargs): 29 | # pylint: disable=duplicate-code 30 | print_axolotl_text_art() 31 | parsed_cfg = load_cfg(config, **kwargs) 32 | parser = transformers.HfArgumentParser((TrainerCliArgs)) 33 | parsed_cli_args, _ = parser.parse_args_into_dataclasses( 34 | return_remaining_strings=True 35 | ) 36 | parsed_cli_args.shard = True 37 | 38 | shard(cfg=parsed_cfg, cli_args=parsed_cli_args) 39 | 40 | 41 | if __name__ == "__main__": 42 | fire.Fire(do_cli) 43 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/common/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | shared module for cli specific things 3 | """ 4 | 5 | import logging 6 | from dataclasses import dataclass, field 7 | from typing import Optional 8 | 9 | from axolotl.logging_config import configure_logging 10 | from axolotl.utils.dict import DictDefault 11 | from axolotl.utils.models import load_model, load_tokenizer 12 | 13 | configure_logging() 14 | LOG = logging.getLogger("axolotl.common.cli") 15 | 16 | 17 | @dataclass 18 | class TrainerCliArgs: 19 | """ 20 | dataclass representing the various non-training arguments 21 | """ 22 | 23 | debug: bool = field(default=False) 24 | debug_text_only: bool = field(default=False) 25 | debug_num_examples: int = field(default=1) 26 | inference: bool = field(default=False) 27 | merge_lora: bool = field(default=False) 28 | prepare_ds_only: bool = field(default=False) 29 | prompter: Optional[str] = field(default=None) 30 | shard: bool = field(default=False) 31 | 32 | 33 | def load_model_and_tokenizer( 34 | *, 35 | cfg: DictDefault, 36 | cli_args: TrainerCliArgs, 37 | ): 38 | LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") 39 | tokenizer = load_tokenizer(cfg) 40 | LOG.info("loading model and (optionally) peft_config...") 41 | model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) 42 | 43 | return model, tokenizer 44 | -------------------------------------------------------------------------------- /training/axolotl/deepspeed/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "offload_optimizer": { 5 | "device": "cpu", 6 | "pin_memory": true 7 | }, 8 | "offload_param": { 9 | "device": "cpu", 10 | "pin_memory": true 11 | }, 12 | "overlap_comm": true, 13 | "contiguous_gradients": true, 14 | "sub_group_size": 0, 15 | "reduce_bucket_size": "auto", 16 | "stage3_prefetch_bucket_size": "auto", 17 | "stage3_param_persistence_threshold": "auto", 18 | "stage3_max_live_parameters": 0, 19 | "stage3_max_reuse_distance": 0, 20 | "stage3_gather_16bit_weights_on_model_save": true 21 | }, 22 | "bf16": { 23 | "enabled": "auto" 24 | }, 25 | "fp16": { 26 | "enabled": "auto", 27 | "auto_cast": false, 28 | "loss_scale": 0, 29 | "initial_scale_power": 32, 30 | "loss_scale_window": 1000, 31 | "hysteresis": 2, 32 | "min_loss_scale": 1 33 | }, 34 | "optimizer": { 35 | "type": "AdamW", 36 | "params": { 37 | "lr": "auto", 38 | "betas": "auto", 39 | "eps": "auto", 40 | "weight_decay": "auto" 41 | } 42 | }, 43 | "scheduler": { 44 | "type": "WarmupLR", 45 | "params": { 46 | "warmup_min_lr": "auto", 47 | "warmup_max_lr": "auto", 48 | "warmup_num_steps": "auto", 49 | "warmup_type": "linear" 50 | } 51 | }, 52 | "gradient_accumulation_steps": "auto", 53 | "train_batch_size": "auto", 54 | "train_micro_batch_size_per_gpu": "auto", 55 | "wall_clock_breakdown": false 56 | } 57 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/llama_embeddings_hijack.py: -------------------------------------------------------------------------------- 1 | """ 2 | patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 3 | """ 4 | 5 | import torch 6 | import transformers.models.llama.modeling_llama 7 | from transformers.utils import logging 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): 13 | # pylint: disable=duplicate-code 14 | def noised_embed(orig_embed, noise_alpha, model): 15 | def new_func(input_ids): 16 | # during training, we add noise to the embedding 17 | # during generation, we don't add noise to the embedding 18 | if model.training: 19 | embed_init = orig_embed(input_ids) 20 | dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) 21 | mag_norm = noise_alpha / torch.sqrt(dims) 22 | return embed_init + torch.zeros_like(embed_init).uniform_( 23 | -mag_norm, mag_norm 24 | ) 25 | return orig_embed(input_ids) 26 | 27 | return new_func 28 | 29 | def post_init(orig_post_init): 30 | def new_func(self): 31 | orig_post_init(self) 32 | self.embed_tokens.forward = noised_embed( 33 | self.embed_tokens.forward, noise_alpha, self 34 | ) 35 | 36 | return new_func 37 | 38 | transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init( 39 | transformers.models.llama.modeling_llama.LlamaModel.post_init 40 | ) 41 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/mistral_embeddings_hijack.py: -------------------------------------------------------------------------------- 1 | """ 2 | patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 3 | """ 4 | 5 | import torch 6 | import transformers.models.mistral.modeling_mistral 7 | from transformers.utils import logging 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): 13 | # pylint: disable=duplicate-code 14 | def noised_embed(orig_embed, noise_alpha, model): 15 | def new_func(input_ids): 16 | # during training, we add noise to the embedding 17 | # during generation, we don't add noise to the embedding 18 | if model.training: 19 | embed_init = orig_embed(input_ids) 20 | dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) 21 | mag_norm = noise_alpha / torch.sqrt(dims) 22 | return embed_init + torch.zeros_like(embed_init).uniform_( 23 | -mag_norm, mag_norm 24 | ) 25 | return orig_embed(input_ids) 26 | 27 | return new_func 28 | 29 | def post_init(orig_post_init): 30 | def new_func(self): 31 | orig_post_init(self) 32 | self.embed_tokens.forward = noised_embed( 33 | self.embed_tokens.forward, noise_alpha, self 34 | ) 35 | 36 | return new_func 37 | 38 | transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init( 39 | transformers.models.mistral.modeling_mistral.MistralModel.post_init 40 | ) 41 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/tokenization.py: -------------------------------------------------------------------------------- 1 | """Module for tokenization utilities""" 2 | 3 | 4 | import logging 5 | 6 | from termcolor import colored 7 | 8 | LOG = logging.getLogger("axolotl") 9 | 10 | 11 | def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): 12 | # the dataset is already shuffled, so let's just check the first 5 elements 13 | for idx in range(num_examples): 14 | check_example_labels(dataset[idx], tokenizer, text_only=text_only) 15 | 16 | 17 | def check_example_labels(example, tokenizer, text_only=False): 18 | # Get the input_ids, labels, and attention_mask from the dataset 19 | input_ids = example["input_ids"] 20 | labels = example["labels"] 21 | 22 | # You can compare the input_ids and labels element-wise 23 | # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 24 | colored_tokens = [] 25 | for _, (input_id, label_id) in enumerate(zip(input_ids, labels)): 26 | decoded_input_token = tokenizer.decode(input_id) 27 | # Choose the color based on whether the label has the ignore value or not 28 | color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") 29 | colored_token = colored(decoded_input_token, color) + ( 30 | not text_only and colored(f"({label_id}, {input_id})", "white") or "" 31 | ) 32 | colored_tokens.append(colored_token) 33 | 34 | delimiter = "" if text_only else " " 35 | LOG.info(delimiter.join(colored_tokens)) 36 | LOG.info("\n\n\n") 37 | 38 | return " ".join(colored_tokens) 39 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/cli/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | CLI to run training on a model 3 | """ 4 | import logging 5 | from pathlib import Path 6 | 7 | import fire 8 | import transformers 9 | from colorama import Fore 10 | 11 | from axolotl.cli import ( 12 | check_accelerate_default_config, 13 | check_user_token, 14 | load_cfg, 15 | load_datasets, 16 | print_axolotl_text_art, 17 | ) 18 | from axolotl.common.cli import TrainerCliArgs 19 | from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH 20 | from axolotl.train import train 21 | 22 | LOG = logging.getLogger("axolotl.cli.train") 23 | 24 | 25 | def do_cli(config: Path = Path("examples/"), **kwargs): 26 | # pylint: disable=duplicate-code 27 | print_axolotl_text_art() 28 | parsed_cfg = load_cfg(config, **kwargs) 29 | check_accelerate_default_config() 30 | check_user_token() 31 | parser = transformers.HfArgumentParser((TrainerCliArgs)) 32 | parsed_cli_args, _ = parser.parse_args_into_dataclasses( 33 | return_remaining_strings=True 34 | ) 35 | if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path: 36 | msg = ( 37 | Fore.RED 38 | + "--prepare_ds_only called without dataset_prepared_path set." 39 | + Fore.RESET 40 | ) 41 | LOG.warning(msg) 42 | parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH 43 | 44 | dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) 45 | if parsed_cli_args.prepare_ds_only: 46 | return 47 | train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) 48 | 49 | 50 | if __name__ == "__main__": 51 | fire.Fire(do_cli) 52 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/orcamini.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt Strategy for finetuning Orca Mini (v2) models 3 | see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information 4 | 5 | Use dataset type: orcamini in conig.yml to use this prompt style. 6 | 7 | Compared to the alpaca_w_system.open_orca dataset type, 8 | this one specifies the system prompt with "### System:". 9 | 10 | Not suited/tested for multiple-turn conversations without further adjustments. 11 | """ 12 | from typing import Generator, Union 13 | 14 | from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy 15 | from axolotl.prompters import AlpacaPrompter 16 | 17 | 18 | class OrcaMiniPrompter(AlpacaPrompter): 19 | """Adjusted Prompter for Orca Mini (v2) datasets""" 20 | 21 | def match_prompt_style(self): 22 | self.turn_no_input_format = ( 23 | "### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n" 24 | ) 25 | 26 | def build_prompt_w_system( 27 | self, 28 | system: str, 29 | instruction: str, 30 | output: Union[None, str] = None, 31 | ) -> Generator[str, None, None]: 32 | # returns the full prompt from instruction and optional input 33 | # if a label (=response, =output) is provided, it's also appended. 34 | res = self.turn_no_input_format.format(system=system, instruction=instruction) 35 | if output: 36 | res = f"{res}{output}" 37 | yield res 38 | 39 | 40 | def load(tokenizer, cfg): 41 | return OpenOrcaPromptTokenizingStrategy( 42 | OrcaMiniPrompter(), 43 | tokenizer, 44 | cfg.train_on_inputs, 45 | cfg.sequence_len, 46 | ) 47 | -------------------------------------------------------------------------------- /training/axolotl/docs/multi-node.md: -------------------------------------------------------------------------------- 1 | # Multi Node 2 | 3 | You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below: 4 | 5 | ~/.cache/huggingface/accelerate/default_config.yaml 6 | ```yaml 7 | compute_environment: LOCAL_MACHINE 8 | debug: false 9 | distributed_type: FSDP 10 | downcast_bf16: 'no' 11 | machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines 12 | main_process_ip: 10.0.0.4 # Set to main machine's IP 13 | main_process_port: 5000 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 2 # Change to the number of machines 17 | num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8) 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | ``` 25 | 26 | Configure your model to use FSDP with for example: 27 | ```yaml 28 | fsdp: 29 | - full_shard 30 | - auto_wrap 31 | fsdp_config: 32 | fsdp_offload_params: true 33 | fsdp_state_dict_type: FULL_STATE_DICT 34 | fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer 35 | ``` 36 | 37 | ## Machine configuration 38 | 39 | On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility. 40 | 41 | You will also need to have the same configuration file for your model on each machine. 42 | 43 | On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines. 44 | 45 | All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine. 46 | -------------------------------------------------------------------------------- /training/axolotl/scripts/finetune.py: -------------------------------------------------------------------------------- 1 | """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" 2 | import logging 3 | from pathlib import Path 4 | 5 | import fire 6 | import transformers 7 | 8 | from axolotl.cli import ( 9 | check_accelerate_default_config, 10 | check_user_token, 11 | do_inference, 12 | do_merge_lora, 13 | load_cfg, 14 | load_datasets, 15 | print_axolotl_text_art, 16 | ) 17 | from axolotl.cli.shard import shard 18 | from axolotl.common.cli import TrainerCliArgs 19 | from axolotl.train import train 20 | 21 | LOG = logging.getLogger("axolotl.scripts.finetune") 22 | 23 | 24 | def do_cli(config: Path = Path("examples/"), **kwargs): 25 | print_axolotl_text_art() 26 | LOG.warning( 27 | str( 28 | PendingDeprecationWarning( 29 | "scripts/finetune.py will be replaced with calling axolotl.cli.train" 30 | ) 31 | ) 32 | ) 33 | parsed_cfg = load_cfg(config, **kwargs) 34 | check_accelerate_default_config() 35 | check_user_token() 36 | parser = transformers.HfArgumentParser((TrainerCliArgs)) 37 | parsed_cli_args, _ = parser.parse_args_into_dataclasses( 38 | return_remaining_strings=True 39 | ) 40 | if parsed_cli_args.inference: 41 | do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) 42 | elif parsed_cli_args.merge_lora: 43 | do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) 44 | elif parsed_cli_args.shard: 45 | shard(cfg=parsed_cfg, cli_args=parsed_cli_args) 46 | else: 47 | dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) 48 | if parsed_cli_args.prepare_ds_only: 49 | return 50 | train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) 51 | 52 | 53 | if __name__ == "__main__": 54 | fire.Fire(do_cli) 55 | -------------------------------------------------------------------------------- /training/axolotl/setup.py: -------------------------------------------------------------------------------- 1 | """setup.py for axolotl""" 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def parse_requirements(): 7 | _install_requires = [] 8 | _dependency_links = [] 9 | with open("./requirements.txt", encoding="utf-8") as requirements_file: 10 | lines = [r.strip() for r in requirements_file.readlines()] 11 | for line in lines: 12 | if line.startswith("--extra-index-url"): 13 | # Handle custom index URLs 14 | _, url = line.split() 15 | _dependency_links.append(url) 16 | elif ( 17 | "flash-attn" not in line 18 | and "deepspeed" not in line 19 | and line 20 | and line[0] != "#" 21 | ): 22 | # Handle standard packages 23 | _install_requires.append(line) 24 | 25 | # TODO(wing) remove once xformers release supports torch 2.1.0 26 | if "torch==2.1.0" in _install_requires: 27 | _install_requires.pop(_install_requires.index("xformers>=0.0.22")) 28 | _install_requires.append( 29 | "xformers @ git+https://github.com/facebookresearch/xformers.git@main" 30 | ) 31 | 32 | return _install_requires, _dependency_links 33 | 34 | 35 | install_requires, dependency_links = parse_requirements() 36 | 37 | 38 | setup( 39 | name="axolotl", 40 | version="0.3.0", 41 | description="LLM Trainer", 42 | long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.", 43 | package_dir={"": "src"}, 44 | packages=find_packages(), 45 | install_requires=install_requires, 46 | dependency_links=dependency_links, 47 | extras_require={ 48 | "flash-attn": [ 49 | "flash-attn>=2.3.0", 50 | ], 51 | "deepspeed": [ 52 | "deepspeed", 53 | ], 54 | }, 55 | ) 56 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/llama_expand_mask.py: -------------------------------------------------------------------------------- 1 | """ 2 | expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf 3 | """ 4 | from typing import Optional 5 | 6 | import torch 7 | 8 | 9 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 10 | """ 11 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 12 | This expansion handles packed sequences so that sequences share the same attention mask integer value 13 | when they attend to each other within that sequence. 14 | This expansion transforms the mask to lower triangular form to prevent future peeking. 15 | """ 16 | bsz, src_len = mask.size() 17 | tgt_len = tgt_len if tgt_len is not None else src_len 18 | 19 | mask = mask.unsqueeze(1).unsqueeze(2) 20 | mask = mask.expand(bsz, 1, tgt_len, src_len) 21 | 22 | # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one 23 | binary_mask = torch.where( 24 | mask != 0, 25 | torch.tensor(1).to(dtype), 26 | torch.tensor(0).to(dtype), 27 | ) 28 | 29 | # Create a block-diagonal mask. 30 | # we multiply by the binary mask so that 0's in the original mask are correctly excluded 31 | zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask 32 | 33 | # Now let's create a lower triangular mask of ones that will zero out the upper triangular part 34 | lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( 35 | mask.device 36 | ) 37 | 38 | # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask 39 | masked_zero_one_mask = zero_one_mask * lower_triangular_ones 40 | inverted_mask = 1.0 - masked_zero_one_mask 41 | 42 | return inverted_mask.masked_fill( 43 | inverted_mask.to(torch.bool), torch.finfo(dtype).min 44 | ) 45 | 46 | 47 | def hijack_expand_mask(): 48 | import transformers 49 | 50 | transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access 51 | _expand_mask 52 | ) 53 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/logging_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common logging module for axolotl 3 | """ 4 | 5 | import os 6 | import sys 7 | from logging import Formatter 8 | from logging.config import dictConfig 9 | from typing import Any, Dict 10 | 11 | from colorama import Fore, Style, init 12 | 13 | 14 | class ColorfulFormatter(Formatter): 15 | """ 16 | Formatter to add coloring to log messages by log type 17 | """ 18 | 19 | COLORS = { 20 | "WARNING": Fore.YELLOW, 21 | "ERROR": Fore.RED, 22 | "CRITICAL": Fore.RED + Style.BRIGHT, 23 | } 24 | 25 | def format(self, record): 26 | record.rank = int(os.getenv("LOCAL_RANK", "0")) 27 | log_message = super().format(record) 28 | return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET 29 | 30 | 31 | DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { 32 | "version": 1, 33 | "formatters": { 34 | "simple": { 35 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", 36 | }, 37 | "colorful": { 38 | "()": ColorfulFormatter, 39 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", 40 | }, 41 | }, 42 | "filters": {}, 43 | "handlers": { 44 | "console": { 45 | "class": "logging.StreamHandler", 46 | "formatter": "simple", 47 | "filters": [], 48 | "stream": sys.stdout, 49 | }, 50 | "color_console": { 51 | "class": "logging.StreamHandler", 52 | "formatter": "colorful", 53 | "filters": [], 54 | "stream": sys.stdout, 55 | }, 56 | }, 57 | "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, 58 | "loggers": { 59 | "axolotl": { 60 | "handlers": ["color_console"], 61 | "level": "DEBUG", 62 | "propagate": False, 63 | }, 64 | }, 65 | } 66 | 67 | 68 | def configure_logging(): 69 | """Configure with default logging""" 70 | init() # Initialize colorama 71 | dictConfig(DEFAULT_LOGGING_CONFIG) 72 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/convert.py: -------------------------------------------------------------------------------- 1 | """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" 2 | 3 | 4 | import json 5 | import sys 6 | 7 | 8 | class FileReader: 9 | """ 10 | Reads a file and returns its contents as a string 11 | """ 12 | 13 | def read(self, file_path): 14 | with open(file_path, encoding="utf-8") as file: 15 | return file.read() 16 | 17 | 18 | class FileWriter: 19 | """ 20 | Writes a string to a file 21 | """ 22 | 23 | def __init__(self, file_path): 24 | self.file_path = file_path 25 | 26 | def write(self, content): 27 | with open(self.file_path, "w", encoding="utf-8") as file: 28 | file.write(content) 29 | 30 | 31 | class StdoutWriter: 32 | """ 33 | Writes a string to stdout 34 | """ 35 | 36 | def write(self, content): 37 | sys.stdout.write(content) 38 | sys.stdout.write("\n") 39 | 40 | 41 | class JsonParser: 42 | """ 43 | Parses a string as JSON and returns the result 44 | """ 45 | 46 | def parse(self, content): 47 | return json.loads(content) 48 | 49 | 50 | class JsonlSerializer: 51 | """ 52 | Serializes a list of JSON objects into a JSONL string 53 | """ 54 | 55 | def serialize(self, data): 56 | lines = [json.dumps(item) for item in data] 57 | return "\n".join(lines) 58 | 59 | 60 | class JsonToJsonlConverter: 61 | """ 62 | Converts a JSON file to JSONL 63 | """ 64 | 65 | def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer): 66 | self.file_reader = file_reader 67 | self.file_writer = file_writer 68 | self.json_parser = json_parser 69 | self.jsonl_serializer = jsonl_serializer 70 | 71 | def convert( 72 | self, input_file_path, output_file_path 73 | ): # pylint: disable=unused-argument 74 | content = self.file_reader.read(input_file_path) 75 | data = self.json_parser.parse(content) 76 | # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations 77 | jsonl_content = self.jsonl_serializer.serialize(data) 78 | self.file_writer.write(jsonl_content) 79 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/bench.py: -------------------------------------------------------------------------------- 1 | """Benchmarking and measurement utilities""" 2 | import functools 3 | 4 | import pynvml 5 | import torch 6 | from pynvml.nvml import NVMLError 7 | 8 | 9 | def check_cuda_device(default_value): 10 | """ 11 | wraps a function and returns the default value instead of running the 12 | wrapped function if cuda isn't available or the device is auto 13 | :param default_value: 14 | :return: 15 | """ 16 | 17 | def deco(func): 18 | @functools.wraps(func) 19 | def wrapper(*args, **kwargs): 20 | device = kwargs.get("device", args[0] if args else None) 21 | 22 | if ( 23 | not torch.cuda.is_available() 24 | or device == "auto" 25 | or torch.device(device).type == "cpu" 26 | ): 27 | return default_value 28 | 29 | return func(*args, **kwargs) 30 | 31 | return wrapper 32 | 33 | return deco 34 | 35 | 36 | @check_cuda_device(0.0) 37 | def gpu_memory_usage(device=0): 38 | return torch.cuda.memory_allocated(device) / 1024.0**3 39 | 40 | 41 | @check_cuda_device((0.0, 0.0, 0.0)) 42 | def gpu_memory_usage_all(device=0): 43 | usage = torch.cuda.memory_allocated(device) / 1024.0**3 44 | reserved = torch.cuda.memory_reserved(device) / 1024.0**3 45 | smi = gpu_memory_usage_smi(device) 46 | return usage, reserved - usage, max(0, smi - reserved) 47 | 48 | 49 | @check_cuda_device(0.0) 50 | def gpu_memory_usage_smi(device=0): 51 | if isinstance(device, torch.device): 52 | device = device.index 53 | if isinstance(device, str) and device.startswith("cuda:"): 54 | device = int(device[5:]) 55 | try: 56 | pynvml.nvmlInit() 57 | handle = pynvml.nvmlDeviceGetHandleByIndex(device) 58 | info = pynvml.nvmlDeviceGetMemoryInfo(handle) 59 | return info.used / 1024.0**3 60 | except NVMLError: 61 | return 0.0 62 | 63 | 64 | def log_gpu_memory_usage(log, msg, device): 65 | usage, cache, misc = gpu_memory_usage_all(device) 66 | extras = [] 67 | if cache > 0: 68 | extras.append(f"+{cache:.03f}GB cache") 69 | if misc > 0: 70 | extras.append(f"+{misc:.03f}GB misc") 71 | log.info( 72 | f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 73 | ) 74 | return usage, cache, misc 75 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/models/phi/configuration_mixformer_sequential.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT license. 5 | 6 | import math 7 | from typing import Any, Dict, List, Optional, Union 8 | 9 | from transformers import PretrainedConfig 10 | 11 | 12 | class MixFormerSequentialConfig(PretrainedConfig): 13 | """MixFormer (sequential for DeepSpeed) configuration.""" 14 | 15 | model_type = "mixformer-sequential" 16 | 17 | attribute_map = { 18 | "max_position_embeddings": "n_positions", 19 | "hidden_size": "n_embd", 20 | "num_attention_heads": "n_head", 21 | "num_hidden_layers": "n_layer", 22 | "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility 23 | "blocks": "architecture", # `blocks` key is for backward compatibility 24 | } 25 | 26 | def __init__( 27 | self, 28 | vocab_size: Optional[int] = 50304, 29 | n_positions: Optional[int] = 2048, 30 | n_embd: Optional[int] = 1024, 31 | n_layer: Optional[int] = 20, 32 | n_inner: Optional[int] = None, 33 | n_head: Optional[int] = 16, 34 | rotary_dim: Optional[int] = 32, 35 | activation_function: Optional[str] = "gelu_new", 36 | embd_layer: Optional[str] = "default", 37 | architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None, 38 | embd_pdrop: Optional[float] = 0.0, 39 | resid_pdrop: Optional[float] = 0.0, 40 | layer_norm_epsilon: Optional[float] = 1e-5, 41 | initializer_range: Optional[float] = 0.02, 42 | tie_word_embeddings: Optional[bool] = False, 43 | pad_vocab_size_multiple: Optional[int] = 64, 44 | **kwargs 45 | ) -> None: 46 | self.vocab_size = int( 47 | math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple 48 | ) 49 | self.n_positions = n_positions 50 | self.n_embd = n_embd 51 | self.n_layer = n_layer 52 | self.n_inner = n_inner 53 | self.n_head = n_head 54 | self.rotary_dim = min(rotary_dim, n_embd // n_head) 55 | self.activation_function = activation_function 56 | self.embd_layer = embd_layer 57 | self.architecture = architecture 58 | self.embd_pdrop = embd_pdrop 59 | self.resid_pdrop = resid_pdrop 60 | self.layer_norm_epsilon = layer_norm_epsilon 61 | self.initializer_range = initializer_range 62 | 63 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) 64 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flash attention monkey patch for cerebras btlm model 3 | """ 4 | 5 | import importlib 6 | import logging 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | from accelerate import init_empty_weights 11 | from flash_attn.flash_attn_interface import flash_attn_func 12 | from transformers import AutoConfig, AutoModelForCausalLM 13 | 14 | LOG = logging.getLogger("axolotl") 15 | 16 | 17 | def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): 18 | # this is a wonky hack to get the remotely loaded module 19 | model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) 20 | # we need to load the model here in order for modeling_btlm to be available 21 | with init_empty_weights(): 22 | AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) 23 | module_name = model_config.__class__.__module__.replace( 24 | ".configuration_btlm", ".modeling_btlm" 25 | ) 26 | modeling_btlm = importlib.import_module(module_name) 27 | modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access 28 | flashattn_attn 29 | ) 30 | 31 | 32 | def flashattn_attn( 33 | self, 34 | query: torch.Tensor, 35 | key: Optional[torch.Tensor] = None, 36 | value: Optional[torch.Tensor] = None, 37 | attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument 38 | head_mask: Optional[torch.Tensor] = None, 39 | position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument 40 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 41 | softmax_scale = ( 42 | 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None 43 | ) 44 | 45 | query = query.permute(0, 2, 1, 3) 46 | key = key.permute(0, 2, 1, 3) 47 | value = value.permute(0, 2, 1, 3) 48 | 49 | # Perform Flash attention 50 | attn_output = flash_attn_func( 51 | query, 52 | key, 53 | value, 54 | dropout_p=0.0, # Assuming you have this attribute 55 | softmax_scale=softmax_scale, # Set this if you have specific scaling in mind 56 | causal=not self.is_cross_attention, # Assuming you have this attribute 57 | return_attn_probs=False, # Set this based on your needs 58 | ) 59 | 60 | # Optional: Apply head mask if it's not None 61 | if head_mask is not None: 62 | attn_output *= head_mask 63 | 64 | attn_output = attn_output.permute(0, 2, 1, 3) 65 | 66 | return attn_output, None # We don't have explicit attn_weights in Flash attention 67 | -------------------------------------------------------------------------------- /training/axolotl/examples/mistral/nips/nips_02.yml: -------------------------------------------------------------------------------- 1 | base_model: mistralai/Mistral-7B-v0.1 2 | base_model_config: mistralai/Mistral-7B-v0.1 3 | model_type: MistralForCausalLM 4 | tokenizer_type: LlamaTokenizer 5 | is_mistral_derived_model: true 6 | 7 | load_in_8bit: false 8 | load_in_4bit: true 9 | strict: false 10 | 11 | datasets: 12 | - path: upaya07/NeurIPS-LLM-data 13 | type: alpaca_chat.load_qa 14 | 15 | random_split: False 16 | instances_for_test_metric: 2000 # Used only if random_split = False. Limits number of test samples to use for evaluation. 17 | val_set_size: 0.1 # used when random_split = True. How much % from train to use for building test data. 18 | 19 | # ============ Change these paths ======================= 20 | dataset_prepared_path: ./NeurIPS-LLM-data 21 | output_dir: ./qlora-NeurIPS-LLM-model 22 | 23 | wandb_project: Mistral-7B-ft-NeurIPS 24 | wandb_run_id: NeurIPS_model_01 25 | 26 | sequence_len: 8192 27 | sample_packing: true 28 | eval_sample_packing: true 29 | pad_to_sequence_len: true 30 | 31 | 32 | adapter: qlora 33 | 34 | lora_r: 128 35 | lora_alpha: 256 36 | lora_dropout: 0.05 37 | lora_target_linear: true 38 | lora_fan_in_fan_out: 39 | lora_target_modules: 40 | - q_proj 41 | - v_proj 42 | - k_proj 43 | - o_proj 44 | - gate_proj 45 | - down_proj 46 | - up_proj 47 | # - lm_head 48 | 49 | #lora_target_modules: 50 | # - gate_proj 51 | # - down_proj 52 | # - up_proj 53 | #lora_target_modules: 54 | # - q_proj 55 | # - v_proj 56 | 57 | noisy_embedding_alpha: 5 58 | gradient_accumulation_steps: 4 59 | micro_batch_size: 2 60 | num_epochs: 3 61 | #max_steps: 32 62 | optimizer: paged_adamw_32bit 63 | #optimizer: adamw_torch 64 | lr_scheduler: cosine 65 | learning_rate: 0.00002 66 | 67 | train_on_inputs: false 68 | group_by_length: false 69 | bf16: true 70 | fp16: false 71 | tf32: false 72 | 73 | bfloat16: true 74 | 75 | gradient_checkpointing: true 76 | early_stopping_patience: 5000000 77 | local_rank: 78 | logging_steps: 1 79 | xformers_attention: 80 | flash_attention: true 81 | 82 | warmup_steps: 100 83 | eval_steps: 64 84 | #eval_table_size: 4 85 | #eval_table_max_new_tokens: 128 86 | save_steps: 64 87 | save_total_limit: 10 88 | metric_for_best_model: "eval_loss" 89 | greater_is_better: False 90 | load_best_model_at_end: True 91 | debug: true 92 | debug_num_examples: 1 93 | 94 | weight_decay: 0.01 95 | # 96 | max_grad_norm: 0.3 97 | 98 | # deepspeed: 99 | fsdp: 100 | fsdp_config: 101 | special_tokens: 102 | bos_token: "" 103 | eos_token: "" 104 | unk_token: "" 105 | 106 | #seed: 57 107 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/metharme.py: -------------------------------------------------------------------------------- 1 | """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" 2 | 3 | import logging 4 | from typing import Tuple 5 | 6 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy 7 | from axolotl.prompters import AlpacaPrompter 8 | 9 | LOG = logging.getLogger("axolotl") 10 | 11 | IGNORE_TOKEN_ID = -100 12 | 13 | # pylint: disable=duplicate-code 14 | 15 | 16 | class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 17 | """ 18 | Tokenizing strategy for the Metharme models 19 | """ 20 | 21 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 22 | return (prompt["prompt"], "", prompt["generation"]) 23 | 24 | def _tokenize( 25 | self, 26 | prompt: str, 27 | add_eos_token: bool = True, 28 | strip_bos_token: bool = False, 29 | num_eos_tokens: int = 3, 30 | ): 31 | result = self.tokenizer( 32 | prompt, 33 | truncation=True, 34 | max_length=self.sequence_len, 35 | padding=False, 36 | return_tensors=None, 37 | ) 38 | if len(result["input_ids"]) == 0: 39 | LOG.warning("Tokenizer result is empty. You may want to audit your dataset") 40 | # If there's already an EOS token there, subtract from the number added 41 | if result["input_ids"][-1] == self.tokenizer.eos_token_id: 42 | num_eos_tokens -= 1 43 | 44 | if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: 45 | for _ in range(num_eos_tokens): 46 | if len(result["input_ids"]) < self.sequence_len: 47 | result["input_ids"].append(self.tokenizer.eos_token_id) 48 | result["attention_mask"].append(1) 49 | 50 | if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: 51 | result["input_ids"] = result["input_ids"][1:] 52 | result["attention_mask"] = result["attention_mask"][1:] 53 | 54 | result["labels"] = result["input_ids"].copy() 55 | return result 56 | 57 | 58 | class MetharmePrompter(AlpacaPrompter): 59 | """ 60 | Prompter for the Metharme models. 61 | """ 62 | 63 | system_prompt = "" 64 | system_no_input_prompt = "" 65 | system_format = "" 66 | turn_format = "{instruction}" 67 | turn_no_input_format = "{instruction}" 68 | 69 | def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called 70 | pass 71 | 72 | 73 | def load(tokenizer, cfg): 74 | return MetharmePromptTokenizingStrategy( 75 | MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len 76 | ) 77 | -------------------------------------------------------------------------------- /data_prep/prepare_math_reasoning_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from datasets import load_dataset 4 | from difflib import SequenceMatcher 5 | 6 | lower_case_prob = 0.2 7 | 8 | def modify_input(question): 9 | num = random.randint(1, 10) 10 | if num <= 3: 11 | question = question.lower() 12 | return question 13 | 14 | def remove_hash(answer: str): 15 | if "####" in answer: 16 | return answer[:answer.rindex("####")].strip() 17 | return answer 18 | 19 | def format_response(answer: str, answer_identifier: str): 20 | answer_prefix_len = len(answer_identifier) 21 | if answer_identifier in answer: 22 | answer_prefix_start_idx = answer.index(answer_identifier) 23 | reasoning = remove_hash(answer[:answer_prefix_start_idx].strip()) 24 | 25 | # ==== Enable it if we want to add "answer" as part of output 26 | answer = answer[answer_prefix_start_idx:].strip() 27 | assert len(answer) > 0 28 | # answer = "Answer: " + answer 29 | return f"{reasoning}\n{answer.strip()}" 30 | else: 31 | return answer 32 | 33 | new_math_recs = [] 34 | valid_records = 0 35 | 36 | math_instruct_dataset = load_dataset("TIGER-Lab/MathInstruct", "train") 37 | valid_sources = ['data/CoT/gsm_train.json', 'data/CoT/aqua_rat.json', 'data/CoT/MATH_train.json'] 38 | print(f"math_instruct_dataset size: {len(math_instruct_dataset['train'])}") 39 | for each in math_instruct_dataset["train"]: 40 | 41 | if each['source'] in valid_sources: 42 | output = {} 43 | output['instruction'] = "" 44 | output['input'] = modify_input(each['instruction']).strip() 45 | output['output'] = format_response(each['output'], "The answer is:").strip() 46 | 47 | new_math_recs.append(output) 48 | 49 | valid_records += 1 50 | 51 | print(valid_records) 52 | 53 | instruction_prefix = "### Instruction:\n" 54 | input_prefix = "### Input:\n" 55 | response_prefix = "### Response:\n" 56 | new_math_dataset = [] 57 | for each in new_math_recs: 58 | if len(each['input'].split(" ")) <= 4 or len(each['output'].split(" ")) < 1: 59 | continue 60 | 61 | rec = {} 62 | rec['question'] = '' 63 | if len(each['instruction'].strip()) > 0: 64 | rec['question'] += f"{instruction_prefix}{each['instruction']}\n\n" 65 | rec['question'] += f"{input_prefix}{each['input']}\n\n" 66 | rec['question'] += f"{response_prefix}" 67 | rec['answer'] = f"{each['output']}" 68 | new_math_dataset.append(rec) 69 | 70 | print(f"new_math_dataset size: {len(new_math_dataset)}") 71 | random.shuffle(new_math_dataset) 72 | sampled_dataset = new_math_dataset[:50000] 73 | 74 | with open('/Users/ajindal/Downloads/math_reasoning.json', 'w') as f: 75 | json.dump(sampled_dataset, f, indent=1) 76 | -------------------------------------------------------------------------------- /training/axolotl/docker/Dockerfile-base: -------------------------------------------------------------------------------- 1 | ARG CUDA_VERSION="11.8.0" 2 | ARG CUDNN_VERSION="8" 3 | ARG UBUNTU_VERSION="22.04" 4 | ARG MAX_JOBS=4 5 | 6 | FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder 7 | 8 | ENV PATH="/root/miniconda3/bin:${PATH}" 9 | 10 | ARG PYTHON_VERSION="3.9" 11 | ARG PYTORCH_VERSION="2.0.1" 12 | ARG CUDA="118" 13 | 14 | ENV PYTHON_VERSION=$PYTHON_VERSION 15 | 16 | RUN apt-get update \ 17 | && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ 18 | && wget \ 19 | https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 20 | && mkdir /root/.conda \ 21 | && bash Miniconda3-latest-Linux-x86_64.sh -b \ 22 | && rm -f Miniconda3-latest-Linux-x86_64.sh \ 23 | && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" 24 | 25 | ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" 26 | 27 | WORKDIR /workspace 28 | 29 | RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ 30 | python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA 31 | 32 | FROM base-builder AS deepspeed-builder 33 | 34 | ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" 35 | 36 | WORKDIR /workspace 37 | 38 | RUN git clone https://github.com/microsoft/DeepSpeed.git && \ 39 | cd DeepSpeed && \ 40 | MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel 41 | 42 | FROM base-builder AS bnb-builder 43 | 44 | WORKDIR /workspace 45 | ARG CUDA="118" 46 | ENV CUDA=$CUDA 47 | ARG MAX_JOBS="-1" 48 | ENV MAX_JOBS=$MAX_JOBS 49 | 50 | RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \ 51 | cd bitsandbytes && \ 52 | CUDA_VERSION=$CUDA make cuda11x && \ 53 | python setup.py bdist_wheel 54 | 55 | FROM base-builder 56 | 57 | ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" 58 | ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 59 | 60 | RUN mkdir -p /workspace/builds 61 | COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes 62 | 63 | RUN mkdir -p /workspace/wheels/bitsandbytes 64 | COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels 65 | COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels 66 | COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes 67 | 68 | RUN pip3 install wheels/deepspeed-*.whl 69 | RUN cd /workspace/builds/bitsandbytes && python3 setup.py install 70 | RUN git lfs install --skip-repo 71 | RUN pip3 install awscli && \ 72 | # The base image ships with `pydantic==1.8.2` which is not working 73 | pip3 install -U --no-cache-dir pydantic==1.10.10 74 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/completion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic completion text 3 | """ 4 | from collections import defaultdict 5 | from typing import Any, Dict, Generator, Optional, Tuple 6 | 7 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy 8 | 9 | 10 | class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 11 | """ 12 | Tokenizing strategy for Completion prompts. 13 | """ 14 | 15 | _field: str = "text" 16 | 17 | def __init__(self, *args, max_length=None, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | if max_length is not None: 20 | self.max_length = max_length 21 | 22 | @property 23 | def supports_batched(self): 24 | return True 25 | 26 | @property 27 | def field(self) -> str: 28 | return self._field 29 | 30 | @field.setter 31 | def field(self, new_field: str): 32 | self._field = new_field 33 | 34 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 35 | return ( 36 | prompt[self.field], 37 | "", 38 | "", 39 | ) 40 | 41 | def tokenize_prompt(self, prompt): 42 | res = defaultdict(lambda: []) 43 | feature_names = list(prompt.keys()) 44 | for row in zip(*prompt.values()): 45 | prompt_row = dict(zip(feature_names, row)) 46 | ( 47 | instruction, 48 | _, 49 | _, 50 | ) = self.parse_instruction_fields(prompt_row) 51 | 52 | full_prompt = self._build_full_prompt(instruction, None, None) 53 | tokenized_full_prompt = self._tokenize(full_prompt) 54 | 55 | for key, val in tokenized_full_prompt.items(): 56 | for i in range(0, len(val), self.sequence_len): 57 | res[key].append(val[i : i + self.sequence_len]) 58 | 59 | return dict(res) 60 | 61 | def _build_full_prompt( 62 | self, instruction, input, response 63 | ): # pylint: disable=redefined-builtin 64 | return next(iter(self.prompter.build_prompt(instruction, input, response))) 65 | 66 | 67 | class CompletionPrompter: 68 | """ 69 | Prompter for completion 70 | """ 71 | 72 | def build_prompt( 73 | self, 74 | instruction: str, 75 | input=None, # pylint: disable=redefined-builtin, unused-argument 76 | output=None, # pylint: disable=unused-argument 77 | ) -> Generator[str, None, None]: 78 | yield instruction 79 | 80 | 81 | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): 82 | strat = CompletionPromptTokenizingStrategy( 83 | CompletionPrompter(), 84 | tokenizer, 85 | cfg.train_on_inputs, 86 | cfg.sequence_len, 87 | max_length=cfg.sequence_len * 64, 88 | ) 89 | if ds_cfg and "field" in ds_cfg: 90 | strat.field = ds_cfg["field"] 91 | 92 | return strat 93 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/user_defined.py: -------------------------------------------------------------------------------- 1 | """ 2 | User Defined prompts with configuration from the YML config 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from functools import partial 7 | from typing import Optional, Tuple 8 | 9 | from axolotl.prompt_strategies.alpaca_w_system import ( 10 | InstructionWSystemPromptTokenizingStrategy, 11 | SystemDataPrompter, 12 | ) 13 | 14 | 15 | @dataclass 16 | class UserDefinedDatasetConfig: 17 | """ 18 | dataclass configuration representing a userdefined dataset type 19 | """ 20 | 21 | system_prompt: str = "" 22 | field_system: str = "system" 23 | field_instruction: str = "instruction" 24 | field_input: str = "input" 25 | field_output: str = "output" 26 | format: str = "{instruction} {input} " 27 | no_input_format: str = "{instruction} " 28 | system_format: str = "{system}" 29 | 30 | def __getitem__(self, item): 31 | return getattr(self, item) 32 | 33 | 34 | class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy): 35 | """ 36 | Prompt Tokenization Strategy for user defined prompts 37 | """ 38 | 39 | 40 | def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None): 41 | if not ds_cfg: 42 | raise ValueError("Missing dataset prompt configuration") 43 | 44 | system_prompt = "" 45 | if ds_cfg.system_prompt: 46 | system_prompt = ds_cfg.system_prompt 47 | 48 | def parse_instruction_fields( 49 | field_instruction, 50 | field_input, 51 | field_output, 52 | field_system, 53 | system_prompt, 54 | prompt, 55 | ) -> Tuple[str, str, str, str]: 56 | return ( 57 | prompt[field_instruction], 58 | prompt[field_input] if field_input in prompt else "", 59 | prompt[field_output] if field_output in prompt else "", 60 | prompt[field_system] if field_system in prompt else system_prompt, 61 | ) 62 | 63 | turn_format = ds_cfg.format 64 | turn_no_input_format = ds_cfg.no_input_format 65 | system_format = ds_cfg.system_format 66 | 67 | class UserDefinedPrompter(SystemDataPrompter): 68 | """ 69 | Prompter for user defined prompts 70 | """ 71 | 72 | def match_prompt_style(self): 73 | self.turn_format = turn_format 74 | self.turn_no_input_format = turn_no_input_format 75 | self.system_format = system_format 76 | 77 | prompter = UserDefinedPrompter() 78 | 79 | strat = UserDefinedPromptTokenizationStrategy( 80 | prompter, 81 | tokenizer, 82 | cfg.train_on_inputs, 83 | cfg.sequence_len, 84 | ) 85 | 86 | setattr( 87 | strat, 88 | "parse_instruction_fields", 89 | partial( 90 | parse_instruction_fields, 91 | ds_cfg.field_instruction, 92 | ds_cfg.field_input, 93 | ds_cfg.field_output, 94 | ds_cfg.field_system, 95 | system_prompt, 96 | ), 97 | ) 98 | return strat 99 | -------------------------------------------------------------------------------- /training/axolotl/docs/nccl.md: -------------------------------------------------------------------------------- 1 | # NCCL 2 | 3 | NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort: 4 | 5 | ```text 6 | Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out. 7 | ``` 8 | 9 | Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you. 10 | 11 | Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command: 12 | 13 | ```shell 14 | nvidia-smi nvlink --status 15 | ``` 16 | 17 | To force NCCL to use NVLink, simply set this in the environment: 18 | 19 | ```shell 20 | export NCCL_P2P_LEVEL=NVL 21 | ``` 22 | 23 | If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below: 24 | 25 | | NCCL_P2P_LEVEL | Description | 26 | | -------------- | ----------- | 27 | | PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. | 28 | | PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. | 29 | | PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) | 30 | 31 | To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example: 32 | 33 | ```shell 34 | ./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3 35 | ``` 36 | 37 | It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL: 38 | 39 | ```shell 40 | export NCCL_DEBUG=INFO 41 | export NCCL_DEBUG_SUBSYS=ALL 42 | export TORCH_DISTRIBUTED_DEBUG=INFO 43 | export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log 44 | ``` 45 | 46 | Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value. 47 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/sharegpt.py: -------------------------------------------------------------------------------- 1 | """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" 2 | from typing import Any, Dict, Optional 3 | 4 | from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template 5 | 6 | from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy 7 | from axolotl.prompters import ShareGPTPrompterV2 8 | 9 | register_conv_template( 10 | Conversation( 11 | name="chatml", 12 | system_template="<|im_start|>system\n{system_message}", 13 | system_message="You are a helpful assistant.", 14 | roles=["<|im_start|>user", "<|im_start|>assistant"], 15 | sep_style=SeparatorStyle.CHATML, 16 | sep="<|im_end|>\n", 17 | ) 18 | ) 19 | 20 | 21 | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): 22 | conversation = ( 23 | ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None 24 | ) 25 | field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None 26 | field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None 27 | return SimpleShareGPTPromptTokenizingStrategy( 28 | ShareGPTPrompterV2( 29 | conversation=conversation, 30 | role_key_model=field_model, 31 | role_key_human=field_human, 32 | ), 33 | tokenizer, 34 | cfg.train_on_inputs, 35 | cfg.sequence_len, 36 | ) 37 | 38 | 39 | def load_role(tokenizer, cfg): 40 | return SimpleRoleShareGPTPromptTokenizingStrategy( 41 | ShareGPTPrompterV2(), 42 | tokenizer, 43 | cfg.train_on_inputs, 44 | cfg.sequence_len, 45 | ) 46 | 47 | 48 | def load_guanaco(tokenizer, cfg): 49 | return GuanacoShareGPTPromptTokenizingStrategy( 50 | ShareGPTPrompterV2(), 51 | tokenizer, 52 | cfg.train_on_inputs, 53 | cfg.sequence_len, 54 | ) 55 | 56 | 57 | class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): 58 | """ 59 | basic sharegpt strategy to grab conversations from the sample row 60 | """ 61 | 62 | def get_conversation_thread(self, prompt): 63 | return prompt["conversations"] 64 | 65 | 66 | class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): 67 | """ 68 | basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from 69 | """ 70 | 71 | def get_conversation_thread(self, prompt): 72 | conversations = prompt["conversations"] 73 | # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... 74 | turns = [{"from": t["role"], "value": t["value"]} for t in conversations] 75 | return turns 76 | 77 | 78 | class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): 79 | """ 80 | sharegpt strategy that remaps oasst data to sharegpt format 81 | """ 82 | 83 | def get_conversation_thread(self, prompt): 84 | conversations = prompt["conversations"] 85 | # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... 86 | role_map = {"prompter": "human", "assistant": "gpt"} 87 | turns = [ 88 | {"from": role_map[t["role"]], "value": t["text"]} for t in conversations 89 | ] 90 | return turns 91 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | """ 3 | Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py 4 | """ 5 | import torch 6 | import transformers 7 | import transformers.models.llama.modeling_llama 8 | from einops import rearrange 9 | 10 | 11 | class XposRotaryEmbedding(torch.nn.Module): 12 | def __init__( 13 | self, 14 | dim, 15 | max_position_embeddings=2048, 16 | base=10000, 17 | device=None, 18 | scale_base=2048, 19 | use_xpos=True, 20 | ): 21 | super().__init__() 22 | self.max_seq_len_cached = max_position_embeddings 23 | self.scale_base = scale_base 24 | 25 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 26 | t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq) 27 | freqs = torch.einsum("i , j -> i j", t, inv_freq) 28 | freqs = torch.cat((freqs, freqs), dim=-1) 29 | 30 | self.register_buffer("inv_freq", inv_freq, persistent=False) 31 | self.register_buffer("freqs_cached", freqs, persistent=False) 32 | 33 | if not use_xpos: 34 | self.register_buffer("scale", None) 35 | self.register_buffer("scale_cached", torch.ones(1)) 36 | return 37 | 38 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) 39 | power = (t - (self.max_seq_len_cached // 2)) / self.scale_base 40 | scale_cached = scale ** rearrange(power, "n -> n 1") 41 | scale_cached = torch.cat((scale_cached, scale_cached), dim=-1) 42 | 43 | self.register_buffer("scale", scale, persistent=False) 44 | self.register_buffer("scale_cached", scale_cached, persistent=False) 45 | 46 | def forward( 47 | self, 48 | x, 49 | seq_len, 50 | ): 51 | if seq_len > self.max_seq_len_cached: 52 | self.max_seq_len_cached = seq_len 53 | t = torch.arange(self.max_seq_len_cached, device=x.device).type_as( 54 | self.inv_freq 55 | ) 56 | freqs = torch.einsum("i , j -> i j", t, self.inv_freq) 57 | freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype) 58 | 59 | self.register_buffer("freqs_cached", freqs) 60 | 61 | if self.scale is None: 62 | self.register_buffer( 63 | "scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype) 64 | ) 65 | 66 | return self.freqs_cached.to(dtype=x.dtype), self.scale_cached 67 | 68 | power = (t - (seq_len // 2)) / self.scale_base 69 | scale = self.scale ** rearrange(power, "n -> n 1") 70 | scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype) 71 | self.register_buffer("scale_cached", scale) 72 | 73 | return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype) 74 | 75 | 76 | def rotate_half(x): 77 | x1, x2 = x.chunk(2, dim=-1) 78 | return torch.cat((-x2, x1), dim=-1) 79 | 80 | 81 | def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None): 82 | freqs = freqs[position_ids, :] 83 | if scale.shape[-1] != 1: 84 | scale = scale[position_ids, :] 85 | 86 | q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale) 87 | k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale) 88 | 89 | return q_embed, k_embed 90 | 91 | 92 | def replace_llama_rope_with_xpos_rope(): 93 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding 94 | transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb 95 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/pygmalion.py: -------------------------------------------------------------------------------- 1 | """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" 2 | 3 | import copy 4 | import logging 5 | from collections import defaultdict 6 | from typing import Generator, List, Tuple 7 | 8 | from axolotl.prompt_tokenizers import ( 9 | PromptTokenizingStrategy, 10 | parse_tokenized_to_result, 11 | tokenize_prompt_default, 12 | ) 13 | 14 | LOG = logging.getLogger("axolotl") 15 | 16 | IGNORE_TOKEN_ID = -100 17 | 18 | 19 | class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): 20 | """ 21 | Tokenizing strategy for Pygmalion. 22 | """ 23 | 24 | bot_prefix_token_ids: List[int] = [] 25 | 26 | def __init__(self, prompter, tokenizer, *args, **kwargs): 27 | super().__init__(prompter, tokenizer, *args, **kwargs) 28 | res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) 29 | self.bot_prefix_token_ids = res["input_ids"] 30 | 31 | def tokenize_prompt(self, prompt): 32 | result, current_len = tokenize_prompt_default() 33 | for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): 34 | role, message = part 35 | if role == "system": 36 | prefix = "<|system|>" 37 | # this should include a bos token, no eos token, strip trailing "\n" 38 | if message.endswith("\n"): 39 | message = message[:-8] 40 | res = self._tokenize( 41 | prefix + "Persona: " + message.strip(), 42 | add_eos_token=False, 43 | strip_bos_token=False, 44 | ) 45 | # everything from this is masked out from the labels 46 | labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) 47 | elif role == "human": 48 | prefix = "<|user|>" 49 | res = self._tokenize( 50 | prefix + " " + message.strip(), 51 | add_eos_token=False, 52 | strip_bos_token=True, 53 | ) 54 | # everything from this is masked out from the labels 55 | labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) 56 | elif role == "bot": 57 | prefix = "<|model|>" 58 | res = self._tokenize( 59 | prefix + " " + message.strip(), 60 | add_eos_token=True, 61 | strip_bos_token=True, 62 | ) 63 | # mask out the prefix token, rest is not masked out from labels 64 | # make sure we create the labels first, otherwise we get incorrect lengths 65 | labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [ 66 | *copy.deepcopy(res["input_ids"]) 67 | ][len(self.bot_prefix_token_ids) :] 68 | else: 69 | LOG.warning(f"unknown role in conversation: {role}") 70 | res = defaultdict(lambda: []) 71 | 72 | # pylint: disable=duplicate-code 73 | result, current_len = parse_tokenized_to_result( 74 | result, 75 | current_len, 76 | res, 77 | labels, 78 | pad_token_id=self.tokenizer.pad_token_id, 79 | ) 80 | return result 81 | 82 | 83 | class PygmalionPrompter: 84 | """ 85 | Prompter for Pygmalion. 86 | """ 87 | 88 | def __init__(self, *args, **kwargs): 89 | pass 90 | 91 | def build_prompt( 92 | self, source, *args, **kwargs # pylint: disable=unused-argument 93 | ) -> Generator[Tuple[str, str], None, None]: 94 | for msg in source: 95 | yield msg["role"], msg["value"] 96 | 97 | 98 | def load(tokenizer, cfg): 99 | return PygmalionPromptTokenizingStrategy( 100 | PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len 101 | ) 102 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | """Module for custom LRScheduler class""" 2 | import math 3 | from functools import partial 4 | 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import LambdaLR, LRScheduler 7 | 8 | 9 | class InterpolatingLogScheduler(LRScheduler): 10 | """ 11 | A scheduler that interpolates learning rates in a logarithmic fashion 12 | """ 13 | 14 | def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): 15 | """A scheduler that interpolates learning rates in a logarithmic fashion 16 | 17 | Args: 18 | - optimizer: pytorch optimizer 19 | - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr 20 | - min_lr: float, the minimum learning rate 21 | - max_lr: float, the maximum learning rate 22 | 23 | Usage: 24 | fc = nn.Linear(1,1) 25 | optimizer = optim.Adam(fc.parameters()) 26 | lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) 27 | """ 28 | self.num_steps = num_steps 29 | self.min_lr = min_lr 30 | self.max_lr = max_lr 31 | self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name 32 | 1 / (num_steps - 1) 33 | ) 34 | super().__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | if self.last_epoch <= 0: 38 | lrs = [self.min_lr for base_lr in self.base_lrs] 39 | elif self.last_epoch < self.num_steps: 40 | lrs = [ 41 | self.min_lr * (self.q ** (self.last_epoch - 1)) 42 | for base_lr in self.base_lrs 43 | ] 44 | else: 45 | lrs = [self.max_lr for base_lr in self.base_lrs] 46 | 47 | return lrs 48 | 49 | 50 | def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( 51 | current_step: int, 52 | *, 53 | num_warmup_steps: int, 54 | num_training_steps: int, 55 | num_cycles: float 56 | ): 57 | if current_step < num_warmup_steps: 58 | return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 59 | progress = float(current_step - num_warmup_steps) / float( 60 | max(1, num_training_steps - num_warmup_steps) 61 | ) 62 | return max( 63 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) 64 | ) 65 | 66 | 67 | def get_cosine_schedule_with_quadratic_warmup( 68 | optimizer: Optimizer, 69 | num_warmup_steps: int, 70 | num_training_steps: int, 71 | num_cycles: float = 0.5, 72 | last_epoch: int = -1, 73 | ): 74 | """ 75 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 76 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 77 | initial lr set in the optimizer. 78 | 79 | Args: 80 | optimizer ([`~torch.optim.Optimizer`]): 81 | The optimizer for which to schedule the learning rate. 82 | num_warmup_steps (`int`): 83 | The number of steps for the warmup phase. 84 | num_training_steps (`int`): 85 | The total number of training steps. 86 | num_cycles (`float`, *optional*, defaults to 0.5): 87 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 88 | following a half-cosine). 89 | last_epoch (`int`, *optional*, defaults to -1): 90 | The index of the last epoch when resuming training. 91 | 92 | Return: 93 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 94 | """ 95 | 96 | lr_lambda = partial( 97 | _get_cosine_schedule_with_quadratic_warmup_lr_lambda, 98 | num_warmup_steps=num_warmup_steps, 99 | num_training_steps=num_training_steps, 100 | num_cycles=num_cycles, 101 | ) 102 | return LambdaLR(optimizer, lr_lambda, last_epoch) 103 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/alpaca_chat.py: -------------------------------------------------------------------------------- 1 | """Module for Alpaca prompt strategy classes""" 2 | 3 | from typing import Any, Dict, Optional, Tuple 4 | 5 | from axolotl.prompt_tokenizers import ( 6 | AlpacaPromptTokenizingStrategy, 7 | InstructionPromptTokenizingStrategy, 8 | ) 9 | from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter 10 | 11 | 12 | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): 13 | prompt_style = PromptStyle.CHAT.value 14 | if ds_cfg and "conversation" in ds_cfg: 15 | prompt_style = ds_cfg["conversation"] 16 | 17 | return AlpacaPromptTokenizingStrategy( 18 | AlpacaPrompter(prompt_style), 19 | tokenizer, 20 | cfg.train_on_inputs, 21 | cfg.sequence_len, 22 | ) 23 | 24 | 25 | class AlpacaConcisePrompter(AlpacaPrompter): 26 | """ 27 | Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers 28 | """ 29 | 30 | system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" 31 | system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" 32 | 33 | 34 | class AlpacaChatPrompter(AlpacaPrompter): 35 | """ 36 | Alpaca Chat Prompter extending the system prompt to for chat-instruct answers 37 | """ 38 | 39 | system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" 40 | system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" 41 | 42 | system_prompt = "" 43 | system_no_input_prompt = "" 44 | 45 | def __init__(self): # pylint: disable=super-init-not-called 46 | self.prompt_style = PromptStyle.CHAT.value 47 | self.match_prompt_style() 48 | 49 | 50 | class NoSystemPrompter(AlpacaPrompter): 51 | """ 52 | Null Prompter with no system prompts 53 | """ 54 | 55 | system_prompt = "" 56 | system_no_input_prompt = "" 57 | turn_format = "{instruction} {input} " 58 | turn_no_input_format = "{instruction} " 59 | 60 | def __init__(self): # pylint: disable=super-init-not-called 61 | pass 62 | 63 | 64 | class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 65 | """ 66 | Tokenizing strategy for AlpacaQA 67 | """ 68 | 69 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 70 | return ( 71 | prompt["question"], 72 | "", 73 | prompt["answer"], 74 | ) 75 | 76 | 77 | class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 78 | """ 79 | Tokenizing strategy for CamelAI datasets 80 | """ 81 | 82 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 83 | return ( 84 | prompt["message_1"], 85 | "", 86 | prompt["message_2"], 87 | ) 88 | 89 | 90 | def load_concise(tokenizer, cfg): 91 | return AlpacaPromptTokenizingStrategy( 92 | AlpacaConcisePrompter(PromptStyle.CHAT.value), 93 | tokenizer, 94 | cfg.train_on_inputs, 95 | cfg.sequence_len, 96 | ) 97 | 98 | 99 | def load_qa(tokenizer, cfg): 100 | return AlpacaQAPromptTokenizingStrategy( 101 | AlpacaChatPrompter(), 102 | tokenizer, 103 | cfg.train_on_inputs, 104 | cfg.sequence_len, 105 | ) 106 | 107 | 108 | def load_camel_ai(tokenizer, cfg): 109 | return CamelAIPromptTokenizingStrategy( 110 | AlpacaChatPrompter(), 111 | tokenizer, 112 | cfg.train_on_inputs, 113 | cfg.sequence_len, 114 | ) 115 | 116 | 117 | def load_no_prompt(tokenizer, cfg): 118 | return AlpacaPromptTokenizingStrategy( 119 | UnpromptedPrompter(PromptStyle.CHAT.value), 120 | tokenizer, 121 | cfg.train_on_inputs, 122 | cfg.sequence_len, 123 | ) 124 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared utils for the monkeypatches 3 | """ 4 | import torch 5 | 6 | 7 | def get_cu_seqlens(attn_mask): 8 | """generate a cumulative sequence length mask for flash attention using attn mask""" 9 | if len(attn_mask.shape) == 1: 10 | attn_mask = attn_mask.unsqueeze(0) 11 | 12 | device = attn_mask.device 13 | results = [] 14 | max_seq_lens = [] 15 | 16 | for row in attn_mask: 17 | # Exclude zeros to avoid adding their positions to the mask 18 | t_non_zeros = row[row != 0] 19 | # Find where the sequence number changes (including the first position) 20 | seq_change = torch.cat( 21 | [ 22 | torch.tensor([1], dtype=torch.int32, device=device), 23 | t_non_zeros[1:] != t_non_zeros[:-1], 24 | ] 25 | ) 26 | # Get the indices where the sequence changes 27 | change_indices = torch.cat( 28 | [ 29 | (seq_change == 1).nonzero(as_tuple=True)[0], 30 | torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), 31 | ] 32 | ) 33 | # Calculate the sequence lengths 34 | seq_lengths = change_indices[1:] - change_indices[:-1] 35 | # Calculate the length of the final sequence or padding 36 | final_seq_length = len(row) - change_indices[-1] 37 | # Append the length of the final sequence or padding to seq_lengths 38 | if final_seq_length.item(): 39 | seq_lengths = torch.cat( 40 | [ 41 | seq_lengths, 42 | torch.tensor( 43 | [final_seq_length.item()], dtype=torch.int32, device=device 44 | ), 45 | ] 46 | ) 47 | # Calculate the cumulative sequence lengths 48 | cu_seqlens = torch.cat( 49 | [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] 50 | ) 51 | max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() 52 | results.append(cu_seqlens) 53 | max_seq_lens.append(max_seq_len) 54 | 55 | return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) 56 | 57 | 58 | def get_cu_seqlens_from_pos_ids(position_ids): 59 | """generate a cumulative sequence length mask for flash attention using pos ids""" 60 | if len(position_ids.shape) == 1: 61 | position_ids = position_ids.unsqueeze(0) 62 | 63 | device = position_ids.device 64 | results = [] 65 | max_seq_lens = [] 66 | 67 | for row in position_ids: 68 | # Count the number of consecutive zeros from the right side 69 | padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() 70 | 71 | # Adjust the row to exclude padding 72 | adjusted_row = row[:-padding_length] if padding_length else row.clone() 73 | 74 | # Find where the position resets to 0 (indicating a new sequence) 75 | seq_starts = torch.cat( 76 | [ 77 | torch.tensor([True], dtype=torch.bool, device=device), 78 | adjusted_row[1:] == 0, 79 | ] 80 | ) 81 | # Get the indices where the sequence starts 82 | start_indices = torch.cat( 83 | [ 84 | (seq_starts).nonzero(as_tuple=True)[0], 85 | torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), 86 | ] 87 | ) 88 | # Calculate the sequence lengths 89 | seq_lengths = start_indices[1:] - start_indices[:-1] 90 | # Calculate the cumulative sequence lengths 91 | cu_seqlens = torch.cat( 92 | [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] 93 | ) 94 | # Append the padding length to the cumulative sequence lengths 95 | if padding_length: 96 | cu_seqlens = torch.cat( 97 | [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)] 98 | ) 99 | max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() 100 | results.append(cu_seqlens) 101 | max_seq_lens.append(max_seq_len) 102 | 103 | return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) 104 | 105 | 106 | def set_module_name(model, name, value): 107 | if "." in name: 108 | parent_name = name.rsplit(".", 1)[0] 109 | child_name = name[len(parent_name) + 1 :] 110 | parent = model.get_submodule(parent_name) 111 | else: 112 | parent_name = "" 113 | parent = model 114 | child_name = name 115 | 116 | setattr(parent, child_name, value) 117 | -------------------------------------------------------------------------------- /data_prep/prepare_exact_match_tasks_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import math 5 | 6 | total_rec, incorrect = 0, 0 7 | 8 | total_records_to_sample = 50*1000 9 | output_dir = "/home/minimalist/Downloads/nips_training_data/NI/1_word/incorrect_preds" 10 | 11 | def fetch_instruction(task_file_name): 12 | # save json files locally from https://github.com/allenai/natural-instructions/tree/master/tasks 13 | with open(os.path.join("/home/minimalist/work/natural-instructions/tasks", task_file_name)) as f: 14 | task_json = json.load(f) 15 | return task_json["Definition"][0] 16 | 17 | def extract_answer(answer): 18 | return answer.split("\n\n")[0] 19 | 20 | def process_file(file_path): 21 | global total_rec, incorrect 22 | incorrect_predictions = [] 23 | with open(file_path) as f: 24 | data = json.load(f) 25 | for rec in data: 26 | ground_truth = rec["ground_truth"] 27 | prediction = extract_answer(rec["prediction"]) 28 | is_correct = False 29 | 30 | out_rec = {} 31 | out_rec['instruction'] = fetch_instruction(rec["task_file"]) 32 | out_rec['input'] = rec["orig_input"] 33 | out_rec['output'] = ground_truth[0] 34 | 35 | if any([x == prediction for x in ground_truth]): 36 | is_correct = True 37 | 38 | if not is_correct: 39 | incorrect_predictions.append(out_rec) 40 | 41 | random.shuffle(incorrect_predictions) 42 | 43 | task_accuracy = (len(data)-len(incorrect_predictions)) / len(data) 44 | 45 | total_rec += len(data) 46 | incorrect += len(incorrect_predictions) 47 | file_name = file_path.split("/home/minimalist/Downloads/super_NI_mistral_inference_output/1_word/")[1] 48 | print(f"\n{file_name}: #task_records {len(data)}, Incorrect: {1-task_accuracy}\nOVERALL: #total_records: {total_rec}, total_incorrect: {incorrect} Incorect: {incorrect / total_rec}") 49 | 50 | output = {} 51 | output['source'] = file_name 52 | output["accuracy"] = task_accuracy 53 | output["instances"] = incorrect_predictions 54 | 55 | with open(f"{output_dir}/{file_name}", 'w') as f: 56 | json.dump(output, f, indent=1) 57 | 58 | 59 | def find_incorrect_predictions(): 60 | files = os.listdir("/home/minimalist/Downloads/super_NI_mistral_inference_output/1_word/") 61 | for file in files: 62 | if not os.path.exists(f"{output_dir}/{file}"): 63 | process_file(os.path.join("/home/minimalist/Downloads/super_NI_mistral_inference_output/1_word", file)) 64 | else: 65 | print(f"Path already exists! {output_dir}/{file}") 66 | 67 | def sample_by_accuracy(file_path, multiplication_factor = 1.0): 68 | with open(file_path) as f: 69 | data = json.load(f) 70 | 71 | accuracy = data["accuracy"] 72 | print(f"accuracy: {accuracy}") 73 | instances: object = data["instances"] 74 | print(f"original #records: {len(instances)}") 75 | instances = [rec for rec in instances if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800] 76 | print(f"#records after removing long inputs: {len(instances)}") 77 | 78 | bucket_id = math.floor(int((accuracy*100)/10)) 79 | if bucket_id == 0: 80 | samples = int(350*multiplication_factor) 81 | elif bucket_id == 1: 82 | samples = int(350*multiplication_factor) 83 | elif bucket_id == 2: 84 | samples = int(300*multiplication_factor) 85 | elif bucket_id == 3: 86 | samples = int(300*multiplication_factor) 87 | elif bucket_id == 4: 88 | samples = int(250*multiplication_factor) 89 | elif bucket_id == 5: 90 | samples = int(250*multiplication_factor) 91 | elif bucket_id == 6: 92 | samples = int(200*multiplication_factor) 93 | elif bucket_id == 7: 94 | samples = int(150*multiplication_factor) 95 | elif bucket_id == 8: 96 | samples = int(100*multiplication_factor) 97 | else: 98 | samples = int(50 * multiplication_factor) 99 | 100 | random.shuffle(instances) 101 | sampled_data = instances[:samples] 102 | return sampled_data 103 | 104 | 105 | def prepare_training_data(multiplication_factor = 1.0): 106 | training_data = [] 107 | data_dir = "/home/minimalist/Downloads/nips_training_data/NI/1_word/incorrect_preds/" 108 | files = os.listdir(data_dir) 109 | for file in files: 110 | print(f"file: {file}") 111 | sampled_data = sample_by_accuracy(os.path.join(data_dir, file), multiplication_factor=multiplication_factor) 112 | print(f"#Sampled Records: {len(sampled_data)}") 113 | 114 | training_data.extend(sampled_data) 115 | print(f"Total data so far: {len(training_data)}\n\n") 116 | 117 | with open(f"/home/minimalist/Downloads/nips_training_data/super_natural_1_word_data_{str(int(multiplication_factor))}.json", 'w') as f: 118 | json.dump(training_data, f, indent=1) 119 | 120 | 121 | find_incorrect_predictions() 122 | prepare_training_data() 123 | -------------------------------------------------------------------------------- /inference/submission_2/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | import logging 4 | import os 5 | import time 6 | 7 | import torch 8 | from huggingface_hub import login 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | from peft import PeftModel 11 | 12 | torch.set_float32_matmul_precision("high") 13 | 14 | from api import ( 15 | ProcessRequest, 16 | ProcessResponse, 17 | TokenizeRequest, 18 | TokenizeResponse, 19 | Token, 20 | ) 21 | 22 | app = FastAPI() 23 | 24 | logger = logging.getLogger(__name__) 25 | # Configure the logging module 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | login(token=os.environ["HUGGINGFACE_TOKEN"]) 29 | 30 | 31 | 32 | def load_peft_model(base_model, peft_model): 33 | peft_model = PeftModel.from_pretrained(base_model, peft_model) 34 | return peft_model 35 | 36 | base_model_name= "mistralai/Mistral-7B-v0.1" 37 | base_model = AutoModelForCausalLM.from_pretrained( 38 | base_model_name, 39 | return_dict=True, 40 | torch_dtype=torch.bfloat16, 41 | device_map="cuda" 42 | ) 43 | model = load_peft_model(base_model, os.environ["HUGGINGFACE_REPO"]) 44 | 45 | 46 | model.eval() 47 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 48 | tokenizer.pad_token = tokenizer.eos_token 49 | LLAMA2_CONTEXT_LENGTH = 4096 50 | 51 | 52 | @app.post("/process") 53 | async def process_request(input_data: ProcessRequest) -> ProcessResponse: 54 | if input_data.seed is not None: 55 | torch.manual_seed(input_data.seed) 56 | 57 | prompt = input_data.prompt.strip() 58 | two_consecutive_new_line_count = prompt.count("\n\n") 59 | if two_consecutive_new_line_count < 2: 60 | prompt = "### Input:\n" + input_data.prompt + "\n\n### Response:\n" 61 | encoded = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False) 62 | 63 | prompt_length = encoded["input_ids"][0].size(0) 64 | max_returned_tokens = prompt_length + input_data.max_new_tokens 65 | assert max_returned_tokens <= LLAMA2_CONTEXT_LENGTH, ( 66 | max_returned_tokens, 67 | LLAMA2_CONTEXT_LENGTH, 68 | ) 69 | 70 | t0 = time.perf_counter() 71 | encoded = {k: v.to("cuda") for k, v in encoded.items()} 72 | with torch.no_grad(): 73 | outputs = model.generate( 74 | **encoded, 75 | max_new_tokens=input_data.max_new_tokens, 76 | do_sample=True, 77 | temperature=input_data.temperature, 78 | top_k=input_data.top_k, 79 | return_dict_in_generate=True, 80 | output_scores=True, 81 | ) 82 | 83 | # logger.info("Request and response - start") 84 | # logger.info(tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)) 85 | # logger.info("Request and response - end") 86 | 87 | t = time.perf_counter() - t0 88 | if not input_data.echo_prompt: 89 | output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True) 90 | output = output.split("\n\n")[0] 91 | if output.lower().strip().startswith("the answer is"): 92 | output = output.strip()[13:].strip() 93 | output = output[:1] 94 | else: 95 | output = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 96 | 97 | tokens_generated = outputs.sequences[0].size(0) - prompt_length 98 | # logger.info( 99 | # f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec" 100 | # ) 101 | # 102 | # logger.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 103 | generated_tokens = [] 104 | 105 | log_probs = torch.log(torch.stack(outputs.scores, dim=1).softmax(-1)) 106 | 107 | gen_sequences = outputs.sequences[:, encoded["input_ids"].shape[-1]:] 108 | gen_logprobs = torch.gather(log_probs, 2, gen_sequences[:, :, None]).squeeze(-1) 109 | 110 | top_indices = torch.argmax(log_probs, dim=-1) 111 | top_logprobs = torch.gather(log_probs, 2, top_indices[:,:,None]).squeeze(-1) 112 | top_indices = top_indices.tolist()[0] 113 | top_logprobs = top_logprobs.tolist()[0] 114 | 115 | for t, lp, tlp in zip(gen_sequences.tolist()[0], gen_logprobs.tolist()[0], zip(top_indices, top_logprobs)): 116 | idx, val = tlp 117 | tok_str = tokenizer.decode(idx) 118 | token_tlp = {tok_str: val} 119 | generated_tokens.append( 120 | Token(text=tokenizer.decode(t), logprob=lp, top_logprob=token_tlp) 121 | ) 122 | logprob_sum = gen_logprobs.sum().item() 123 | 124 | return ProcessResponse( 125 | text=output, tokens=generated_tokens, logprob=logprob_sum, request_time=t 126 | ) 127 | 128 | 129 | @app.post("/tokenize") 130 | async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse: 131 | t0 = time.perf_counter() 132 | encoded = tokenizer( 133 | input_data.text 134 | ) 135 | t = time.perf_counter() - t0 136 | tokens = encoded["input_ids"] 137 | return TokenizeResponse(tokens=tokens, request_time=t) 138 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/context_qa.py: -------------------------------------------------------------------------------- 1 | """Module containing the classes for Context QA Prompt Tokenization Strategies""" 2 | from typing import Tuple 3 | 4 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy 5 | from axolotl.prompters import AlpacaPrompter, PromptStyle 6 | 7 | 8 | # article, unanswerable_question, question, answer 9 | def load_404(tokenizer, cfg): 10 | return AlpacaMissingInfoContextPromptTokenizingStrategy( 11 | AlpacaContextPrompter(PromptStyle.CHAT.value), 12 | tokenizer, 13 | cfg.train_on_inputs, 14 | cfg.sequence_len, 15 | ) 16 | 17 | 18 | def load(tokenizer, cfg): 19 | return AlpacaContextPromptTokenizingStrategy( 20 | AlpacaContextPrompter(PromptStyle.CHAT.value), 21 | tokenizer, 22 | cfg.train_on_inputs, 23 | cfg.sequence_len, 24 | ) 25 | 26 | 27 | def load_v2(tokenizer, cfg): 28 | return ContextQaV2PromptTokenizingStrategy( 29 | ContextV2Prompter(), 30 | tokenizer, 31 | cfg.train_on_inputs, 32 | cfg.sequence_len, 33 | ) 34 | 35 | def load_dolly(tokenizer, cfg): 36 | return ContextQaDollyPromptTokenizingStrategy( 37 | ContextDollyPrompter(), 38 | tokenizer, 39 | cfg.train_on_inputs, 40 | cfg.sequence_len, 41 | ) 42 | 43 | 44 | 45 | class AlpacaContextPrompter(AlpacaPrompter): 46 | """ 47 | Customized system prompted for concise QA 48 | """ 49 | 50 | system_prompt = ( 51 | "Use the following contextual information to concisely answer the question.\n" 52 | ) 53 | system_no_input_prompt = ( 54 | "Use the following contextual information to concisely answer the question.\n" 55 | ) 56 | 57 | 58 | class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 59 | """ 60 | Tokenization Strategy to combine in-context article with a question and answer 61 | """ 62 | 63 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 64 | return ( 65 | prompt["article"] + "\n===\n" + prompt["question"], 66 | "", 67 | prompt["answer"], 68 | ) 69 | 70 | 71 | class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 72 | """ 73 | Tokenization Strategy to combine in-context article with a question and answer 74 | """ 75 | 76 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 77 | return ( 78 | "Context: " 79 | + prompt["context"] 80 | + "\nQuestion: " 81 | + prompt["question"] 82 | + "\n", 83 | "", 84 | "Answer: " + prompt["answer"], 85 | ) 86 | 87 | class ContextQaDollyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 88 | """ 89 | Tokenization Strategy to combine in-context article with a question and answer 90 | """ 91 | 92 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 93 | instruction = f"[INST] {prompt['instruction']}" 94 | context = f"Here's some context: {prompt['context']}" if len(prompt["context"]) > 0 else None 95 | response = f"[/INST] {prompt['response']}" 96 | 97 | if context is not None: 98 | instruction += f"\n\n{context}" 99 | instruction = instruction.strip() 100 | 101 | return ( 102 | f"{instruction}\n\n", 103 | "", 104 | response.strip(), 105 | ) 106 | 107 | 108 | class ContextV2Prompter(AlpacaPrompter): 109 | """ 110 | Customized system prompted for concise QA 111 | """ 112 | 113 | system_prompt = "" 114 | system_no_input_prompt = "" 115 | 116 | def match_prompt_style(self): 117 | # pylint: disable=duplicate-code 118 | self.turn_format = "{instruction}\n{input}" 119 | self.turn_no_input_format = "{instruction}" 120 | self.system_format = "{system}" 121 | 122 | class ContextDollyPrompter(AlpacaPrompter): 123 | """ 124 | Customized system prompted for concise QA 125 | """ 126 | 127 | system_prompt = "" 128 | system_no_input_prompt = "" 129 | 130 | def match_prompt_style(self): 131 | # pylint: disable=duplicate-code 132 | self.turn_format = "{instruction}\n{context}" 133 | self.turn_no_input_format = "{instruction}" 134 | self.system_format = "{system}" 135 | 136 | 137 | class AlpacaMissingInfoContextPromptTokenizingStrategy( 138 | InstructionPromptTokenizingStrategy 139 | ): 140 | """ 141 | Tokenization Strategy to combine in-context article with a question that can't be answered 142 | from the context and a default response to that effect 143 | """ 144 | 145 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 146 | return ( 147 | prompt["article"] + "\n===\n" + prompt["unanswerable_question"], 148 | "", 149 | "The context provided does not contain any information about your inquiry. " 150 | "Therefore, I'm unable to answer your question based on the given context.", 151 | ) 152 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention 3 | """ 4 | 5 | import warnings 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import transformers.models.llama.modeling_llama 11 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 12 | 13 | 14 | def hijack_llama_sdp_attention(): 15 | transformers.models.llama.modeling_llama.LlamaAttention.forward = ( 16 | sdp_attention_forward 17 | ) 18 | 19 | 20 | def sdp_attention_forward( 21 | self, 22 | hidden_states: torch.Tensor, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | position_ids: Optional[torch.LongTensor] = None, 25 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 26 | output_attentions: bool = False, 27 | use_cache: bool = False, 28 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 29 | # pylint: disable=duplicate-code 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | if not hasattr(self, "pretraining_tp"): 33 | self.pretraining_tp = 1 34 | 35 | if self.pretraining_tp > 1: 36 | key_value_slicing = ( 37 | self.num_key_value_heads * self.head_dim 38 | ) // self.pretraining_tp 39 | query_slices = self.q_proj.weight.split( 40 | (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 41 | ) 42 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 43 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 44 | 45 | query_states = [ 46 | F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) 47 | ] 48 | query_states = torch.cat(query_states, dim=-1) 49 | 50 | key_states = [ 51 | F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) 52 | ] 53 | key_states = torch.cat(key_states, dim=-1) 54 | 55 | value_states = [ 56 | F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) 57 | ] 58 | value_states = torch.cat(value_states, dim=-1) 59 | 60 | else: 61 | query_states = self.q_proj(hidden_states) 62 | key_states = self.k_proj(hidden_states) 63 | value_states = self.v_proj(hidden_states) 64 | 65 | query_states = query_states.view( 66 | bsz, q_len, self.num_heads, self.head_dim 67 | ).transpose(1, 2) 68 | key_states = key_states.view( 69 | bsz, q_len, self.num_key_value_heads, self.head_dim 70 | ).transpose(1, 2) 71 | value_states = value_states.view( 72 | bsz, q_len, self.num_key_value_heads, self.head_dim 73 | ).transpose(1, 2) 74 | # [bsz, q_len, nh, hd] 75 | # [bsz, nh, q_len, hd] 76 | 77 | kv_seq_len = key_states.shape[-2] 78 | if past_key_value is not None: 79 | kv_seq_len += past_key_value[0].shape[-2] 80 | 81 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 82 | query_states, key_states = apply_rotary_pos_emb( 83 | query_states, key_states, cos, sin, position_ids 84 | ) 85 | # [bsz, nh, t, hd] 86 | 87 | if past_key_value is not None: 88 | # reuse k, v, self_attention 89 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 90 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 91 | 92 | past_key_value = (key_states, value_states) if use_cache else None 93 | 94 | # repeat k/v heads if n_kv_heads < n_heads 95 | key_states = repeat_kv(key_states, self.num_key_value_groups) 96 | value_states = repeat_kv(value_states, self.num_key_value_groups) 97 | 98 | if output_attentions: 99 | warnings.warn( 100 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 101 | ) 102 | 103 | # 104 | # sdp-attn start 105 | # 106 | 107 | with torch.backends.cuda.sdp_kernel(): 108 | attn_output = torch.nn.functional.scaled_dot_product_attention( 109 | query_states, 110 | key_states, 111 | value_states, 112 | attn_mask=attention_mask, 113 | is_causal=False, 114 | ) 115 | 116 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 117 | raise ValueError( 118 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 119 | f" {attn_output.size()}" 120 | ) 121 | attn_output = attn_output.transpose(1, 2) 122 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 123 | 124 | # 125 | # sdp-attn end 126 | # 127 | 128 | if self.pretraining_tp > 1: 129 | attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) 130 | o_proj_slices = self.o_proj.weight.split( 131 | self.hidden_size // self.pretraining_tp, dim=1 132 | ) 133 | attn_output = sum( 134 | F.linear(attn_output[i], o_proj_slices[i]) 135 | for i in range(self.pretraining_tp) 136 | ) 137 | else: 138 | attn_output = self.o_proj(attn_output) 139 | 140 | return attn_output, None, past_key_value 141 | -------------------------------------------------------------------------------- /data_prep/prepare_generation_tasks_dataset.py: -------------------------------------------------------------------------------- 1 | from rouge import Rouge 2 | import json 3 | import os 4 | import csv 5 | import random 6 | import math 7 | 8 | filtered_list = [] 9 | def compute_rouge_score(llm_json, output_json): 10 | f = open(llm_json) 11 | data = json.load(f) 12 | rouge = Rouge() 13 | list_2 = [] 14 | list_3 = [] 15 | list_4 = [] 16 | list_5 = [] 17 | list_6 = [] 18 | list_7 = [] 19 | list_8 = [] 20 | final_list = [] 21 | 22 | for i in data: 23 | ground_truths = i["ground_truth"] 24 | prediction = i["prediction"].split("\n\n")[0].strip() 25 | if len(ground_truths) > 0: 26 | rouge_scores = [] 27 | for ground_truth in ground_truths: 28 | ground_truth = ground_truth.replace('.','') 29 | prediction = prediction.replace('.','') 30 | if len(ground_truth) == 0 and len(prediction) == 0: 31 | ground_truth = "DUMMY" 32 | prediction = "DUMMY" 33 | if len(ground_truth) == 0: 34 | ground_truth = prediction 35 | if len(prediction) == 0: 36 | prediction = ground_truth 37 | scores = rouge.get_scores(prediction.lower(), ground_truth.lower()) 38 | score = float(scores[0]['rouge-l']['f']) 39 | rouge_scores.append(score) 40 | max_rouge_score = max(rouge_scores) 41 | if max_rouge_score < 0.2: 42 | list_2.append(i) 43 | if max_rouge_score >= 0.2 and max_rouge_score < 0.3: 44 | list_3.append(i) 45 | if max_rouge_score >= 0.3 and max_rouge_score < 0.4: 46 | list_4.append(i) 47 | if max_rouge_score >= 0.4 and max_rouge_score < 0.5: 48 | list_5.append(i) 49 | if max_rouge_score >= 0.5 and max_rouge_score < 0.6: 50 | list_6.append(i) 51 | if max_rouge_score >= 0.6 and max_rouge_score < 0.7: 52 | list_7.append(i) 53 | if max_rouge_score >= 0.7 and max_rouge_score < 0.8: 54 | list_8.append(i) 55 | 56 | random.shuffle(list_2) 57 | random.shuffle(list_3) 58 | random.shuffle(list_4) 59 | random.shuffle(list_5) 60 | random.shuffle(list_6) 61 | random.shuffle(list_7) 62 | random.shuffle(list_8) 63 | 64 | list_2_cnt = len(list_2) 65 | list_3_cnt = len(list_3) 66 | list_4_cnt = len(list_4) 67 | list_5_cnt = len(list_5) 68 | list_6_cnt = len(list_6) 69 | list_7_cnt = len(list_7) 70 | list_8_cnt = len(list_8) 71 | 72 | list_2_instance_cnt = int(math.floor(0.8 * list_2_cnt)) 73 | list_3_instance_cnt = int(math.floor(0.2 * list_3_cnt)) 74 | list_4_instance_cnt = int(math.floor(0.2 * list_4_cnt)) 75 | list_5_instance_cnt = int(math.floor(0.2 * list_5_cnt)) 76 | list_6_instance_cnt = int(math.floor(0.2 * list_6_cnt)) 77 | list_7_instance_cnt = int(math.floor(0.2 * list_7_cnt)) 78 | list_8_instance_cnt = int(math.floor(0.2 * list_8_cnt)) 79 | 80 | for i in range(0,list_2_instance_cnt): 81 | final_list.append(list_2[i]) 82 | for i in range(0,list_3_instance_cnt): 83 | final_list.append(list_3[i]) 84 | for i in range(0,list_4_instance_cnt): 85 | final_list.append(list_4[i]) 86 | for i in range(0,list_5_instance_cnt): 87 | final_list.append(list_5[i]) 88 | for i in range(0,list_6_instance_cnt): 89 | final_list.append(list_6[i]) 90 | for i in range(0,list_7_instance_cnt): 91 | final_list.append(list_7[i]) 92 | for i in range(0,list_8_instance_cnt): 93 | final_list.append(list_8[i]) 94 | 95 | for i in final_list: 96 | ground_truths = i["ground_truth"] 97 | input_text = i["orig_input"] 98 | words = input_text.split(" ") 99 | if len(ground_truths) == 1 and len(words) <= 1100: 100 | del i["id"] 101 | del i["few_shot_prompt"] 102 | del i["prediction"] 103 | del i["orig_output"] 104 | i["output"] = i["ground_truth"][0] 105 | i["input"] = i["orig_input"] 106 | del i["ground_truth"] 107 | del i["orig_input"] 108 | filtered_list.append(i) 109 | list_2.clear() 110 | list_3.clear() 111 | list_4.clear() 112 | list_5.clear() 113 | list_6.clear() 114 | list_7.clear() 115 | list_8.clear() 116 | 117 | size = len(filtered_list) 118 | final_list.clear() 119 | return size 120 | 121 | fields = ["filename", "total"] 122 | rows = [] 123 | total_cnt = 0 124 | directory = '/home/utilizeai/Documents/NIPS-LLM/V2_super_natural_inference_mistral_7B' 125 | output_directory = '/home/utilizeai/Documents/NIPS-LLM/final' 126 | for filename in os.listdir(directory): 127 | print(filename) 128 | llm_json = os.path.join(directory, filename) 129 | output_json = os.path.join(output_directory, filename) 130 | cnt = compute_rouge_score(llm_json, output_json) 131 | print(cnt) 132 | total_cnt = total_cnt + cnt 133 | row = [] 134 | row.append(filename) 135 | row.append(cnt) 136 | rows.append(row) 137 | 138 | print(total_cnt) 139 | filename = "/home/utilizeai/Documents/NIPS-LLM/supernatural_instructions_filter_cnt.csv" 140 | 141 | # writing to csv file 142 | with open(filename, 'w') as csvfile: 143 | # creating a csv writer object 144 | csvwriter = csv.writer(csvfile) 145 | 146 | # writing the fields 147 | csvwriter.writerow(fields) 148 | 149 | # writing the data rows 150 | csvwriter.writerows(rows) 151 | 152 | output_json = os.path.join(output_directory, "natural_instructions_generation_tasks_dataset.json") 153 | with open(output_json, 'w') as f: 154 | json.dump(filtered_list, f, indent=4) 155 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/collators.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataCollator for axolotl to pad labels and position_ids for packed sequences 3 | """ 4 | from dataclasses import dataclass 5 | from typing import Any, Optional, Union 6 | 7 | import numpy as np 8 | from transformers import PreTrainedTokenizerBase 9 | from transformers.utils import PaddingStrategy 10 | 11 | 12 | @dataclass 13 | class DataCollatorForSeq2Seq: 14 | """ 15 | Data collator that will dynamically pad the inputs received, as well as the labels and position_ids 16 | 17 | Args: 18 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 19 | The tokenizer used for encoding the data. 20 | model ([`PreTrainedModel`]): 21 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to 22 | prepare the *decoder_input_ids* 23 | 24 | This is useful when using *label_smoothing* to avoid calculating loss twice. 25 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 26 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 27 | among: 28 | 29 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single 30 | sequence is provided). 31 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 32 | acceptable input length for the model if that argument is not provided. 33 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). 34 | max_length (`int`, *optional*): 35 | Maximum length of the returned list and optionally padding length (see above). 36 | pad_to_multiple_of (`int`, *optional*): 37 | If set will pad the sequence to a multiple of the provided value. 38 | 39 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 40 | 7.5 (Volta). 41 | label_pad_token_id (`int`, *optional*, defaults to -100): 42 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 43 | return_tensors (`str`): 44 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 45 | """ 46 | 47 | tokenizer: PreTrainedTokenizerBase 48 | model: Optional[Any] = None 49 | padding: Union[bool, str, PaddingStrategy] = True 50 | max_length: Optional[int] = None 51 | pad_to_multiple_of: Optional[int] = None 52 | label_pad_token_id: int = -100 53 | position_pad_token_id: int = 0 54 | return_tensors: str = "pt" 55 | 56 | def __call__(self, features, return_tensors=None): 57 | labels = None 58 | if return_tensors is None: 59 | return_tensors = self.return_tensors 60 | 61 | for feature_name, pad_token_id in [ 62 | ("labels", self.label_pad_token_id), 63 | ("position_ids", self.position_pad_token_id), 64 | ]: 65 | feat = ( 66 | [feature[feature_name] for feature in features] 67 | if feature_name in features[0].keys() 68 | else None 69 | ) 70 | labels = feat if feat and feature_name == "labels" else labels 71 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 72 | # same length to return tensors. 73 | if feat is not None: 74 | max_feature_length = max(len(l) for l in feat) # noqa: E741 75 | if self.pad_to_multiple_of is not None: 76 | max_feature_length = ( 77 | (max_feature_length + self.pad_to_multiple_of - 1) 78 | // self.pad_to_multiple_of 79 | * self.pad_to_multiple_of 80 | ) 81 | 82 | padding_side = self.tokenizer.padding_side 83 | for feature in features: 84 | remainder = [pad_token_id] * ( 85 | max_feature_length - len(feature[feature_name]) 86 | ) 87 | if isinstance(feature[feature_name], list): 88 | feature[feature_name] = ( 89 | feature[feature_name] + remainder 90 | if padding_side == "right" 91 | else remainder + feature[feature_name] 92 | ) 93 | elif padding_side == "right": 94 | feature[feature_name] = np.concatenate( 95 | [feature[feature_name], remainder] 96 | ).astype(np.int64) 97 | else: 98 | feature[feature_name] = np.concatenate( 99 | [remainder, feature[feature_name]] 100 | ).astype(np.int64) 101 | 102 | features = self.tokenizer.pad( 103 | features, 104 | padding=self.padding, 105 | max_length=self.max_length, 106 | pad_to_multiple_of=self.pad_to_multiple_of, 107 | return_tensors=return_tensors, 108 | ) 109 | 110 | # prepare decoder_input_ids 111 | if ( 112 | labels is not None 113 | and self.model is not None 114 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels") 115 | ): 116 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( 117 | labels=features["labels"] 118 | ) 119 | features["decoder_input_ids"] = decoder_input_ids 120 | 121 | return features 122 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/alpaca_w_system.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt strategies loader for alpaca instruction datasets with system prompts 3 | """ 4 | from typing import Generator, Tuple, Union 5 | 6 | from axolotl.prompt_tokenizers import PromptTokenizingStrategy 7 | from axolotl.prompters import AlpacaPrompter, PromptStyle 8 | 9 | 10 | class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): 11 | """ 12 | Tokenizing strategy for instruction-based prompts. 13 | """ 14 | 15 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: 16 | return ( 17 | prompt["instruction"], 18 | prompt["input"] if "input" in prompt else "", 19 | prompt["output"], 20 | prompt["system"], 21 | ) 22 | 23 | def tokenize_prompt(self, prompt): 24 | # pylint: disable=duplicate-code 25 | ( 26 | instruction, 27 | input, # pylint: disable=redefined-builtin 28 | response, 29 | system, 30 | ) = self.parse_instruction_fields(prompt) 31 | user_prompt = next( 32 | iter( 33 | self.prompter.build_prompt_w_system( 34 | system, 35 | instruction, 36 | input, 37 | ) 38 | ) 39 | ) 40 | tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) 41 | if not self.train_on_inputs: 42 | user_prompt_len = len(tokenized_prompt["input_ids"]) 43 | # TODO this could be sped up using numpy array slicing 44 | tokenized_prompt["labels"] = [-100] * user_prompt_len 45 | tokenized_res_prompt = self._tokenize( 46 | response, strip_bos_token=True, add_eos_token=True 47 | ) 48 | tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] 49 | tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] 50 | tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] 51 | 52 | return tokenized_prompt 53 | 54 | 55 | class SystemDataPrompter(AlpacaPrompter): 56 | """ 57 | Alpaca Style Prompter that uses system prompts from the dataset 58 | """ 59 | 60 | system_format: str = "### System:\n{system}\n\n" 61 | 62 | def build_prompt_w_system( 63 | self, 64 | system: str, 65 | instruction: str, 66 | input: Union[None, str] = None, # pylint: disable=redefined-builtin 67 | output: Union[None, str] = None, 68 | ) -> Generator[str, None, None]: 69 | # returns the full prompt from instruction and optional input 70 | # if a label (=response, =output) is provided, it's also appended. 71 | formatted_sys_prompt = ( 72 | self.system_format.format(system=system) 73 | if system and self.system_format 74 | else "" 75 | ) 76 | if input: 77 | res = formatted_sys_prompt + self.turn_format.format( 78 | instruction=instruction, input=input 79 | ) 80 | else: 81 | res = formatted_sys_prompt + self.turn_no_input_format.format( 82 | instruction=instruction 83 | ) 84 | if output: 85 | res = f"{res}{output}" 86 | yield res 87 | 88 | 89 | class OpenOrcaSystemDataPrompter(SystemDataPrompter): 90 | """ 91 | Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts 92 | """ 93 | 94 | def match_prompt_style(self): 95 | # pylint: disable=duplicate-code 96 | if self.prompt_style == PromptStyle.INSTRUCT.value: 97 | self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n" 98 | self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n" 99 | self.system_format = "### System:\n{system}\n" 100 | if self.prompt_style == PromptStyle.CHAT.value: 101 | self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" 102 | self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" 103 | self.system_format = "SYSTEM: {system}\n" 104 | if self.prompt_style == PromptStyle.CHATML.value: 105 | self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" 106 | self.turn_no_input_format = ( 107 | "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" 108 | ) 109 | self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" 110 | 111 | 112 | class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): 113 | """ 114 | Tokenizing strategy for OpenOrca datasets 115 | """ 116 | 117 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: 118 | return ( 119 | prompt["question"], 120 | "", 121 | prompt["response"], 122 | prompt["system_prompt"], 123 | ) 124 | 125 | 126 | def load(tokenizer, cfg): 127 | return load_chat(tokenizer, cfg) 128 | 129 | 130 | def load_instruct(tokenizer, cfg): 131 | return InstructionWSystemPromptTokenizingStrategy( 132 | SystemDataPrompter(PromptStyle.INSTRUCT.value), 133 | tokenizer, 134 | cfg.train_on_inputs, 135 | cfg.sequence_len, 136 | ) 137 | 138 | 139 | def load_chat(tokenizer, cfg): 140 | return InstructionWSystemPromptTokenizingStrategy( 141 | SystemDataPrompter(PromptStyle.CHAT.value), 142 | tokenizer, 143 | cfg.train_on_inputs, 144 | cfg.sequence_len, 145 | ) 146 | 147 | 148 | def load_open_orca(tokenizer, cfg): 149 | return OpenOrcaPromptTokenizingStrategy( 150 | OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value), 151 | tokenizer, 152 | cfg.train_on_inputs, 153 | cfg.sequence_len, 154 | ) 155 | 156 | 157 | def load_open_orca_chatml(tokenizer, cfg): 158 | return OpenOrcaPromptTokenizingStrategy( 159 | OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value), 160 | tokenizer, 161 | cfg.train_on_inputs, 162 | cfg.sequence_len, 163 | ) 164 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import warnings 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import transformers.models.llama.modeling_llama 12 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 13 | 14 | try: 15 | import xformers.ops 16 | except ImportError: 17 | logging.error("xformers not found! Please install it before trying to use it.") 18 | 19 | 20 | def hijack_llama_attention(): 21 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 22 | 23 | 24 | def xformers_forward( 25 | self, 26 | hidden_states: torch.Tensor, 27 | attention_mask: Optional[torch.Tensor] = None, 28 | position_ids: Optional[torch.LongTensor] = None, 29 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 30 | output_attentions: bool = False, 31 | use_cache: bool = False, 32 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 33 | # pylint: disable=duplicate-code 34 | bsz, q_len, _ = hidden_states.size() 35 | 36 | if not hasattr(self, "pretraining_tp"): 37 | self.pretraining_tp = 1 38 | 39 | if self.pretraining_tp > 1: 40 | key_value_slicing = ( 41 | self.num_key_value_heads * self.head_dim 42 | ) // self.pretraining_tp 43 | query_slices = self.q_proj.weight.split( 44 | (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 45 | ) 46 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 47 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 48 | 49 | query_states = [ 50 | F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) 51 | ] 52 | query_states = torch.cat(query_states, dim=-1) 53 | 54 | key_states = [ 55 | F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) 56 | ] 57 | key_states = torch.cat(key_states, dim=-1) 58 | 59 | value_states = [ 60 | F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) 61 | ] 62 | value_states = torch.cat(value_states, dim=-1) 63 | 64 | else: 65 | query_states = self.q_proj(hidden_states) 66 | key_states = self.k_proj(hidden_states) 67 | value_states = self.v_proj(hidden_states) 68 | 69 | query_states = query_states.view( 70 | bsz, q_len, self.num_heads, self.head_dim 71 | ).transpose(1, 2) 72 | key_states = key_states.view( 73 | bsz, q_len, self.num_key_value_heads, self.head_dim 74 | ).transpose(1, 2) 75 | value_states = value_states.view( 76 | bsz, q_len, self.num_key_value_heads, self.head_dim 77 | ).transpose(1, 2) 78 | # [bsz, q_len, nh, hd] 79 | # [bsz, nh, q_len, hd] 80 | 81 | kv_seq_len = key_states.shape[-2] 82 | if past_key_value is not None: 83 | kv_seq_len += past_key_value[0].shape[-2] 84 | 85 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 86 | query_states, key_states = apply_rotary_pos_emb( 87 | query_states, key_states, cos, sin, position_ids 88 | ) 89 | # [bsz, nh, t, hd] 90 | 91 | if past_key_value is not None: 92 | # reuse k, v, self_attention 93 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 94 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 95 | 96 | past_key_value = (key_states, value_states) if use_cache else None 97 | 98 | # repeat k/v heads if n_kv_heads < n_heads 99 | key_states = repeat_kv(key_states, self.num_key_value_groups) 100 | value_states = repeat_kv(value_states, self.num_key_value_groups) 101 | 102 | if output_attentions: 103 | warnings.warn( 104 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 105 | ) 106 | 107 | # 108 | # xformers-attn start 109 | # 110 | 111 | query_states = query_states.transpose(1, 2) 112 | key_states = key_states.transpose(1, 2) 113 | value_states = value_states.transpose(1, 2) 114 | 115 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 116 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 117 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 118 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 119 | attn_output = xformers.ops.memory_efficient_attention( 120 | query_states, key_states, value_states, attn_bias=None 121 | ) 122 | else: 123 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 124 | attn_output = xformers.ops.memory_efficient_attention( 125 | query_states, 126 | key_states, 127 | value_states, 128 | # attn_bias=attention_mask, 129 | attn_bias=xformers.ops.LowerTriangularMask(), 130 | ) 131 | 132 | if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): 133 | raise ValueError( 134 | f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" 135 | f" {attn_output.size()}" 136 | ) 137 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 138 | 139 | # 140 | # xformers-attn end 141 | # 142 | 143 | if self.pretraining_tp > 1: 144 | attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) 145 | o_proj_slices = self.o_proj.weight.split( 146 | self.hidden_size // self.pretraining_tp, dim=1 147 | ) 148 | attn_output = sum( 149 | F.linear(attn_output[i], o_proj_slices[i]) 150 | for i in range(self.pretraining_tp) 151 | ) 152 | else: 153 | attn_output = self.o_proj(attn_output) 154 | 155 | return attn_output, None, past_key_value 156 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/monkeypatch/fastchat_conversation_turns.py: -------------------------------------------------------------------------------- 1 | """ 2 | monkeypatch to add a get_turns method 3 | """ 4 | 5 | import logging 6 | from typing import Generator, Tuple 7 | 8 | from fastchat.conversation import SeparatorStyle 9 | 10 | LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") 11 | 12 | 13 | def get_prompt(self) -> str: 14 | ret = "" 15 | for role, msg in self.get_turns(): 16 | ret += role + msg 17 | return ret 18 | 19 | 20 | def get_turns( # pylint: disable=too-many-return-statements 21 | self, 22 | ) -> Generator[Tuple[str, str], None, None]: 23 | """Get the prompt for generation.""" 24 | system_prompt = self.system_template.format(system_message=self.system_message) 25 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 26 | yield "", system_prompt + self.sep 27 | for role, message in self.messages: 28 | if message: 29 | yield role + ": ", message + self.sep 30 | else: 31 | yield role + ":", "" 32 | return 33 | if self.sep_style == SeparatorStyle.ADD_COLON_TWO: 34 | seps = [self.sep, self.sep2] 35 | yield "", system_prompt + seps[0] 36 | for i, (role, message) in enumerate(self.messages): 37 | if message: 38 | yield role + ": ", message + seps[i % 2] 39 | else: 40 | yield role + ":", "" 41 | return 42 | if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 43 | yield "", system_prompt + self.sep 44 | for role, message in self.messages: 45 | if message: 46 | yield role + ": ", message + self.sep 47 | else: 48 | yield role + ": ", "" # must be end with a space 49 | return 50 | if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 51 | yield "", "" if system_prompt == "" else system_prompt + self.sep 52 | for role, message in self.messages: 53 | if message: 54 | yield role + "\n", message + self.sep 55 | else: 56 | yield role + "\n", "" 57 | return 58 | if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 59 | yield "", system_prompt 60 | for role, message in self.messages: 61 | if message: 62 | yield role, message + self.sep 63 | else: 64 | yield role, "" 65 | return 66 | if self.sep_style == SeparatorStyle.NO_COLON_TWO: 67 | seps = [self.sep, self.sep2] 68 | yield "", system_prompt 69 | for i, (role, message) in enumerate(self.messages): 70 | if message: 71 | yield role, message + seps[i % 2] 72 | else: 73 | yield role, "" 74 | return 75 | if self.sep_style == SeparatorStyle.RWKV: 76 | yield "", system_prompt 77 | for i, (role, message) in enumerate(self.messages): 78 | if message: 79 | yield role + ": ", message.replace("\r\n", "\n").replace( 80 | "\n\n", "\n" 81 | ) + "\n\n" 82 | else: 83 | yield role + ":", "" 84 | return 85 | if self.sep_style == SeparatorStyle.LLAMA2: 86 | seps = [self.sep, self.sep2] 87 | if self.system_message: 88 | yield "", system_prompt 89 | else: 90 | yield "", "[INST] " 91 | for i, (role, message) in enumerate(self.messages[1:]): 92 | if message: 93 | yield role + " ", message + seps[i % 2] 94 | else: 95 | yield role, "" 96 | return 97 | if self.sep_style == SeparatorStyle.CHATGLM: 98 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 99 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 100 | round_add_n = 1 if self.name == "chatglm2" else 0 101 | if system_prompt: 102 | yield "", system_prompt + self.sep 103 | 104 | for i, (role, message) in enumerate(self.messages): 105 | if i % 2 == 0: 106 | yield "", f"[Round {i//2 + round_add_n}]{self.sep}" 107 | 108 | if message: 109 | yield f"{role}:", f"{message}{self.sep}" 110 | else: 111 | yield f"{role}:", "" 112 | return 113 | if self.sep_style == SeparatorStyle.CHATML: 114 | yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" 115 | for role, message in self.messages: 116 | if message: 117 | yield role + "\n", message + self.sep + "\n" 118 | else: 119 | yield role + "\n", "" 120 | return 121 | if self.sep_style == SeparatorStyle.CHATINTERN: 122 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 123 | seps = [self.sep, self.sep2] 124 | yield "", system_prompt 125 | for i, (role, message) in enumerate(self.messages): 126 | prefix = "" if i % 2 == 0 else "" 127 | if message: 128 | yield prefix + role + ":", message + seps[i % 2] + "\n" 129 | else: 130 | yield role + ":", "" 131 | return 132 | if self.sep_style == SeparatorStyle.DOLLY: 133 | seps = [self.sep, self.sep2] 134 | yield "", system_prompt 135 | for i, (role, message) in enumerate(self.messages): 136 | if message: 137 | suffix = "\n\n" if i % 2 == 1 else "" 138 | yield role + ":\n", message + seps[i % 2] + suffix 139 | else: 140 | yield role + ":\n", "" 141 | return 142 | if self.sep_style == SeparatorStyle.PHOENIX: 143 | yield "", system_prompt 144 | for role, message in self.messages: 145 | if message: 146 | yield role + ": ", "" + message + "" 147 | else: 148 | yield role + ": " + "", "" 149 | return 150 | if self.sep_style == SeparatorStyle.ROBIN: 151 | yield "", system_prompt + self.sep 152 | for role, message in self.messages: 153 | if message: 154 | yield role + ":\n", message + self.sep 155 | else: 156 | yield role + ":\n", "" 157 | return 158 | if self.sep_style == SeparatorStyle.FALCON_CHAT: 159 | if self.system_message: 160 | yield "", system_prompt + self.sep 161 | for role, message in self.messages: 162 | if message: 163 | yield role + ": ", message + self.sep 164 | else: 165 | yield role + ":", "" 166 | else: 167 | raise ValueError(f"Invalid style: {self.sep_style}") 168 | 169 | 170 | def add_get_turns_to_conversation(): 171 | import fastchat.conversation 172 | 173 | fastchat.conversation.Conversation.get_turns = get_turns 174 | fastchat.conversation.Conversation.get_prompt = get_prompt 175 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/creative_acr.py: -------------------------------------------------------------------------------- 1 | """Module loading the CreativePromptTokenizingStrategy and similar classes""" 2 | 3 | from typing import Generator, Tuple, Union 4 | 5 | import yaml 6 | 7 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy 8 | 9 | 10 | class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 11 | """ 12 | Tokenizing strategy for Creative Answering 13 | """ 14 | 15 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 16 | question = prompt["instruction"] 17 | answer = prompt[ 18 | "revision" 19 | ] # don't use prompt[answer], that's data we don't want in the dataset 20 | return ( 21 | question, 22 | "", 23 | answer, 24 | ) 25 | 26 | 27 | class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 28 | """ 29 | Tokenizing strategy for Creative Critique 30 | """ 31 | 32 | user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria: 33 | refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. 34 | prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias. 35 | creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity 36 | comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth 37 | Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria. 38 | 39 | Output your answer in YAML format like so: 40 | scores: 41 | refusal: <0-10> 42 | prescriptive_bias: <0-10> 43 | creativity: <0-10> 44 | comprehensiveness: <0-10> 45 | critiques: 46 | refusal: 47 | explanation: ... 48 | improvements: ... 49 | prescriptive_bias: 50 | explanation: ... 51 | improvements: ... 52 | creativity: 53 | explanation: ... 54 | improvements: ... 55 | comprehensiveness: 56 | explanation: ... 57 | improvements: ... 58 | 59 | Question: {question} 60 | Answer: {answer} 61 | """ 62 | 63 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 64 | scores = yaml.dump( 65 | prompt["scores"], 66 | default_flow_style=False, 67 | Dumper=yaml.Dumper, 68 | ) 69 | critiques = yaml.dump( 70 | prompt["critiques"], 71 | default_flow_style=False, 72 | Dumper=yaml.Dumper, 73 | ) 74 | evaluation = scores + critiques 75 | question = prompt["instruction"] 76 | answer = prompt["answer"] 77 | return ( 78 | self.user_prompt.format(question=question, answer=answer), 79 | "", 80 | evaluation, 81 | ) 82 | 83 | 84 | class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): 85 | """ 86 | Tokenizing strategy for Creative Revise 87 | """ 88 | 89 | user_prompt = """Definitions: 90 | refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. 91 | prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias. 92 | creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity 93 | comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth 94 | 95 | Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response. 96 | 97 | Question: {question} 98 | Answer: {answer} 99 | Evaluation: 100 | {evaluation} 101 | """ 102 | 103 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: 104 | scores = yaml.dump( 105 | prompt["scores"], 106 | default_flow_style=False, 107 | Dumper=yaml.Dumper, 108 | ) 109 | critiques = yaml.dump( 110 | prompt["critiques"], 111 | default_flow_style=False, 112 | Dumper=yaml.Dumper, 113 | ) 114 | evaluation = scores + critiques 115 | question = prompt["instruction"] 116 | answer = prompt["answer"] 117 | return ( 118 | self.user_prompt.format( 119 | question=question, answer=answer, evaluation=evaluation 120 | ), 121 | "", 122 | prompt["revision"], 123 | ) 124 | 125 | 126 | class CreativePrompterBase: 127 | """ 128 | Base class for Creative Prompters 129 | """ 130 | 131 | system_prompt = "" 132 | prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:" 133 | 134 | def build_prompt( 135 | self, 136 | instruction: str, 137 | input: Union[ # pylint: disable=redefined-builtin, unused-argument 138 | None, str 139 | ] = None, 140 | output: Union[None, str] = None, 141 | ) -> Generator[str, None, None]: 142 | if self.system_prompt: 143 | res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:" 144 | else: 145 | res = f"USER: {instruction}\nASSISTANT:" 146 | if output: 147 | res = f"{res}{output}" 148 | yield res 149 | 150 | 151 | class CreativeAnswerPrompter(CreativePrompterBase): 152 | """ 153 | Prompter for Creative Answering 154 | """ 155 | 156 | system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity." 157 | 158 | 159 | class CreativeCritiquePrompter(CreativePrompterBase): 160 | """ 161 | Prompter for Creative Critique 162 | """ 163 | 164 | system_prompt = "" 165 | 166 | 167 | class CreativeRevisePrompter(CreativePrompterBase): 168 | """ 169 | Prompter for Creative Revise 170 | """ 171 | 172 | system_prompt = "" 173 | 174 | 175 | def load_answer(tokenizer, cfg): 176 | return CreativeAnsweringPromptTokenizingStrategy( 177 | CreativeAnswerPrompter(), 178 | tokenizer, 179 | cfg.train_on_inputs, 180 | cfg.sequence_len, 181 | ) 182 | 183 | 184 | def load_critique(tokenizer, cfg): 185 | return CreativeCritiquePromptTokenizingStrategy( 186 | CreativeCritiquePrompter(), 187 | tokenizer, 188 | cfg.train_on_inputs, 189 | cfg.sequence_len, 190 | ) 191 | 192 | 193 | def load_revise(tokenizer, cfg): 194 | return CreativeRevisePromptTokenizingStrategy( 195 | CreativeRevisePrompter(), 196 | tokenizer, 197 | cfg.train_on_inputs, 198 | cfg.sequence_len, 199 | ) 200 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/train.py: -------------------------------------------------------------------------------- 1 | """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" 2 | 3 | import logging 4 | import os 5 | import signal 6 | import sys 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import torch 12 | import transformers.modelcard 13 | from datasets import Dataset 14 | from optimum.bettertransformer import BetterTransformer 15 | from transformers.deepspeed import is_deepspeed_zero3_enabled 16 | 17 | from axolotl.common.cli import TrainerCliArgs 18 | from axolotl.logging_config import configure_logging 19 | from axolotl.utils.dict import DictDefault 20 | from axolotl.utils.models import load_model, load_tokenizer 21 | from axolotl.utils.trainer import setup_trainer 22 | 23 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 24 | src_dir = os.path.join(project_root, "src") 25 | sys.path.insert(0, src_dir) 26 | 27 | configure_logging() 28 | LOG = logging.getLogger("axolotl.train") 29 | 30 | 31 | @dataclass 32 | class TrainDatasetMeta: 33 | """ 34 | dataclass to capture the dataset specific options for training 35 | """ 36 | 37 | train_dataset: Dataset 38 | eval_dataset: Optional[Dataset] = None 39 | total_num_steps: Optional[int] = None 40 | 41 | 42 | def train( 43 | *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta 44 | ): 45 | # load the tokenizer first 46 | LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") 47 | tokenizer = load_tokenizer(cfg) 48 | 49 | train_dataset = dataset_meta.train_dataset 50 | eval_dataset = dataset_meta.eval_dataset 51 | total_num_steps = dataset_meta.total_num_steps 52 | 53 | # print(f"Training instance: {train_dataset[0]}") 54 | print("=" * 35 + "Sample Training Instance" + "=" * 35) 55 | print(tokenizer.decode(train_dataset[0]["input_ids"])) 56 | print("=" * 80) 57 | # exit(0) 58 | 59 | # Load the model and tokenizer 60 | LOG.info("loading model and (optionally) peft_config...") 61 | model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) 62 | 63 | safe_serialization = cfg.save_safetensors is True 64 | 65 | if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: 66 | possible_checkpoints = [ 67 | str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") 68 | ] 69 | if len(possible_checkpoints) > 0: 70 | sorted_paths = sorted( 71 | possible_checkpoints, 72 | key=lambda path: int(path.split("-")[-1]), 73 | ) 74 | cfg.resume_from_checkpoint = sorted_paths[-1] 75 | LOG.info( 76 | f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" 77 | ) 78 | resume_from_checkpoint = cfg.resume_from_checkpoint 79 | 80 | trainer = setup_trainer( 81 | cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps 82 | ) 83 | 84 | model.config.use_cache = False 85 | 86 | # go ahead and presave, so we have the adapter config available to inspect 87 | if peft_config: 88 | LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") 89 | peft_config.save_pretrained(cfg.output_dir) 90 | # additionally presave the tokenizer and model configs 91 | if not Path(cfg.output_dir).is_dir(): 92 | os.makedirs(cfg.output_dir, exist_ok=True) 93 | tokenizer.save_pretrained(str(Path(cfg.output_dir))) 94 | model.config.save_pretrained(str(Path(cfg.output_dir))) 95 | 96 | # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model 97 | if cfg.local_rank == 0: 98 | 99 | def terminate_handler(_, __, model): 100 | if cfg.flash_optimum: 101 | model = BetterTransformer.reverse(model) 102 | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) 103 | sys.exit(0) 104 | 105 | signal.signal( 106 | signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) 107 | ) 108 | 109 | badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" 110 | transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" 111 | 112 | LOG.info("Starting trainer...") 113 | if cfg.group_by_length: 114 | LOG.info("hang tight... sorting dataset for group_by_length") 115 | 116 | if cfg.flash_optimum: 117 | with torch.backends.cuda.sdp_kernel( 118 | enable_flash=True, enable_math=True, enable_mem_efficient=True 119 | ): 120 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 121 | else: 122 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 123 | 124 | LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") 125 | 126 | # post training 127 | for name, module in model.named_modules(): 128 | if hasattr(module, "_post_training"): 129 | module._post_training(model, name) # pylint: disable=protected-access 130 | 131 | if trainer.is_fsdp_enabled: 132 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 133 | LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") 134 | 135 | if cfg.relora_steps: 136 | if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): 137 | model = model.merge_and_unload() 138 | else: 139 | # final model weights have already been saved by `ReLoRACallback.on_train_end` 140 | return model, tokenizer 141 | 142 | # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading 143 | # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file 144 | if cfg.fsdp: 145 | trainer.save_model(cfg.output_dir) 146 | elif cfg.deepspeed and is_deepspeed_zero3_enabled(): 147 | # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading 148 | trainer.accelerator.wait_for_everyone() 149 | unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) 150 | 151 | # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if 152 | # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or 153 | # `zero3_save_16bit_model` is True in DeepSpeed Plugin. 154 | # For Zero Stages 1 and 2, models are saved as usual in the output directory. 155 | # The model name saved is `pytorch_model.bin` 156 | unwrapped_model.save_pretrained( 157 | cfg.output_dir, 158 | is_main_process=trainer.accelerator.is_main_process, 159 | save_function=trainer.accelerator.save, 160 | state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), 161 | ) 162 | elif cfg.local_rank == 0: 163 | if cfg.flash_optimum: 164 | model = BetterTransformer.reverse(model) 165 | 166 | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) 167 | 168 | if not cfg.hub_model_id: 169 | trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) 170 | 171 | return model, tokenizer 172 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/datasets.py: -------------------------------------------------------------------------------- 1 | """Module containing Dataset functionality""" 2 | 3 | import logging 4 | import os 5 | from typing import List 6 | 7 | import torch 8 | from datasets import Dataset, IterableDataset 9 | 10 | from .prompt_tokenizers import PromptTokenizingStrategy 11 | 12 | # We want this to be a wrapper for an existing dataset that we have loaded 13 | # lets use the concept of middlewares to wrap each dataset, for example 14 | # ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])) 15 | # let's check to ensure we don't truncate an item in the middle, we'll use 16 | # the collators later on to pad the datasets 17 | 18 | LOG = logging.getLogger("axolotl") 19 | 20 | 21 | class TokenizedPromptDataset(Dataset): 22 | """ 23 | Dataset that returns tokenized prompts from a stream of text files. 24 | Args: 25 | prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. 26 | dataset (dataset.Dataset): Dataset with text files. 27 | """ 28 | 29 | def __init__( # pylint: disable=super-init-not-called 30 | self, 31 | prompt_tokenizer: PromptTokenizingStrategy, 32 | dataset: IterableDataset, 33 | **kwargs, 34 | ): 35 | self.prompt_tokenizer = prompt_tokenizer 36 | super().__init__(self.process(dataset).data, **kwargs) 37 | 38 | def process(self, dataset): 39 | features = dataset.features.keys() 40 | num_proc = min(64, os.cpu_count()) 41 | map_kwargs = {} 42 | if self.prompt_tokenizer.supports_batched: 43 | map_kwargs["batched"] = True 44 | map_kwargs["batch_size"] = 100 45 | return dataset.map( 46 | self.prompt_tokenizer.tokenize_prompt, 47 | num_proc=num_proc, 48 | remove_columns=features, 49 | **map_kwargs, 50 | ) 51 | 52 | 53 | # TODO this isn't the best since it can't interleave datasets 54 | class ConstantLengthDataset(IterableDataset): 55 | """ 56 | Iterable dataset that returns constant length chunks of tokens from stream of text files. 57 | Args: 58 | tokenizer (Tokenizer): The processor used for processing the data. 59 | dataset (dataset.Dataset): Dataset with text files. 60 | seq_length (int): Length of token sequences to return. 61 | """ 62 | 63 | def __init__( # pylint: disable=super-init-not-called 64 | self, 65 | tokenizer, 66 | datasets, 67 | seq_length=2048, 68 | ): 69 | self.tokenizer = tokenizer 70 | self.concat_token_id = tokenizer.eos_token_id 71 | self.datasets: List[IterableDataset] = datasets 72 | self.seq_length = seq_length 73 | 74 | vocab_size = len(tokenizer.get_vocab()) 75 | 76 | if vocab_size <= torch.iinfo(torch.int16).max: 77 | self.tokens_dtype = torch.int16 78 | elif vocab_size <= torch.iinfo(torch.int32).max: 79 | self.tokens_dtype = torch.int32 80 | else: 81 | self.tokens_dtype = torch.int64 82 | 83 | def __iter__(self): 84 | buffer = { 85 | "input_ids": [], 86 | "attention_mask": [], 87 | "labels": [], 88 | "position_ids": [], 89 | } 90 | buffer_len = 0 91 | for dataset in self.datasets: 92 | idx = 0 93 | iterator = iter(dataset) 94 | more_examples = True 95 | while more_examples: 96 | try: 97 | example = next(iterator) 98 | idx += 1 99 | except StopIteration: 100 | more_examples = False 101 | example = None 102 | 103 | add_concat_token = False 104 | if example: 105 | example_len = len(example["input_ids"]) 106 | add_concat_token = example["input_ids"][-1] != self.concat_token_id 107 | else: 108 | example_len = 0 109 | 110 | if not example_len or ( 111 | buffer_len + int(add_concat_token) + example_len > self.seq_length 112 | ): 113 | if buffer["input_ids"]: 114 | input_ids = torch.cat(buffer["input_ids"], dim=-1)[ 115 | : self.seq_length 116 | ] 117 | attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ 118 | : self.seq_length 119 | ] 120 | position_ids = torch.cat(buffer["position_ids"], dim=-1)[ 121 | : self.seq_length 122 | ] 123 | labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] 124 | if labels.size() == input_ids.size() and ( 125 | attention_mask.size() == input_ids.size() 126 | ): 127 | yield { 128 | "input_ids": input_ids, 129 | "labels": labels, 130 | "attention_mask": attention_mask, 131 | "position_ids": position_ids, 132 | } 133 | else: 134 | LOG.warning( 135 | f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" 136 | ) 137 | buffer = { 138 | "input_ids": [], 139 | "attention_mask": [], 140 | "labels": [], 141 | "position_ids": [], 142 | } 143 | buffer_len = 0 144 | idx = 1 145 | 146 | if example: 147 | # FIXME 148 | # just going to drop data points that are too long 149 | if len(example["input_ids"]) <= self.seq_length: 150 | input_ids = example["input_ids"] 151 | attention_mask = example["attention_mask"] 152 | labels = example["labels"] 153 | 154 | if add_concat_token: 155 | input_ids.append(self.concat_token_id) 156 | attention_mask.append(1) 157 | labels.append(self.concat_token_id) 158 | 159 | input_ids_with_concat = torch.tensor( 160 | input_ids, dtype=self.tokens_dtype 161 | ) 162 | attention_mask_with_concat = torch.tensor( 163 | [idx * m for m in attention_mask], dtype=torch.int16 164 | ) 165 | labels_with_concat = torch.tensor( 166 | labels, dtype=self.tokens_dtype 167 | ) 168 | position_ids = torch.arange( 169 | len(input_ids), dtype=self.tokens_dtype 170 | ) 171 | 172 | buffer["input_ids"].append(input_ids_with_concat) 173 | buffer["attention_mask"].append(attention_mask_with_concat) 174 | buffer["labels"].append(labels_with_concat) 175 | buffer["position_ids"].append(position_ids) 176 | buffer_len += len(input_ids) 177 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | utility helpers for distributed checks 3 | """ 4 | import os 5 | import pickle # nosec 6 | from contextlib import contextmanager 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from accelerate import Accelerator 11 | 12 | accelerate = None # pylint: disable=invalid-name 13 | 14 | 15 | def load_accelerate(): 16 | global accelerate # pylint: disable=global-statement 17 | accelerate = Accelerator() 18 | 19 | 20 | def is_distributed(): 21 | """ 22 | Check if distributed training is initialized. 23 | """ 24 | global accelerate # pylint: disable=global-statement 25 | if not accelerate: 26 | accelerate = Accelerator() 27 | return dist.is_available() and dist.is_initialized() 28 | 29 | 30 | def barrier(): 31 | """ 32 | Acts as a barrier to wait for all processes. This ensures that all processes 33 | reach the barrier before proceeding further. 34 | """ 35 | if is_distributed(): 36 | dist.barrier() 37 | 38 | 39 | def is_main_process(): 40 | """ 41 | Check if the current process is the main process. 42 | If not in distributed mode, always return True. 43 | """ 44 | if not is_distributed(): 45 | return True 46 | return dist.get_rank() == 0 47 | 48 | 49 | def get_world_size(): 50 | return int(os.getenv("WORLD_SIZE", "1")) 51 | 52 | 53 | @contextmanager 54 | def zero_first(is_main): 55 | """ 56 | runs the wrapped context so that rank 0 runs first before other ranks 57 | """ 58 | if not is_main: # other ranks wait first 59 | barrier() 60 | yield 61 | if is_main: # then rank 0 waits after it has run the context 62 | barrier() 63 | 64 | 65 | def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name 66 | """ 67 | Run a callable 'fn' on all ranks and gather the results on the specified rank. 68 | 69 | Args: 70 | - fn (callable): A function that computes the value. This should not have any side effects. 71 | - rank (int, optional): The rank that gathers the values. Default is 0. 72 | - world_size (int, optional): Total number of processes in the current distributed setup. 73 | 74 | Returns: 75 | - A list of computed values from all ranks if on the gathering rank, otherwise None. 76 | """ 77 | value_scalar = fn() 78 | if not is_distributed(): 79 | return [value_scalar] 80 | value_tensor = torch.tensor( 81 | value_scalar, device=torch.cuda.current_device() 82 | ).float() 83 | 84 | if not is_main_process(): 85 | dist.gather(value_tensor, dst=0) 86 | else: 87 | gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] 88 | dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) 89 | 90 | # Convert tensors back to their original type (int or float) 91 | gathered_values = [] 92 | for tensor in gathered_tensors: 93 | if tensor == tensor.int(): 94 | gathered_values.append(int(tensor.item())) 95 | else: 96 | gathered_values.append(float(tensor.item())) 97 | return gathered_values 98 | return None 99 | 100 | 101 | def broadcast_dict(vals: dict): 102 | if not is_distributed(): 103 | return vals 104 | 105 | if is_main_process(): 106 | data_byte = pickle.dumps(vals) 107 | data_tensor = torch.ByteTensor(list(data_byte)).to("cuda") 108 | data_size = torch.IntTensor([len(data_byte)]).to("cuda") 109 | else: 110 | data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda") 111 | data_size = torch.IntTensor([0]).to("cuda") 112 | 113 | dist.broadcast(data_size, 0) 114 | if not is_main_process(): 115 | # resize 116 | data_tensor = data_tensor.new_empty([data_size.item()]) 117 | 118 | dist.broadcast(data_tensor, 0) 119 | 120 | if not is_main_process(): 121 | data_list = data_tensor.cpu().tolist() 122 | data_byte = bytes(data_list[: data_size.item()]) 123 | vals = pickle.loads(data_byte) # nosec 124 | 125 | return vals 126 | 127 | 128 | def compute_and_broadcast(fn): # pylint: disable=invalid-name 129 | """ 130 | Compute a value using the function 'fn' only on the specified rank (default is 0). 131 | The value is then broadcasted to all other ranks. 132 | 133 | Args: 134 | - fn (callable): A function that computes the value. This should not have any side effects. 135 | - rank (int, optional): The rank that computes the value. Default is 0. 136 | 137 | Returns: 138 | - The computed value (int or float). 139 | """ 140 | if is_main_process(): 141 | value_scalar = fn() 142 | value_tensor = torch.tensor( 143 | value_scalar, device=torch.cuda.current_device() 144 | ).float() 145 | else: 146 | value_tensor = torch.tensor( 147 | 0.0, device=torch.cuda.current_device() 148 | ) # Placeholder tensor 149 | 150 | # Broadcast the tensor to all processes. 151 | barrier() 152 | dist.broadcast(value_tensor, src=0) 153 | 154 | # Convert the tensor back to its original type (int or float) 155 | if value_tensor == value_tensor.int(): 156 | return int(value_tensor.item()) 157 | return float(value_tensor.item()) 158 | 159 | 160 | def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name 161 | """ 162 | Run a callable 'fn' on all ranks and gather the results on the specified rank. 163 | 164 | Args: 165 | - fn (callable): A function that computes the value. This should not have any side effects. 166 | - rank (int, optional): The rank that gathers the values. Default is 0. 167 | - world_size (int, optional): Total number of processes in the current distributed setup. 168 | 169 | Returns: 170 | - A list of computed values from all ranks if on the gathering rank, otherwise None. 171 | """ 172 | value_scalar = fn() 173 | value_tensor = torch.tensor( 174 | value_scalar, device=torch.cuda.current_device() 175 | ).float() 176 | 177 | # Placeholder tensor for gathering results 178 | if is_main_process(): 179 | gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] 180 | else: 181 | gathered_tensors = None 182 | 183 | dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) 184 | 185 | if is_main_process(): 186 | # Convert tensors back to their original type (int or float) 187 | gathered_values = [] 188 | for tensor in gathered_tensors: 189 | if tensor == tensor.int(): 190 | gathered_values.append(int(tensor.item())) 191 | else: 192 | gathered_values.append(float(tensor.item())) 193 | return gathered_values 194 | return None 195 | 196 | 197 | def reduce_and_broadcast(fn1, fn2): 198 | """ 199 | Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2', 200 | and then broadcast the reduced result to all ranks. 201 | 202 | Args: 203 | - fn1 (callable): A function that computes the value on each rank. 204 | - fn2 (callable): A reduction function that takes a list of values and returns a single value. 205 | - world_size (int, optional): Total number of processes in the current distributed setup. 206 | 207 | Returns: 208 | - The reduced and broadcasted value. 209 | """ 210 | 211 | # Gather values from all ranks using fn1 212 | if not is_distributed(): 213 | return fn2([fn1()]) 214 | 215 | gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size()) 216 | 217 | # Use compute_and_broadcast to compute the reduced value on the main process 218 | # and then broadcast it to all ranks 219 | return compute_and_broadcast(lambda: fn2(gathered_values)) 220 | -------------------------------------------------------------------------------- /data_prep/combine_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from datasets import load_dataset 5 | 6 | # This path contains all the datasets: https://drive.google.com/drive/folders/1bPfxrwcgGrX3-3CHJ0SLckY22zgdPwCj 7 | # Todo: Refer to this for description of each dataset: 8 | 9 | """ 10 | How to run the script: 11 | 1. Install necessary libraries. 12 | 2. Replace local dataset paths with datasets from above drive path 13 | 3. Replace output paths at end of the script to save dataset at your own local paths. 14 | """ 15 | 16 | def fetch_instruction(task_file_name): 17 | with open(os.path.join("/home/minimalist/work/natural-instructions/tasks", task_file_name)) as f: 18 | task_json = json.load(f) 19 | return task_json["Definition"][0] 20 | 21 | def prepare_quac_file(): 22 | file_path = "/home/minimalist/Downloads/nips_training_data/quac_train.json" 23 | data = [] 24 | lines = open(file_path, "r").readlines() 25 | for line in lines: 26 | rec = json.loads(line.strip()) 27 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) > 800: 28 | continue 29 | data.append(rec) 30 | random.shuffle(data) 31 | sampled_data = data 32 | print(f"quac dataset size: {len(sampled_data)}") 33 | return sampled_data 34 | 35 | 36 | def prepare_openQA_files(): 37 | data = [] 38 | 39 | file_paths = ["/home/minimalist/Downloads/nips_training_data/openQA_train.json" 40 | ] 41 | for file_path in file_paths: 42 | lines = open(file_path, "r").readlines() 43 | for line in lines: 44 | rec = json.loads(line.strip()) 45 | rec['instruction'] = "You are given an incomplete sentence along with a few reference completions. Only one of the reference completion is coherent and correct based on the context in sentence and all other reference completions are incorrect. Your task is to choose best matching completion from reference options and return that." 46 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) > 800: 47 | continue 48 | data.append(rec) 49 | random.shuffle(data) 50 | sampled_data = data 51 | print(f"OpenQA dataset size: {len(sampled_data)}") 52 | return sampled_data 53 | 54 | 55 | def prepare_cnn_dailymail_file(): 56 | file_path = "/home/minimalist/Downloads/nips_training_data/cnn_dailymail_summarization_random_train.json" 57 | data = [] 58 | lines = open(file_path, "r").readlines() 59 | for line in lines: 60 | rec = json.loads(line.strip()) 61 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) > 800: 62 | continue 63 | data.append(rec) 64 | random.shuffle(data) 65 | sampled_data = data 66 | print(f"CNN/DailyMail dataset size: {len(sampled_data)}") 67 | return sampled_data 68 | 69 | def prepare_maths_file(): 70 | file_path = "/home/minimalist/Downloads/nips_training_data/math_reasoning.json" 71 | 72 | with open(file_path) as f: 73 | data = json.load(f) 74 | 75 | sampled_data = [] 76 | for rec in data: 77 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800 and "print(" not in rec['output']: 78 | sampled_data.append(rec) 79 | 80 | random.shuffle(sampled_data) 81 | sampled_data = sampled_data[:250000] 82 | print(f"maths_reasoning dataset size: {len(sampled_data)}") 83 | return sampled_data 84 | 85 | 86 | def prepare_platypus_file(): 87 | file_path = "/home/minimalist/Downloads/nips_training_data/platypus_no_llm.json" 88 | 89 | with open(file_path) as f: 90 | data = json.load(f) 91 | 92 | sampled_data = [] 93 | for rec in data: 94 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800: 95 | sampled_data.append(rec) 96 | 97 | random.shuffle(sampled_data) 98 | sampled_data = sampled_data 99 | print(f"platypus dataset size: {len(sampled_data)}") 100 | return data 101 | 102 | 103 | def prepare_super_natural_generation_tasks_samples_file(): 104 | file_path = "/home/minimalist/Downloads/nips_training_data/natural_instructions_generation_tasks_dataset.json" 105 | task_to_instuction_map = {} 106 | 107 | with open(file_path) as f: 108 | data = json.load(f) 109 | print("Loaded file") 110 | 111 | sampled_data = [] 112 | for i, rec in enumerate(data): 113 | if rec["task_file"] not in task_to_instuction_map: 114 | rec['instruction'] = fetch_instruction(rec["task_file"]) 115 | task_to_instuction_map[rec["task_file"]] = rec['instruction'] 116 | else: 117 | rec['instruction'] = task_to_instuction_map[rec["task_file"]] 118 | rec.pop('task_file', None) 119 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800: 120 | sampled_data.append(rec) 121 | 122 | random.shuffle(sampled_data) 123 | sampled_data = sampled_data 124 | print(f"Super Natural Sentence dataset size: {len(sampled_data)}") 125 | return sampled_data 126 | 127 | 128 | def prepare_super_natural_exact_match_tasks_samples_file(): 129 | file_path = "/home/minimalist/Downloads/nips_training_data/natural_instructions_exact_match_tasks_dataset.json" 130 | 131 | with open(file_path) as f: 132 | data = json.load(f) 133 | 134 | random.shuffle(data) 135 | sampled_data = data 136 | print(f"Super Natural One Word dataset size: {len(sampled_data)}") 137 | return sampled_data 138 | 139 | 140 | 141 | def generate_lima_prompt(example): 142 | """Generates a standardized message to prompt the model with an instruction, optional input and a 143 | 'response' field.""" 144 | 145 | if example["input"]: 146 | return ( 147 | "Below is an instruction that describes a task, paired with an input that provides further context. " 148 | "Write a response that appropriately completes the request.\n\n" 149 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" 150 | ) 151 | return ( 152 | "Below is an instruction that describes a task. " 153 | "Write a response that appropriately completes the request.\n\n" 154 | f"### Instruction:\n{example['instruction']}\n\n### Response:" 155 | ) 156 | 157 | 158 | def prepare_lima_file(): 159 | dataset = load_dataset("GAIR/lima")["train"] 160 | data = [] 161 | for convo in dataset["conversations"]: 162 | rec = {"instruction": "", "input": convo[0], "output": convo[1]} 163 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800: 164 | data.append(rec) 165 | print(f"LIMA dataset size: {len(data)}") 166 | return data 167 | 168 | 169 | 170 | data = [] 171 | data += prepare_quac_file() 172 | data += prepare_openQA_files() 173 | data += prepare_cnn_dailymail_file() 174 | data += prepare_maths_file() 175 | data += prepare_platypus_file() 176 | data += prepare_super_natural_generation_tasks_samples_file() 177 | data += prepare_super_natural_exact_match_tasks_samples_file() 178 | data += prepare_lima_file() 179 | 180 | print(f"total data size: {len(data)}") 181 | 182 | random.shuffle(data) 183 | 184 | instruction_prefix = "### Instruction:\n" 185 | input_prefix = "### Input:\n" 186 | response_prefix = "### Response:\n" 187 | final_dataset = [] 188 | all_inputs = set() 189 | for each in data: 190 | if each['input'] not in all_inputs: 191 | all_inputs.add(each['input']) 192 | else: 193 | continue 194 | 195 | if len(each['input'].split(" ")) <= 4 or len(each['output'].split(" ")) < 1: 196 | continue 197 | 198 | rec = {} 199 | rec['question'] = '' 200 | if len(each['instruction'].strip()) > 0: 201 | rec['question'] += f"{instruction_prefix}{each['instruction']}\n\n" 202 | rec['question'] += f"{input_prefix}{each['input']}\n\n" 203 | rec['question'] += f"{response_prefix}" 204 | rec['answer'] = f"{each['output']}" 205 | final_dataset.append(rec) 206 | 207 | print(f"final_dataset size: {len(final_dataset)}") 208 | train_dataset = final_dataset[:int(0.98*len(final_dataset))] 209 | eval_dataset = final_dataset[int(0.98*len(final_dataset)):] 210 | 211 | with open(f"/home/minimalist/Downloads/nips_training_data/train_dataset.json", 'w') as f: 212 | json.dump(train_dataset, f, indent=1) 213 | 214 | with open(f"/home/minimalist/Downloads/nips_training_data/eval_dataset.json", 'w') as f: 215 | json.dump(eval_dataset, f, indent=1) 216 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/prompt_strategies/llama2_chat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt Strategy for finetuning Llama2 chat models 3 | see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation. 4 | 5 | This implementation is based on the Vicuna PR and the fastchat repo, see also: 6 | https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847 7 | 8 | Use dataset type: "llama2_chat" in conig.yml to use this prompt style. 9 | 10 | E.g. in the config.yml: 11 | ``` 12 | datasets: 13 | - path: llama_finetune_train.jsonl 14 | type: llama2_chat 15 | ``` 16 | 17 | The dataset itself should look like this: 18 | ``` 19 | {'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]} 20 | ``` 21 | in a jsonl file. The first message should be from the human, the second from gpt. 22 | For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns). 23 | 24 | Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! 25 | """ 26 | 27 | import logging 28 | from dataclasses import dataclass, field 29 | from typing import Generator, List, Sequence 30 | 31 | from axolotl.prompt_tokenizers import PromptTokenizingStrategy 32 | from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE 33 | 34 | 35 | @dataclass 36 | class Llama2ChatConversation: 37 | """A class that manages prompt templates and keeps all conversation history. 38 | copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py""" 39 | 40 | name: str = "llama2" 41 | # The system prompt 42 | system: str = ( 43 | "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " 44 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " 45 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 46 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " 47 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n" 48 | ) 49 | roles: Sequence[str] = ("[INST]", "[/INST]") 50 | messages: List[List[str]] = field(default_factory=list) 51 | offset: int = 0 52 | sep = " " 53 | sep2 = " " 54 | stop_token_ids = [2] 55 | 56 | def get_prompt(self) -> str: 57 | """Get the prompt for generation.""" 58 | seps = [self.sep, self.sep2] 59 | ret = "" 60 | for i, (role, message) in enumerate(self.messages): 61 | if (i == len(self.messages) - 1) and (role == self.roles[0]): 62 | # last message is from user (due to length), 63 | # return prompt without it for training 64 | return ret 65 | if i == 0: 66 | ret += self.system + message.strip() 67 | else: 68 | ret += role + " " + message.strip() + seps[i % 2] 69 | return ret 70 | 71 | def append_message(self, role: str, message: str): 72 | """Append a new message.""" 73 | self.messages.append([role, message]) 74 | 75 | 76 | class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): 77 | """ 78 | Tokenizing strategy for ShareGPT prompts. 79 | adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py 80 | """ 81 | 82 | def __init__(self, *args, **kwargs): 83 | super().__init__(*args, **kwargs) 84 | self.sequence_len = 4096 85 | self.tokenizer.add_special_tokens({"pad_token": ""}) 86 | # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json 87 | 88 | def tokenize_prompt(self, prompt): 89 | conv = next(self.prompter.build_prompt(prompt)) 90 | conversation_str = conv.get_prompt() 91 | 92 | # Tokenize conversations 93 | input_ids = self.tokenizer( 94 | conversation_str, 95 | return_tensors="pt", 96 | padding="max_length", 97 | max_length=self.sequence_len, 98 | truncation=True, 99 | ).input_ids[0] 100 | target = input_ids.clone() 101 | 102 | # Mask targets. Only compute loss on the assistant outputs. 103 | sep = conv.roles[1] 104 | 105 | total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) 106 | 107 | turns = conversation_str.split(conv.sep2) 108 | cur_len = 1 109 | target[:cur_len] = IGNORE_TOKEN_ID 110 | for turn in turns: 111 | if turn == "": 112 | break 113 | turn_len = len(self.tokenizer(turn).input_ids) 114 | 115 | parts = turn.split(sep) 116 | if len(parts) != 2: 117 | break 118 | parts[0] += sep 119 | # "-1" is hardcoded for the LLaMA tokenizer to make the offset correct. 120 | instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1 121 | 122 | # Ignore the user instructions 123 | target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID 124 | cur_len += turn_len + 2 # due to length of role token 125 | 126 | target[cur_len:] = IGNORE_TOKEN_ID 127 | 128 | if cur_len < self.sequence_len: 129 | if cur_len != total_len: 130 | target[:] = IGNORE_TOKEN_ID 131 | logging.warning( 132 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 133 | f" (ignored)" 134 | ) 135 | 136 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist() 137 | input_ids = input_ids.tolist() 138 | target = target.tolist() 139 | # this is a fix for the tokenizer which tokenizes [ differently with eos tokens and 140 | # follows the original llama implementation 141 | for i in range(2, total_len - 2): 142 | if input_ids[i] == 29961: 143 | input_ids[i] = 518 144 | if target[i] == 29961: 145 | target[i] = 518 146 | return { 147 | "input_ids": input_ids, 148 | "labels": target, 149 | "attention_mask": attention_mask, 150 | } 151 | 152 | 153 | class Llama2ChatPrompter: # pylint: disable=too-few-public-methods 154 | """ 155 | A prompter that generates prompts for Llama2 models. 156 | """ 157 | 158 | system_prompt = ( 159 | "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " 160 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " 161 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 162 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " 163 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n" 164 | ) 165 | 166 | def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: 167 | # see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78 168 | source = source["conversations"] # fix data structure for datasets 169 | 170 | # if system prompt provided, use it 171 | if source[0]["from"] == "system": 172 | system = f"[INST] <>\n{source[0]['value']}\n<>\n\n" 173 | source = source[1:] 174 | else: 175 | system = self.system_prompt 176 | 177 | conv = Llama2ChatConversation(system=system) 178 | 179 | if len(source) < 2: 180 | # If there isn't a back and forth conversation, ignore it 181 | # also happens on the data splitting leaving empty conversations 182 | raise IndexError 183 | 184 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 185 | 186 | if roles[source[0]["from"]] != conv.roles[0]: 187 | # Skip the first one if it is not from human 188 | source = source[1:] 189 | 190 | conv.messages = [] # pylint: disable=R0801 191 | for j, sentence in enumerate(source): 192 | role = roles[sentence["from"]] 193 | assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE 194 | if sentence["value"]: 195 | conv.append_message(role, sentence["value"]) 196 | yield conv 197 | 198 | 199 | def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy: 200 | return LLama2ChatTokenizingStrategy( 201 | Llama2ChatPrompter(), 202 | tokenizer, 203 | cfg.train_on_inputs, 204 | cfg.sequence_len, 205 | ) 206 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" 2 | 3 | import importlib 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional, Union 10 | 11 | import torch 12 | import yaml 13 | 14 | # add src to the pythonpath so we don't need to pip install this 15 | from accelerate.commands.config import config_args 16 | from art import text2art 17 | from huggingface_hub import HfApi 18 | from huggingface_hub.utils import LocalTokenNotFoundError 19 | from transformers import GenerationConfig, TextStreamer 20 | 21 | from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer 22 | from axolotl.logging_config import configure_logging 23 | from axolotl.train import TrainDatasetMeta 24 | from axolotl.utils.config import normalize_config, validate_config 25 | from axolotl.utils.data import prepare_dataset 26 | from axolotl.utils.dict import DictDefault 27 | from axolotl.utils.distributed import is_main_process 28 | from axolotl.utils.models import load_tokenizer 29 | from axolotl.utils.tokenization import check_dataset_labels 30 | from axolotl.utils.wandb_ import setup_wandb_env_vars 31 | 32 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 33 | src_dir = os.path.join(project_root, "src") 34 | sys.path.insert(0, src_dir) 35 | 36 | configure_logging() 37 | LOG = logging.getLogger("axolotl.scripts") 38 | 39 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 40 | 41 | 42 | def print_axolotl_text_art(suffix=None): 43 | font = "nancyj" 44 | ascii_text = " axolotl" 45 | if suffix: 46 | ascii_text += f" x {suffix}" 47 | ascii_art = text2art(" axolotl", font=font) 48 | 49 | if is_main_process(): 50 | print(ascii_art) 51 | 52 | 53 | def get_multi_line_input() -> Optional[str]: 54 | print("Give me an instruction (Ctrl + D to submit): ") 55 | instruction = "" 56 | for line in sys.stdin: 57 | instruction += line # pylint: disable=consider-using-join 58 | # instruction = pathlib.Path("/proc/self/fd/0").read_text() 59 | return instruction 60 | 61 | 62 | def do_merge_lora( 63 | *, 64 | cfg: DictDefault, 65 | cli_args: TrainerCliArgs, 66 | ): 67 | model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) 68 | safe_serialization = cfg.save_safetensors is True 69 | 70 | LOG.info("running merge of LoRA with base model") 71 | model = model.merge_and_unload() 72 | model.to(dtype=torch.float16) 73 | 74 | if cfg.local_rank == 0: 75 | LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") 76 | model.save_pretrained( 77 | str(Path(cfg.output_dir) / "merged"), 78 | safe_serialization=safe_serialization, 79 | ) 80 | tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) 81 | 82 | 83 | def do_inference( 84 | *, 85 | cfg: DictDefault, 86 | cli_args: TrainerCliArgs, 87 | ): 88 | model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) 89 | prompter = cli_args.prompter 90 | default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} 91 | 92 | for token, symbol in default_tokens.items(): 93 | # If the token isn't already specified in the config, add it 94 | if not (cfg.special_tokens and token in cfg.special_tokens): 95 | tokenizer.add_special_tokens({token: symbol}) 96 | 97 | prompter_module = None 98 | if prompter: 99 | prompter_module = getattr( 100 | importlib.import_module("axolotl.prompters"), prompter 101 | ) 102 | 103 | if cfg.landmark_attention: 104 | from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id 105 | 106 | set_model_mem_id(model, tokenizer) 107 | model.set_mem_cache_args( 108 | max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None 109 | ) 110 | 111 | model = model.to(cfg.device) 112 | 113 | while True: 114 | print("=" * 80) 115 | # support for multiline inputs 116 | instruction = get_multi_line_input() 117 | if not instruction: 118 | return 119 | if prompter_module: 120 | prompt: str = next( 121 | prompter_module().build_prompt(instruction=instruction.strip("\n")) 122 | ) 123 | else: 124 | prompt = instruction.strip() 125 | batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) 126 | 127 | print("=" * 40) 128 | model.eval() 129 | with torch.no_grad(): 130 | generation_config = GenerationConfig( 131 | repetition_penalty=1.1, 132 | max_new_tokens=1024, 133 | temperature=0.9, 134 | top_p=0.95, 135 | top_k=40, 136 | bos_token_id=tokenizer.bos_token_id, 137 | eos_token_id=tokenizer.eos_token_id, 138 | pad_token_id=tokenizer.pad_token_id, 139 | do_sample=True, 140 | use_cache=True, 141 | return_dict_in_generate=True, 142 | output_attentions=False, 143 | output_hidden_states=False, 144 | output_scores=False, 145 | ) 146 | streamer = TextStreamer(tokenizer) 147 | generated = model.generate( 148 | inputs=batch["input_ids"].to(cfg.device), 149 | generation_config=generation_config, 150 | streamer=streamer, 151 | ) 152 | print("=" * 40) 153 | print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) 154 | 155 | 156 | def choose_config(path: Path): 157 | yaml_files = list(path.glob("*.yml")) 158 | 159 | if not yaml_files: 160 | raise ValueError( 161 | "No YAML config files found in the specified directory. Are you using a .yml extension?" 162 | ) 163 | 164 | if len(yaml_files) == 1: 165 | print(f"Using default YAML file '{yaml_files[0]}'") 166 | return yaml_files[0] 167 | 168 | print("Choose a YAML file:") 169 | for idx, file in enumerate(yaml_files): 170 | print(f"{idx + 1}. {file}") 171 | 172 | chosen_file = None 173 | while chosen_file is None: 174 | try: 175 | choice = int(input("Enter the number of your choice: ")) 176 | if 1 <= choice <= len(yaml_files): 177 | chosen_file = yaml_files[choice - 1] 178 | else: 179 | print("Invalid choice. Please choose a number from the list.") 180 | except ValueError: 181 | print("Invalid input. Please enter a number.") 182 | 183 | return chosen_file 184 | 185 | 186 | def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: 187 | return not any(el in list2 for el in list1) 188 | 189 | 190 | def load_cfg(config: Path = Path("examples/"), **kwargs): 191 | if Path(config).is_dir(): 192 | config = choose_config(config) 193 | 194 | # load the config from the yaml file 195 | with open(config, encoding="utf-8") as file: 196 | cfg: DictDefault = DictDefault(yaml.safe_load(file)) 197 | cfg.axolotl_config_path = config 198 | # if there are any options passed in the cli, if it is something that seems valid from the yaml, 199 | # then overwrite the value 200 | cfg_keys = cfg.keys() 201 | for k, _ in kwargs.items(): 202 | # if not strict, allow writing to cfg even if it's not in the yml already 203 | if k in cfg_keys or not cfg.strict: 204 | # handle booleans 205 | if isinstance(cfg[k], bool): 206 | cfg[k] = bool(kwargs[k]) 207 | else: 208 | cfg[k] = kwargs[k] 209 | 210 | validate_config(cfg) 211 | 212 | normalize_config(cfg) 213 | 214 | setup_wandb_env_vars(cfg) 215 | return cfg 216 | 217 | 218 | def load_datasets( 219 | *, 220 | cfg: DictDefault, 221 | cli_args: TrainerCliArgs, 222 | ) -> TrainDatasetMeta: 223 | tokenizer = load_tokenizer(cfg) 224 | 225 | train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) 226 | 227 | if cli_args.debug or cfg.debug: 228 | LOG.info("check_dataset_labels...") 229 | check_dataset_labels( 230 | train_dataset.select( 231 | [ 232 | 0 # nosec 233 | for _ in range(cli_args.debug_num_examples) 234 | ] 235 | ), 236 | tokenizer, 237 | num_examples=cli_args.debug_num_examples, 238 | text_only=cli_args.debug_text_only, 239 | ) 240 | 241 | return TrainDatasetMeta( 242 | train_dataset=train_dataset, 243 | eval_dataset=eval_dataset, 244 | total_num_steps=total_num_steps, 245 | ) 246 | 247 | 248 | def check_accelerate_default_config(): 249 | if Path(config_args.default_yaml_config_file).exists(): 250 | LOG.warning( 251 | f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" 252 | ) 253 | 254 | 255 | def check_user_token(): 256 | # Verify if token is valid 257 | api = HfApi() 258 | try: 259 | user_info = api.whoami() 260 | return bool(user_info) 261 | except LocalTokenNotFoundError: 262 | LOG.warning( 263 | "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." 264 | ) 265 | return False 266 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/trainer.py: -------------------------------------------------------------------------------- 1 | """Module containing the Trainer class and related functions""" 2 | import logging 3 | import math 4 | import os 5 | from contextlib import contextmanager 6 | from functools import partial 7 | from typing import List 8 | 9 | import numpy as np 10 | import torch 11 | import torch.cuda 12 | import torch.distributed as dist 13 | from datasets import set_caching_enabled 14 | from torch.utils.data import DistributedSampler, RandomSampler 15 | 16 | from axolotl.core.trainer_builder import HFCausalTrainerBuilder 17 | from axolotl.utils.collators import DataCollatorForSeq2Seq 18 | from axolotl.utils.dataloader import MultipackDistributedDataloader 19 | from axolotl.utils.distributed import ( 20 | is_distributed, 21 | is_main_process, 22 | reduce_and_broadcast, 23 | zero_first, 24 | ) 25 | 26 | LOG = logging.getLogger("axolotl") 27 | 28 | 29 | @torch.jit.script 30 | def weighted_cross_entropy( 31 | logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor 32 | ): 33 | # Flatten the logits, labels, and weights tensors 34 | logits = logits.view( 35 | -1, logits.size(-1) 36 | ) # logits becomes of shape [batch_size*sequence_length, vocab_size] 37 | labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length] 38 | weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length] 39 | 40 | # Compute the unweighted cross entropy loss 41 | losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none") 42 | 43 | # Apply the weights to the losses and compute their sum 44 | return (weights * losses).sum() 45 | 46 | 47 | @torch.jit.script 48 | def create_weighted_mask(labels: torch.Tensor): 49 | # Check if the tensor is 2D. If not, unsqueeze it to make it 2D 50 | if len(labels.shape) == 1: 51 | labels = labels.unsqueeze(0) 52 | 53 | weights = torch.zeros_like(labels).float() 54 | for i in range(labels.shape[0]): 55 | mask = labels[i] != -100 56 | 57 | # Create a tensor to track group ids 58 | group_ids = torch.zeros_like(labels[i]).int() 59 | curr_group_id = 0 60 | 61 | for j in range(1, len(labels[i])): 62 | if mask[j] and not mask[j - 1]: # switch from masked to unmasked label 63 | curr_group_id += 1 # start new group 64 | group_ids[j] = ( 65 | curr_group_id if mask[j] else 0 66 | ) # assign group id if unmasked label 67 | 68 | # Count only unmasked labels in each group 69 | group_counts = torch.bincount(group_ids[mask]) 70 | 71 | mask_weights = torch.zeros_like(labels[i]).float() 72 | mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] 73 | 74 | weights[i] = mask_weights 75 | 76 | return weights.squeeze() # squeeze the output to match the input dimension 77 | 78 | 79 | def trainer_weighted_loss(model_output, labels, shift_labels=True): 80 | logits = ( 81 | model_output["logits"] if isinstance(model_output, dict) else model_output[0] 82 | ) 83 | if shift_labels: 84 | logits = logits[..., :-1, :].contiguous() 85 | labels = labels[..., 1:].contiguous() 86 | 87 | weights = create_weighted_mask(labels) 88 | return weighted_cross_entropy(logits, labels, weights) 89 | 90 | 91 | def add_position_ids(sample): 92 | sample_len = len(sample["input_ids"]) 93 | sample["position_ids"] = torch.arange(len(sample["input_ids"])) 94 | sample["length"] = sample_len 95 | return sample 96 | 97 | 98 | def add_length(sample): 99 | sample["length"] = len(sample["input_ids"]) 100 | return sample 101 | 102 | 103 | def drop_long_seq(sample, sequence_len=2048): 104 | return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 105 | 106 | 107 | @contextmanager 108 | def disable_datasets_caching(): 109 | try: 110 | set_caching_enabled(False) 111 | yield 112 | finally: 113 | set_caching_enabled(True) 114 | 115 | 116 | def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): 117 | drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) 118 | with zero_first(is_main_process()): 119 | train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes) 120 | if eval_dataset: 121 | eval_dataset = eval_dataset.filter( 122 | drop_long, num_proc=cfg.dataset_processes 123 | ) 124 | 125 | if cfg.group_by_length: 126 | train_dataset = train_dataset.map( 127 | add_length, num_proc=cfg.dataset_processes 128 | ) 129 | 130 | if cfg.sample_packing: 131 | train_dataset = train_dataset.map( 132 | add_position_ids, num_proc=cfg.dataset_processes 133 | ) 134 | if cfg.eval_sample_packing is not False: 135 | if eval_dataset: 136 | eval_dataset = eval_dataset.map( 137 | add_position_ids, num_proc=cfg.dataset_processes 138 | ) 139 | 140 | # Phi doesn't want the attention_mask feature when training 141 | if "CodeGenTokenizer" in tokenizer.__class__.__name__ or ( 142 | cfg.is_mistral_derived_model and cfg.flash_attention 143 | ): 144 | train_dataset = train_dataset.remove_columns("attention_mask") 145 | if eval_dataset: 146 | eval_dataset = eval_dataset.remove_columns("attention_mask") 147 | 148 | return train_dataset, eval_dataset 149 | 150 | 151 | def calculate_total_num_steps(cfg, train_dataset, tokenizer): 152 | if cfg.sample_packing: 153 | # we have to drop anything longer then sequence len otherwise 154 | # flash attention with position ids fails 155 | if not cfg.total_num_tokens: 156 | LOG.info("calculating total_num_tokens") 157 | total_num_tokens = np.sum( 158 | train_dataset.data.column("input_ids") 159 | .to_pandas() 160 | .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda 161 | .values 162 | ) 163 | LOG.info(f"total_num_tokens: {total_num_tokens}") 164 | cfg.total_num_tokens = total_num_tokens 165 | 166 | if not cfg.total_supervised_tokens: 167 | total_supervised_tokens = ( 168 | train_dataset.data.column("labels") 169 | .to_pandas() 170 | .apply(lambda x: np.sum(np.array(x) != -100)) 171 | .sum() 172 | ) 173 | LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`") 174 | cfg.total_supervised_tokens = total_supervised_tokens 175 | 176 | if cfg.sample_packing_eff_est: 177 | total_num_steps = ( 178 | # match count to len est in dataloader 179 | ( 180 | math.floor( 181 | 0.99 182 | * cfg.total_num_tokens 183 | / cfg.sample_packing_eff_est 184 | / cfg.sequence_len 185 | // cfg.batch_size 186 | // int(os.environ.get("WORLD_SIZE", 1)) 187 | ) 188 | - 1 189 | ) 190 | * cfg.num_epochs 191 | ) 192 | LOG.info( 193 | f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" 194 | ) 195 | else: 196 | if cfg.world_size > 1 and is_distributed(): 197 | sampler = DistributedSampler( 198 | train_dataset, 199 | num_replicas=cfg.world_size, 200 | rank=dist.get_rank(), 201 | seed=cfg.seed or 42, 202 | ) 203 | else: 204 | sampler = RandomSampler(train_dataset) 205 | 206 | data_loader = MultipackDistributedDataloader( 207 | train_dataset, 208 | batch_size=cfg.micro_batch_size, 209 | seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, 210 | collate_fn=DataCollatorForSeq2Seq( 211 | tokenizer, 212 | return_tensors="pt", 213 | padding="longest", 214 | ), 215 | sampler=sampler, 216 | packing_efficiency_estimate=cfg.sample_packing_eff_est, 217 | sample_packing_seq_len_multiplier=cfg.micro_batch_size, 218 | device_count=int(os.environ.get("WORLD_SIZE", 1)), 219 | ) 220 | data_loader_len = data_loader.len_w_stats() 221 | actual_eff = data_loader.efficiency() 222 | LOG.info(f"data_loader_len: {data_loader_len}") 223 | # FIXME: is there a bug here somewhere? the total num steps depends 224 | # on the agreed on value for sample_packing_eff_est 225 | total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) 226 | 227 | def calc_sample_packing_eff_est(estimates: List[float]): 228 | LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") 229 | return max(estimates) 230 | 231 | sample_packing_actual_eff_all = reduce_and_broadcast( 232 | lambda: actual_eff, 233 | calc_sample_packing_eff_est, 234 | ) 235 | sample_packing_eff_est = ( 236 | math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 237 | ) 238 | cfg.sample_packing_eff_est = sample_packing_eff_est 239 | LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") 240 | else: 241 | total_num_steps = int( 242 | math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) 243 | ) 244 | LOG.info(f"total_num_steps: {total_num_steps}") 245 | return total_num_steps 246 | 247 | 248 | def setup_fsdp_envs(cfg): 249 | os.environ["ACCELERATE_USE_FSDP"] = "true" 250 | if cfg.fsdp_config.fsdp_offload_params: 251 | os.environ["FSDP_OFFLOAD_PARAMS"] = "true" 252 | if cfg.fsdp_config.fsdp_sync_module_states: 253 | os.environ["FSDP_SYNC_MODULE_STATES"] = "true" 254 | if cfg.fsdp_config.fsdp_state_dict_type: 255 | os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type 256 | if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: 257 | os.environ[ 258 | "FSDP_TRANSFORMER_CLS_TO_WRAP" 259 | ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap 260 | 261 | 262 | def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): 263 | if cfg.fsdp: 264 | setup_fsdp_envs(cfg) 265 | elif cfg.deepspeed: 266 | os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" 267 | 268 | trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) 269 | trainer_builder.train_dataset = train_dataset 270 | trainer_builder.eval_dataset = eval_dataset 271 | 272 | return trainer_builder.build(total_num_steps) 273 | -------------------------------------------------------------------------------- /training/axolotl/src/axolotl/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import hashlib 3 | import itertools 4 | import logging 5 | import math 6 | from typing import Any, Callable, List, Union 7 | 8 | import numba 9 | import numpy as np 10 | from torch.utils.data import DistributedSampler, Sampler 11 | 12 | LOG = logging.getLogger("axolotl.utils.dataloader") 13 | 14 | 15 | @numba.njit 16 | def ffd_check(a: np.ndarray, c: int, n: int): 17 | # First-fit-decreasing bin packing 18 | # Check if a[] could fit in n bins with capacity c 19 | # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing 20 | 21 | a = np.sort(a)[::-1] 22 | bins = np.full((n,), c, dtype=a.dtype) 23 | for size in a: 24 | not_found = True 25 | for idx in range(n): 26 | if bins[idx] >= size: 27 | bins[idx] -= size 28 | not_found = False 29 | break 30 | 31 | if not_found: 32 | return False 33 | 34 | return True 35 | 36 | 37 | @numba.njit 38 | def ffd_with_result(a: np.ndarray, c: int, start_index: int): 39 | # First-fit-decreasing bin packing (with result return) 40 | 41 | indices = np.argsort(a)[::-1] 42 | a = a[indices] 43 | 44 | bins: List[Any] = [] 45 | bins_result: List[Any] = [] 46 | for a_id, size in enumerate(a): 47 | add_new = True 48 | for idx in range(len(bins)): 49 | if bins[idx] >= size: 50 | bins[idx] -= size 51 | bins_result[idx].append(indices[a_id] + start_index) 52 | add_new = False 53 | break 54 | 55 | if add_new: 56 | bins.append(c - size) 57 | bins_result.append([indices[a_id] + start_index]) 58 | 59 | return bins_result, len(a) 60 | 61 | 62 | @numba.njit 63 | def allocate( 64 | lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int 65 | ): 66 | """ 67 | :param lengths: array of lengths of each sample 68 | :param lengths_cumsum: cumulative sum of consecutive lengths 69 | :param rank: rank for this process 70 | :param c: length of tokens per batch 71 | :param n: number of ranks 72 | :return: 73 | """ 74 | # Dynamic batch allocator, similar to Multifit 75 | # https://en.wikipedia.org/wiki/Multifit_algorithm 76 | # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) 77 | 78 | s = 0 79 | start_index = 0 80 | result = [] 81 | result_totseqs = [] 82 | 83 | while True: 84 | # binary search [left, right) 85 | left = 1 86 | right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") 87 | 88 | while right - left > 1: 89 | mid = (left + right) // 2 90 | if ffd_check(lengths[start_index : start_index + mid], c, n): 91 | left = mid 92 | else: 93 | right = mid 94 | 95 | # use length left 96 | batch, tot_seqs = ffd_with_result( 97 | lengths[start_index : start_index + left], c, start_index 98 | ) 99 | if len(batch) < n: 100 | break 101 | 102 | start_index += left 103 | s = lengths_cumsum[start_index - 1] 104 | 105 | # add local rank 106 | result.append(batch[rank]) 107 | # add total seqs for all ranks 108 | result_totseqs.append(tot_seqs) 109 | # yield batch[rank], tot_seqs, s, len(result) * c * n 110 | return result, result_totseqs, s, len(result) * c * n 111 | 112 | 113 | def chunk(iterable, n): 114 | """ 115 | Chunk data into tuples of length n 116 | """ 117 | # batched('ABCDEFG', 3) --> ABC DEF G 118 | if n < 1: 119 | raise ValueError("n must be at least one") 120 | it = iter(iterable) 121 | while batch := tuple(itertools.islice(it, n)): 122 | yield batch 123 | 124 | 125 | def hash_indices(lst: List[int]) -> str: 126 | # Convert the list of integers to a string representation 127 | concatenated = ",".join(map(str, lst)) 128 | 129 | # Generate the hash 130 | sha256 = hashlib.sha256() 131 | sha256.update(concatenated.encode()) 132 | 133 | return sha256.hexdigest() 134 | 135 | 136 | class MultipackDistributedDataloader: 137 | """Unpadded data loading using Multipack. 138 | Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py 139 | Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard. 140 | """ 141 | 142 | def __init__( 143 | self, 144 | dataset: Any, 145 | collate_fn: Callable, 146 | seq_max_length: int = 2048, 147 | batch_size: int = 1, 148 | sampler: Union[Sampler, DistributedSampler] = None, 149 | packing_efficiency_estimate: float = 1.0, 150 | sample_packing_seq_len_multiplier: int = 1, 151 | device_count: int = 1, 152 | ): 153 | # Dataset 154 | self.dataset = dataset 155 | self.lengths = ( 156 | dataset.data.column("position_ids") 157 | .to_pandas() 158 | .apply(lambda x: x[-1] + 1) 159 | .values 160 | ) 161 | assert isinstance(self.lengths, np.ndarray) 162 | assert batch_size % sample_packing_seq_len_multiplier == 0 163 | assert batch_size >= sample_packing_seq_len_multiplier 164 | self.sampler = sampler 165 | self.batch_size = batch_size 166 | self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier 167 | self.seq_max_length = seq_max_length 168 | self.batch_max_length = batch_size * seq_max_length 169 | self.collate_fn = collate_fn 170 | 171 | self.num_replicas = 1 172 | self.rank = 0 173 | 174 | # statistics 175 | self.eff_total_used = 0 176 | self.eff_total_slots = 0 177 | self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 178 | self.device_count = device_count 179 | 180 | def generate_batches(self, set_stats=False): 181 | LOG.info("generating packed batches") 182 | if self.sampler: 183 | indices = [idx for idx in self.sampler] 184 | else: 185 | indices = range(0, len(self.dataset)) 186 | 187 | LOG.info(hash_indices(indices)) 188 | lengths = self.lengths[indices] 189 | lengths_cumsum = np.cumsum(lengths) 190 | 191 | batches, totseqs, total_used, total_slots = allocate( 192 | lengths=lengths, 193 | lengths_cumsum=lengths_cumsum, 194 | rank=self.rank, 195 | # c=self.batch_max_length, 196 | c=self.seq_max_length * self.sample_packing_seq_len_multiplier, 197 | n=self.num_replicas, 198 | ) 199 | 200 | batches = [[indices[b_idx] for b_idx in batch] for batch in batches] 201 | 202 | # statistics 203 | if set_stats: 204 | self.eff_total_used += total_used 205 | self.eff_total_slots += total_slots 206 | 207 | return batches, totseqs 208 | 209 | def __iter__(self): 210 | if hasattr(self.sampler, "set_epoch"): 211 | new_epoch = self.sampler.epoch + 1 212 | self.sampler.set_epoch(new_epoch) 213 | LOG.info(f"calling sampler.set_epoch({new_epoch})") 214 | all_batches, _ = self.generate_batches(set_stats=True) 215 | features = self.dataset.features.keys() 216 | len_remaining = self._len_est() 217 | for batches in chunk( 218 | all_batches, self.batch_size // self.sample_packing_seq_len_multiplier 219 | ): 220 | chunked_data = [] 221 | attn_mask_cum_idx = 0 222 | for batch in batches: 223 | concatenated = {} 224 | batched_data = [self.dataset[batch_idx] for batch_idx in batch] 225 | for feature in features: 226 | if feature == "length": 227 | continue 228 | if feature == "attention_mask": 229 | arrays = [ 230 | (attn_mask_cum_idx + idx + 1) * np.array(item[feature]) 231 | for idx, item in enumerate(batched_data) 232 | if feature in item 233 | ] 234 | attn_mask_cum_idx += len(batched_data) 235 | concatenated[feature] = np.concatenate(arrays) 236 | else: 237 | arrays = [ 238 | np.array(item[feature]) 239 | for item in batched_data 240 | if feature in item 241 | ] 242 | concatenated[feature] = np.concatenate(arrays) 243 | chunked_data.append(concatenated) 244 | yield self.collate_fn(chunked_data) 245 | len_remaining -= 1 246 | if not len_remaining: 247 | return 248 | # yield a no-op for cases where we don't have any data left to pack 249 | for i in range(0, len_remaining): 250 | yield self.collate_fn( 251 | [ 252 | { 253 | "input_ids": [0], 254 | "labels": [-100], 255 | "attention_mask": [True], 256 | "position_ids": [0], 257 | } 258 | ] 259 | ) 260 | 261 | def _len_est(self): 262 | lengths_sum = np.sum(self.lengths) 263 | lengths_sum_per_device = lengths_sum // self.device_count 264 | LOG.info( 265 | f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " 266 | f"total_num_tokens per device: {lengths_sum_per_device}" 267 | ) 268 | 269 | # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler 270 | return ( 271 | math.floor( 272 | 0.99 273 | * lengths_sum_per_device 274 | / self.packing_efficiency_estimate 275 | // self.seq_max_length 276 | // self.batch_size 277 | ) 278 | - 1 279 | ) 280 | 281 | def __len__(self): 282 | # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get 283 | # the same share of total tokens 284 | # if not self.eff_total_used: 285 | # batches, _ = self.generate_batches(set_stats=True) 286 | # LOG.info( 287 | # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " 288 | # f"actual packing efficiency: {self.efficiency()}" 289 | # ) 290 | return max(1, self._len_est()) 291 | 292 | def len_w_stats(self): 293 | if not self.eff_total_used: 294 | batches, _ = self.generate_batches(set_stats=True) 295 | LOG.info( 296 | f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " 297 | f"actual packing efficiency: {self.efficiency()}" 298 | ) 299 | return max(1, self._len_est()) 300 | 301 | def efficiency(self): 302 | return self.eff_total_used / self.eff_total_slots 303 | --------------------------------------------------------------------------------