├── training
└── axolotl
│ ├── src
│ ├── axolotl
│ │ ├── __init__.py
│ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── const.py
│ │ │ └── cli.py
│ │ ├── core
│ │ │ └── __init__.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── phi
│ │ │ │ ├── __init__.py
│ │ │ │ └── configuration_mixformer_sequential.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── dict.py
│ │ │ ├── wandb_.py
│ │ │ ├── tokenization.py
│ │ │ ├── bench.py
│ │ │ ├── schedulers.py
│ │ │ ├── collators.py
│ │ │ ├── distributed.py
│ │ │ ├── trainer.py
│ │ │ └── dataloader.py
│ │ ├── prompt_strategies
│ │ │ ├── alpaca_instruct.py
│ │ │ ├── __init__.py
│ │ │ ├── sharegpt_jokes.py
│ │ │ ├── orcamini.py
│ │ │ ├── metharme.py
│ │ │ ├── completion.py
│ │ │ ├── user_defined.py
│ │ │ ├── sharegpt.py
│ │ │ ├── pygmalion.py
│ │ │ ├── alpaca_chat.py
│ │ │ ├── context_qa.py
│ │ │ ├── alpaca_w_system.py
│ │ │ ├── creative_acr.py
│ │ │ └── llama2_chat.py
│ │ ├── cli
│ │ │ ├── merge_lora.py
│ │ │ ├── inference.py
│ │ │ ├── shard.py
│ │ │ ├── train.py
│ │ │ └── __init__.py
│ │ ├── monkeypatch
│ │ │ ├── llama_embeddings_hijack.py
│ │ │ ├── mistral_embeddings_hijack.py
│ │ │ ├── llama_expand_mask.py
│ │ │ ├── btlm_attn_hijack_flash.py
│ │ │ ├── xpos_rope_llama_monkey_patch.py
│ │ │ ├── utils.py
│ │ │ ├── llama_attn_hijack_sdp.py
│ │ │ ├── llama_attn_hijack_xformers.py
│ │ │ └── fastchat_conversation_turns.py
│ │ ├── logging_config.py
│ │ ├── convert.py
│ │ ├── train.py
│ │ └── datasets.py
│ └── axolotl.egg-info
│ │ ├── top_level.txt
│ │ ├── dependency_links.txt
│ │ ├── SOURCES.txt
│ │ ├── PKG-INFO
│ │ └── requires.txt
│ ├── dist
│ └── axolotl-0.3.0-py3.9.egg
│ ├── docs
│ ├── faq.md
│ ├── multipack.md
│ ├── multi-node.md
│ └── nccl.md
│ ├── scripts
│ ├── runpod-entrypoint.sh
│ └── finetune.py
│ ├── docker
│ ├── Dockerfile-runpod
│ ├── Dockerfile
│ └── Dockerfile-base
│ ├── requirements.txt
│ ├── deepspeed
│ ├── zero1.json
│ ├── zero2.json
│ └── zero3.json
│ ├── setup.py
│ └── examples
│ └── mistral
│ └── nips
│ └── nips_02.yml
├── inference
└── submission_2
│ ├── fast_api_requirements.txt
│ ├── Dockerfile
│ ├── api.py
│ └── main.py
└── data_prep
├── prepare_math_reasoning_dataset.py
├── prepare_exact_match_tasks_dataset.py
├── prepare_generation_tasks_dataset.py
└── combine_datasets.py
/training/axolotl/src/axolotl/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/inference/submission_2/fast_api_requirements.txt:
--------------------------------------------------------------------------------
1 | # FAST API
2 | fastapi>=0.68.0,<0.69.0
3 | pydantic>=1.8.0,<2.0.0
4 | uvicorn>=0.15.0,<0.16.0
5 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/common/const.py:
--------------------------------------------------------------------------------
1 | """
2 | Various shared constants
3 | """
4 |
5 | DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
6 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 | https://download.pytorch.org/whl/cu118
2 | https://huggingface.github.io/autogptq-index/whl/cu118/
3 |
--------------------------------------------------------------------------------
/training/axolotl/dist/axolotl-0.3.0-py3.9.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Upaya07/NeurIPS-llm-efficiency-challenge/HEAD/training/axolotl/dist/axolotl-0.3.0-py3.9.egg
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | src/axolotl.egg-info/PKG-INFO
3 | src/axolotl.egg-info/SOURCES.txt
4 | src/axolotl.egg-info/dependency_links.txt
5 | src/axolotl.egg-info/requires.txt
6 | src/axolotl.egg-info/top_level.txt
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/models/phi/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MixFormers model architecture used for phi models
3 | """
4 |
5 | from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
6 | from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
7 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: axolotl
3 | Version: 0.3.0
4 | Summary: LLM Trainer
5 | Provides-Extra: flash-attn
6 | Provides-Extra: deepspeed
7 |
8 | Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
9 |
--------------------------------------------------------------------------------
/training/axolotl/docs/faq.md:
--------------------------------------------------------------------------------
1 | # Axolotl FAQ's
2 |
3 |
4 | > The trainer stopped and hasn't progressed in several minutes.
5 |
6 | Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
7 |
8 | > Exitcode -9
9 |
10 | This usually happens when you run out of system RAM.
11 |
12 | > Exitcode -7 while using deepspeed
13 |
14 | Try upgrading deepspeed w: `pip install -U deepspeed`
15 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/dict.py:
--------------------------------------------------------------------------------
1 | """Module containing the DictDefault class"""
2 |
3 | from addict import Dict
4 |
5 |
6 | class DictDefault(Dict):
7 | """
8 | A Dict that returns None instead of returning empty Dict for missing keys.
9 | """
10 |
11 | def __missing__(self, key):
12 | return None
13 |
14 | def __or__(self, other):
15 | return DictDefault(super().__or__(other))
16 |
--------------------------------------------------------------------------------
/training/axolotl/scripts/runpod-entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Export specific ENV variables to /etc/rp_environment
4 | echo "Exporting environment variables..."
5 | printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
6 | echo 'source /etc/rp_environment' >> ~/.bashrc
7 |
8 | if [[ $PUBLIC_KEY ]]
9 | then
10 | mkdir -p ~/.ssh
11 | chmod 700 ~/.ssh
12 | echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
13 | chmod 700 -R ~/.ssh
14 | # Start the SSH service in the background
15 | service ssh start
16 | else
17 | echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon"
18 | fi
19 |
20 | # Execute the passed arguments (CMD)
21 | exec "$@"
22 |
--------------------------------------------------------------------------------
/inference/submission_2/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
2 |
3 | RUN apt-get update && apt-get install -y git python3-virtualenv wget
4 |
5 | RUN pip install transformers
6 | RUN pip install torch>=2.0.1 huggingface_hub accelerate sentencepiece optimum py7zr scipy appdirs peft
7 |
8 | WORKDIR /workspace
9 | # Setup server requriements
10 | COPY ./fast_api_requirements.txt fast_api_requirements.txt
11 | RUN pip install --no-cache-dir --upgrade -r fast_api_requirements.txt
12 |
13 | ENV HUGGINGFACE_REPO="upaya07/Birbal-7B-V1"
14 |
15 | # Copy over single file server
16 | COPY ./main.py main.py
17 | COPY ./api.py api.py
18 | # Run the server
19 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
20 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | auto-gptq==0.4.2
3 | packaging
4 | peft==0.6.0
5 | transformers@ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
6 | bitsandbytes>=0.41.1
7 | accelerate@ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
8 | addict
9 | fire
10 | PyYAML>=6.0
11 | datasets
12 | sentencepiece
13 | wandb
14 | einops
15 | xformers>=0.0.22
16 | optimum==1.13.2
17 | hf_transfer
18 | colorama
19 | numba
20 | numpy>=1.24.4
21 | bert-score==0.3.13
22 | evaluate==0.4.0
23 | rouge-score==0.1.2
24 | scipy
25 | scikit-learn==1.2.2
26 | pynvml
27 | art
28 | fschat==0.2.29
29 | gradio
30 |
31 | [deepspeed]
32 | deepspeed
33 |
34 | [flash-attn]
35 | flash-attn>=2.3.0
36 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/alpaca_instruct.py:
--------------------------------------------------------------------------------
1 | """Module loading the AlpacaInstructPromptTokenizingStrategy class"""
2 |
3 | from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
4 | from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
5 |
6 |
7 | def load(tokenizer, cfg):
8 | return AlpacaPromptTokenizingStrategy(
9 | AlpacaPrompter(PromptStyle.INSTRUCT.value),
10 | tokenizer,
11 | cfg.train_on_inputs,
12 | cfg.sequence_len,
13 | )
14 |
15 |
16 | def load_no_prompt(tokenizer, cfg):
17 | return AlpacaPromptTokenizingStrategy(
18 | UnpromptedPrompter(PromptStyle.INSTRUCT.value),
19 | tokenizer,
20 | cfg.train_on_inputs,
21 | cfg.sequence_len,
22 | )
23 |
--------------------------------------------------------------------------------
/training/axolotl/docker/Dockerfile-runpod:
--------------------------------------------------------------------------------
1 | ARG BASE_TAG=main
2 | FROM winglian/axolotl:$BASE_TAG
3 |
4 | ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
5 | ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
6 | ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
7 |
8 | COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
9 |
10 | RUN apt install --yes --no-install-recommends openssh-server tmux && \
11 | mkdir -p ~/.ssh && \
12 | chmod 700 ~/.ssh && \
13 | printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
14 | chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh && \
15 | chmod +x /root/runpod-entrypoint.sh
16 |
17 | ENTRYPOINT ["/root/runpod-entrypoint.sh"]
18 | CMD ["sleep", "infinity"]
19 |
--------------------------------------------------------------------------------
/training/axolotl/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu118
2 | --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
3 | torch==2.0.1
4 | auto-gptq==0.4.2
5 | packaging
6 | peft==0.6.0
7 | transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
8 | bitsandbytes>=0.41.1
9 | accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
10 | deepspeed
11 | addict
12 | fire
13 | PyYAML>=6.0
14 | datasets
15 | flash-attn>=2.3.0
16 | sentencepiece
17 | wandb
18 | einops
19 | xformers>=0.0.22
20 | optimum==1.13.2
21 | hf_transfer
22 | colorama
23 | numba
24 | numpy>=1.24.4
25 | # qlora things
26 | bert-score==0.3.13
27 | evaluate==0.4.0
28 | rouge-score==0.1.2
29 | scipy
30 | scikit-learn==1.2.2
31 | pynvml
32 | art
33 | fschat==0.2.29
34 | gradio
--------------------------------------------------------------------------------
/inference/submission_2/api.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 | from typing import List, Dict, Optional
4 |
5 |
6 | class ProcessRequest(BaseModel):
7 | prompt: str
8 | num_samples: int = 1
9 | max_new_tokens: int = 50
10 | top_k: int = 200
11 | temperature: float = 1e-7
12 | seed: Optional[int] = None
13 | echo_prompt: Optional[bool]
14 |
15 |
16 | class Token(BaseModel):
17 | text: str
18 | logprob: float
19 | top_logprob: Dict[str, float]
20 |
21 |
22 | class ProcessResponse(BaseModel):
23 | text: str
24 | tokens: List[Token]
25 | logprob: float
26 | request_time: float
27 |
28 |
29 | class TokenizeRequest(BaseModel):
30 | text: str
31 | truncation: bool = True
32 | max_length: int = 2048
33 |
34 |
35 | class TokenizeResponse(BaseModel):
36 | tokens: List[int]
37 | request_time: float
38 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/cli/merge_lora.py:
--------------------------------------------------------------------------------
1 | """
2 | CLI to run merge a trained LoRA into a base model
3 | """
4 | from pathlib import Path
5 |
6 | import fire
7 | import transformers
8 |
9 | from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
10 | from axolotl.common.cli import TrainerCliArgs
11 |
12 |
13 | def do_cli(config: Path = Path("examples/"), **kwargs):
14 | # pylint: disable=duplicate-code
15 | print_axolotl_text_art()
16 | parser = transformers.HfArgumentParser((TrainerCliArgs))
17 | parsed_cli_args, _ = parser.parse_args_into_dataclasses(
18 | return_remaining_strings=True
19 | )
20 | parsed_cli_args.merge_lora = True
21 | parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
22 |
23 | do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
24 |
25 |
26 | if __name__ == "__main__":
27 | fire.Fire(do_cli)
28 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/cli/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | CLI to run inference on a trained model
3 | """
4 | from pathlib import Path
5 |
6 | import fire
7 | import transformers
8 |
9 | from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
10 | from axolotl.common.cli import TrainerCliArgs
11 |
12 |
13 | def do_cli(config: Path = Path("examples/"), **kwargs):
14 | # pylint: disable=duplicate-code
15 | print_axolotl_text_art()
16 | parsed_cfg = load_cfg(config, **kwargs)
17 | parsed_cfg.sample_packing = False
18 | parser = transformers.HfArgumentParser((TrainerCliArgs))
19 | parsed_cli_args, _ = parser.parse_args_into_dataclasses(
20 | return_remaining_strings=True
21 | )
22 | parsed_cli_args.inference = True
23 |
24 | do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
25 |
26 |
27 | if __name__ == "__main__":
28 | fire.Fire(do_cli)
29 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/wandb_.py:
--------------------------------------------------------------------------------
1 | """Module for wandb utilities"""
2 |
3 | import os
4 |
5 |
6 | def setup_wandb_env_vars(cfg):
7 | if cfg.wandb_mode and cfg.wandb_mode == "offline":
8 | os.environ["WANDB_MODE"] = cfg.wandb_mode
9 | elif cfg.wandb_project and len(cfg.wandb_project) > 0:
10 | os.environ["WANDB_PROJECT"] = cfg.wandb_project
11 | cfg.use_wandb = True
12 | if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
13 | os.environ["WANDB_ENTITY"] = cfg.wandb_entity
14 | if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
15 | os.environ["WANDB_WATCH"] = cfg.wandb_watch
16 | if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
17 | os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
18 | if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
19 | os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
20 | else:
21 | os.environ["WANDB_DISABLED"] = "true"
22 |
--------------------------------------------------------------------------------
/training/axolotl/deepspeed/zero1.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "overlap_comm": true
5 | },
6 | "bf16": {
7 | "enabled": "auto"
8 | },
9 | "fp16": {
10 | "enabled": "auto",
11 | "auto_cast": false,
12 | "loss_scale": 0,
13 | "initial_scale_power": 32,
14 | "loss_scale_window": 1000,
15 | "hysteresis": 2,
16 | "min_loss_scale": 1
17 | },
18 | "optimizer": {
19 | "type": "AdamW",
20 | "params": {
21 | "lr": "auto",
22 | "betas": "auto",
23 | "eps": "auto",
24 | "weight_decay": "auto"
25 | }
26 | },
27 | "scheduler": {
28 | "type": "WarmupDecayLR",
29 | "params": {
30 | "warmup_min_lr": "auto",
31 | "warmup_max_lr": "auto",
32 | "warmup_num_steps": "auto",
33 | "warmup_type": "linear",
34 | "total_num_steps": "auto"
35 | }
36 | },
37 | "gradient_accumulation_steps": "auto",
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
42 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/__init__.py:
--------------------------------------------------------------------------------
1 | """Module to load prompt strategies."""
2 |
3 | import importlib
4 | import inspect
5 |
6 | from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
7 |
8 |
9 | def load(strategy, tokenizer, cfg, ds_cfg):
10 | try:
11 | load_fn = "load"
12 | if strategy.split(".")[-1].startswith("load_"):
13 | load_fn = strategy.split(".")[-1]
14 | strategy = ".".join(strategy.split(".")[:-1])
15 | mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
16 | func = getattr(mod, load_fn)
17 | load_kwargs = {}
18 | if strategy == "user_defined":
19 | load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
20 | else:
21 | sig = inspect.signature(func)
22 | if "ds_cfg" in sig.parameters:
23 | load_kwargs["ds_cfg"] = ds_cfg
24 | return func(tokenizer, cfg, **load_kwargs)
25 | except Exception: # pylint: disable=broad-exception-caught
26 | return None
27 |
--------------------------------------------------------------------------------
/training/axolotl/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG BASE_TAG=main-base
2 | FROM winglian/axolotl-base:$BASE_TAG
3 |
4 | ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
5 | ARG AXOLOTL_EXTRAS=""
6 | ARG CUDA="118"
7 | ENV BNB_CUDA_VERSION=$CUDA
8 | ARG PYTORCH_VERSION="2.0.1"
9 |
10 | ENV PYTORCH_VERSION=$PYTORCH_VERSION
11 |
12 | RUN apt-get update && \
13 | apt-get install -y vim curl
14 |
15 | WORKDIR /workspace
16 |
17 | RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
18 |
19 | WORKDIR /workspace/axolotl
20 |
21 | # If AXOLOTL_EXTRAS is set, append it in brackets
22 | RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
23 | RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
24 | pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
25 | else \
26 | pip install -e .[flash-attn]; \
27 | fi
28 |
29 | # fix so that git fetch/pull from remote works
30 | RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
31 | git config --get remote.origin.fetch
32 |
33 | # helper for huggingface-login cli
34 | RUN git config --global credential.helper store
35 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/sharegpt_jokes.py:
--------------------------------------------------------------------------------
1 | """Module for Jokes prompts using sharegpt style """
2 | from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
3 | from axolotl.prompters import ShareGPTPrompterV2
4 |
5 |
6 | def load(tokenizer, cfg):
7 | return SimpleJokesShareGPTPromptTokenizingStrategy(
8 | ShareGPTPrompterV2(),
9 | tokenizer,
10 | cfg.train_on_inputs,
11 | cfg.sequence_len,
12 | )
13 |
14 |
15 | class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
16 | """
17 | Tokenization strategy for asking bot to tell a joke and then explain why its funny
18 | """
19 |
20 | # title, text, explanation
21 | def get_conversation_thread(self, prompt):
22 | title = "" if not prompt["title"] else prompt["title"] + " "
23 | return [
24 | {"from": "human", "value": "Tell me a joke."},
25 | {"from": "gpt", "value": title + prompt["text"]},
26 | {"from": "human", "value": "Why is that joke funny?"},
27 | {"from": "gpt", "value": prompt["explanation"]},
28 | ]
29 |
--------------------------------------------------------------------------------
/training/axolotl/deepspeed/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "offload_optimizer": {
5 | "device": "cpu"
6 | },
7 | "contiguous_gradients": true,
8 | "overlap_comm": true
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "auto_cast": false,
16 | "loss_scale": 0,
17 | "initial_scale_power": 32,
18 | "loss_scale_window": 1000,
19 | "hysteresis": 2,
20 | "min_loss_scale": 1
21 | },
22 | "optimizer": {
23 | "type": "AdamW",
24 | "params": {
25 | "lr": "auto",
26 | "betas": "auto",
27 | "eps": "auto",
28 | "weight_decay": "auto"
29 | }
30 | },
31 | "scheduler": {
32 | "type": "WarmupDecayLR",
33 | "params": {
34 | "warmup_min_lr": "auto",
35 | "warmup_max_lr": "auto",
36 | "warmup_num_steps": "auto",
37 | "warmup_type": "linear",
38 | "total_num_steps": "auto"
39 | }
40 | },
41 | "gradient_accumulation_steps": "auto",
42 | "train_batch_size": "auto",
43 | "train_micro_batch_size_per_gpu": "auto",
44 | "wall_clock_breakdown": false
45 | }
46 |
--------------------------------------------------------------------------------
/training/axolotl/docs/multipack.md:
--------------------------------------------------------------------------------
1 | # Multipack
2 |
3 | 4k context, bsz =4,
4 | each character represents 256 tokens
5 | X represents a padding token
6 |
7 | ```
8 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
9 | [[ A A A A A A A A A A A ]
10 | B B B B B B ]
11 | C C C C C C C ]
12 | D D D D ]]
13 |
14 | [[ E E E E E E E E ]
15 | [ F F F F ]
16 | [ G G G ]
17 | [ H H H H ]]
18 |
19 | [[ I I I ]
20 | [ J J J ]
21 | [ K K K K K]
22 | [ L L L ]]
23 | ```
24 |
25 | after padding to longest input in each step
26 | ```
27 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
28 | [[ A A A A A A A A A A A ]
29 | B B B B B B X X X X X X ]
30 | C C C C C C C X X X X ]
31 | D D D D X X X X X X X ]]
32 |
33 | [[ E E E E E E E E ]
34 | [ F F F F X X X X ]
35 | [ G G G X X X X X ]
36 | [ H H H H X X X X ]]
37 |
38 | [[ I I I X X ]
39 | [ J J J X X ]
40 | [ K K K K K ]
41 | [ L L L X X ]]
42 | ```
43 |
44 | w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
45 | ```
46 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
47 | [[ A A A A A A A A A A A B B B B B
48 | B C C C C C C C D D D D E E E E
49 | E E E E F F F F F G G G H H H H
50 | I I I J J J J K K K K K L L L X ]]
51 | ```
52 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/cli/shard.py:
--------------------------------------------------------------------------------
1 | """
2 | CLI to shard a trained model into 10GiB chunks
3 | """
4 | import logging
5 | from pathlib import Path
6 |
7 | import fire
8 | import transformers
9 |
10 | from axolotl.cli import load_cfg, print_axolotl_text_art
11 | from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
12 | from axolotl.utils.dict import DictDefault
13 |
14 | LOG = logging.getLogger("axolotl.scripts")
15 |
16 |
17 | def shard(
18 | *,
19 | cfg: DictDefault,
20 | cli_args: TrainerCliArgs,
21 | ):
22 | model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
23 | safe_serialization = cfg.save_safetensors is True
24 | LOG.debug("Re-saving model w/ sharding")
25 | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
26 |
27 |
28 | def do_cli(config: Path = Path("examples/"), **kwargs):
29 | # pylint: disable=duplicate-code
30 | print_axolotl_text_art()
31 | parsed_cfg = load_cfg(config, **kwargs)
32 | parser = transformers.HfArgumentParser((TrainerCliArgs))
33 | parsed_cli_args, _ = parser.parse_args_into_dataclasses(
34 | return_remaining_strings=True
35 | )
36 | parsed_cli_args.shard = True
37 |
38 | shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
39 |
40 |
41 | if __name__ == "__main__":
42 | fire.Fire(do_cli)
43 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/common/cli.py:
--------------------------------------------------------------------------------
1 | """
2 | shared module for cli specific things
3 | """
4 |
5 | import logging
6 | from dataclasses import dataclass, field
7 | from typing import Optional
8 |
9 | from axolotl.logging_config import configure_logging
10 | from axolotl.utils.dict import DictDefault
11 | from axolotl.utils.models import load_model, load_tokenizer
12 |
13 | configure_logging()
14 | LOG = logging.getLogger("axolotl.common.cli")
15 |
16 |
17 | @dataclass
18 | class TrainerCliArgs:
19 | """
20 | dataclass representing the various non-training arguments
21 | """
22 |
23 | debug: bool = field(default=False)
24 | debug_text_only: bool = field(default=False)
25 | debug_num_examples: int = field(default=1)
26 | inference: bool = field(default=False)
27 | merge_lora: bool = field(default=False)
28 | prepare_ds_only: bool = field(default=False)
29 | prompter: Optional[str] = field(default=None)
30 | shard: bool = field(default=False)
31 |
32 |
33 | def load_model_and_tokenizer(
34 | *,
35 | cfg: DictDefault,
36 | cli_args: TrainerCliArgs,
37 | ):
38 | LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
39 | tokenizer = load_tokenizer(cfg)
40 | LOG.info("loading model and (optionally) peft_config...")
41 | model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
42 |
43 | return model, tokenizer
44 |
--------------------------------------------------------------------------------
/training/axolotl/deepspeed/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "offload_optimizer": {
5 | "device": "cpu",
6 | "pin_memory": true
7 | },
8 | "offload_param": {
9 | "device": "cpu",
10 | "pin_memory": true
11 | },
12 | "overlap_comm": true,
13 | "contiguous_gradients": true,
14 | "sub_group_size": 0,
15 | "reduce_bucket_size": "auto",
16 | "stage3_prefetch_bucket_size": "auto",
17 | "stage3_param_persistence_threshold": "auto",
18 | "stage3_max_live_parameters": 0,
19 | "stage3_max_reuse_distance": 0,
20 | "stage3_gather_16bit_weights_on_model_save": true
21 | },
22 | "bf16": {
23 | "enabled": "auto"
24 | },
25 | "fp16": {
26 | "enabled": "auto",
27 | "auto_cast": false,
28 | "loss_scale": 0,
29 | "initial_scale_power": 32,
30 | "loss_scale_window": 1000,
31 | "hysteresis": 2,
32 | "min_loss_scale": 1
33 | },
34 | "optimizer": {
35 | "type": "AdamW",
36 | "params": {
37 | "lr": "auto",
38 | "betas": "auto",
39 | "eps": "auto",
40 | "weight_decay": "auto"
41 | }
42 | },
43 | "scheduler": {
44 | "type": "WarmupLR",
45 | "params": {
46 | "warmup_min_lr": "auto",
47 | "warmup_max_lr": "auto",
48 | "warmup_num_steps": "auto",
49 | "warmup_type": "linear"
50 | }
51 | },
52 | "gradient_accumulation_steps": "auto",
53 | "train_batch_size": "auto",
54 | "train_micro_batch_size_per_gpu": "auto",
55 | "wall_clock_breakdown": false
56 | }
57 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/llama_embeddings_hijack.py:
--------------------------------------------------------------------------------
1 | """
2 | patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
3 | """
4 |
5 | import torch
6 | import transformers.models.llama.modeling_llama
7 | from transformers.utils import logging
8 |
9 | logger = logging.get_logger(__name__)
10 |
11 |
12 | def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
13 | # pylint: disable=duplicate-code
14 | def noised_embed(orig_embed, noise_alpha, model):
15 | def new_func(input_ids):
16 | # during training, we add noise to the embedding
17 | # during generation, we don't add noise to the embedding
18 | if model.training:
19 | embed_init = orig_embed(input_ids)
20 | dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
21 | mag_norm = noise_alpha / torch.sqrt(dims)
22 | return embed_init + torch.zeros_like(embed_init).uniform_(
23 | -mag_norm, mag_norm
24 | )
25 | return orig_embed(input_ids)
26 |
27 | return new_func
28 |
29 | def post_init(orig_post_init):
30 | def new_func(self):
31 | orig_post_init(self)
32 | self.embed_tokens.forward = noised_embed(
33 | self.embed_tokens.forward, noise_alpha, self
34 | )
35 |
36 | return new_func
37 |
38 | transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
39 | transformers.models.llama.modeling_llama.LlamaModel.post_init
40 | )
41 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/mistral_embeddings_hijack.py:
--------------------------------------------------------------------------------
1 | """
2 | patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
3 | """
4 |
5 | import torch
6 | import transformers.models.mistral.modeling_mistral
7 | from transformers.utils import logging
8 |
9 | logger = logging.get_logger(__name__)
10 |
11 |
12 | def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
13 | # pylint: disable=duplicate-code
14 | def noised_embed(orig_embed, noise_alpha, model):
15 | def new_func(input_ids):
16 | # during training, we add noise to the embedding
17 | # during generation, we don't add noise to the embedding
18 | if model.training:
19 | embed_init = orig_embed(input_ids)
20 | dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
21 | mag_norm = noise_alpha / torch.sqrt(dims)
22 | return embed_init + torch.zeros_like(embed_init).uniform_(
23 | -mag_norm, mag_norm
24 | )
25 | return orig_embed(input_ids)
26 |
27 | return new_func
28 |
29 | def post_init(orig_post_init):
30 | def new_func(self):
31 | orig_post_init(self)
32 | self.embed_tokens.forward = noised_embed(
33 | self.embed_tokens.forward, noise_alpha, self
34 | )
35 |
36 | return new_func
37 |
38 | transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
39 | transformers.models.mistral.modeling_mistral.MistralModel.post_init
40 | )
41 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/tokenization.py:
--------------------------------------------------------------------------------
1 | """Module for tokenization utilities"""
2 |
3 |
4 | import logging
5 |
6 | from termcolor import colored
7 |
8 | LOG = logging.getLogger("axolotl")
9 |
10 |
11 | def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
12 | # the dataset is already shuffled, so let's just check the first 5 elements
13 | for idx in range(num_examples):
14 | check_example_labels(dataset[idx], tokenizer, text_only=text_only)
15 |
16 |
17 | def check_example_labels(example, tokenizer, text_only=False):
18 | # Get the input_ids, labels, and attention_mask from the dataset
19 | input_ids = example["input_ids"]
20 | labels = example["labels"]
21 |
22 | # You can compare the input_ids and labels element-wise
23 | # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
24 | colored_tokens = []
25 | for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
26 | decoded_input_token = tokenizer.decode(input_id)
27 | # Choose the color based on whether the label has the ignore value or not
28 | color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
29 | colored_token = colored(decoded_input_token, color) + (
30 | not text_only and colored(f"({label_id}, {input_id})", "white") or ""
31 | )
32 | colored_tokens.append(colored_token)
33 |
34 | delimiter = "" if text_only else " "
35 | LOG.info(delimiter.join(colored_tokens))
36 | LOG.info("\n\n\n")
37 |
38 | return " ".join(colored_tokens)
39 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/cli/train.py:
--------------------------------------------------------------------------------
1 | """
2 | CLI to run training on a model
3 | """
4 | import logging
5 | from pathlib import Path
6 |
7 | import fire
8 | import transformers
9 | from colorama import Fore
10 |
11 | from axolotl.cli import (
12 | check_accelerate_default_config,
13 | check_user_token,
14 | load_cfg,
15 | load_datasets,
16 | print_axolotl_text_art,
17 | )
18 | from axolotl.common.cli import TrainerCliArgs
19 | from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
20 | from axolotl.train import train
21 |
22 | LOG = logging.getLogger("axolotl.cli.train")
23 |
24 |
25 | def do_cli(config: Path = Path("examples/"), **kwargs):
26 | # pylint: disable=duplicate-code
27 | print_axolotl_text_art()
28 | parsed_cfg = load_cfg(config, **kwargs)
29 | check_accelerate_default_config()
30 | check_user_token()
31 | parser = transformers.HfArgumentParser((TrainerCliArgs))
32 | parsed_cli_args, _ = parser.parse_args_into_dataclasses(
33 | return_remaining_strings=True
34 | )
35 | if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
36 | msg = (
37 | Fore.RED
38 | + "--prepare_ds_only called without dataset_prepared_path set."
39 | + Fore.RESET
40 | )
41 | LOG.warning(msg)
42 | parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
43 |
44 | dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
45 | if parsed_cli_args.prepare_ds_only:
46 | return
47 | train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
48 |
49 |
50 | if __name__ == "__main__":
51 | fire.Fire(do_cli)
52 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/orcamini.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt Strategy for finetuning Orca Mini (v2) models
3 | see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information
4 |
5 | Use dataset type: orcamini in conig.yml to use this prompt style.
6 |
7 | Compared to the alpaca_w_system.open_orca dataset type,
8 | this one specifies the system prompt with "### System:".
9 |
10 | Not suited/tested for multiple-turn conversations without further adjustments.
11 | """
12 | from typing import Generator, Union
13 |
14 | from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
15 | from axolotl.prompters import AlpacaPrompter
16 |
17 |
18 | class OrcaMiniPrompter(AlpacaPrompter):
19 | """Adjusted Prompter for Orca Mini (v2) datasets"""
20 |
21 | def match_prompt_style(self):
22 | self.turn_no_input_format = (
23 | "### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n"
24 | )
25 |
26 | def build_prompt_w_system(
27 | self,
28 | system: str,
29 | instruction: str,
30 | output: Union[None, str] = None,
31 | ) -> Generator[str, None, None]:
32 | # returns the full prompt from instruction and optional input
33 | # if a label (=response, =output) is provided, it's also appended.
34 | res = self.turn_no_input_format.format(system=system, instruction=instruction)
35 | if output:
36 | res = f"{res}{output}"
37 | yield res
38 |
39 |
40 | def load(tokenizer, cfg):
41 | return OpenOrcaPromptTokenizingStrategy(
42 | OrcaMiniPrompter(),
43 | tokenizer,
44 | cfg.train_on_inputs,
45 | cfg.sequence_len,
46 | )
47 |
--------------------------------------------------------------------------------
/training/axolotl/docs/multi-node.md:
--------------------------------------------------------------------------------
1 | # Multi Node
2 |
3 | You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
4 |
5 | ~/.cache/huggingface/accelerate/default_config.yaml
6 | ```yaml
7 | compute_environment: LOCAL_MACHINE
8 | debug: false
9 | distributed_type: FSDP
10 | downcast_bf16: 'no'
11 | machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
12 | main_process_ip: 10.0.0.4 # Set to main machine's IP
13 | main_process_port: 5000
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 2 # Change to the number of machines
17 | num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
24 | ```
25 |
26 | Configure your model to use FSDP with for example:
27 | ```yaml
28 | fsdp:
29 | - full_shard
30 | - auto_wrap
31 | fsdp_config:
32 | fsdp_offload_params: true
33 | fsdp_state_dict_type: FULL_STATE_DICT
34 | fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
35 | ```
36 |
37 | ## Machine configuration
38 |
39 | On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
40 |
41 | You will also need to have the same configuration file for your model on each machine.
42 |
43 | On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
44 |
45 | All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
46 |
--------------------------------------------------------------------------------
/training/axolotl/scripts/finetune.py:
--------------------------------------------------------------------------------
1 | """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2 | import logging
3 | from pathlib import Path
4 |
5 | import fire
6 | import transformers
7 |
8 | from axolotl.cli import (
9 | check_accelerate_default_config,
10 | check_user_token,
11 | do_inference,
12 | do_merge_lora,
13 | load_cfg,
14 | load_datasets,
15 | print_axolotl_text_art,
16 | )
17 | from axolotl.cli.shard import shard
18 | from axolotl.common.cli import TrainerCliArgs
19 | from axolotl.train import train
20 |
21 | LOG = logging.getLogger("axolotl.scripts.finetune")
22 |
23 |
24 | def do_cli(config: Path = Path("examples/"), **kwargs):
25 | print_axolotl_text_art()
26 | LOG.warning(
27 | str(
28 | PendingDeprecationWarning(
29 | "scripts/finetune.py will be replaced with calling axolotl.cli.train"
30 | )
31 | )
32 | )
33 | parsed_cfg = load_cfg(config, **kwargs)
34 | check_accelerate_default_config()
35 | check_user_token()
36 | parser = transformers.HfArgumentParser((TrainerCliArgs))
37 | parsed_cli_args, _ = parser.parse_args_into_dataclasses(
38 | return_remaining_strings=True
39 | )
40 | if parsed_cli_args.inference:
41 | do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
42 | elif parsed_cli_args.merge_lora:
43 | do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
44 | elif parsed_cli_args.shard:
45 | shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
46 | else:
47 | dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
48 | if parsed_cli_args.prepare_ds_only:
49 | return
50 | train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
51 |
52 |
53 | if __name__ == "__main__":
54 | fire.Fire(do_cli)
55 |
--------------------------------------------------------------------------------
/training/axolotl/setup.py:
--------------------------------------------------------------------------------
1 | """setup.py for axolotl"""
2 |
3 | from setuptools import find_packages, setup
4 |
5 |
6 | def parse_requirements():
7 | _install_requires = []
8 | _dependency_links = []
9 | with open("./requirements.txt", encoding="utf-8") as requirements_file:
10 | lines = [r.strip() for r in requirements_file.readlines()]
11 | for line in lines:
12 | if line.startswith("--extra-index-url"):
13 | # Handle custom index URLs
14 | _, url = line.split()
15 | _dependency_links.append(url)
16 | elif (
17 | "flash-attn" not in line
18 | and "deepspeed" not in line
19 | and line
20 | and line[0] != "#"
21 | ):
22 | # Handle standard packages
23 | _install_requires.append(line)
24 |
25 | # TODO(wing) remove once xformers release supports torch 2.1.0
26 | if "torch==2.1.0" in _install_requires:
27 | _install_requires.pop(_install_requires.index("xformers>=0.0.22"))
28 | _install_requires.append(
29 | "xformers @ git+https://github.com/facebookresearch/xformers.git@main"
30 | )
31 |
32 | return _install_requires, _dependency_links
33 |
34 |
35 | install_requires, dependency_links = parse_requirements()
36 |
37 |
38 | setup(
39 | name="axolotl",
40 | version="0.3.0",
41 | description="LLM Trainer",
42 | long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
43 | package_dir={"": "src"},
44 | packages=find_packages(),
45 | install_requires=install_requires,
46 | dependency_links=dependency_links,
47 | extras_require={
48 | "flash-attn": [
49 | "flash-attn>=2.3.0",
50 | ],
51 | "deepspeed": [
52 | "deepspeed",
53 | ],
54 | },
55 | )
56 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/llama_expand_mask.py:
--------------------------------------------------------------------------------
1 | """
2 | expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
3 | """
4 | from typing import Optional
5 |
6 | import torch
7 |
8 |
9 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
10 | """
11 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
12 | This expansion handles packed sequences so that sequences share the same attention mask integer value
13 | when they attend to each other within that sequence.
14 | This expansion transforms the mask to lower triangular form to prevent future peeking.
15 | """
16 | bsz, src_len = mask.size()
17 | tgt_len = tgt_len if tgt_len is not None else src_len
18 |
19 | mask = mask.unsqueeze(1).unsqueeze(2)
20 | mask = mask.expand(bsz, 1, tgt_len, src_len)
21 |
22 | # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
23 | binary_mask = torch.where(
24 | mask != 0,
25 | torch.tensor(1).to(dtype),
26 | torch.tensor(0).to(dtype),
27 | )
28 |
29 | # Create a block-diagonal mask.
30 | # we multiply by the binary mask so that 0's in the original mask are correctly excluded
31 | zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
32 |
33 | # Now let's create a lower triangular mask of ones that will zero out the upper triangular part
34 | lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
35 | mask.device
36 | )
37 |
38 | # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
39 | masked_zero_one_mask = zero_one_mask * lower_triangular_ones
40 | inverted_mask = 1.0 - masked_zero_one_mask
41 |
42 | return inverted_mask.masked_fill(
43 | inverted_mask.to(torch.bool), torch.finfo(dtype).min
44 | )
45 |
46 |
47 | def hijack_expand_mask():
48 | import transformers
49 |
50 | transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
51 | _expand_mask
52 | )
53 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/logging_config.py:
--------------------------------------------------------------------------------
1 | """
2 | Common logging module for axolotl
3 | """
4 |
5 | import os
6 | import sys
7 | from logging import Formatter
8 | from logging.config import dictConfig
9 | from typing import Any, Dict
10 |
11 | from colorama import Fore, Style, init
12 |
13 |
14 | class ColorfulFormatter(Formatter):
15 | """
16 | Formatter to add coloring to log messages by log type
17 | """
18 |
19 | COLORS = {
20 | "WARNING": Fore.YELLOW,
21 | "ERROR": Fore.RED,
22 | "CRITICAL": Fore.RED + Style.BRIGHT,
23 | }
24 |
25 | def format(self, record):
26 | record.rank = int(os.getenv("LOCAL_RANK", "0"))
27 | log_message = super().format(record)
28 | return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
29 |
30 |
31 | DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
32 | "version": 1,
33 | "formatters": {
34 | "simple": {
35 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
36 | },
37 | "colorful": {
38 | "()": ColorfulFormatter,
39 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
40 | },
41 | },
42 | "filters": {},
43 | "handlers": {
44 | "console": {
45 | "class": "logging.StreamHandler",
46 | "formatter": "simple",
47 | "filters": [],
48 | "stream": sys.stdout,
49 | },
50 | "color_console": {
51 | "class": "logging.StreamHandler",
52 | "formatter": "colorful",
53 | "filters": [],
54 | "stream": sys.stdout,
55 | },
56 | },
57 | "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
58 | "loggers": {
59 | "axolotl": {
60 | "handlers": ["color_console"],
61 | "level": "DEBUG",
62 | "propagate": False,
63 | },
64 | },
65 | }
66 |
67 |
68 | def configure_logging():
69 | """Configure with default logging"""
70 | init() # Initialize colorama
71 | dictConfig(DEFAULT_LOGGING_CONFIG)
72 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/convert.py:
--------------------------------------------------------------------------------
1 | """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
2 |
3 |
4 | import json
5 | import sys
6 |
7 |
8 | class FileReader:
9 | """
10 | Reads a file and returns its contents as a string
11 | """
12 |
13 | def read(self, file_path):
14 | with open(file_path, encoding="utf-8") as file:
15 | return file.read()
16 |
17 |
18 | class FileWriter:
19 | """
20 | Writes a string to a file
21 | """
22 |
23 | def __init__(self, file_path):
24 | self.file_path = file_path
25 |
26 | def write(self, content):
27 | with open(self.file_path, "w", encoding="utf-8") as file:
28 | file.write(content)
29 |
30 |
31 | class StdoutWriter:
32 | """
33 | Writes a string to stdout
34 | """
35 |
36 | def write(self, content):
37 | sys.stdout.write(content)
38 | sys.stdout.write("\n")
39 |
40 |
41 | class JsonParser:
42 | """
43 | Parses a string as JSON and returns the result
44 | """
45 |
46 | def parse(self, content):
47 | return json.loads(content)
48 |
49 |
50 | class JsonlSerializer:
51 | """
52 | Serializes a list of JSON objects into a JSONL string
53 | """
54 |
55 | def serialize(self, data):
56 | lines = [json.dumps(item) for item in data]
57 | return "\n".join(lines)
58 |
59 |
60 | class JsonToJsonlConverter:
61 | """
62 | Converts a JSON file to JSONL
63 | """
64 |
65 | def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
66 | self.file_reader = file_reader
67 | self.file_writer = file_writer
68 | self.json_parser = json_parser
69 | self.jsonl_serializer = jsonl_serializer
70 |
71 | def convert(
72 | self, input_file_path, output_file_path
73 | ): # pylint: disable=unused-argument
74 | content = self.file_reader.read(input_file_path)
75 | data = self.json_parser.parse(content)
76 | # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
77 | jsonl_content = self.jsonl_serializer.serialize(data)
78 | self.file_writer.write(jsonl_content)
79 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/bench.py:
--------------------------------------------------------------------------------
1 | """Benchmarking and measurement utilities"""
2 | import functools
3 |
4 | import pynvml
5 | import torch
6 | from pynvml.nvml import NVMLError
7 |
8 |
9 | def check_cuda_device(default_value):
10 | """
11 | wraps a function and returns the default value instead of running the
12 | wrapped function if cuda isn't available or the device is auto
13 | :param default_value:
14 | :return:
15 | """
16 |
17 | def deco(func):
18 | @functools.wraps(func)
19 | def wrapper(*args, **kwargs):
20 | device = kwargs.get("device", args[0] if args else None)
21 |
22 | if (
23 | not torch.cuda.is_available()
24 | or device == "auto"
25 | or torch.device(device).type == "cpu"
26 | ):
27 | return default_value
28 |
29 | return func(*args, **kwargs)
30 |
31 | return wrapper
32 |
33 | return deco
34 |
35 |
36 | @check_cuda_device(0.0)
37 | def gpu_memory_usage(device=0):
38 | return torch.cuda.memory_allocated(device) / 1024.0**3
39 |
40 |
41 | @check_cuda_device((0.0, 0.0, 0.0))
42 | def gpu_memory_usage_all(device=0):
43 | usage = torch.cuda.memory_allocated(device) / 1024.0**3
44 | reserved = torch.cuda.memory_reserved(device) / 1024.0**3
45 | smi = gpu_memory_usage_smi(device)
46 | return usage, reserved - usage, max(0, smi - reserved)
47 |
48 |
49 | @check_cuda_device(0.0)
50 | def gpu_memory_usage_smi(device=0):
51 | if isinstance(device, torch.device):
52 | device = device.index
53 | if isinstance(device, str) and device.startswith("cuda:"):
54 | device = int(device[5:])
55 | try:
56 | pynvml.nvmlInit()
57 | handle = pynvml.nvmlDeviceGetHandleByIndex(device)
58 | info = pynvml.nvmlDeviceGetMemoryInfo(handle)
59 | return info.used / 1024.0**3
60 | except NVMLError:
61 | return 0.0
62 |
63 |
64 | def log_gpu_memory_usage(log, msg, device):
65 | usage, cache, misc = gpu_memory_usage_all(device)
66 | extras = []
67 | if cache > 0:
68 | extras.append(f"+{cache:.03f}GB cache")
69 | if misc > 0:
70 | extras.append(f"+{misc:.03f}GB misc")
71 | log.info(
72 | f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
73 | )
74 | return usage, cache, misc
75 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/models/phi/configuration_mixformer_sequential.py:
--------------------------------------------------------------------------------
1 | # pylint: skip-file
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Licensed under the MIT license.
5 |
6 | import math
7 | from typing import Any, Dict, List, Optional, Union
8 |
9 | from transformers import PretrainedConfig
10 |
11 |
12 | class MixFormerSequentialConfig(PretrainedConfig):
13 | """MixFormer (sequential for DeepSpeed) configuration."""
14 |
15 | model_type = "mixformer-sequential"
16 |
17 | attribute_map = {
18 | "max_position_embeddings": "n_positions",
19 | "hidden_size": "n_embd",
20 | "num_attention_heads": "n_head",
21 | "num_hidden_layers": "n_layer",
22 | "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
23 | "blocks": "architecture", # `blocks` key is for backward compatibility
24 | }
25 |
26 | def __init__(
27 | self,
28 | vocab_size: Optional[int] = 50304,
29 | n_positions: Optional[int] = 2048,
30 | n_embd: Optional[int] = 1024,
31 | n_layer: Optional[int] = 20,
32 | n_inner: Optional[int] = None,
33 | n_head: Optional[int] = 16,
34 | rotary_dim: Optional[int] = 32,
35 | activation_function: Optional[str] = "gelu_new",
36 | embd_layer: Optional[str] = "default",
37 | architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
38 | embd_pdrop: Optional[float] = 0.0,
39 | resid_pdrop: Optional[float] = 0.0,
40 | layer_norm_epsilon: Optional[float] = 1e-5,
41 | initializer_range: Optional[float] = 0.02,
42 | tie_word_embeddings: Optional[bool] = False,
43 | pad_vocab_size_multiple: Optional[int] = 64,
44 | **kwargs
45 | ) -> None:
46 | self.vocab_size = int(
47 | math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
48 | )
49 | self.n_positions = n_positions
50 | self.n_embd = n_embd
51 | self.n_layer = n_layer
52 | self.n_inner = n_inner
53 | self.n_head = n_head
54 | self.rotary_dim = min(rotary_dim, n_embd // n_head)
55 | self.activation_function = activation_function
56 | self.embd_layer = embd_layer
57 | self.architecture = architecture
58 | self.embd_pdrop = embd_pdrop
59 | self.resid_pdrop = resid_pdrop
60 | self.layer_norm_epsilon = layer_norm_epsilon
61 | self.initializer_range = initializer_range
62 |
63 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
64 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py:
--------------------------------------------------------------------------------
1 | """
2 | Flash attention monkey patch for cerebras btlm model
3 | """
4 |
5 | import importlib
6 | import logging
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | from accelerate import init_empty_weights
11 | from flash_attn.flash_attn_interface import flash_attn_func
12 | from transformers import AutoConfig, AutoModelForCausalLM
13 |
14 | LOG = logging.getLogger("axolotl")
15 |
16 |
17 | def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
18 | # this is a wonky hack to get the remotely loaded module
19 | model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
20 | # we need to load the model here in order for modeling_btlm to be available
21 | with init_empty_weights():
22 | AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
23 | module_name = model_config.__class__.__module__.replace(
24 | ".configuration_btlm", ".modeling_btlm"
25 | )
26 | modeling_btlm = importlib.import_module(module_name)
27 | modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
28 | flashattn_attn
29 | )
30 |
31 |
32 | def flashattn_attn(
33 | self,
34 | query: torch.Tensor,
35 | key: Optional[torch.Tensor] = None,
36 | value: Optional[torch.Tensor] = None,
37 | attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
38 | head_mask: Optional[torch.Tensor] = None,
39 | position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
40 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
41 | softmax_scale = (
42 | 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
43 | )
44 |
45 | query = query.permute(0, 2, 1, 3)
46 | key = key.permute(0, 2, 1, 3)
47 | value = value.permute(0, 2, 1, 3)
48 |
49 | # Perform Flash attention
50 | attn_output = flash_attn_func(
51 | query,
52 | key,
53 | value,
54 | dropout_p=0.0, # Assuming you have this attribute
55 | softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
56 | causal=not self.is_cross_attention, # Assuming you have this attribute
57 | return_attn_probs=False, # Set this based on your needs
58 | )
59 |
60 | # Optional: Apply head mask if it's not None
61 | if head_mask is not None:
62 | attn_output *= head_mask
63 |
64 | attn_output = attn_output.permute(0, 2, 1, 3)
65 |
66 | return attn_output, None # We don't have explicit attn_weights in Flash attention
67 |
--------------------------------------------------------------------------------
/training/axolotl/examples/mistral/nips/nips_02.yml:
--------------------------------------------------------------------------------
1 | base_model: mistralai/Mistral-7B-v0.1
2 | base_model_config: mistralai/Mistral-7B-v0.1
3 | model_type: MistralForCausalLM
4 | tokenizer_type: LlamaTokenizer
5 | is_mistral_derived_model: true
6 |
7 | load_in_8bit: false
8 | load_in_4bit: true
9 | strict: false
10 |
11 | datasets:
12 | - path: upaya07/NeurIPS-LLM-data
13 | type: alpaca_chat.load_qa
14 |
15 | random_split: False
16 | instances_for_test_metric: 2000 # Used only if random_split = False. Limits number of test samples to use for evaluation.
17 | val_set_size: 0.1 # used when random_split = True. How much % from train to use for building test data.
18 |
19 | # ============ Change these paths =======================
20 | dataset_prepared_path: ./NeurIPS-LLM-data
21 | output_dir: ./qlora-NeurIPS-LLM-model
22 |
23 | wandb_project: Mistral-7B-ft-NeurIPS
24 | wandb_run_id: NeurIPS_model_01
25 |
26 | sequence_len: 8192
27 | sample_packing: true
28 | eval_sample_packing: true
29 | pad_to_sequence_len: true
30 |
31 |
32 | adapter: qlora
33 |
34 | lora_r: 128
35 | lora_alpha: 256
36 | lora_dropout: 0.05
37 | lora_target_linear: true
38 | lora_fan_in_fan_out:
39 | lora_target_modules:
40 | - q_proj
41 | - v_proj
42 | - k_proj
43 | - o_proj
44 | - gate_proj
45 | - down_proj
46 | - up_proj
47 | # - lm_head
48 |
49 | #lora_target_modules:
50 | # - gate_proj
51 | # - down_proj
52 | # - up_proj
53 | #lora_target_modules:
54 | # - q_proj
55 | # - v_proj
56 |
57 | noisy_embedding_alpha: 5
58 | gradient_accumulation_steps: 4
59 | micro_batch_size: 2
60 | num_epochs: 3
61 | #max_steps: 32
62 | optimizer: paged_adamw_32bit
63 | #optimizer: adamw_torch
64 | lr_scheduler: cosine
65 | learning_rate: 0.00002
66 |
67 | train_on_inputs: false
68 | group_by_length: false
69 | bf16: true
70 | fp16: false
71 | tf32: false
72 |
73 | bfloat16: true
74 |
75 | gradient_checkpointing: true
76 | early_stopping_patience: 5000000
77 | local_rank:
78 | logging_steps: 1
79 | xformers_attention:
80 | flash_attention: true
81 |
82 | warmup_steps: 100
83 | eval_steps: 64
84 | #eval_table_size: 4
85 | #eval_table_max_new_tokens: 128
86 | save_steps: 64
87 | save_total_limit: 10
88 | metric_for_best_model: "eval_loss"
89 | greater_is_better: False
90 | load_best_model_at_end: True
91 | debug: true
92 | debug_num_examples: 1
93 |
94 | weight_decay: 0.01
95 | #
96 | max_grad_norm: 0.3
97 |
98 | # deepspeed:
99 | fsdp:
100 | fsdp_config:
101 | special_tokens:
102 | bos_token: ""
103 | eos_token: ""
104 | unk_token: ""
105 |
106 | #seed: 57
107 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/metharme.py:
--------------------------------------------------------------------------------
1 | """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
2 |
3 | import logging
4 | from typing import Tuple
5 |
6 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
7 | from axolotl.prompters import AlpacaPrompter
8 |
9 | LOG = logging.getLogger("axolotl")
10 |
11 | IGNORE_TOKEN_ID = -100
12 |
13 | # pylint: disable=duplicate-code
14 |
15 |
16 | class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
17 | """
18 | Tokenizing strategy for the Metharme models
19 | """
20 |
21 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
22 | return (prompt["prompt"], "", prompt["generation"])
23 |
24 | def _tokenize(
25 | self,
26 | prompt: str,
27 | add_eos_token: bool = True,
28 | strip_bos_token: bool = False,
29 | num_eos_tokens: int = 3,
30 | ):
31 | result = self.tokenizer(
32 | prompt,
33 | truncation=True,
34 | max_length=self.sequence_len,
35 | padding=False,
36 | return_tensors=None,
37 | )
38 | if len(result["input_ids"]) == 0:
39 | LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
40 | # If there's already an EOS token there, subtract from the number added
41 | if result["input_ids"][-1] == self.tokenizer.eos_token_id:
42 | num_eos_tokens -= 1
43 |
44 | if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
45 | for _ in range(num_eos_tokens):
46 | if len(result["input_ids"]) < self.sequence_len:
47 | result["input_ids"].append(self.tokenizer.eos_token_id)
48 | result["attention_mask"].append(1)
49 |
50 | if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
51 | result["input_ids"] = result["input_ids"][1:]
52 | result["attention_mask"] = result["attention_mask"][1:]
53 |
54 | result["labels"] = result["input_ids"].copy()
55 | return result
56 |
57 |
58 | class MetharmePrompter(AlpacaPrompter):
59 | """
60 | Prompter for the Metharme models.
61 | """
62 |
63 | system_prompt = ""
64 | system_no_input_prompt = ""
65 | system_format = ""
66 | turn_format = "{instruction}"
67 | turn_no_input_format = "{instruction}"
68 |
69 | def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
70 | pass
71 |
72 |
73 | def load(tokenizer, cfg):
74 | return MetharmePromptTokenizingStrategy(
75 | MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
76 | )
77 |
--------------------------------------------------------------------------------
/data_prep/prepare_math_reasoning_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from datasets import load_dataset
4 | from difflib import SequenceMatcher
5 |
6 | lower_case_prob = 0.2
7 |
8 | def modify_input(question):
9 | num = random.randint(1, 10)
10 | if num <= 3:
11 | question = question.lower()
12 | return question
13 |
14 | def remove_hash(answer: str):
15 | if "####" in answer:
16 | return answer[:answer.rindex("####")].strip()
17 | return answer
18 |
19 | def format_response(answer: str, answer_identifier: str):
20 | answer_prefix_len = len(answer_identifier)
21 | if answer_identifier in answer:
22 | answer_prefix_start_idx = answer.index(answer_identifier)
23 | reasoning = remove_hash(answer[:answer_prefix_start_idx].strip())
24 |
25 | # ==== Enable it if we want to add "answer" as part of output
26 | answer = answer[answer_prefix_start_idx:].strip()
27 | assert len(answer) > 0
28 | # answer = "Answer: " + answer
29 | return f"{reasoning}\n{answer.strip()}"
30 | else:
31 | return answer
32 |
33 | new_math_recs = []
34 | valid_records = 0
35 |
36 | math_instruct_dataset = load_dataset("TIGER-Lab/MathInstruct", "train")
37 | valid_sources = ['data/CoT/gsm_train.json', 'data/CoT/aqua_rat.json', 'data/CoT/MATH_train.json']
38 | print(f"math_instruct_dataset size: {len(math_instruct_dataset['train'])}")
39 | for each in math_instruct_dataset["train"]:
40 |
41 | if each['source'] in valid_sources:
42 | output = {}
43 | output['instruction'] = ""
44 | output['input'] = modify_input(each['instruction']).strip()
45 | output['output'] = format_response(each['output'], "The answer is:").strip()
46 |
47 | new_math_recs.append(output)
48 |
49 | valid_records += 1
50 |
51 | print(valid_records)
52 |
53 | instruction_prefix = "### Instruction:\n"
54 | input_prefix = "### Input:\n"
55 | response_prefix = "### Response:\n"
56 | new_math_dataset = []
57 | for each in new_math_recs:
58 | if len(each['input'].split(" ")) <= 4 or len(each['output'].split(" ")) < 1:
59 | continue
60 |
61 | rec = {}
62 | rec['question'] = ''
63 | if len(each['instruction'].strip()) > 0:
64 | rec['question'] += f"{instruction_prefix}{each['instruction']}\n\n"
65 | rec['question'] += f"{input_prefix}{each['input']}\n\n"
66 | rec['question'] += f"{response_prefix}"
67 | rec['answer'] = f"{each['output']}"
68 | new_math_dataset.append(rec)
69 |
70 | print(f"new_math_dataset size: {len(new_math_dataset)}")
71 | random.shuffle(new_math_dataset)
72 | sampled_dataset = new_math_dataset[:50000]
73 |
74 | with open('/Users/ajindal/Downloads/math_reasoning.json', 'w') as f:
75 | json.dump(sampled_dataset, f, indent=1)
76 |
--------------------------------------------------------------------------------
/training/axolotl/docker/Dockerfile-base:
--------------------------------------------------------------------------------
1 | ARG CUDA_VERSION="11.8.0"
2 | ARG CUDNN_VERSION="8"
3 | ARG UBUNTU_VERSION="22.04"
4 | ARG MAX_JOBS=4
5 |
6 | FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
7 |
8 | ENV PATH="/root/miniconda3/bin:${PATH}"
9 |
10 | ARG PYTHON_VERSION="3.9"
11 | ARG PYTORCH_VERSION="2.0.1"
12 | ARG CUDA="118"
13 |
14 | ENV PYTHON_VERSION=$PYTHON_VERSION
15 |
16 | RUN apt-get update \
17 | && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
18 | && wget \
19 | https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
20 | && mkdir /root/.conda \
21 | && bash Miniconda3-latest-Linux-x86_64.sh -b \
22 | && rm -f Miniconda3-latest-Linux-x86_64.sh \
23 | && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
24 |
25 | ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
26 |
27 | WORKDIR /workspace
28 |
29 | RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
30 | python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
31 |
32 | FROM base-builder AS deepspeed-builder
33 |
34 | ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
35 |
36 | WORKDIR /workspace
37 |
38 | RUN git clone https://github.com/microsoft/DeepSpeed.git && \
39 | cd DeepSpeed && \
40 | MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
41 |
42 | FROM base-builder AS bnb-builder
43 |
44 | WORKDIR /workspace
45 | ARG CUDA="118"
46 | ENV CUDA=$CUDA
47 | ARG MAX_JOBS="-1"
48 | ENV MAX_JOBS=$MAX_JOBS
49 |
50 | RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
51 | cd bitsandbytes && \
52 | CUDA_VERSION=$CUDA make cuda11x && \
53 | python setup.py bdist_wheel
54 |
55 | FROM base-builder
56 |
57 | ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
58 | ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
59 |
60 | RUN mkdir -p /workspace/builds
61 | COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
62 |
63 | RUN mkdir -p /workspace/wheels/bitsandbytes
64 | COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
65 | COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
66 | COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
67 |
68 | RUN pip3 install wheels/deepspeed-*.whl
69 | RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
70 | RUN git lfs install --skip-repo
71 | RUN pip3 install awscli && \
72 | # The base image ships with `pydantic==1.8.2` which is not working
73 | pip3 install -U --no-cache-dir pydantic==1.10.10
74 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/completion.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic completion text
3 | """
4 | from collections import defaultdict
5 | from typing import Any, Dict, Generator, Optional, Tuple
6 |
7 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
8 |
9 |
10 | class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
11 | """
12 | Tokenizing strategy for Completion prompts.
13 | """
14 |
15 | _field: str = "text"
16 |
17 | def __init__(self, *args, max_length=None, **kwargs):
18 | super().__init__(*args, **kwargs)
19 | if max_length is not None:
20 | self.max_length = max_length
21 |
22 | @property
23 | def supports_batched(self):
24 | return True
25 |
26 | @property
27 | def field(self) -> str:
28 | return self._field
29 |
30 | @field.setter
31 | def field(self, new_field: str):
32 | self._field = new_field
33 |
34 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
35 | return (
36 | prompt[self.field],
37 | "",
38 | "",
39 | )
40 |
41 | def tokenize_prompt(self, prompt):
42 | res = defaultdict(lambda: [])
43 | feature_names = list(prompt.keys())
44 | for row in zip(*prompt.values()):
45 | prompt_row = dict(zip(feature_names, row))
46 | (
47 | instruction,
48 | _,
49 | _,
50 | ) = self.parse_instruction_fields(prompt_row)
51 |
52 | full_prompt = self._build_full_prompt(instruction, None, None)
53 | tokenized_full_prompt = self._tokenize(full_prompt)
54 |
55 | for key, val in tokenized_full_prompt.items():
56 | for i in range(0, len(val), self.sequence_len):
57 | res[key].append(val[i : i + self.sequence_len])
58 |
59 | return dict(res)
60 |
61 | def _build_full_prompt(
62 | self, instruction, input, response
63 | ): # pylint: disable=redefined-builtin
64 | return next(iter(self.prompter.build_prompt(instruction, input, response)))
65 |
66 |
67 | class CompletionPrompter:
68 | """
69 | Prompter for completion
70 | """
71 |
72 | def build_prompt(
73 | self,
74 | instruction: str,
75 | input=None, # pylint: disable=redefined-builtin, unused-argument
76 | output=None, # pylint: disable=unused-argument
77 | ) -> Generator[str, None, None]:
78 | yield instruction
79 |
80 |
81 | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
82 | strat = CompletionPromptTokenizingStrategy(
83 | CompletionPrompter(),
84 | tokenizer,
85 | cfg.train_on_inputs,
86 | cfg.sequence_len,
87 | max_length=cfg.sequence_len * 64,
88 | )
89 | if ds_cfg and "field" in ds_cfg:
90 | strat.field = ds_cfg["field"]
91 |
92 | return strat
93 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/user_defined.py:
--------------------------------------------------------------------------------
1 | """
2 | User Defined prompts with configuration from the YML config
3 | """
4 |
5 | from dataclasses import dataclass
6 | from functools import partial
7 | from typing import Optional, Tuple
8 |
9 | from axolotl.prompt_strategies.alpaca_w_system import (
10 | InstructionWSystemPromptTokenizingStrategy,
11 | SystemDataPrompter,
12 | )
13 |
14 |
15 | @dataclass
16 | class UserDefinedDatasetConfig:
17 | """
18 | dataclass configuration representing a userdefined dataset type
19 | """
20 |
21 | system_prompt: str = ""
22 | field_system: str = "system"
23 | field_instruction: str = "instruction"
24 | field_input: str = "input"
25 | field_output: str = "output"
26 | format: str = "{instruction} {input} "
27 | no_input_format: str = "{instruction} "
28 | system_format: str = "{system}"
29 |
30 | def __getitem__(self, item):
31 | return getattr(self, item)
32 |
33 |
34 | class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
35 | """
36 | Prompt Tokenization Strategy for user defined prompts
37 | """
38 |
39 |
40 | def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
41 | if not ds_cfg:
42 | raise ValueError("Missing dataset prompt configuration")
43 |
44 | system_prompt = ""
45 | if ds_cfg.system_prompt:
46 | system_prompt = ds_cfg.system_prompt
47 |
48 | def parse_instruction_fields(
49 | field_instruction,
50 | field_input,
51 | field_output,
52 | field_system,
53 | system_prompt,
54 | prompt,
55 | ) -> Tuple[str, str, str, str]:
56 | return (
57 | prompt[field_instruction],
58 | prompt[field_input] if field_input in prompt else "",
59 | prompt[field_output] if field_output in prompt else "",
60 | prompt[field_system] if field_system in prompt else system_prompt,
61 | )
62 |
63 | turn_format = ds_cfg.format
64 | turn_no_input_format = ds_cfg.no_input_format
65 | system_format = ds_cfg.system_format
66 |
67 | class UserDefinedPrompter(SystemDataPrompter):
68 | """
69 | Prompter for user defined prompts
70 | """
71 |
72 | def match_prompt_style(self):
73 | self.turn_format = turn_format
74 | self.turn_no_input_format = turn_no_input_format
75 | self.system_format = system_format
76 |
77 | prompter = UserDefinedPrompter()
78 |
79 | strat = UserDefinedPromptTokenizationStrategy(
80 | prompter,
81 | tokenizer,
82 | cfg.train_on_inputs,
83 | cfg.sequence_len,
84 | )
85 |
86 | setattr(
87 | strat,
88 | "parse_instruction_fields",
89 | partial(
90 | parse_instruction_fields,
91 | ds_cfg.field_instruction,
92 | ds_cfg.field_input,
93 | ds_cfg.field_output,
94 | ds_cfg.field_system,
95 | system_prompt,
96 | ),
97 | )
98 | return strat
99 |
--------------------------------------------------------------------------------
/training/axolotl/docs/nccl.md:
--------------------------------------------------------------------------------
1 | # NCCL
2 |
3 | NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
4 |
5 | ```text
6 | Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
7 | ```
8 |
9 | Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
10 |
11 | Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
12 |
13 | ```shell
14 | nvidia-smi nvlink --status
15 | ```
16 |
17 | To force NCCL to use NVLink, simply set this in the environment:
18 |
19 | ```shell
20 | export NCCL_P2P_LEVEL=NVL
21 | ```
22 |
23 | If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
24 |
25 | | NCCL_P2P_LEVEL | Description |
26 | | -------------- | ----------- |
27 | | PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
28 | | PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
29 | | PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
30 |
31 | To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
32 |
33 | ```shell
34 | ./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
35 | ```
36 |
37 | It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
38 |
39 | ```shell
40 | export NCCL_DEBUG=INFO
41 | export NCCL_DEBUG_SUBSYS=ALL
42 | export TORCH_DISTRIBUTED_DEBUG=INFO
43 | export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
44 | ```
45 |
46 | Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.
47 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/sharegpt.py:
--------------------------------------------------------------------------------
1 | """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2 | from typing import Any, Dict, Optional
3 |
4 | from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
5 |
6 | from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
7 | from axolotl.prompters import ShareGPTPrompterV2
8 |
9 | register_conv_template(
10 | Conversation(
11 | name="chatml",
12 | system_template="<|im_start|>system\n{system_message}",
13 | system_message="You are a helpful assistant.",
14 | roles=["<|im_start|>user", "<|im_start|>assistant"],
15 | sep_style=SeparatorStyle.CHATML,
16 | sep="<|im_end|>\n",
17 | )
18 | )
19 |
20 |
21 | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
22 | conversation = (
23 | ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
24 | )
25 | field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
26 | field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
27 | return SimpleShareGPTPromptTokenizingStrategy(
28 | ShareGPTPrompterV2(
29 | conversation=conversation,
30 | role_key_model=field_model,
31 | role_key_human=field_human,
32 | ),
33 | tokenizer,
34 | cfg.train_on_inputs,
35 | cfg.sequence_len,
36 | )
37 |
38 |
39 | def load_role(tokenizer, cfg):
40 | return SimpleRoleShareGPTPromptTokenizingStrategy(
41 | ShareGPTPrompterV2(),
42 | tokenizer,
43 | cfg.train_on_inputs,
44 | cfg.sequence_len,
45 | )
46 |
47 |
48 | def load_guanaco(tokenizer, cfg):
49 | return GuanacoShareGPTPromptTokenizingStrategy(
50 | ShareGPTPrompterV2(),
51 | tokenizer,
52 | cfg.train_on_inputs,
53 | cfg.sequence_len,
54 | )
55 |
56 |
57 | class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
58 | """
59 | basic sharegpt strategy to grab conversations from the sample row
60 | """
61 |
62 | def get_conversation_thread(self, prompt):
63 | return prompt["conversations"]
64 |
65 |
66 | class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
67 | """
68 | basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
69 | """
70 |
71 | def get_conversation_thread(self, prompt):
72 | conversations = prompt["conversations"]
73 | # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
74 | turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
75 | return turns
76 |
77 |
78 | class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
79 | """
80 | sharegpt strategy that remaps oasst data to sharegpt format
81 | """
82 |
83 | def get_conversation_thread(self, prompt):
84 | conversations = prompt["conversations"]
85 | # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
86 | role_map = {"prompter": "human", "assistant": "gpt"}
87 | turns = [
88 | {"from": role_map[t["role"]], "value": t["text"]} for t in conversations
89 | ]
90 | return turns
91 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # pylint: skip-file
2 | """
3 | Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
4 | """
5 | import torch
6 | import transformers
7 | import transformers.models.llama.modeling_llama
8 | from einops import rearrange
9 |
10 |
11 | class XposRotaryEmbedding(torch.nn.Module):
12 | def __init__(
13 | self,
14 | dim,
15 | max_position_embeddings=2048,
16 | base=10000,
17 | device=None,
18 | scale_base=2048,
19 | use_xpos=True,
20 | ):
21 | super().__init__()
22 | self.max_seq_len_cached = max_position_embeddings
23 | self.scale_base = scale_base
24 |
25 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
26 | t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
27 | freqs = torch.einsum("i , j -> i j", t, inv_freq)
28 | freqs = torch.cat((freqs, freqs), dim=-1)
29 |
30 | self.register_buffer("inv_freq", inv_freq, persistent=False)
31 | self.register_buffer("freqs_cached", freqs, persistent=False)
32 |
33 | if not use_xpos:
34 | self.register_buffer("scale", None)
35 | self.register_buffer("scale_cached", torch.ones(1))
36 | return
37 |
38 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
39 | power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
40 | scale_cached = scale ** rearrange(power, "n -> n 1")
41 | scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
42 |
43 | self.register_buffer("scale", scale, persistent=False)
44 | self.register_buffer("scale_cached", scale_cached, persistent=False)
45 |
46 | def forward(
47 | self,
48 | x,
49 | seq_len,
50 | ):
51 | if seq_len > self.max_seq_len_cached:
52 | self.max_seq_len_cached = seq_len
53 | t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
54 | self.inv_freq
55 | )
56 | freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
57 | freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
58 |
59 | self.register_buffer("freqs_cached", freqs)
60 |
61 | if self.scale is None:
62 | self.register_buffer(
63 | "scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
64 | )
65 |
66 | return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
67 |
68 | power = (t - (seq_len // 2)) / self.scale_base
69 | scale = self.scale ** rearrange(power, "n -> n 1")
70 | scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
71 | self.register_buffer("scale_cached", scale)
72 |
73 | return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
74 |
75 |
76 | def rotate_half(x):
77 | x1, x2 = x.chunk(2, dim=-1)
78 | return torch.cat((-x2, x1), dim=-1)
79 |
80 |
81 | def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
82 | freqs = freqs[position_ids, :]
83 | if scale.shape[-1] != 1:
84 | scale = scale[position_ids, :]
85 |
86 | q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
87 | k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
88 |
89 | return q_embed, k_embed
90 |
91 |
92 | def replace_llama_rope_with_xpos_rope():
93 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
94 | transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
95 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/pygmalion.py:
--------------------------------------------------------------------------------
1 | """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
2 |
3 | import copy
4 | import logging
5 | from collections import defaultdict
6 | from typing import Generator, List, Tuple
7 |
8 | from axolotl.prompt_tokenizers import (
9 | PromptTokenizingStrategy,
10 | parse_tokenized_to_result,
11 | tokenize_prompt_default,
12 | )
13 |
14 | LOG = logging.getLogger("axolotl")
15 |
16 | IGNORE_TOKEN_ID = -100
17 |
18 |
19 | class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
20 | """
21 | Tokenizing strategy for Pygmalion.
22 | """
23 |
24 | bot_prefix_token_ids: List[int] = []
25 |
26 | def __init__(self, prompter, tokenizer, *args, **kwargs):
27 | super().__init__(prompter, tokenizer, *args, **kwargs)
28 | res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
29 | self.bot_prefix_token_ids = res["input_ids"]
30 |
31 | def tokenize_prompt(self, prompt):
32 | result, current_len = tokenize_prompt_default()
33 | for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
34 | role, message = part
35 | if role == "system":
36 | prefix = "<|system|>"
37 | # this should include a bos token, no eos token, strip trailing "\n"
38 | if message.endswith("\n"):
39 | message = message[:-8]
40 | res = self._tokenize(
41 | prefix + "Persona: " + message.strip(),
42 | add_eos_token=False,
43 | strip_bos_token=False,
44 | )
45 | # everything from this is masked out from the labels
46 | labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
47 | elif role == "human":
48 | prefix = "<|user|>"
49 | res = self._tokenize(
50 | prefix + " " + message.strip(),
51 | add_eos_token=False,
52 | strip_bos_token=True,
53 | )
54 | # everything from this is masked out from the labels
55 | labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
56 | elif role == "bot":
57 | prefix = "<|model|>"
58 | res = self._tokenize(
59 | prefix + " " + message.strip(),
60 | add_eos_token=True,
61 | strip_bos_token=True,
62 | )
63 | # mask out the prefix token, rest is not masked out from labels
64 | # make sure we create the labels first, otherwise we get incorrect lengths
65 | labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
66 | *copy.deepcopy(res["input_ids"])
67 | ][len(self.bot_prefix_token_ids) :]
68 | else:
69 | LOG.warning(f"unknown role in conversation: {role}")
70 | res = defaultdict(lambda: [])
71 |
72 | # pylint: disable=duplicate-code
73 | result, current_len = parse_tokenized_to_result(
74 | result,
75 | current_len,
76 | res,
77 | labels,
78 | pad_token_id=self.tokenizer.pad_token_id,
79 | )
80 | return result
81 |
82 |
83 | class PygmalionPrompter:
84 | """
85 | Prompter for Pygmalion.
86 | """
87 |
88 | def __init__(self, *args, **kwargs):
89 | pass
90 |
91 | def build_prompt(
92 | self, source, *args, **kwargs # pylint: disable=unused-argument
93 | ) -> Generator[Tuple[str, str], None, None]:
94 | for msg in source:
95 | yield msg["role"], msg["value"]
96 |
97 |
98 | def load(tokenizer, cfg):
99 | return PygmalionPromptTokenizingStrategy(
100 | PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
101 | )
102 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/schedulers.py:
--------------------------------------------------------------------------------
1 | """Module for custom LRScheduler class"""
2 | import math
3 | from functools import partial
4 |
5 | from torch.optim import Optimizer
6 | from torch.optim.lr_scheduler import LambdaLR, LRScheduler
7 |
8 |
9 | class InterpolatingLogScheduler(LRScheduler):
10 | """
11 | A scheduler that interpolates learning rates in a logarithmic fashion
12 | """
13 |
14 | def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
15 | """A scheduler that interpolates learning rates in a logarithmic fashion
16 |
17 | Args:
18 | - optimizer: pytorch optimizer
19 | - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
20 | - min_lr: float, the minimum learning rate
21 | - max_lr: float, the maximum learning rate
22 |
23 | Usage:
24 | fc = nn.Linear(1,1)
25 | optimizer = optim.Adam(fc.parameters())
26 | lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
27 | """
28 | self.num_steps = num_steps
29 | self.min_lr = min_lr
30 | self.max_lr = max_lr
31 | self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name
32 | 1 / (num_steps - 1)
33 | )
34 | super().__init__(optimizer, last_epoch)
35 |
36 | def get_lr(self):
37 | if self.last_epoch <= 0:
38 | lrs = [self.min_lr for base_lr in self.base_lrs]
39 | elif self.last_epoch < self.num_steps:
40 | lrs = [
41 | self.min_lr * (self.q ** (self.last_epoch - 1))
42 | for base_lr in self.base_lrs
43 | ]
44 | else:
45 | lrs = [self.max_lr for base_lr in self.base_lrs]
46 |
47 | return lrs
48 |
49 |
50 | def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
51 | current_step: int,
52 | *,
53 | num_warmup_steps: int,
54 | num_training_steps: int,
55 | num_cycles: float
56 | ):
57 | if current_step < num_warmup_steps:
58 | return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
59 | progress = float(current_step - num_warmup_steps) / float(
60 | max(1, num_training_steps - num_warmup_steps)
61 | )
62 | return max(
63 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
64 | )
65 |
66 |
67 | def get_cosine_schedule_with_quadratic_warmup(
68 | optimizer: Optimizer,
69 | num_warmup_steps: int,
70 | num_training_steps: int,
71 | num_cycles: float = 0.5,
72 | last_epoch: int = -1,
73 | ):
74 | """
75 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
76 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
77 | initial lr set in the optimizer.
78 |
79 | Args:
80 | optimizer ([`~torch.optim.Optimizer`]):
81 | The optimizer for which to schedule the learning rate.
82 | num_warmup_steps (`int`):
83 | The number of steps for the warmup phase.
84 | num_training_steps (`int`):
85 | The total number of training steps.
86 | num_cycles (`float`, *optional*, defaults to 0.5):
87 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
88 | following a half-cosine).
89 | last_epoch (`int`, *optional*, defaults to -1):
90 | The index of the last epoch when resuming training.
91 |
92 | Return:
93 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
94 | """
95 |
96 | lr_lambda = partial(
97 | _get_cosine_schedule_with_quadratic_warmup_lr_lambda,
98 | num_warmup_steps=num_warmup_steps,
99 | num_training_steps=num_training_steps,
100 | num_cycles=num_cycles,
101 | )
102 | return LambdaLR(optimizer, lr_lambda, last_epoch)
103 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/alpaca_chat.py:
--------------------------------------------------------------------------------
1 | """Module for Alpaca prompt strategy classes"""
2 |
3 | from typing import Any, Dict, Optional, Tuple
4 |
5 | from axolotl.prompt_tokenizers import (
6 | AlpacaPromptTokenizingStrategy,
7 | InstructionPromptTokenizingStrategy,
8 | )
9 | from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
10 |
11 |
12 | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
13 | prompt_style = PromptStyle.CHAT.value
14 | if ds_cfg and "conversation" in ds_cfg:
15 | prompt_style = ds_cfg["conversation"]
16 |
17 | return AlpacaPromptTokenizingStrategy(
18 | AlpacaPrompter(prompt_style),
19 | tokenizer,
20 | cfg.train_on_inputs,
21 | cfg.sequence_len,
22 | )
23 |
24 |
25 | class AlpacaConcisePrompter(AlpacaPrompter):
26 | """
27 | Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
28 | """
29 |
30 | system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
31 | system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
32 |
33 |
34 | class AlpacaChatPrompter(AlpacaPrompter):
35 | """
36 | Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
37 | """
38 |
39 | system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
40 | system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
41 |
42 | system_prompt = ""
43 | system_no_input_prompt = ""
44 |
45 | def __init__(self): # pylint: disable=super-init-not-called
46 | self.prompt_style = PromptStyle.CHAT.value
47 | self.match_prompt_style()
48 |
49 |
50 | class NoSystemPrompter(AlpacaPrompter):
51 | """
52 | Null Prompter with no system prompts
53 | """
54 |
55 | system_prompt = ""
56 | system_no_input_prompt = ""
57 | turn_format = "{instruction} {input} "
58 | turn_no_input_format = "{instruction} "
59 |
60 | def __init__(self): # pylint: disable=super-init-not-called
61 | pass
62 |
63 |
64 | class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
65 | """
66 | Tokenizing strategy for AlpacaQA
67 | """
68 |
69 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
70 | return (
71 | prompt["question"],
72 | "",
73 | prompt["answer"],
74 | )
75 |
76 |
77 | class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
78 | """
79 | Tokenizing strategy for CamelAI datasets
80 | """
81 |
82 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
83 | return (
84 | prompt["message_1"],
85 | "",
86 | prompt["message_2"],
87 | )
88 |
89 |
90 | def load_concise(tokenizer, cfg):
91 | return AlpacaPromptTokenizingStrategy(
92 | AlpacaConcisePrompter(PromptStyle.CHAT.value),
93 | tokenizer,
94 | cfg.train_on_inputs,
95 | cfg.sequence_len,
96 | )
97 |
98 |
99 | def load_qa(tokenizer, cfg):
100 | return AlpacaQAPromptTokenizingStrategy(
101 | AlpacaChatPrompter(),
102 | tokenizer,
103 | cfg.train_on_inputs,
104 | cfg.sequence_len,
105 | )
106 |
107 |
108 | def load_camel_ai(tokenizer, cfg):
109 | return CamelAIPromptTokenizingStrategy(
110 | AlpacaChatPrompter(),
111 | tokenizer,
112 | cfg.train_on_inputs,
113 | cfg.sequence_len,
114 | )
115 |
116 |
117 | def load_no_prompt(tokenizer, cfg):
118 | return AlpacaPromptTokenizingStrategy(
119 | UnpromptedPrompter(PromptStyle.CHAT.value),
120 | tokenizer,
121 | cfg.train_on_inputs,
122 | cfg.sequence_len,
123 | )
124 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Shared utils for the monkeypatches
3 | """
4 | import torch
5 |
6 |
7 | def get_cu_seqlens(attn_mask):
8 | """generate a cumulative sequence length mask for flash attention using attn mask"""
9 | if len(attn_mask.shape) == 1:
10 | attn_mask = attn_mask.unsqueeze(0)
11 |
12 | device = attn_mask.device
13 | results = []
14 | max_seq_lens = []
15 |
16 | for row in attn_mask:
17 | # Exclude zeros to avoid adding their positions to the mask
18 | t_non_zeros = row[row != 0]
19 | # Find where the sequence number changes (including the first position)
20 | seq_change = torch.cat(
21 | [
22 | torch.tensor([1], dtype=torch.int32, device=device),
23 | t_non_zeros[1:] != t_non_zeros[:-1],
24 | ]
25 | )
26 | # Get the indices where the sequence changes
27 | change_indices = torch.cat(
28 | [
29 | (seq_change == 1).nonzero(as_tuple=True)[0],
30 | torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
31 | ]
32 | )
33 | # Calculate the sequence lengths
34 | seq_lengths = change_indices[1:] - change_indices[:-1]
35 | # Calculate the length of the final sequence or padding
36 | final_seq_length = len(row) - change_indices[-1]
37 | # Append the length of the final sequence or padding to seq_lengths
38 | if final_seq_length.item():
39 | seq_lengths = torch.cat(
40 | [
41 | seq_lengths,
42 | torch.tensor(
43 | [final_seq_length.item()], dtype=torch.int32, device=device
44 | ),
45 | ]
46 | )
47 | # Calculate the cumulative sequence lengths
48 | cu_seqlens = torch.cat(
49 | [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
50 | )
51 | max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
52 | results.append(cu_seqlens)
53 | max_seq_lens.append(max_seq_len)
54 |
55 | return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
56 |
57 |
58 | def get_cu_seqlens_from_pos_ids(position_ids):
59 | """generate a cumulative sequence length mask for flash attention using pos ids"""
60 | if len(position_ids.shape) == 1:
61 | position_ids = position_ids.unsqueeze(0)
62 |
63 | device = position_ids.device
64 | results = []
65 | max_seq_lens = []
66 |
67 | for row in position_ids:
68 | # Count the number of consecutive zeros from the right side
69 | padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
70 |
71 | # Adjust the row to exclude padding
72 | adjusted_row = row[:-padding_length] if padding_length else row.clone()
73 |
74 | # Find where the position resets to 0 (indicating a new sequence)
75 | seq_starts = torch.cat(
76 | [
77 | torch.tensor([True], dtype=torch.bool, device=device),
78 | adjusted_row[1:] == 0,
79 | ]
80 | )
81 | # Get the indices where the sequence starts
82 | start_indices = torch.cat(
83 | [
84 | (seq_starts).nonzero(as_tuple=True)[0],
85 | torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
86 | ]
87 | )
88 | # Calculate the sequence lengths
89 | seq_lengths = start_indices[1:] - start_indices[:-1]
90 | # Calculate the cumulative sequence lengths
91 | cu_seqlens = torch.cat(
92 | [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
93 | )
94 | # Append the padding length to the cumulative sequence lengths
95 | if padding_length:
96 | cu_seqlens = torch.cat(
97 | [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
98 | )
99 | max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
100 | results.append(cu_seqlens)
101 | max_seq_lens.append(max_seq_len)
102 |
103 | return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
104 |
105 |
106 | def set_module_name(model, name, value):
107 | if "." in name:
108 | parent_name = name.rsplit(".", 1)[0]
109 | child_name = name[len(parent_name) + 1 :]
110 | parent = model.get_submodule(parent_name)
111 | else:
112 | parent_name = ""
113 | parent = model
114 | child_name = name
115 |
116 | setattr(parent, child_name, value)
117 |
--------------------------------------------------------------------------------
/data_prep/prepare_exact_match_tasks_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import math
5 |
6 | total_rec, incorrect = 0, 0
7 |
8 | total_records_to_sample = 50*1000
9 | output_dir = "/home/minimalist/Downloads/nips_training_data/NI/1_word/incorrect_preds"
10 |
11 | def fetch_instruction(task_file_name):
12 | # save json files locally from https://github.com/allenai/natural-instructions/tree/master/tasks
13 | with open(os.path.join("/home/minimalist/work/natural-instructions/tasks", task_file_name)) as f:
14 | task_json = json.load(f)
15 | return task_json["Definition"][0]
16 |
17 | def extract_answer(answer):
18 | return answer.split("\n\n")[0]
19 |
20 | def process_file(file_path):
21 | global total_rec, incorrect
22 | incorrect_predictions = []
23 | with open(file_path) as f:
24 | data = json.load(f)
25 | for rec in data:
26 | ground_truth = rec["ground_truth"]
27 | prediction = extract_answer(rec["prediction"])
28 | is_correct = False
29 |
30 | out_rec = {}
31 | out_rec['instruction'] = fetch_instruction(rec["task_file"])
32 | out_rec['input'] = rec["orig_input"]
33 | out_rec['output'] = ground_truth[0]
34 |
35 | if any([x == prediction for x in ground_truth]):
36 | is_correct = True
37 |
38 | if not is_correct:
39 | incorrect_predictions.append(out_rec)
40 |
41 | random.shuffle(incorrect_predictions)
42 |
43 | task_accuracy = (len(data)-len(incorrect_predictions)) / len(data)
44 |
45 | total_rec += len(data)
46 | incorrect += len(incorrect_predictions)
47 | file_name = file_path.split("/home/minimalist/Downloads/super_NI_mistral_inference_output/1_word/")[1]
48 | print(f"\n{file_name}: #task_records {len(data)}, Incorrect: {1-task_accuracy}\nOVERALL: #total_records: {total_rec}, total_incorrect: {incorrect} Incorect: {incorrect / total_rec}")
49 |
50 | output = {}
51 | output['source'] = file_name
52 | output["accuracy"] = task_accuracy
53 | output["instances"] = incorrect_predictions
54 |
55 | with open(f"{output_dir}/{file_name}", 'w') as f:
56 | json.dump(output, f, indent=1)
57 |
58 |
59 | def find_incorrect_predictions():
60 | files = os.listdir("/home/minimalist/Downloads/super_NI_mistral_inference_output/1_word/")
61 | for file in files:
62 | if not os.path.exists(f"{output_dir}/{file}"):
63 | process_file(os.path.join("/home/minimalist/Downloads/super_NI_mistral_inference_output/1_word", file))
64 | else:
65 | print(f"Path already exists! {output_dir}/{file}")
66 |
67 | def sample_by_accuracy(file_path, multiplication_factor = 1.0):
68 | with open(file_path) as f:
69 | data = json.load(f)
70 |
71 | accuracy = data["accuracy"]
72 | print(f"accuracy: {accuracy}")
73 | instances: object = data["instances"]
74 | print(f"original #records: {len(instances)}")
75 | instances = [rec for rec in instances if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800]
76 | print(f"#records after removing long inputs: {len(instances)}")
77 |
78 | bucket_id = math.floor(int((accuracy*100)/10))
79 | if bucket_id == 0:
80 | samples = int(350*multiplication_factor)
81 | elif bucket_id == 1:
82 | samples = int(350*multiplication_factor)
83 | elif bucket_id == 2:
84 | samples = int(300*multiplication_factor)
85 | elif bucket_id == 3:
86 | samples = int(300*multiplication_factor)
87 | elif bucket_id == 4:
88 | samples = int(250*multiplication_factor)
89 | elif bucket_id == 5:
90 | samples = int(250*multiplication_factor)
91 | elif bucket_id == 6:
92 | samples = int(200*multiplication_factor)
93 | elif bucket_id == 7:
94 | samples = int(150*multiplication_factor)
95 | elif bucket_id == 8:
96 | samples = int(100*multiplication_factor)
97 | else:
98 | samples = int(50 * multiplication_factor)
99 |
100 | random.shuffle(instances)
101 | sampled_data = instances[:samples]
102 | return sampled_data
103 |
104 |
105 | def prepare_training_data(multiplication_factor = 1.0):
106 | training_data = []
107 | data_dir = "/home/minimalist/Downloads/nips_training_data/NI/1_word/incorrect_preds/"
108 | files = os.listdir(data_dir)
109 | for file in files:
110 | print(f"file: {file}")
111 | sampled_data = sample_by_accuracy(os.path.join(data_dir, file), multiplication_factor=multiplication_factor)
112 | print(f"#Sampled Records: {len(sampled_data)}")
113 |
114 | training_data.extend(sampled_data)
115 | print(f"Total data so far: {len(training_data)}\n\n")
116 |
117 | with open(f"/home/minimalist/Downloads/nips_training_data/super_natural_1_word_data_{str(int(multiplication_factor))}.json", 'w') as f:
118 | json.dump(training_data, f, indent=1)
119 |
120 |
121 | find_incorrect_predictions()
122 | prepare_training_data()
123 |
--------------------------------------------------------------------------------
/inference/submission_2/main.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 |
3 | import logging
4 | import os
5 | import time
6 |
7 | import torch
8 | from huggingface_hub import login
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 | from peft import PeftModel
11 |
12 | torch.set_float32_matmul_precision("high")
13 |
14 | from api import (
15 | ProcessRequest,
16 | ProcessResponse,
17 | TokenizeRequest,
18 | TokenizeResponse,
19 | Token,
20 | )
21 |
22 | app = FastAPI()
23 |
24 | logger = logging.getLogger(__name__)
25 | # Configure the logging module
26 | logging.basicConfig(level=logging.INFO)
27 |
28 | login(token=os.environ["HUGGINGFACE_TOKEN"])
29 |
30 |
31 |
32 | def load_peft_model(base_model, peft_model):
33 | peft_model = PeftModel.from_pretrained(base_model, peft_model)
34 | return peft_model
35 |
36 | base_model_name= "mistralai/Mistral-7B-v0.1"
37 | base_model = AutoModelForCausalLM.from_pretrained(
38 | base_model_name,
39 | return_dict=True,
40 | torch_dtype=torch.bfloat16,
41 | device_map="cuda"
42 | )
43 | model = load_peft_model(base_model, os.environ["HUGGINGFACE_REPO"])
44 |
45 |
46 | model.eval()
47 | tokenizer = AutoTokenizer.from_pretrained(base_model_name)
48 | tokenizer.pad_token = tokenizer.eos_token
49 | LLAMA2_CONTEXT_LENGTH = 4096
50 |
51 |
52 | @app.post("/process")
53 | async def process_request(input_data: ProcessRequest) -> ProcessResponse:
54 | if input_data.seed is not None:
55 | torch.manual_seed(input_data.seed)
56 |
57 | prompt = input_data.prompt.strip()
58 | two_consecutive_new_line_count = prompt.count("\n\n")
59 | if two_consecutive_new_line_count < 2:
60 | prompt = "### Input:\n" + input_data.prompt + "\n\n### Response:\n"
61 | encoded = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False)
62 |
63 | prompt_length = encoded["input_ids"][0].size(0)
64 | max_returned_tokens = prompt_length + input_data.max_new_tokens
65 | assert max_returned_tokens <= LLAMA2_CONTEXT_LENGTH, (
66 | max_returned_tokens,
67 | LLAMA2_CONTEXT_LENGTH,
68 | )
69 |
70 | t0 = time.perf_counter()
71 | encoded = {k: v.to("cuda") for k, v in encoded.items()}
72 | with torch.no_grad():
73 | outputs = model.generate(
74 | **encoded,
75 | max_new_tokens=input_data.max_new_tokens,
76 | do_sample=True,
77 | temperature=input_data.temperature,
78 | top_k=input_data.top_k,
79 | return_dict_in_generate=True,
80 | output_scores=True,
81 | )
82 |
83 | # logger.info("Request and response - start")
84 | # logger.info(tokenizer.decode(outputs.sequences[0], skip_special_tokens=True))
85 | # logger.info("Request and response - end")
86 |
87 | t = time.perf_counter() - t0
88 | if not input_data.echo_prompt:
89 | output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True)
90 | output = output.split("\n\n")[0]
91 | if output.lower().strip().startswith("the answer is"):
92 | output = output.strip()[13:].strip()
93 | output = output[:1]
94 | else:
95 | output = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
96 |
97 | tokens_generated = outputs.sequences[0].size(0) - prompt_length
98 | # logger.info(
99 | # f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec"
100 | # )
101 | #
102 | # logger.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
103 | generated_tokens = []
104 |
105 | log_probs = torch.log(torch.stack(outputs.scores, dim=1).softmax(-1))
106 |
107 | gen_sequences = outputs.sequences[:, encoded["input_ids"].shape[-1]:]
108 | gen_logprobs = torch.gather(log_probs, 2, gen_sequences[:, :, None]).squeeze(-1)
109 |
110 | top_indices = torch.argmax(log_probs, dim=-1)
111 | top_logprobs = torch.gather(log_probs, 2, top_indices[:,:,None]).squeeze(-1)
112 | top_indices = top_indices.tolist()[0]
113 | top_logprobs = top_logprobs.tolist()[0]
114 |
115 | for t, lp, tlp in zip(gen_sequences.tolist()[0], gen_logprobs.tolist()[0], zip(top_indices, top_logprobs)):
116 | idx, val = tlp
117 | tok_str = tokenizer.decode(idx)
118 | token_tlp = {tok_str: val}
119 | generated_tokens.append(
120 | Token(text=tokenizer.decode(t), logprob=lp, top_logprob=token_tlp)
121 | )
122 | logprob_sum = gen_logprobs.sum().item()
123 |
124 | return ProcessResponse(
125 | text=output, tokens=generated_tokens, logprob=logprob_sum, request_time=t
126 | )
127 |
128 |
129 | @app.post("/tokenize")
130 | async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse:
131 | t0 = time.perf_counter()
132 | encoded = tokenizer(
133 | input_data.text
134 | )
135 | t = time.perf_counter() - t0
136 | tokens = encoded["input_ids"]
137 | return TokenizeResponse(tokens=tokens, request_time=t)
138 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/context_qa.py:
--------------------------------------------------------------------------------
1 | """Module containing the classes for Context QA Prompt Tokenization Strategies"""
2 | from typing import Tuple
3 |
4 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
5 | from axolotl.prompters import AlpacaPrompter, PromptStyle
6 |
7 |
8 | # article, unanswerable_question, question, answer
9 | def load_404(tokenizer, cfg):
10 | return AlpacaMissingInfoContextPromptTokenizingStrategy(
11 | AlpacaContextPrompter(PromptStyle.CHAT.value),
12 | tokenizer,
13 | cfg.train_on_inputs,
14 | cfg.sequence_len,
15 | )
16 |
17 |
18 | def load(tokenizer, cfg):
19 | return AlpacaContextPromptTokenizingStrategy(
20 | AlpacaContextPrompter(PromptStyle.CHAT.value),
21 | tokenizer,
22 | cfg.train_on_inputs,
23 | cfg.sequence_len,
24 | )
25 |
26 |
27 | def load_v2(tokenizer, cfg):
28 | return ContextQaV2PromptTokenizingStrategy(
29 | ContextV2Prompter(),
30 | tokenizer,
31 | cfg.train_on_inputs,
32 | cfg.sequence_len,
33 | )
34 |
35 | def load_dolly(tokenizer, cfg):
36 | return ContextQaDollyPromptTokenizingStrategy(
37 | ContextDollyPrompter(),
38 | tokenizer,
39 | cfg.train_on_inputs,
40 | cfg.sequence_len,
41 | )
42 |
43 |
44 |
45 | class AlpacaContextPrompter(AlpacaPrompter):
46 | """
47 | Customized system prompted for concise QA
48 | """
49 |
50 | system_prompt = (
51 | "Use the following contextual information to concisely answer the question.\n"
52 | )
53 | system_no_input_prompt = (
54 | "Use the following contextual information to concisely answer the question.\n"
55 | )
56 |
57 |
58 | class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
59 | """
60 | Tokenization Strategy to combine in-context article with a question and answer
61 | """
62 |
63 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
64 | return (
65 | prompt["article"] + "\n===\n" + prompt["question"],
66 | "",
67 | prompt["answer"],
68 | )
69 |
70 |
71 | class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
72 | """
73 | Tokenization Strategy to combine in-context article with a question and answer
74 | """
75 |
76 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
77 | return (
78 | "Context: "
79 | + prompt["context"]
80 | + "\nQuestion: "
81 | + prompt["question"]
82 | + "\n",
83 | "",
84 | "Answer: " + prompt["answer"],
85 | )
86 |
87 | class ContextQaDollyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
88 | """
89 | Tokenization Strategy to combine in-context article with a question and answer
90 | """
91 |
92 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
93 | instruction = f"[INST] {prompt['instruction']}"
94 | context = f"Here's some context: {prompt['context']}" if len(prompt["context"]) > 0 else None
95 | response = f"[/INST] {prompt['response']}"
96 |
97 | if context is not None:
98 | instruction += f"\n\n{context}"
99 | instruction = instruction.strip()
100 |
101 | return (
102 | f"{instruction}\n\n",
103 | "",
104 | response.strip(),
105 | )
106 |
107 |
108 | class ContextV2Prompter(AlpacaPrompter):
109 | """
110 | Customized system prompted for concise QA
111 | """
112 |
113 | system_prompt = ""
114 | system_no_input_prompt = ""
115 |
116 | def match_prompt_style(self):
117 | # pylint: disable=duplicate-code
118 | self.turn_format = "{instruction}\n{input}"
119 | self.turn_no_input_format = "{instruction}"
120 | self.system_format = "{system}"
121 |
122 | class ContextDollyPrompter(AlpacaPrompter):
123 | """
124 | Customized system prompted for concise QA
125 | """
126 |
127 | system_prompt = ""
128 | system_no_input_prompt = ""
129 |
130 | def match_prompt_style(self):
131 | # pylint: disable=duplicate-code
132 | self.turn_format = "{instruction}\n{context}"
133 | self.turn_no_input_format = "{instruction}"
134 | self.system_format = "{system}"
135 |
136 |
137 | class AlpacaMissingInfoContextPromptTokenizingStrategy(
138 | InstructionPromptTokenizingStrategy
139 | ):
140 | """
141 | Tokenization Strategy to combine in-context article with a question that can't be answered
142 | from the context and a default response to that effect
143 | """
144 |
145 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
146 | return (
147 | prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
148 | "",
149 | "The context provided does not contain any information about your inquiry. "
150 | "Therefore, I'm unable to answer your question based on the given context.",
151 | )
152 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py:
--------------------------------------------------------------------------------
1 | """
2 | Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
3 | """
4 |
5 | import warnings
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | import transformers.models.llama.modeling_llama
11 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
12 |
13 |
14 | def hijack_llama_sdp_attention():
15 | transformers.models.llama.modeling_llama.LlamaAttention.forward = (
16 | sdp_attention_forward
17 | )
18 |
19 |
20 | def sdp_attention_forward(
21 | self,
22 | hidden_states: torch.Tensor,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | position_ids: Optional[torch.LongTensor] = None,
25 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
26 | output_attentions: bool = False,
27 | use_cache: bool = False,
28 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
29 | # pylint: disable=duplicate-code
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | if not hasattr(self, "pretraining_tp"):
33 | self.pretraining_tp = 1
34 |
35 | if self.pretraining_tp > 1:
36 | key_value_slicing = (
37 | self.num_key_value_heads * self.head_dim
38 | ) // self.pretraining_tp
39 | query_slices = self.q_proj.weight.split(
40 | (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
41 | )
42 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
43 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
44 |
45 | query_states = [
46 | F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
47 | ]
48 | query_states = torch.cat(query_states, dim=-1)
49 |
50 | key_states = [
51 | F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
52 | ]
53 | key_states = torch.cat(key_states, dim=-1)
54 |
55 | value_states = [
56 | F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
57 | ]
58 | value_states = torch.cat(value_states, dim=-1)
59 |
60 | else:
61 | query_states = self.q_proj(hidden_states)
62 | key_states = self.k_proj(hidden_states)
63 | value_states = self.v_proj(hidden_states)
64 |
65 | query_states = query_states.view(
66 | bsz, q_len, self.num_heads, self.head_dim
67 | ).transpose(1, 2)
68 | key_states = key_states.view(
69 | bsz, q_len, self.num_key_value_heads, self.head_dim
70 | ).transpose(1, 2)
71 | value_states = value_states.view(
72 | bsz, q_len, self.num_key_value_heads, self.head_dim
73 | ).transpose(1, 2)
74 | # [bsz, q_len, nh, hd]
75 | # [bsz, nh, q_len, hd]
76 |
77 | kv_seq_len = key_states.shape[-2]
78 | if past_key_value is not None:
79 | kv_seq_len += past_key_value[0].shape[-2]
80 |
81 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
82 | query_states, key_states = apply_rotary_pos_emb(
83 | query_states, key_states, cos, sin, position_ids
84 | )
85 | # [bsz, nh, t, hd]
86 |
87 | if past_key_value is not None:
88 | # reuse k, v, self_attention
89 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
90 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
91 |
92 | past_key_value = (key_states, value_states) if use_cache else None
93 |
94 | # repeat k/v heads if n_kv_heads < n_heads
95 | key_states = repeat_kv(key_states, self.num_key_value_groups)
96 | value_states = repeat_kv(value_states, self.num_key_value_groups)
97 |
98 | if output_attentions:
99 | warnings.warn(
100 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
101 | )
102 |
103 | #
104 | # sdp-attn start
105 | #
106 |
107 | with torch.backends.cuda.sdp_kernel():
108 | attn_output = torch.nn.functional.scaled_dot_product_attention(
109 | query_states,
110 | key_states,
111 | value_states,
112 | attn_mask=attention_mask,
113 | is_causal=False,
114 | )
115 |
116 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
117 | raise ValueError(
118 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
119 | f" {attn_output.size()}"
120 | )
121 | attn_output = attn_output.transpose(1, 2)
122 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
123 |
124 | #
125 | # sdp-attn end
126 | #
127 |
128 | if self.pretraining_tp > 1:
129 | attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
130 | o_proj_slices = self.o_proj.weight.split(
131 | self.hidden_size // self.pretraining_tp, dim=1
132 | )
133 | attn_output = sum(
134 | F.linear(attn_output[i], o_proj_slices[i])
135 | for i in range(self.pretraining_tp)
136 | )
137 | else:
138 | attn_output = self.o_proj(attn_output)
139 |
140 | return attn_output, None, past_key_value
141 |
--------------------------------------------------------------------------------
/data_prep/prepare_generation_tasks_dataset.py:
--------------------------------------------------------------------------------
1 | from rouge import Rouge
2 | import json
3 | import os
4 | import csv
5 | import random
6 | import math
7 |
8 | filtered_list = []
9 | def compute_rouge_score(llm_json, output_json):
10 | f = open(llm_json)
11 | data = json.load(f)
12 | rouge = Rouge()
13 | list_2 = []
14 | list_3 = []
15 | list_4 = []
16 | list_5 = []
17 | list_6 = []
18 | list_7 = []
19 | list_8 = []
20 | final_list = []
21 |
22 | for i in data:
23 | ground_truths = i["ground_truth"]
24 | prediction = i["prediction"].split("\n\n")[0].strip()
25 | if len(ground_truths) > 0:
26 | rouge_scores = []
27 | for ground_truth in ground_truths:
28 | ground_truth = ground_truth.replace('.','')
29 | prediction = prediction.replace('.','')
30 | if len(ground_truth) == 0 and len(prediction) == 0:
31 | ground_truth = "DUMMY"
32 | prediction = "DUMMY"
33 | if len(ground_truth) == 0:
34 | ground_truth = prediction
35 | if len(prediction) == 0:
36 | prediction = ground_truth
37 | scores = rouge.get_scores(prediction.lower(), ground_truth.lower())
38 | score = float(scores[0]['rouge-l']['f'])
39 | rouge_scores.append(score)
40 | max_rouge_score = max(rouge_scores)
41 | if max_rouge_score < 0.2:
42 | list_2.append(i)
43 | if max_rouge_score >= 0.2 and max_rouge_score < 0.3:
44 | list_3.append(i)
45 | if max_rouge_score >= 0.3 and max_rouge_score < 0.4:
46 | list_4.append(i)
47 | if max_rouge_score >= 0.4 and max_rouge_score < 0.5:
48 | list_5.append(i)
49 | if max_rouge_score >= 0.5 and max_rouge_score < 0.6:
50 | list_6.append(i)
51 | if max_rouge_score >= 0.6 and max_rouge_score < 0.7:
52 | list_7.append(i)
53 | if max_rouge_score >= 0.7 and max_rouge_score < 0.8:
54 | list_8.append(i)
55 |
56 | random.shuffle(list_2)
57 | random.shuffle(list_3)
58 | random.shuffle(list_4)
59 | random.shuffle(list_5)
60 | random.shuffle(list_6)
61 | random.shuffle(list_7)
62 | random.shuffle(list_8)
63 |
64 | list_2_cnt = len(list_2)
65 | list_3_cnt = len(list_3)
66 | list_4_cnt = len(list_4)
67 | list_5_cnt = len(list_5)
68 | list_6_cnt = len(list_6)
69 | list_7_cnt = len(list_7)
70 | list_8_cnt = len(list_8)
71 |
72 | list_2_instance_cnt = int(math.floor(0.8 * list_2_cnt))
73 | list_3_instance_cnt = int(math.floor(0.2 * list_3_cnt))
74 | list_4_instance_cnt = int(math.floor(0.2 * list_4_cnt))
75 | list_5_instance_cnt = int(math.floor(0.2 * list_5_cnt))
76 | list_6_instance_cnt = int(math.floor(0.2 * list_6_cnt))
77 | list_7_instance_cnt = int(math.floor(0.2 * list_7_cnt))
78 | list_8_instance_cnt = int(math.floor(0.2 * list_8_cnt))
79 |
80 | for i in range(0,list_2_instance_cnt):
81 | final_list.append(list_2[i])
82 | for i in range(0,list_3_instance_cnt):
83 | final_list.append(list_3[i])
84 | for i in range(0,list_4_instance_cnt):
85 | final_list.append(list_4[i])
86 | for i in range(0,list_5_instance_cnt):
87 | final_list.append(list_5[i])
88 | for i in range(0,list_6_instance_cnt):
89 | final_list.append(list_6[i])
90 | for i in range(0,list_7_instance_cnt):
91 | final_list.append(list_7[i])
92 | for i in range(0,list_8_instance_cnt):
93 | final_list.append(list_8[i])
94 |
95 | for i in final_list:
96 | ground_truths = i["ground_truth"]
97 | input_text = i["orig_input"]
98 | words = input_text.split(" ")
99 | if len(ground_truths) == 1 and len(words) <= 1100:
100 | del i["id"]
101 | del i["few_shot_prompt"]
102 | del i["prediction"]
103 | del i["orig_output"]
104 | i["output"] = i["ground_truth"][0]
105 | i["input"] = i["orig_input"]
106 | del i["ground_truth"]
107 | del i["orig_input"]
108 | filtered_list.append(i)
109 | list_2.clear()
110 | list_3.clear()
111 | list_4.clear()
112 | list_5.clear()
113 | list_6.clear()
114 | list_7.clear()
115 | list_8.clear()
116 |
117 | size = len(filtered_list)
118 | final_list.clear()
119 | return size
120 |
121 | fields = ["filename", "total"]
122 | rows = []
123 | total_cnt = 0
124 | directory = '/home/utilizeai/Documents/NIPS-LLM/V2_super_natural_inference_mistral_7B'
125 | output_directory = '/home/utilizeai/Documents/NIPS-LLM/final'
126 | for filename in os.listdir(directory):
127 | print(filename)
128 | llm_json = os.path.join(directory, filename)
129 | output_json = os.path.join(output_directory, filename)
130 | cnt = compute_rouge_score(llm_json, output_json)
131 | print(cnt)
132 | total_cnt = total_cnt + cnt
133 | row = []
134 | row.append(filename)
135 | row.append(cnt)
136 | rows.append(row)
137 |
138 | print(total_cnt)
139 | filename = "/home/utilizeai/Documents/NIPS-LLM/supernatural_instructions_filter_cnt.csv"
140 |
141 | # writing to csv file
142 | with open(filename, 'w') as csvfile:
143 | # creating a csv writer object
144 | csvwriter = csv.writer(csvfile)
145 |
146 | # writing the fields
147 | csvwriter.writerow(fields)
148 |
149 | # writing the data rows
150 | csvwriter.writerows(rows)
151 |
152 | output_json = os.path.join(output_directory, "natural_instructions_generation_tasks_dataset.json")
153 | with open(output_json, 'w') as f:
154 | json.dump(filtered_list, f, indent=4)
155 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/collators.py:
--------------------------------------------------------------------------------
1 | """
2 | DataCollator for axolotl to pad labels and position_ids for packed sequences
3 | """
4 | from dataclasses import dataclass
5 | from typing import Any, Optional, Union
6 |
7 | import numpy as np
8 | from transformers import PreTrainedTokenizerBase
9 | from transformers.utils import PaddingStrategy
10 |
11 |
12 | @dataclass
13 | class DataCollatorForSeq2Seq:
14 | """
15 | Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
16 |
17 | Args:
18 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
19 | The tokenizer used for encoding the data.
20 | model ([`PreTrainedModel`]):
21 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
22 | prepare the *decoder_input_ids*
23 |
24 | This is useful when using *label_smoothing* to avoid calculating loss twice.
25 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
26 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
27 | among:
28 |
29 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
30 | sequence is provided).
31 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
32 | acceptable input length for the model if that argument is not provided.
33 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
34 | max_length (`int`, *optional*):
35 | Maximum length of the returned list and optionally padding length (see above).
36 | pad_to_multiple_of (`int`, *optional*):
37 | If set will pad the sequence to a multiple of the provided value.
38 |
39 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
40 | 7.5 (Volta).
41 | label_pad_token_id (`int`, *optional*, defaults to -100):
42 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
43 | return_tensors (`str`):
44 | The type of Tensor to return. Allowable values are "np", "pt" and "tf".
45 | """
46 |
47 | tokenizer: PreTrainedTokenizerBase
48 | model: Optional[Any] = None
49 | padding: Union[bool, str, PaddingStrategy] = True
50 | max_length: Optional[int] = None
51 | pad_to_multiple_of: Optional[int] = None
52 | label_pad_token_id: int = -100
53 | position_pad_token_id: int = 0
54 | return_tensors: str = "pt"
55 |
56 | def __call__(self, features, return_tensors=None):
57 | labels = None
58 | if return_tensors is None:
59 | return_tensors = self.return_tensors
60 |
61 | for feature_name, pad_token_id in [
62 | ("labels", self.label_pad_token_id),
63 | ("position_ids", self.position_pad_token_id),
64 | ]:
65 | feat = (
66 | [feature[feature_name] for feature in features]
67 | if feature_name in features[0].keys()
68 | else None
69 | )
70 | labels = feat if feat and feature_name == "labels" else labels
71 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
72 | # same length to return tensors.
73 | if feat is not None:
74 | max_feature_length = max(len(l) for l in feat) # noqa: E741
75 | if self.pad_to_multiple_of is not None:
76 | max_feature_length = (
77 | (max_feature_length + self.pad_to_multiple_of - 1)
78 | // self.pad_to_multiple_of
79 | * self.pad_to_multiple_of
80 | )
81 |
82 | padding_side = self.tokenizer.padding_side
83 | for feature in features:
84 | remainder = [pad_token_id] * (
85 | max_feature_length - len(feature[feature_name])
86 | )
87 | if isinstance(feature[feature_name], list):
88 | feature[feature_name] = (
89 | feature[feature_name] + remainder
90 | if padding_side == "right"
91 | else remainder + feature[feature_name]
92 | )
93 | elif padding_side == "right":
94 | feature[feature_name] = np.concatenate(
95 | [feature[feature_name], remainder]
96 | ).astype(np.int64)
97 | else:
98 | feature[feature_name] = np.concatenate(
99 | [remainder, feature[feature_name]]
100 | ).astype(np.int64)
101 |
102 | features = self.tokenizer.pad(
103 | features,
104 | padding=self.padding,
105 | max_length=self.max_length,
106 | pad_to_multiple_of=self.pad_to_multiple_of,
107 | return_tensors=return_tensors,
108 | )
109 |
110 | # prepare decoder_input_ids
111 | if (
112 | labels is not None
113 | and self.model is not None
114 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
115 | ):
116 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
117 | labels=features["labels"]
118 | )
119 | features["decoder_input_ids"] = decoder_input_ids
120 |
121 | return features
122 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/alpaca_w_system.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt strategies loader for alpaca instruction datasets with system prompts
3 | """
4 | from typing import Generator, Tuple, Union
5 |
6 | from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7 | from axolotl.prompters import AlpacaPrompter, PromptStyle
8 |
9 |
10 | class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
11 | """
12 | Tokenizing strategy for instruction-based prompts.
13 | """
14 |
15 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
16 | return (
17 | prompt["instruction"],
18 | prompt["input"] if "input" in prompt else "",
19 | prompt["output"],
20 | prompt["system"],
21 | )
22 |
23 | def tokenize_prompt(self, prompt):
24 | # pylint: disable=duplicate-code
25 | (
26 | instruction,
27 | input, # pylint: disable=redefined-builtin
28 | response,
29 | system,
30 | ) = self.parse_instruction_fields(prompt)
31 | user_prompt = next(
32 | iter(
33 | self.prompter.build_prompt_w_system(
34 | system,
35 | instruction,
36 | input,
37 | )
38 | )
39 | )
40 | tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
41 | if not self.train_on_inputs:
42 | user_prompt_len = len(tokenized_prompt["input_ids"])
43 | # TODO this could be sped up using numpy array slicing
44 | tokenized_prompt["labels"] = [-100] * user_prompt_len
45 | tokenized_res_prompt = self._tokenize(
46 | response, strip_bos_token=True, add_eos_token=True
47 | )
48 | tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
49 | tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
50 | tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
51 |
52 | return tokenized_prompt
53 |
54 |
55 | class SystemDataPrompter(AlpacaPrompter):
56 | """
57 | Alpaca Style Prompter that uses system prompts from the dataset
58 | """
59 |
60 | system_format: str = "### System:\n{system}\n\n"
61 |
62 | def build_prompt_w_system(
63 | self,
64 | system: str,
65 | instruction: str,
66 | input: Union[None, str] = None, # pylint: disable=redefined-builtin
67 | output: Union[None, str] = None,
68 | ) -> Generator[str, None, None]:
69 | # returns the full prompt from instruction and optional input
70 | # if a label (=response, =output) is provided, it's also appended.
71 | formatted_sys_prompt = (
72 | self.system_format.format(system=system)
73 | if system and self.system_format
74 | else ""
75 | )
76 | if input:
77 | res = formatted_sys_prompt + self.turn_format.format(
78 | instruction=instruction, input=input
79 | )
80 | else:
81 | res = formatted_sys_prompt + self.turn_no_input_format.format(
82 | instruction=instruction
83 | )
84 | if output:
85 | res = f"{res}{output}"
86 | yield res
87 |
88 |
89 | class OpenOrcaSystemDataPrompter(SystemDataPrompter):
90 | """
91 | Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
92 | """
93 |
94 | def match_prompt_style(self):
95 | # pylint: disable=duplicate-code
96 | if self.prompt_style == PromptStyle.INSTRUCT.value:
97 | self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n"
98 | self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n"
99 | self.system_format = "### System:\n{system}\n"
100 | if self.prompt_style == PromptStyle.CHAT.value:
101 | self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
102 | self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
103 | self.system_format = "SYSTEM: {system}\n"
104 | if self.prompt_style == PromptStyle.CHATML.value:
105 | self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
106 | self.turn_no_input_format = (
107 | "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
108 | )
109 | self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
110 |
111 |
112 | class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
113 | """
114 | Tokenizing strategy for OpenOrca datasets
115 | """
116 |
117 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
118 | return (
119 | prompt["question"],
120 | "",
121 | prompt["response"],
122 | prompt["system_prompt"],
123 | )
124 |
125 |
126 | def load(tokenizer, cfg):
127 | return load_chat(tokenizer, cfg)
128 |
129 |
130 | def load_instruct(tokenizer, cfg):
131 | return InstructionWSystemPromptTokenizingStrategy(
132 | SystemDataPrompter(PromptStyle.INSTRUCT.value),
133 | tokenizer,
134 | cfg.train_on_inputs,
135 | cfg.sequence_len,
136 | )
137 |
138 |
139 | def load_chat(tokenizer, cfg):
140 | return InstructionWSystemPromptTokenizingStrategy(
141 | SystemDataPrompter(PromptStyle.CHAT.value),
142 | tokenizer,
143 | cfg.train_on_inputs,
144 | cfg.sequence_len,
145 | )
146 |
147 |
148 | def load_open_orca(tokenizer, cfg):
149 | return OpenOrcaPromptTokenizingStrategy(
150 | OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
151 | tokenizer,
152 | cfg.train_on_inputs,
153 | cfg.sequence_len,
154 | )
155 |
156 |
157 | def load_open_orca_chatml(tokenizer, cfg):
158 | return OpenOrcaPromptTokenizingStrategy(
159 | OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
160 | tokenizer,
161 | cfg.train_on_inputs,
162 | cfg.sequence_len,
163 | )
164 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py:
--------------------------------------------------------------------------------
1 | """
2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3 | """
4 |
5 | import logging
6 | import warnings
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | import transformers.models.llama.modeling_llama
12 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
13 |
14 | try:
15 | import xformers.ops
16 | except ImportError:
17 | logging.error("xformers not found! Please install it before trying to use it.")
18 |
19 |
20 | def hijack_llama_attention():
21 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
22 |
23 |
24 | def xformers_forward(
25 | self,
26 | hidden_states: torch.Tensor,
27 | attention_mask: Optional[torch.Tensor] = None,
28 | position_ids: Optional[torch.LongTensor] = None,
29 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
30 | output_attentions: bool = False,
31 | use_cache: bool = False,
32 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
33 | # pylint: disable=duplicate-code
34 | bsz, q_len, _ = hidden_states.size()
35 |
36 | if not hasattr(self, "pretraining_tp"):
37 | self.pretraining_tp = 1
38 |
39 | if self.pretraining_tp > 1:
40 | key_value_slicing = (
41 | self.num_key_value_heads * self.head_dim
42 | ) // self.pretraining_tp
43 | query_slices = self.q_proj.weight.split(
44 | (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
45 | )
46 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
47 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
48 |
49 | query_states = [
50 | F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
51 | ]
52 | query_states = torch.cat(query_states, dim=-1)
53 |
54 | key_states = [
55 | F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
56 | ]
57 | key_states = torch.cat(key_states, dim=-1)
58 |
59 | value_states = [
60 | F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
61 | ]
62 | value_states = torch.cat(value_states, dim=-1)
63 |
64 | else:
65 | query_states = self.q_proj(hidden_states)
66 | key_states = self.k_proj(hidden_states)
67 | value_states = self.v_proj(hidden_states)
68 |
69 | query_states = query_states.view(
70 | bsz, q_len, self.num_heads, self.head_dim
71 | ).transpose(1, 2)
72 | key_states = key_states.view(
73 | bsz, q_len, self.num_key_value_heads, self.head_dim
74 | ).transpose(1, 2)
75 | value_states = value_states.view(
76 | bsz, q_len, self.num_key_value_heads, self.head_dim
77 | ).transpose(1, 2)
78 | # [bsz, q_len, nh, hd]
79 | # [bsz, nh, q_len, hd]
80 |
81 | kv_seq_len = key_states.shape[-2]
82 | if past_key_value is not None:
83 | kv_seq_len += past_key_value[0].shape[-2]
84 |
85 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
86 | query_states, key_states = apply_rotary_pos_emb(
87 | query_states, key_states, cos, sin, position_ids
88 | )
89 | # [bsz, nh, t, hd]
90 |
91 | if past_key_value is not None:
92 | # reuse k, v, self_attention
93 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
94 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
95 |
96 | past_key_value = (key_states, value_states) if use_cache else None
97 |
98 | # repeat k/v heads if n_kv_heads < n_heads
99 | key_states = repeat_kv(key_states, self.num_key_value_groups)
100 | value_states = repeat_kv(value_states, self.num_key_value_groups)
101 |
102 | if output_attentions:
103 | warnings.warn(
104 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
105 | )
106 |
107 | #
108 | # xformers-attn start
109 | #
110 |
111 | query_states = query_states.transpose(1, 2)
112 | key_states = key_states.transpose(1, 2)
113 | value_states = value_states.transpose(1, 2)
114 |
115 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
116 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
117 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
118 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
119 | attn_output = xformers.ops.memory_efficient_attention(
120 | query_states, key_states, value_states, attn_bias=None
121 | )
122 | else:
123 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
124 | attn_output = xformers.ops.memory_efficient_attention(
125 | query_states,
126 | key_states,
127 | value_states,
128 | # attn_bias=attention_mask,
129 | attn_bias=xformers.ops.LowerTriangularMask(),
130 | )
131 |
132 | if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
133 | raise ValueError(
134 | f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
135 | f" {attn_output.size()}"
136 | )
137 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
138 |
139 | #
140 | # xformers-attn end
141 | #
142 |
143 | if self.pretraining_tp > 1:
144 | attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
145 | o_proj_slices = self.o_proj.weight.split(
146 | self.hidden_size // self.pretraining_tp, dim=1
147 | )
148 | attn_output = sum(
149 | F.linear(attn_output[i], o_proj_slices[i])
150 | for i in range(self.pretraining_tp)
151 | )
152 | else:
153 | attn_output = self.o_proj(attn_output)
154 |
155 | return attn_output, None, past_key_value
156 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/monkeypatch/fastchat_conversation_turns.py:
--------------------------------------------------------------------------------
1 | """
2 | monkeypatch to add a get_turns method
3 | """
4 |
5 | import logging
6 | from typing import Generator, Tuple
7 |
8 | from fastchat.conversation import SeparatorStyle
9 |
10 | LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
11 |
12 |
13 | def get_prompt(self) -> str:
14 | ret = ""
15 | for role, msg in self.get_turns():
16 | ret += role + msg
17 | return ret
18 |
19 |
20 | def get_turns( # pylint: disable=too-many-return-statements
21 | self,
22 | ) -> Generator[Tuple[str, str], None, None]:
23 | """Get the prompt for generation."""
24 | system_prompt = self.system_template.format(system_message=self.system_message)
25 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
26 | yield "", system_prompt + self.sep
27 | for role, message in self.messages:
28 | if message:
29 | yield role + ": ", message + self.sep
30 | else:
31 | yield role + ":", ""
32 | return
33 | if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
34 | seps = [self.sep, self.sep2]
35 | yield "", system_prompt + seps[0]
36 | for i, (role, message) in enumerate(self.messages):
37 | if message:
38 | yield role + ": ", message + seps[i % 2]
39 | else:
40 | yield role + ":", ""
41 | return
42 | if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
43 | yield "", system_prompt + self.sep
44 | for role, message in self.messages:
45 | if message:
46 | yield role + ": ", message + self.sep
47 | else:
48 | yield role + ": ", "" # must be end with a space
49 | return
50 | if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
51 | yield "", "" if system_prompt == "" else system_prompt + self.sep
52 | for role, message in self.messages:
53 | if message:
54 | yield role + "\n", message + self.sep
55 | else:
56 | yield role + "\n", ""
57 | return
58 | if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
59 | yield "", system_prompt
60 | for role, message in self.messages:
61 | if message:
62 | yield role, message + self.sep
63 | else:
64 | yield role, ""
65 | return
66 | if self.sep_style == SeparatorStyle.NO_COLON_TWO:
67 | seps = [self.sep, self.sep2]
68 | yield "", system_prompt
69 | for i, (role, message) in enumerate(self.messages):
70 | if message:
71 | yield role, message + seps[i % 2]
72 | else:
73 | yield role, ""
74 | return
75 | if self.sep_style == SeparatorStyle.RWKV:
76 | yield "", system_prompt
77 | for i, (role, message) in enumerate(self.messages):
78 | if message:
79 | yield role + ": ", message.replace("\r\n", "\n").replace(
80 | "\n\n", "\n"
81 | ) + "\n\n"
82 | else:
83 | yield role + ":", ""
84 | return
85 | if self.sep_style == SeparatorStyle.LLAMA2:
86 | seps = [self.sep, self.sep2]
87 | if self.system_message:
88 | yield "", system_prompt
89 | else:
90 | yield "", "[INST] "
91 | for i, (role, message) in enumerate(self.messages[1:]):
92 | if message:
93 | yield role + " ", message + seps[i % 2]
94 | else:
95 | yield role, ""
96 | return
97 | if self.sep_style == SeparatorStyle.CHATGLM:
98 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
99 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
100 | round_add_n = 1 if self.name == "chatglm2" else 0
101 | if system_prompt:
102 | yield "", system_prompt + self.sep
103 |
104 | for i, (role, message) in enumerate(self.messages):
105 | if i % 2 == 0:
106 | yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
107 |
108 | if message:
109 | yield f"{role}:", f"{message}{self.sep}"
110 | else:
111 | yield f"{role}:", ""
112 | return
113 | if self.sep_style == SeparatorStyle.CHATML:
114 | yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
115 | for role, message in self.messages:
116 | if message:
117 | yield role + "\n", message + self.sep + "\n"
118 | else:
119 | yield role + "\n", ""
120 | return
121 | if self.sep_style == SeparatorStyle.CHATINTERN:
122 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
123 | seps = [self.sep, self.sep2]
124 | yield "", system_prompt
125 | for i, (role, message) in enumerate(self.messages):
126 | prefix = "" if i % 2 == 0 else ""
127 | if message:
128 | yield prefix + role + ":", message + seps[i % 2] + "\n"
129 | else:
130 | yield role + ":", ""
131 | return
132 | if self.sep_style == SeparatorStyle.DOLLY:
133 | seps = [self.sep, self.sep2]
134 | yield "", system_prompt
135 | for i, (role, message) in enumerate(self.messages):
136 | if message:
137 | suffix = "\n\n" if i % 2 == 1 else ""
138 | yield role + ":\n", message + seps[i % 2] + suffix
139 | else:
140 | yield role + ":\n", ""
141 | return
142 | if self.sep_style == SeparatorStyle.PHOENIX:
143 | yield "", system_prompt
144 | for role, message in self.messages:
145 | if message:
146 | yield role + ": ", "" + message + ""
147 | else:
148 | yield role + ": " + "", ""
149 | return
150 | if self.sep_style == SeparatorStyle.ROBIN:
151 | yield "", system_prompt + self.sep
152 | for role, message in self.messages:
153 | if message:
154 | yield role + ":\n", message + self.sep
155 | else:
156 | yield role + ":\n", ""
157 | return
158 | if self.sep_style == SeparatorStyle.FALCON_CHAT:
159 | if self.system_message:
160 | yield "", system_prompt + self.sep
161 | for role, message in self.messages:
162 | if message:
163 | yield role + ": ", message + self.sep
164 | else:
165 | yield role + ":", ""
166 | else:
167 | raise ValueError(f"Invalid style: {self.sep_style}")
168 |
169 |
170 | def add_get_turns_to_conversation():
171 | import fastchat.conversation
172 |
173 | fastchat.conversation.Conversation.get_turns = get_turns
174 | fastchat.conversation.Conversation.get_prompt = get_prompt
175 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/creative_acr.py:
--------------------------------------------------------------------------------
1 | """Module loading the CreativePromptTokenizingStrategy and similar classes"""
2 |
3 | from typing import Generator, Tuple, Union
4 |
5 | import yaml
6 |
7 | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
8 |
9 |
10 | class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
11 | """
12 | Tokenizing strategy for Creative Answering
13 | """
14 |
15 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
16 | question = prompt["instruction"]
17 | answer = prompt[
18 | "revision"
19 | ] # don't use prompt[answer], that's data we don't want in the dataset
20 | return (
21 | question,
22 | "",
23 | answer,
24 | )
25 |
26 |
27 | class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
28 | """
29 | Tokenizing strategy for Creative Critique
30 | """
31 |
32 | user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
33 | refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
34 | prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
35 | creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
36 | comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
37 | Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.
38 |
39 | Output your answer in YAML format like so:
40 | scores:
41 | refusal: <0-10>
42 | prescriptive_bias: <0-10>
43 | creativity: <0-10>
44 | comprehensiveness: <0-10>
45 | critiques:
46 | refusal:
47 | explanation: ...
48 | improvements: ...
49 | prescriptive_bias:
50 | explanation: ...
51 | improvements: ...
52 | creativity:
53 | explanation: ...
54 | improvements: ...
55 | comprehensiveness:
56 | explanation: ...
57 | improvements: ...
58 |
59 | Question: {question}
60 | Answer: {answer}
61 | """
62 |
63 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
64 | scores = yaml.dump(
65 | prompt["scores"],
66 | default_flow_style=False,
67 | Dumper=yaml.Dumper,
68 | )
69 | critiques = yaml.dump(
70 | prompt["critiques"],
71 | default_flow_style=False,
72 | Dumper=yaml.Dumper,
73 | )
74 | evaluation = scores + critiques
75 | question = prompt["instruction"]
76 | answer = prompt["answer"]
77 | return (
78 | self.user_prompt.format(question=question, answer=answer),
79 | "",
80 | evaluation,
81 | )
82 |
83 |
84 | class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
85 | """
86 | Tokenizing strategy for Creative Revise
87 | """
88 |
89 | user_prompt = """Definitions:
90 | refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
91 | prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
92 | creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
93 | comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
94 |
95 | Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.
96 |
97 | Question: {question}
98 | Answer: {answer}
99 | Evaluation:
100 | {evaluation}
101 | """
102 |
103 | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
104 | scores = yaml.dump(
105 | prompt["scores"],
106 | default_flow_style=False,
107 | Dumper=yaml.Dumper,
108 | )
109 | critiques = yaml.dump(
110 | prompt["critiques"],
111 | default_flow_style=False,
112 | Dumper=yaml.Dumper,
113 | )
114 | evaluation = scores + critiques
115 | question = prompt["instruction"]
116 | answer = prompt["answer"]
117 | return (
118 | self.user_prompt.format(
119 | question=question, answer=answer, evaluation=evaluation
120 | ),
121 | "",
122 | prompt["revision"],
123 | )
124 |
125 |
126 | class CreativePrompterBase:
127 | """
128 | Base class for Creative Prompters
129 | """
130 |
131 | system_prompt = ""
132 | prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
133 |
134 | def build_prompt(
135 | self,
136 | instruction: str,
137 | input: Union[ # pylint: disable=redefined-builtin, unused-argument
138 | None, str
139 | ] = None,
140 | output: Union[None, str] = None,
141 | ) -> Generator[str, None, None]:
142 | if self.system_prompt:
143 | res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:"
144 | else:
145 | res = f"USER: {instruction}\nASSISTANT:"
146 | if output:
147 | res = f"{res}{output}"
148 | yield res
149 |
150 |
151 | class CreativeAnswerPrompter(CreativePrompterBase):
152 | """
153 | Prompter for Creative Answering
154 | """
155 |
156 | system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
157 |
158 |
159 | class CreativeCritiquePrompter(CreativePrompterBase):
160 | """
161 | Prompter for Creative Critique
162 | """
163 |
164 | system_prompt = ""
165 |
166 |
167 | class CreativeRevisePrompter(CreativePrompterBase):
168 | """
169 | Prompter for Creative Revise
170 | """
171 |
172 | system_prompt = ""
173 |
174 |
175 | def load_answer(tokenizer, cfg):
176 | return CreativeAnsweringPromptTokenizingStrategy(
177 | CreativeAnswerPrompter(),
178 | tokenizer,
179 | cfg.train_on_inputs,
180 | cfg.sequence_len,
181 | )
182 |
183 |
184 | def load_critique(tokenizer, cfg):
185 | return CreativeCritiquePromptTokenizingStrategy(
186 | CreativeCritiquePrompter(),
187 | tokenizer,
188 | cfg.train_on_inputs,
189 | cfg.sequence_len,
190 | )
191 |
192 |
193 | def load_revise(tokenizer, cfg):
194 | return CreativeRevisePromptTokenizingStrategy(
195 | CreativeRevisePrompter(),
196 | tokenizer,
197 | cfg.train_on_inputs,
198 | cfg.sequence_len,
199 | )
200 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/train.py:
--------------------------------------------------------------------------------
1 | """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2 |
3 | import logging
4 | import os
5 | import signal
6 | import sys
7 | from dataclasses import dataclass
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import torch
12 | import transformers.modelcard
13 | from datasets import Dataset
14 | from optimum.bettertransformer import BetterTransformer
15 | from transformers.deepspeed import is_deepspeed_zero3_enabled
16 |
17 | from axolotl.common.cli import TrainerCliArgs
18 | from axolotl.logging_config import configure_logging
19 | from axolotl.utils.dict import DictDefault
20 | from axolotl.utils.models import load_model, load_tokenizer
21 | from axolotl.utils.trainer import setup_trainer
22 |
23 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
24 | src_dir = os.path.join(project_root, "src")
25 | sys.path.insert(0, src_dir)
26 |
27 | configure_logging()
28 | LOG = logging.getLogger("axolotl.train")
29 |
30 |
31 | @dataclass
32 | class TrainDatasetMeta:
33 | """
34 | dataclass to capture the dataset specific options for training
35 | """
36 |
37 | train_dataset: Dataset
38 | eval_dataset: Optional[Dataset] = None
39 | total_num_steps: Optional[int] = None
40 |
41 |
42 | def train(
43 | *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
44 | ):
45 | # load the tokenizer first
46 | LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
47 | tokenizer = load_tokenizer(cfg)
48 |
49 | train_dataset = dataset_meta.train_dataset
50 | eval_dataset = dataset_meta.eval_dataset
51 | total_num_steps = dataset_meta.total_num_steps
52 |
53 | # print(f"Training instance: {train_dataset[0]}")
54 | print("=" * 35 + "Sample Training Instance" + "=" * 35)
55 | print(tokenizer.decode(train_dataset[0]["input_ids"]))
56 | print("=" * 80)
57 | # exit(0)
58 |
59 | # Load the model and tokenizer
60 | LOG.info("loading model and (optionally) peft_config...")
61 | model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
62 |
63 | safe_serialization = cfg.save_safetensors is True
64 |
65 | if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
66 | possible_checkpoints = [
67 | str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
68 | ]
69 | if len(possible_checkpoints) > 0:
70 | sorted_paths = sorted(
71 | possible_checkpoints,
72 | key=lambda path: int(path.split("-")[-1]),
73 | )
74 | cfg.resume_from_checkpoint = sorted_paths[-1]
75 | LOG.info(
76 | f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
77 | )
78 | resume_from_checkpoint = cfg.resume_from_checkpoint
79 |
80 | trainer = setup_trainer(
81 | cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
82 | )
83 |
84 | model.config.use_cache = False
85 |
86 | # go ahead and presave, so we have the adapter config available to inspect
87 | if peft_config:
88 | LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
89 | peft_config.save_pretrained(cfg.output_dir)
90 | # additionally presave the tokenizer and model configs
91 | if not Path(cfg.output_dir).is_dir():
92 | os.makedirs(cfg.output_dir, exist_ok=True)
93 | tokenizer.save_pretrained(str(Path(cfg.output_dir)))
94 | model.config.save_pretrained(str(Path(cfg.output_dir)))
95 |
96 | # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
97 | if cfg.local_rank == 0:
98 |
99 | def terminate_handler(_, __, model):
100 | if cfg.flash_optimum:
101 | model = BetterTransformer.reverse(model)
102 | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
103 | sys.exit(0)
104 |
105 | signal.signal(
106 | signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
107 | )
108 |
109 | badge_markdown = """[
](https://github.com/OpenAccess-AI-Collective/axolotl)"""
110 | transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
111 |
112 | LOG.info("Starting trainer...")
113 | if cfg.group_by_length:
114 | LOG.info("hang tight... sorting dataset for group_by_length")
115 |
116 | if cfg.flash_optimum:
117 | with torch.backends.cuda.sdp_kernel(
118 | enable_flash=True, enable_math=True, enable_mem_efficient=True
119 | ):
120 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
121 | else:
122 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
123 |
124 | LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
125 |
126 | # post training
127 | for name, module in model.named_modules():
128 | if hasattr(module, "_post_training"):
129 | module._post_training(model, name) # pylint: disable=protected-access
130 |
131 | if trainer.is_fsdp_enabled:
132 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
133 | LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
134 |
135 | if cfg.relora_steps:
136 | if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
137 | model = model.merge_and_unload()
138 | else:
139 | # final model weights have already been saved by `ReLoRACallback.on_train_end`
140 | return model, tokenizer
141 |
142 | # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
143 | # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
144 | if cfg.fsdp:
145 | trainer.save_model(cfg.output_dir)
146 | elif cfg.deepspeed and is_deepspeed_zero3_enabled():
147 | # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
148 | trainer.accelerator.wait_for_everyone()
149 | unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
150 |
151 | # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
152 | # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
153 | # `zero3_save_16bit_model` is True in DeepSpeed Plugin.
154 | # For Zero Stages 1 and 2, models are saved as usual in the output directory.
155 | # The model name saved is `pytorch_model.bin`
156 | unwrapped_model.save_pretrained(
157 | cfg.output_dir,
158 | is_main_process=trainer.accelerator.is_main_process,
159 | save_function=trainer.accelerator.save,
160 | state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
161 | )
162 | elif cfg.local_rank == 0:
163 | if cfg.flash_optimum:
164 | model = BetterTransformer.reverse(model)
165 |
166 | model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
167 |
168 | if not cfg.hub_model_id:
169 | trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
170 |
171 | return model, tokenizer
172 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/datasets.py:
--------------------------------------------------------------------------------
1 | """Module containing Dataset functionality"""
2 |
3 | import logging
4 | import os
5 | from typing import List
6 |
7 | import torch
8 | from datasets import Dataset, IterableDataset
9 |
10 | from .prompt_tokenizers import PromptTokenizingStrategy
11 |
12 | # We want this to be a wrapper for an existing dataset that we have loaded
13 | # lets use the concept of middlewares to wrap each dataset, for example
14 | # ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
15 | # let's check to ensure we don't truncate an item in the middle, we'll use
16 | # the collators later on to pad the datasets
17 |
18 | LOG = logging.getLogger("axolotl")
19 |
20 |
21 | class TokenizedPromptDataset(Dataset):
22 | """
23 | Dataset that returns tokenized prompts from a stream of text files.
24 | Args:
25 | prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
26 | dataset (dataset.Dataset): Dataset with text files.
27 | """
28 |
29 | def __init__( # pylint: disable=super-init-not-called
30 | self,
31 | prompt_tokenizer: PromptTokenizingStrategy,
32 | dataset: IterableDataset,
33 | **kwargs,
34 | ):
35 | self.prompt_tokenizer = prompt_tokenizer
36 | super().__init__(self.process(dataset).data, **kwargs)
37 |
38 | def process(self, dataset):
39 | features = dataset.features.keys()
40 | num_proc = min(64, os.cpu_count())
41 | map_kwargs = {}
42 | if self.prompt_tokenizer.supports_batched:
43 | map_kwargs["batched"] = True
44 | map_kwargs["batch_size"] = 100
45 | return dataset.map(
46 | self.prompt_tokenizer.tokenize_prompt,
47 | num_proc=num_proc,
48 | remove_columns=features,
49 | **map_kwargs,
50 | )
51 |
52 |
53 | # TODO this isn't the best since it can't interleave datasets
54 | class ConstantLengthDataset(IterableDataset):
55 | """
56 | Iterable dataset that returns constant length chunks of tokens from stream of text files.
57 | Args:
58 | tokenizer (Tokenizer): The processor used for processing the data.
59 | dataset (dataset.Dataset): Dataset with text files.
60 | seq_length (int): Length of token sequences to return.
61 | """
62 |
63 | def __init__( # pylint: disable=super-init-not-called
64 | self,
65 | tokenizer,
66 | datasets,
67 | seq_length=2048,
68 | ):
69 | self.tokenizer = tokenizer
70 | self.concat_token_id = tokenizer.eos_token_id
71 | self.datasets: List[IterableDataset] = datasets
72 | self.seq_length = seq_length
73 |
74 | vocab_size = len(tokenizer.get_vocab())
75 |
76 | if vocab_size <= torch.iinfo(torch.int16).max:
77 | self.tokens_dtype = torch.int16
78 | elif vocab_size <= torch.iinfo(torch.int32).max:
79 | self.tokens_dtype = torch.int32
80 | else:
81 | self.tokens_dtype = torch.int64
82 |
83 | def __iter__(self):
84 | buffer = {
85 | "input_ids": [],
86 | "attention_mask": [],
87 | "labels": [],
88 | "position_ids": [],
89 | }
90 | buffer_len = 0
91 | for dataset in self.datasets:
92 | idx = 0
93 | iterator = iter(dataset)
94 | more_examples = True
95 | while more_examples:
96 | try:
97 | example = next(iterator)
98 | idx += 1
99 | except StopIteration:
100 | more_examples = False
101 | example = None
102 |
103 | add_concat_token = False
104 | if example:
105 | example_len = len(example["input_ids"])
106 | add_concat_token = example["input_ids"][-1] != self.concat_token_id
107 | else:
108 | example_len = 0
109 |
110 | if not example_len or (
111 | buffer_len + int(add_concat_token) + example_len > self.seq_length
112 | ):
113 | if buffer["input_ids"]:
114 | input_ids = torch.cat(buffer["input_ids"], dim=-1)[
115 | : self.seq_length
116 | ]
117 | attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
118 | : self.seq_length
119 | ]
120 | position_ids = torch.cat(buffer["position_ids"], dim=-1)[
121 | : self.seq_length
122 | ]
123 | labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
124 | if labels.size() == input_ids.size() and (
125 | attention_mask.size() == input_ids.size()
126 | ):
127 | yield {
128 | "input_ids": input_ids,
129 | "labels": labels,
130 | "attention_mask": attention_mask,
131 | "position_ids": position_ids,
132 | }
133 | else:
134 | LOG.warning(
135 | f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
136 | )
137 | buffer = {
138 | "input_ids": [],
139 | "attention_mask": [],
140 | "labels": [],
141 | "position_ids": [],
142 | }
143 | buffer_len = 0
144 | idx = 1
145 |
146 | if example:
147 | # FIXME
148 | # just going to drop data points that are too long
149 | if len(example["input_ids"]) <= self.seq_length:
150 | input_ids = example["input_ids"]
151 | attention_mask = example["attention_mask"]
152 | labels = example["labels"]
153 |
154 | if add_concat_token:
155 | input_ids.append(self.concat_token_id)
156 | attention_mask.append(1)
157 | labels.append(self.concat_token_id)
158 |
159 | input_ids_with_concat = torch.tensor(
160 | input_ids, dtype=self.tokens_dtype
161 | )
162 | attention_mask_with_concat = torch.tensor(
163 | [idx * m for m in attention_mask], dtype=torch.int16
164 | )
165 | labels_with_concat = torch.tensor(
166 | labels, dtype=self.tokens_dtype
167 | )
168 | position_ids = torch.arange(
169 | len(input_ids), dtype=self.tokens_dtype
170 | )
171 |
172 | buffer["input_ids"].append(input_ids_with_concat)
173 | buffer["attention_mask"].append(attention_mask_with_concat)
174 | buffer["labels"].append(labels_with_concat)
175 | buffer["position_ids"].append(position_ids)
176 | buffer_len += len(input_ids)
177 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/distributed.py:
--------------------------------------------------------------------------------
1 | """
2 | utility helpers for distributed checks
3 | """
4 | import os
5 | import pickle # nosec
6 | from contextlib import contextmanager
7 |
8 | import torch
9 | import torch.distributed as dist
10 | from accelerate import Accelerator
11 |
12 | accelerate = None # pylint: disable=invalid-name
13 |
14 |
15 | def load_accelerate():
16 | global accelerate # pylint: disable=global-statement
17 | accelerate = Accelerator()
18 |
19 |
20 | def is_distributed():
21 | """
22 | Check if distributed training is initialized.
23 | """
24 | global accelerate # pylint: disable=global-statement
25 | if not accelerate:
26 | accelerate = Accelerator()
27 | return dist.is_available() and dist.is_initialized()
28 |
29 |
30 | def barrier():
31 | """
32 | Acts as a barrier to wait for all processes. This ensures that all processes
33 | reach the barrier before proceeding further.
34 | """
35 | if is_distributed():
36 | dist.barrier()
37 |
38 |
39 | def is_main_process():
40 | """
41 | Check if the current process is the main process.
42 | If not in distributed mode, always return True.
43 | """
44 | if not is_distributed():
45 | return True
46 | return dist.get_rank() == 0
47 |
48 |
49 | def get_world_size():
50 | return int(os.getenv("WORLD_SIZE", "1"))
51 |
52 |
53 | @contextmanager
54 | def zero_first(is_main):
55 | """
56 | runs the wrapped context so that rank 0 runs first before other ranks
57 | """
58 | if not is_main: # other ranks wait first
59 | barrier()
60 | yield
61 | if is_main: # then rank 0 waits after it has run the context
62 | barrier()
63 |
64 |
65 | def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
66 | """
67 | Run a callable 'fn' on all ranks and gather the results on the specified rank.
68 |
69 | Args:
70 | - fn (callable): A function that computes the value. This should not have any side effects.
71 | - rank (int, optional): The rank that gathers the values. Default is 0.
72 | - world_size (int, optional): Total number of processes in the current distributed setup.
73 |
74 | Returns:
75 | - A list of computed values from all ranks if on the gathering rank, otherwise None.
76 | """
77 | value_scalar = fn()
78 | if not is_distributed():
79 | return [value_scalar]
80 | value_tensor = torch.tensor(
81 | value_scalar, device=torch.cuda.current_device()
82 | ).float()
83 |
84 | if not is_main_process():
85 | dist.gather(value_tensor, dst=0)
86 | else:
87 | gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
88 | dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
89 |
90 | # Convert tensors back to their original type (int or float)
91 | gathered_values = []
92 | for tensor in gathered_tensors:
93 | if tensor == tensor.int():
94 | gathered_values.append(int(tensor.item()))
95 | else:
96 | gathered_values.append(float(tensor.item()))
97 | return gathered_values
98 | return None
99 |
100 |
101 | def broadcast_dict(vals: dict):
102 | if not is_distributed():
103 | return vals
104 |
105 | if is_main_process():
106 | data_byte = pickle.dumps(vals)
107 | data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
108 | data_size = torch.IntTensor([len(data_byte)]).to("cuda")
109 | else:
110 | data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
111 | data_size = torch.IntTensor([0]).to("cuda")
112 |
113 | dist.broadcast(data_size, 0)
114 | if not is_main_process():
115 | # resize
116 | data_tensor = data_tensor.new_empty([data_size.item()])
117 |
118 | dist.broadcast(data_tensor, 0)
119 |
120 | if not is_main_process():
121 | data_list = data_tensor.cpu().tolist()
122 | data_byte = bytes(data_list[: data_size.item()])
123 | vals = pickle.loads(data_byte) # nosec
124 |
125 | return vals
126 |
127 |
128 | def compute_and_broadcast(fn): # pylint: disable=invalid-name
129 | """
130 | Compute a value using the function 'fn' only on the specified rank (default is 0).
131 | The value is then broadcasted to all other ranks.
132 |
133 | Args:
134 | - fn (callable): A function that computes the value. This should not have any side effects.
135 | - rank (int, optional): The rank that computes the value. Default is 0.
136 |
137 | Returns:
138 | - The computed value (int or float).
139 | """
140 | if is_main_process():
141 | value_scalar = fn()
142 | value_tensor = torch.tensor(
143 | value_scalar, device=torch.cuda.current_device()
144 | ).float()
145 | else:
146 | value_tensor = torch.tensor(
147 | 0.0, device=torch.cuda.current_device()
148 | ) # Placeholder tensor
149 |
150 | # Broadcast the tensor to all processes.
151 | barrier()
152 | dist.broadcast(value_tensor, src=0)
153 |
154 | # Convert the tensor back to its original type (int or float)
155 | if value_tensor == value_tensor.int():
156 | return int(value_tensor.item())
157 | return float(value_tensor.item())
158 |
159 |
160 | def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
161 | """
162 | Run a callable 'fn' on all ranks and gather the results on the specified rank.
163 |
164 | Args:
165 | - fn (callable): A function that computes the value. This should not have any side effects.
166 | - rank (int, optional): The rank that gathers the values. Default is 0.
167 | - world_size (int, optional): Total number of processes in the current distributed setup.
168 |
169 | Returns:
170 | - A list of computed values from all ranks if on the gathering rank, otherwise None.
171 | """
172 | value_scalar = fn()
173 | value_tensor = torch.tensor(
174 | value_scalar, device=torch.cuda.current_device()
175 | ).float()
176 |
177 | # Placeholder tensor for gathering results
178 | if is_main_process():
179 | gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
180 | else:
181 | gathered_tensors = None
182 |
183 | dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
184 |
185 | if is_main_process():
186 | # Convert tensors back to their original type (int or float)
187 | gathered_values = []
188 | for tensor in gathered_tensors:
189 | if tensor == tensor.int():
190 | gathered_values.append(int(tensor.item()))
191 | else:
192 | gathered_values.append(float(tensor.item()))
193 | return gathered_values
194 | return None
195 |
196 |
197 | def reduce_and_broadcast(fn1, fn2):
198 | """
199 | Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
200 | and then broadcast the reduced result to all ranks.
201 |
202 | Args:
203 | - fn1 (callable): A function that computes the value on each rank.
204 | - fn2 (callable): A reduction function that takes a list of values and returns a single value.
205 | - world_size (int, optional): Total number of processes in the current distributed setup.
206 |
207 | Returns:
208 | - The reduced and broadcasted value.
209 | """
210 |
211 | # Gather values from all ranks using fn1
212 | if not is_distributed():
213 | return fn2([fn1()])
214 |
215 | gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
216 |
217 | # Use compute_and_broadcast to compute the reduced value on the main process
218 | # and then broadcast it to all ranks
219 | return compute_and_broadcast(lambda: fn2(gathered_values))
220 |
--------------------------------------------------------------------------------
/data_prep/combine_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from datasets import load_dataset
5 |
6 | # This path contains all the datasets: https://drive.google.com/drive/folders/1bPfxrwcgGrX3-3CHJ0SLckY22zgdPwCj
7 | # Todo: Refer to this for description of each dataset:
8 |
9 | """
10 | How to run the script:
11 | 1. Install necessary libraries.
12 | 2. Replace local dataset paths with datasets from above drive path
13 | 3. Replace output paths at end of the script to save dataset at your own local paths.
14 | """
15 |
16 | def fetch_instruction(task_file_name):
17 | with open(os.path.join("/home/minimalist/work/natural-instructions/tasks", task_file_name)) as f:
18 | task_json = json.load(f)
19 | return task_json["Definition"][0]
20 |
21 | def prepare_quac_file():
22 | file_path = "/home/minimalist/Downloads/nips_training_data/quac_train.json"
23 | data = []
24 | lines = open(file_path, "r").readlines()
25 | for line in lines:
26 | rec = json.loads(line.strip())
27 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) > 800:
28 | continue
29 | data.append(rec)
30 | random.shuffle(data)
31 | sampled_data = data
32 | print(f"quac dataset size: {len(sampled_data)}")
33 | return sampled_data
34 |
35 |
36 | def prepare_openQA_files():
37 | data = []
38 |
39 | file_paths = ["/home/minimalist/Downloads/nips_training_data/openQA_train.json"
40 | ]
41 | for file_path in file_paths:
42 | lines = open(file_path, "r").readlines()
43 | for line in lines:
44 | rec = json.loads(line.strip())
45 | rec['instruction'] = "You are given an incomplete sentence along with a few reference completions. Only one of the reference completion is coherent and correct based on the context in sentence and all other reference completions are incorrect. Your task is to choose best matching completion from reference options and return that."
46 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) > 800:
47 | continue
48 | data.append(rec)
49 | random.shuffle(data)
50 | sampled_data = data
51 | print(f"OpenQA dataset size: {len(sampled_data)}")
52 | return sampled_data
53 |
54 |
55 | def prepare_cnn_dailymail_file():
56 | file_path = "/home/minimalist/Downloads/nips_training_data/cnn_dailymail_summarization_random_train.json"
57 | data = []
58 | lines = open(file_path, "r").readlines()
59 | for line in lines:
60 | rec = json.loads(line.strip())
61 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) > 800:
62 | continue
63 | data.append(rec)
64 | random.shuffle(data)
65 | sampled_data = data
66 | print(f"CNN/DailyMail dataset size: {len(sampled_data)}")
67 | return sampled_data
68 |
69 | def prepare_maths_file():
70 | file_path = "/home/minimalist/Downloads/nips_training_data/math_reasoning.json"
71 |
72 | with open(file_path) as f:
73 | data = json.load(f)
74 |
75 | sampled_data = []
76 | for rec in data:
77 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800 and "print(" not in rec['output']:
78 | sampled_data.append(rec)
79 |
80 | random.shuffle(sampled_data)
81 | sampled_data = sampled_data[:250000]
82 | print(f"maths_reasoning dataset size: {len(sampled_data)}")
83 | return sampled_data
84 |
85 |
86 | def prepare_platypus_file():
87 | file_path = "/home/minimalist/Downloads/nips_training_data/platypus_no_llm.json"
88 |
89 | with open(file_path) as f:
90 | data = json.load(f)
91 |
92 | sampled_data = []
93 | for rec in data:
94 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800:
95 | sampled_data.append(rec)
96 |
97 | random.shuffle(sampled_data)
98 | sampled_data = sampled_data
99 | print(f"platypus dataset size: {len(sampled_data)}")
100 | return data
101 |
102 |
103 | def prepare_super_natural_generation_tasks_samples_file():
104 | file_path = "/home/minimalist/Downloads/nips_training_data/natural_instructions_generation_tasks_dataset.json"
105 | task_to_instuction_map = {}
106 |
107 | with open(file_path) as f:
108 | data = json.load(f)
109 | print("Loaded file")
110 |
111 | sampled_data = []
112 | for i, rec in enumerate(data):
113 | if rec["task_file"] not in task_to_instuction_map:
114 | rec['instruction'] = fetch_instruction(rec["task_file"])
115 | task_to_instuction_map[rec["task_file"]] = rec['instruction']
116 | else:
117 | rec['instruction'] = task_to_instuction_map[rec["task_file"]]
118 | rec.pop('task_file', None)
119 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800:
120 | sampled_data.append(rec)
121 |
122 | random.shuffle(sampled_data)
123 | sampled_data = sampled_data
124 | print(f"Super Natural Sentence dataset size: {len(sampled_data)}")
125 | return sampled_data
126 |
127 |
128 | def prepare_super_natural_exact_match_tasks_samples_file():
129 | file_path = "/home/minimalist/Downloads/nips_training_data/natural_instructions_exact_match_tasks_dataset.json"
130 |
131 | with open(file_path) as f:
132 | data = json.load(f)
133 |
134 | random.shuffle(data)
135 | sampled_data = data
136 | print(f"Super Natural One Word dataset size: {len(sampled_data)}")
137 | return sampled_data
138 |
139 |
140 |
141 | def generate_lima_prompt(example):
142 | """Generates a standardized message to prompt the model with an instruction, optional input and a
143 | 'response' field."""
144 |
145 | if example["input"]:
146 | return (
147 | "Below is an instruction that describes a task, paired with an input that provides further context. "
148 | "Write a response that appropriately completes the request.\n\n"
149 | f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
150 | )
151 | return (
152 | "Below is an instruction that describes a task. "
153 | "Write a response that appropriately completes the request.\n\n"
154 | f"### Instruction:\n{example['instruction']}\n\n### Response:"
155 | )
156 |
157 |
158 | def prepare_lima_file():
159 | dataset = load_dataset("GAIR/lima")["train"]
160 | data = []
161 | for convo in dataset["conversations"]:
162 | rec = {"instruction": "", "input": convo[0], "output": convo[1]}
163 | if len(f"{rec['instruction']} {rec['input']}".split(" ")) < 800:
164 | data.append(rec)
165 | print(f"LIMA dataset size: {len(data)}")
166 | return data
167 |
168 |
169 |
170 | data = []
171 | data += prepare_quac_file()
172 | data += prepare_openQA_files()
173 | data += prepare_cnn_dailymail_file()
174 | data += prepare_maths_file()
175 | data += prepare_platypus_file()
176 | data += prepare_super_natural_generation_tasks_samples_file()
177 | data += prepare_super_natural_exact_match_tasks_samples_file()
178 | data += prepare_lima_file()
179 |
180 | print(f"total data size: {len(data)}")
181 |
182 | random.shuffle(data)
183 |
184 | instruction_prefix = "### Instruction:\n"
185 | input_prefix = "### Input:\n"
186 | response_prefix = "### Response:\n"
187 | final_dataset = []
188 | all_inputs = set()
189 | for each in data:
190 | if each['input'] not in all_inputs:
191 | all_inputs.add(each['input'])
192 | else:
193 | continue
194 |
195 | if len(each['input'].split(" ")) <= 4 or len(each['output'].split(" ")) < 1:
196 | continue
197 |
198 | rec = {}
199 | rec['question'] = ''
200 | if len(each['instruction'].strip()) > 0:
201 | rec['question'] += f"{instruction_prefix}{each['instruction']}\n\n"
202 | rec['question'] += f"{input_prefix}{each['input']}\n\n"
203 | rec['question'] += f"{response_prefix}"
204 | rec['answer'] = f"{each['output']}"
205 | final_dataset.append(rec)
206 |
207 | print(f"final_dataset size: {len(final_dataset)}")
208 | train_dataset = final_dataset[:int(0.98*len(final_dataset))]
209 | eval_dataset = final_dataset[int(0.98*len(final_dataset)):]
210 |
211 | with open(f"/home/minimalist/Downloads/nips_training_data/train_dataset.json", 'w') as f:
212 | json.dump(train_dataset, f, indent=1)
213 |
214 | with open(f"/home/minimalist/Downloads/nips_training_data/eval_dataset.json", 'w') as f:
215 | json.dump(eval_dataset, f, indent=1)
216 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/prompt_strategies/llama2_chat.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt Strategy for finetuning Llama2 chat models
3 | see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
4 |
5 | This implementation is based on the Vicuna PR and the fastchat repo, see also:
6 | https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
7 |
8 | Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
9 |
10 | E.g. in the config.yml:
11 | ```
12 | datasets:
13 | - path: llama_finetune_train.jsonl
14 | type: llama2_chat
15 | ```
16 |
17 | The dataset itself should look like this:
18 | ```
19 | {'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
20 | ```
21 | in a jsonl file. The first message should be from the human, the second from gpt.
22 | For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
23 |
24 | Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
25 | """
26 |
27 | import logging
28 | from dataclasses import dataclass, field
29 | from typing import Generator, List, Sequence
30 |
31 | from axolotl.prompt_tokenizers import PromptTokenizingStrategy
32 | from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
33 |
34 |
35 | @dataclass
36 | class Llama2ChatConversation:
37 | """A class that manages prompt templates and keeps all conversation history.
38 | copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
39 |
40 | name: str = "llama2"
41 | # The system prompt
42 | system: str = (
43 | "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
44 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
45 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
46 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
47 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n"
48 | )
49 | roles: Sequence[str] = ("[INST]", "[/INST]")
50 | messages: List[List[str]] = field(default_factory=list)
51 | offset: int = 0
52 | sep = " "
53 | sep2 = " "
54 | stop_token_ids = [2]
55 |
56 | def get_prompt(self) -> str:
57 | """Get the prompt for generation."""
58 | seps = [self.sep, self.sep2]
59 | ret = ""
60 | for i, (role, message) in enumerate(self.messages):
61 | if (i == len(self.messages) - 1) and (role == self.roles[0]):
62 | # last message is from user (due to length),
63 | # return prompt without it for training
64 | return ret
65 | if i == 0:
66 | ret += self.system + message.strip()
67 | else:
68 | ret += role + " " + message.strip() + seps[i % 2]
69 | return ret
70 |
71 | def append_message(self, role: str, message: str):
72 | """Append a new message."""
73 | self.messages.append([role, message])
74 |
75 |
76 | class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
77 | """
78 | Tokenizing strategy for ShareGPT prompts.
79 | adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
80 | """
81 |
82 | def __init__(self, *args, **kwargs):
83 | super().__init__(*args, **kwargs)
84 | self.sequence_len = 4096
85 | self.tokenizer.add_special_tokens({"pad_token": ""})
86 | # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
87 |
88 | def tokenize_prompt(self, prompt):
89 | conv = next(self.prompter.build_prompt(prompt))
90 | conversation_str = conv.get_prompt()
91 |
92 | # Tokenize conversations
93 | input_ids = self.tokenizer(
94 | conversation_str,
95 | return_tensors="pt",
96 | padding="max_length",
97 | max_length=self.sequence_len,
98 | truncation=True,
99 | ).input_ids[0]
100 | target = input_ids.clone()
101 |
102 | # Mask targets. Only compute loss on the assistant outputs.
103 | sep = conv.roles[1]
104 |
105 | total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
106 |
107 | turns = conversation_str.split(conv.sep2)
108 | cur_len = 1
109 | target[:cur_len] = IGNORE_TOKEN_ID
110 | for turn in turns:
111 | if turn == "":
112 | break
113 | turn_len = len(self.tokenizer(turn).input_ids)
114 |
115 | parts = turn.split(sep)
116 | if len(parts) != 2:
117 | break
118 | parts[0] += sep
119 | # "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
120 | instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
121 |
122 | # Ignore the user instructions
123 | target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
124 | cur_len += turn_len + 2 # due to length of role token
125 |
126 | target[cur_len:] = IGNORE_TOKEN_ID
127 |
128 | if cur_len < self.sequence_len:
129 | if cur_len != total_len:
130 | target[:] = IGNORE_TOKEN_ID
131 | logging.warning(
132 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
133 | f" (ignored)"
134 | )
135 |
136 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
137 | input_ids = input_ids.tolist()
138 | target = target.tolist()
139 | # this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
140 | # follows the original llama implementation
141 | for i in range(2, total_len - 2):
142 | if input_ids[i] == 29961:
143 | input_ids[i] = 518
144 | if target[i] == 29961:
145 | target[i] = 518
146 | return {
147 | "input_ids": input_ids,
148 | "labels": target,
149 | "attention_mask": attention_mask,
150 | }
151 |
152 |
153 | class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
154 | """
155 | A prompter that generates prompts for Llama2 models.
156 | """
157 |
158 | system_prompt = (
159 | "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
160 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
161 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
162 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
163 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n"
164 | )
165 |
166 | def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
167 | # see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
168 | source = source["conversations"] # fix data structure for datasets
169 |
170 | # if system prompt provided, use it
171 | if source[0]["from"] == "system":
172 | system = f"[INST] <>\n{source[0]['value']}\n<>\n\n"
173 | source = source[1:]
174 | else:
175 | system = self.system_prompt
176 |
177 | conv = Llama2ChatConversation(system=system)
178 |
179 | if len(source) < 2:
180 | # If there isn't a back and forth conversation, ignore it
181 | # also happens on the data splitting leaving empty conversations
182 | raise IndexError
183 |
184 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
185 |
186 | if roles[source[0]["from"]] != conv.roles[0]:
187 | # Skip the first one if it is not from human
188 | source = source[1:]
189 |
190 | conv.messages = [] # pylint: disable=R0801
191 | for j, sentence in enumerate(source):
192 | role = roles[sentence["from"]]
193 | assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
194 | if sentence["value"]:
195 | conv.append_message(role, sentence["value"])
196 | yield conv
197 |
198 |
199 | def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
200 | return LLama2ChatTokenizingStrategy(
201 | Llama2ChatPrompter(),
202 | tokenizer,
203 | cfg.train_on_inputs,
204 | cfg.sequence_len,
205 | )
206 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/cli/__init__.py:
--------------------------------------------------------------------------------
1 | """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2 |
3 | import importlib
4 | import logging
5 | import os
6 | import random
7 | import sys
8 | from pathlib import Path
9 | from typing import Any, Dict, List, Optional, Union
10 |
11 | import torch
12 | import yaml
13 |
14 | # add src to the pythonpath so we don't need to pip install this
15 | from accelerate.commands.config import config_args
16 | from art import text2art
17 | from huggingface_hub import HfApi
18 | from huggingface_hub.utils import LocalTokenNotFoundError
19 | from transformers import GenerationConfig, TextStreamer
20 |
21 | from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
22 | from axolotl.logging_config import configure_logging
23 | from axolotl.train import TrainDatasetMeta
24 | from axolotl.utils.config import normalize_config, validate_config
25 | from axolotl.utils.data import prepare_dataset
26 | from axolotl.utils.dict import DictDefault
27 | from axolotl.utils.distributed import is_main_process
28 | from axolotl.utils.models import load_tokenizer
29 | from axolotl.utils.tokenization import check_dataset_labels
30 | from axolotl.utils.wandb_ import setup_wandb_env_vars
31 |
32 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
33 | src_dir = os.path.join(project_root, "src")
34 | sys.path.insert(0, src_dir)
35 |
36 | configure_logging()
37 | LOG = logging.getLogger("axolotl.scripts")
38 |
39 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
40 |
41 |
42 | def print_axolotl_text_art(suffix=None):
43 | font = "nancyj"
44 | ascii_text = " axolotl"
45 | if suffix:
46 | ascii_text += f" x {suffix}"
47 | ascii_art = text2art(" axolotl", font=font)
48 |
49 | if is_main_process():
50 | print(ascii_art)
51 |
52 |
53 | def get_multi_line_input() -> Optional[str]:
54 | print("Give me an instruction (Ctrl + D to submit): ")
55 | instruction = ""
56 | for line in sys.stdin:
57 | instruction += line # pylint: disable=consider-using-join
58 | # instruction = pathlib.Path("/proc/self/fd/0").read_text()
59 | return instruction
60 |
61 |
62 | def do_merge_lora(
63 | *,
64 | cfg: DictDefault,
65 | cli_args: TrainerCliArgs,
66 | ):
67 | model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
68 | safe_serialization = cfg.save_safetensors is True
69 |
70 | LOG.info("running merge of LoRA with base model")
71 | model = model.merge_and_unload()
72 | model.to(dtype=torch.float16)
73 |
74 | if cfg.local_rank == 0:
75 | LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
76 | model.save_pretrained(
77 | str(Path(cfg.output_dir) / "merged"),
78 | safe_serialization=safe_serialization,
79 | )
80 | tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
81 |
82 |
83 | def do_inference(
84 | *,
85 | cfg: DictDefault,
86 | cli_args: TrainerCliArgs,
87 | ):
88 | model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
89 | prompter = cli_args.prompter
90 | default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""}
91 |
92 | for token, symbol in default_tokens.items():
93 | # If the token isn't already specified in the config, add it
94 | if not (cfg.special_tokens and token in cfg.special_tokens):
95 | tokenizer.add_special_tokens({token: symbol})
96 |
97 | prompter_module = None
98 | if prompter:
99 | prompter_module = getattr(
100 | importlib.import_module("axolotl.prompters"), prompter
101 | )
102 |
103 | if cfg.landmark_attention:
104 | from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
105 |
106 | set_model_mem_id(model, tokenizer)
107 | model.set_mem_cache_args(
108 | max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
109 | )
110 |
111 | model = model.to(cfg.device)
112 |
113 | while True:
114 | print("=" * 80)
115 | # support for multiline inputs
116 | instruction = get_multi_line_input()
117 | if not instruction:
118 | return
119 | if prompter_module:
120 | prompt: str = next(
121 | prompter_module().build_prompt(instruction=instruction.strip("\n"))
122 | )
123 | else:
124 | prompt = instruction.strip()
125 | batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
126 |
127 | print("=" * 40)
128 | model.eval()
129 | with torch.no_grad():
130 | generation_config = GenerationConfig(
131 | repetition_penalty=1.1,
132 | max_new_tokens=1024,
133 | temperature=0.9,
134 | top_p=0.95,
135 | top_k=40,
136 | bos_token_id=tokenizer.bos_token_id,
137 | eos_token_id=tokenizer.eos_token_id,
138 | pad_token_id=tokenizer.pad_token_id,
139 | do_sample=True,
140 | use_cache=True,
141 | return_dict_in_generate=True,
142 | output_attentions=False,
143 | output_hidden_states=False,
144 | output_scores=False,
145 | )
146 | streamer = TextStreamer(tokenizer)
147 | generated = model.generate(
148 | inputs=batch["input_ids"].to(cfg.device),
149 | generation_config=generation_config,
150 | streamer=streamer,
151 | )
152 | print("=" * 40)
153 | print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
154 |
155 |
156 | def choose_config(path: Path):
157 | yaml_files = list(path.glob("*.yml"))
158 |
159 | if not yaml_files:
160 | raise ValueError(
161 | "No YAML config files found in the specified directory. Are you using a .yml extension?"
162 | )
163 |
164 | if len(yaml_files) == 1:
165 | print(f"Using default YAML file '{yaml_files[0]}'")
166 | return yaml_files[0]
167 |
168 | print("Choose a YAML file:")
169 | for idx, file in enumerate(yaml_files):
170 | print(f"{idx + 1}. {file}")
171 |
172 | chosen_file = None
173 | while chosen_file is None:
174 | try:
175 | choice = int(input("Enter the number of your choice: "))
176 | if 1 <= choice <= len(yaml_files):
177 | chosen_file = yaml_files[choice - 1]
178 | else:
179 | print("Invalid choice. Please choose a number from the list.")
180 | except ValueError:
181 | print("Invalid input. Please enter a number.")
182 |
183 | return chosen_file
184 |
185 |
186 | def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
187 | return not any(el in list2 for el in list1)
188 |
189 |
190 | def load_cfg(config: Path = Path("examples/"), **kwargs):
191 | if Path(config).is_dir():
192 | config = choose_config(config)
193 |
194 | # load the config from the yaml file
195 | with open(config, encoding="utf-8") as file:
196 | cfg: DictDefault = DictDefault(yaml.safe_load(file))
197 | cfg.axolotl_config_path = config
198 | # if there are any options passed in the cli, if it is something that seems valid from the yaml,
199 | # then overwrite the value
200 | cfg_keys = cfg.keys()
201 | for k, _ in kwargs.items():
202 | # if not strict, allow writing to cfg even if it's not in the yml already
203 | if k in cfg_keys or not cfg.strict:
204 | # handle booleans
205 | if isinstance(cfg[k], bool):
206 | cfg[k] = bool(kwargs[k])
207 | else:
208 | cfg[k] = kwargs[k]
209 |
210 | validate_config(cfg)
211 |
212 | normalize_config(cfg)
213 |
214 | setup_wandb_env_vars(cfg)
215 | return cfg
216 |
217 |
218 | def load_datasets(
219 | *,
220 | cfg: DictDefault,
221 | cli_args: TrainerCliArgs,
222 | ) -> TrainDatasetMeta:
223 | tokenizer = load_tokenizer(cfg)
224 |
225 | train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
226 |
227 | if cli_args.debug or cfg.debug:
228 | LOG.info("check_dataset_labels...")
229 | check_dataset_labels(
230 | train_dataset.select(
231 | [
232 | 0 # nosec
233 | for _ in range(cli_args.debug_num_examples)
234 | ]
235 | ),
236 | tokenizer,
237 | num_examples=cli_args.debug_num_examples,
238 | text_only=cli_args.debug_text_only,
239 | )
240 |
241 | return TrainDatasetMeta(
242 | train_dataset=train_dataset,
243 | eval_dataset=eval_dataset,
244 | total_num_steps=total_num_steps,
245 | )
246 |
247 |
248 | def check_accelerate_default_config():
249 | if Path(config_args.default_yaml_config_file).exists():
250 | LOG.warning(
251 | f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
252 | )
253 |
254 |
255 | def check_user_token():
256 | # Verify if token is valid
257 | api = HfApi()
258 | try:
259 | user_info = api.whoami()
260 | return bool(user_info)
261 | except LocalTokenNotFoundError:
262 | LOG.warning(
263 | "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
264 | )
265 | return False
266 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/trainer.py:
--------------------------------------------------------------------------------
1 | """Module containing the Trainer class and related functions"""
2 | import logging
3 | import math
4 | import os
5 | from contextlib import contextmanager
6 | from functools import partial
7 | from typing import List
8 |
9 | import numpy as np
10 | import torch
11 | import torch.cuda
12 | import torch.distributed as dist
13 | from datasets import set_caching_enabled
14 | from torch.utils.data import DistributedSampler, RandomSampler
15 |
16 | from axolotl.core.trainer_builder import HFCausalTrainerBuilder
17 | from axolotl.utils.collators import DataCollatorForSeq2Seq
18 | from axolotl.utils.dataloader import MultipackDistributedDataloader
19 | from axolotl.utils.distributed import (
20 | is_distributed,
21 | is_main_process,
22 | reduce_and_broadcast,
23 | zero_first,
24 | )
25 |
26 | LOG = logging.getLogger("axolotl")
27 |
28 |
29 | @torch.jit.script
30 | def weighted_cross_entropy(
31 | logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
32 | ):
33 | # Flatten the logits, labels, and weights tensors
34 | logits = logits.view(
35 | -1, logits.size(-1)
36 | ) # logits becomes of shape [batch_size*sequence_length, vocab_size]
37 | labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
38 | weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
39 |
40 | # Compute the unweighted cross entropy loss
41 | losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
42 |
43 | # Apply the weights to the losses and compute their sum
44 | return (weights * losses).sum()
45 |
46 |
47 | @torch.jit.script
48 | def create_weighted_mask(labels: torch.Tensor):
49 | # Check if the tensor is 2D. If not, unsqueeze it to make it 2D
50 | if len(labels.shape) == 1:
51 | labels = labels.unsqueeze(0)
52 |
53 | weights = torch.zeros_like(labels).float()
54 | for i in range(labels.shape[0]):
55 | mask = labels[i] != -100
56 |
57 | # Create a tensor to track group ids
58 | group_ids = torch.zeros_like(labels[i]).int()
59 | curr_group_id = 0
60 |
61 | for j in range(1, len(labels[i])):
62 | if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
63 | curr_group_id += 1 # start new group
64 | group_ids[j] = (
65 | curr_group_id if mask[j] else 0
66 | ) # assign group id if unmasked label
67 |
68 | # Count only unmasked labels in each group
69 | group_counts = torch.bincount(group_ids[mask])
70 |
71 | mask_weights = torch.zeros_like(labels[i]).float()
72 | mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
73 |
74 | weights[i] = mask_weights
75 |
76 | return weights.squeeze() # squeeze the output to match the input dimension
77 |
78 |
79 | def trainer_weighted_loss(model_output, labels, shift_labels=True):
80 | logits = (
81 | model_output["logits"] if isinstance(model_output, dict) else model_output[0]
82 | )
83 | if shift_labels:
84 | logits = logits[..., :-1, :].contiguous()
85 | labels = labels[..., 1:].contiguous()
86 |
87 | weights = create_weighted_mask(labels)
88 | return weighted_cross_entropy(logits, labels, weights)
89 |
90 |
91 | def add_position_ids(sample):
92 | sample_len = len(sample["input_ids"])
93 | sample["position_ids"] = torch.arange(len(sample["input_ids"]))
94 | sample["length"] = sample_len
95 | return sample
96 |
97 |
98 | def add_length(sample):
99 | sample["length"] = len(sample["input_ids"])
100 | return sample
101 |
102 |
103 | def drop_long_seq(sample, sequence_len=2048):
104 | return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
105 |
106 |
107 | @contextmanager
108 | def disable_datasets_caching():
109 | try:
110 | set_caching_enabled(False)
111 | yield
112 | finally:
113 | set_caching_enabled(True)
114 |
115 |
116 | def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
117 | drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
118 | with zero_first(is_main_process()):
119 | train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
120 | if eval_dataset:
121 | eval_dataset = eval_dataset.filter(
122 | drop_long, num_proc=cfg.dataset_processes
123 | )
124 |
125 | if cfg.group_by_length:
126 | train_dataset = train_dataset.map(
127 | add_length, num_proc=cfg.dataset_processes
128 | )
129 |
130 | if cfg.sample_packing:
131 | train_dataset = train_dataset.map(
132 | add_position_ids, num_proc=cfg.dataset_processes
133 | )
134 | if cfg.eval_sample_packing is not False:
135 | if eval_dataset:
136 | eval_dataset = eval_dataset.map(
137 | add_position_ids, num_proc=cfg.dataset_processes
138 | )
139 |
140 | # Phi doesn't want the attention_mask feature when training
141 | if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
142 | cfg.is_mistral_derived_model and cfg.flash_attention
143 | ):
144 | train_dataset = train_dataset.remove_columns("attention_mask")
145 | if eval_dataset:
146 | eval_dataset = eval_dataset.remove_columns("attention_mask")
147 |
148 | return train_dataset, eval_dataset
149 |
150 |
151 | def calculate_total_num_steps(cfg, train_dataset, tokenizer):
152 | if cfg.sample_packing:
153 | # we have to drop anything longer then sequence len otherwise
154 | # flash attention with position ids fails
155 | if not cfg.total_num_tokens:
156 | LOG.info("calculating total_num_tokens")
157 | total_num_tokens = np.sum(
158 | train_dataset.data.column("input_ids")
159 | .to_pandas()
160 | .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
161 | .values
162 | )
163 | LOG.info(f"total_num_tokens: {total_num_tokens}")
164 | cfg.total_num_tokens = total_num_tokens
165 |
166 | if not cfg.total_supervised_tokens:
167 | total_supervised_tokens = (
168 | train_dataset.data.column("labels")
169 | .to_pandas()
170 | .apply(lambda x: np.sum(np.array(x) != -100))
171 | .sum()
172 | )
173 | LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
174 | cfg.total_supervised_tokens = total_supervised_tokens
175 |
176 | if cfg.sample_packing_eff_est:
177 | total_num_steps = (
178 | # match count to len est in dataloader
179 | (
180 | math.floor(
181 | 0.99
182 | * cfg.total_num_tokens
183 | / cfg.sample_packing_eff_est
184 | / cfg.sequence_len
185 | // cfg.batch_size
186 | // int(os.environ.get("WORLD_SIZE", 1))
187 | )
188 | - 1
189 | )
190 | * cfg.num_epochs
191 | )
192 | LOG.info(
193 | f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
194 | )
195 | else:
196 | if cfg.world_size > 1 and is_distributed():
197 | sampler = DistributedSampler(
198 | train_dataset,
199 | num_replicas=cfg.world_size,
200 | rank=dist.get_rank(),
201 | seed=cfg.seed or 42,
202 | )
203 | else:
204 | sampler = RandomSampler(train_dataset)
205 |
206 | data_loader = MultipackDistributedDataloader(
207 | train_dataset,
208 | batch_size=cfg.micro_batch_size,
209 | seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
210 | collate_fn=DataCollatorForSeq2Seq(
211 | tokenizer,
212 | return_tensors="pt",
213 | padding="longest",
214 | ),
215 | sampler=sampler,
216 | packing_efficiency_estimate=cfg.sample_packing_eff_est,
217 | sample_packing_seq_len_multiplier=cfg.micro_batch_size,
218 | device_count=int(os.environ.get("WORLD_SIZE", 1)),
219 | )
220 | data_loader_len = data_loader.len_w_stats()
221 | actual_eff = data_loader.efficiency()
222 | LOG.info(f"data_loader_len: {data_loader_len}")
223 | # FIXME: is there a bug here somewhere? the total num steps depends
224 | # on the agreed on value for sample_packing_eff_est
225 | total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
226 |
227 | def calc_sample_packing_eff_est(estimates: List[float]):
228 | LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
229 | return max(estimates)
230 |
231 | sample_packing_actual_eff_all = reduce_and_broadcast(
232 | lambda: actual_eff,
233 | calc_sample_packing_eff_est,
234 | )
235 | sample_packing_eff_est = (
236 | math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
237 | )
238 | cfg.sample_packing_eff_est = sample_packing_eff_est
239 | LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
240 | else:
241 | total_num_steps = int(
242 | math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
243 | )
244 | LOG.info(f"total_num_steps: {total_num_steps}")
245 | return total_num_steps
246 |
247 |
248 | def setup_fsdp_envs(cfg):
249 | os.environ["ACCELERATE_USE_FSDP"] = "true"
250 | if cfg.fsdp_config.fsdp_offload_params:
251 | os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
252 | if cfg.fsdp_config.fsdp_sync_module_states:
253 | os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
254 | if cfg.fsdp_config.fsdp_state_dict_type:
255 | os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
256 | if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
257 | os.environ[
258 | "FSDP_TRANSFORMER_CLS_TO_WRAP"
259 | ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
260 |
261 |
262 | def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
263 | if cfg.fsdp:
264 | setup_fsdp_envs(cfg)
265 | elif cfg.deepspeed:
266 | os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
267 |
268 | trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
269 | trainer_builder.train_dataset = train_dataset
270 | trainer_builder.eval_dataset = eval_dataset
271 |
272 | return trainer_builder.build(total_num_steps)
273 |
--------------------------------------------------------------------------------
/training/axolotl/src/axolotl/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | # pylint: skip-file
2 | import hashlib
3 | import itertools
4 | import logging
5 | import math
6 | from typing import Any, Callable, List, Union
7 |
8 | import numba
9 | import numpy as np
10 | from torch.utils.data import DistributedSampler, Sampler
11 |
12 | LOG = logging.getLogger("axolotl.utils.dataloader")
13 |
14 |
15 | @numba.njit
16 | def ffd_check(a: np.ndarray, c: int, n: int):
17 | # First-fit-decreasing bin packing
18 | # Check if a[] could fit in n bins with capacity c
19 | # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
20 |
21 | a = np.sort(a)[::-1]
22 | bins = np.full((n,), c, dtype=a.dtype)
23 | for size in a:
24 | not_found = True
25 | for idx in range(n):
26 | if bins[idx] >= size:
27 | bins[idx] -= size
28 | not_found = False
29 | break
30 |
31 | if not_found:
32 | return False
33 |
34 | return True
35 |
36 |
37 | @numba.njit
38 | def ffd_with_result(a: np.ndarray, c: int, start_index: int):
39 | # First-fit-decreasing bin packing (with result return)
40 |
41 | indices = np.argsort(a)[::-1]
42 | a = a[indices]
43 |
44 | bins: List[Any] = []
45 | bins_result: List[Any] = []
46 | for a_id, size in enumerate(a):
47 | add_new = True
48 | for idx in range(len(bins)):
49 | if bins[idx] >= size:
50 | bins[idx] -= size
51 | bins_result[idx].append(indices[a_id] + start_index)
52 | add_new = False
53 | break
54 |
55 | if add_new:
56 | bins.append(c - size)
57 | bins_result.append([indices[a_id] + start_index])
58 |
59 | return bins_result, len(a)
60 |
61 |
62 | @numba.njit
63 | def allocate(
64 | lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
65 | ):
66 | """
67 | :param lengths: array of lengths of each sample
68 | :param lengths_cumsum: cumulative sum of consecutive lengths
69 | :param rank: rank for this process
70 | :param c: length of tokens per batch
71 | :param n: number of ranks
72 | :return:
73 | """
74 | # Dynamic batch allocator, similar to Multifit
75 | # https://en.wikipedia.org/wiki/Multifit_algorithm
76 | # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
77 |
78 | s = 0
79 | start_index = 0
80 | result = []
81 | result_totseqs = []
82 |
83 | while True:
84 | # binary search [left, right)
85 | left = 1
86 | right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
87 |
88 | while right - left > 1:
89 | mid = (left + right) // 2
90 | if ffd_check(lengths[start_index : start_index + mid], c, n):
91 | left = mid
92 | else:
93 | right = mid
94 |
95 | # use length left
96 | batch, tot_seqs = ffd_with_result(
97 | lengths[start_index : start_index + left], c, start_index
98 | )
99 | if len(batch) < n:
100 | break
101 |
102 | start_index += left
103 | s = lengths_cumsum[start_index - 1]
104 |
105 | # add local rank
106 | result.append(batch[rank])
107 | # add total seqs for all ranks
108 | result_totseqs.append(tot_seqs)
109 | # yield batch[rank], tot_seqs, s, len(result) * c * n
110 | return result, result_totseqs, s, len(result) * c * n
111 |
112 |
113 | def chunk(iterable, n):
114 | """
115 | Chunk data into tuples of length n
116 | """
117 | # batched('ABCDEFG', 3) --> ABC DEF G
118 | if n < 1:
119 | raise ValueError("n must be at least one")
120 | it = iter(iterable)
121 | while batch := tuple(itertools.islice(it, n)):
122 | yield batch
123 |
124 |
125 | def hash_indices(lst: List[int]) -> str:
126 | # Convert the list of integers to a string representation
127 | concatenated = ",".join(map(str, lst))
128 |
129 | # Generate the hash
130 | sha256 = hashlib.sha256()
131 | sha256.update(concatenated.encode())
132 |
133 | return sha256.hexdigest()
134 |
135 |
136 | class MultipackDistributedDataloader:
137 | """Unpadded data loading using Multipack.
138 | Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
139 | Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
140 | """
141 |
142 | def __init__(
143 | self,
144 | dataset: Any,
145 | collate_fn: Callable,
146 | seq_max_length: int = 2048,
147 | batch_size: int = 1,
148 | sampler: Union[Sampler, DistributedSampler] = None,
149 | packing_efficiency_estimate: float = 1.0,
150 | sample_packing_seq_len_multiplier: int = 1,
151 | device_count: int = 1,
152 | ):
153 | # Dataset
154 | self.dataset = dataset
155 | self.lengths = (
156 | dataset.data.column("position_ids")
157 | .to_pandas()
158 | .apply(lambda x: x[-1] + 1)
159 | .values
160 | )
161 | assert isinstance(self.lengths, np.ndarray)
162 | assert batch_size % sample_packing_seq_len_multiplier == 0
163 | assert batch_size >= sample_packing_seq_len_multiplier
164 | self.sampler = sampler
165 | self.batch_size = batch_size
166 | self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
167 | self.seq_max_length = seq_max_length
168 | self.batch_max_length = batch_size * seq_max_length
169 | self.collate_fn = collate_fn
170 |
171 | self.num_replicas = 1
172 | self.rank = 0
173 |
174 | # statistics
175 | self.eff_total_used = 0
176 | self.eff_total_slots = 0
177 | self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
178 | self.device_count = device_count
179 |
180 | def generate_batches(self, set_stats=False):
181 | LOG.info("generating packed batches")
182 | if self.sampler:
183 | indices = [idx for idx in self.sampler]
184 | else:
185 | indices = range(0, len(self.dataset))
186 |
187 | LOG.info(hash_indices(indices))
188 | lengths = self.lengths[indices]
189 | lengths_cumsum = np.cumsum(lengths)
190 |
191 | batches, totseqs, total_used, total_slots = allocate(
192 | lengths=lengths,
193 | lengths_cumsum=lengths_cumsum,
194 | rank=self.rank,
195 | # c=self.batch_max_length,
196 | c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
197 | n=self.num_replicas,
198 | )
199 |
200 | batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
201 |
202 | # statistics
203 | if set_stats:
204 | self.eff_total_used += total_used
205 | self.eff_total_slots += total_slots
206 |
207 | return batches, totseqs
208 |
209 | def __iter__(self):
210 | if hasattr(self.sampler, "set_epoch"):
211 | new_epoch = self.sampler.epoch + 1
212 | self.sampler.set_epoch(new_epoch)
213 | LOG.info(f"calling sampler.set_epoch({new_epoch})")
214 | all_batches, _ = self.generate_batches(set_stats=True)
215 | features = self.dataset.features.keys()
216 | len_remaining = self._len_est()
217 | for batches in chunk(
218 | all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
219 | ):
220 | chunked_data = []
221 | attn_mask_cum_idx = 0
222 | for batch in batches:
223 | concatenated = {}
224 | batched_data = [self.dataset[batch_idx] for batch_idx in batch]
225 | for feature in features:
226 | if feature == "length":
227 | continue
228 | if feature == "attention_mask":
229 | arrays = [
230 | (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
231 | for idx, item in enumerate(batched_data)
232 | if feature in item
233 | ]
234 | attn_mask_cum_idx += len(batched_data)
235 | concatenated[feature] = np.concatenate(arrays)
236 | else:
237 | arrays = [
238 | np.array(item[feature])
239 | for item in batched_data
240 | if feature in item
241 | ]
242 | concatenated[feature] = np.concatenate(arrays)
243 | chunked_data.append(concatenated)
244 | yield self.collate_fn(chunked_data)
245 | len_remaining -= 1
246 | if not len_remaining:
247 | return
248 | # yield a no-op for cases where we don't have any data left to pack
249 | for i in range(0, len_remaining):
250 | yield self.collate_fn(
251 | [
252 | {
253 | "input_ids": [0],
254 | "labels": [-100],
255 | "attention_mask": [True],
256 | "position_ids": [0],
257 | }
258 | ]
259 | )
260 |
261 | def _len_est(self):
262 | lengths_sum = np.sum(self.lengths)
263 | lengths_sum_per_device = lengths_sum // self.device_count
264 | LOG.info(
265 | f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
266 | f"total_num_tokens per device: {lengths_sum_per_device}"
267 | )
268 |
269 | # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
270 | return (
271 | math.floor(
272 | 0.99
273 | * lengths_sum_per_device
274 | / self.packing_efficiency_estimate
275 | // self.seq_max_length
276 | // self.batch_size
277 | )
278 | - 1
279 | )
280 |
281 | def __len__(self):
282 | # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
283 | # the same share of total tokens
284 | # if not self.eff_total_used:
285 | # batches, _ = self.generate_batches(set_stats=True)
286 | # LOG.info(
287 | # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
288 | # f"actual packing efficiency: {self.efficiency()}"
289 | # )
290 | return max(1, self._len_est())
291 |
292 | def len_w_stats(self):
293 | if not self.eff_total_used:
294 | batches, _ = self.generate_batches(set_stats=True)
295 | LOG.info(
296 | f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
297 | f"actual packing efficiency: {self.efficiency()}"
298 | )
299 | return max(1, self._len_est())
300 |
301 | def efficiency(self):
302 | return self.eff_total_used / self.eff_total_slots
303 |
--------------------------------------------------------------------------------