├── .gitignore ├── Dockerfile ├── Finetune4bConfig.py ├── LICENSE ├── README.md ├── alpaca_lora_4bit_penguin_fact.gif ├── amp_wrapper.py ├── arg_parser.py ├── autograd_4bit.py ├── custom_autotune.py ├── data.txt ├── finetune.py ├── gradient_checkpointing.py ├── inference.py ├── matmul_utils_4bit.py ├── monkeypatch └── llama_flash_attn_monkey_patch.py ├── requirements.txt ├── text-generation-webui └── custom_monkey_patch.py ├── train_data.py └── triton_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | alpaca_lora/ 2 | repository/ 3 | __pycache__/ 4 | llama-13b-4bit 5 | llama-13b-4bit.pt 6 | text-generation-webui/ 7 | repository/ 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental 2 | 3 | # Dockerfile is split into parts because we want to cache building the requirements and downloading the model, both of which can take a long time. 4 | 5 | FROM nvidia/cuda:11.7.0-devel-ubuntu22.04 AS builder 6 | 7 | RUN apt-get update && apt-get install -y python3 python3-pip git 8 | 9 | RUN pip3 install --upgrade pip 10 | 11 | # Some of the requirements expect some python packages in their setup.py, just install them first. 12 | RUN --mount=type=cache,target=/root/.cache/pip pip install --user torch==2.0.0 13 | RUN --mount=type=cache,target=/root/.cache/pip pip install --user semantic-version==2.10.0 requests tqdm 14 | 15 | # The docker build environment has trouble detecting CUDA version, build for all reasonable archs 16 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" 17 | COPY requirements.txt requirements.txt 18 | RUN --mount=type=cache,target=/root/.cache pip install --user -r requirements.txt 19 | 20 | # ------------------------------- 21 | 22 | # Download the model 23 | FROM nvidia/cuda:11.7.0-devel-ubuntu22.04 AS downloader 24 | RUN apt-get update && apt-get install -y wget 25 | 26 | RUN wget --progress=bar:force:noscroll https://huggingface.co/decapoda-research/llama-7b-hf-int4/resolve/main/llama-7b-4bit.pt 27 | 28 | 29 | 30 | # ------------------------------- 31 | 32 | #FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel 33 | FROM nvidia/cuda:11.7.0-devel-ubuntu22.04 34 | 35 | RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get install -y git python3 python3-pip 36 | 37 | RUN ln -s `which python3` /usr/bin/python 38 | 39 | 40 | # Copy the installed packages from the first stage 41 | COPY --from=builder /root/.local /root/.local 42 | 43 | RUN mkdir alpaca_lora_4bit 44 | WORKDIR alpaca_lora_4bit 45 | 46 | COPY --from=downloader llama-7b-4bit.pt llama-7b-4bit.pt 47 | 48 | #RUN git clone --depth=1 --branch main https://github.com/andybarry/text-generation-webui-4bit.git text-generation-webui-tmp 49 | 50 | RUN git clone --depth=1 --branch main https://github.com/oobabooga/text-generation-webui.git text-generation-webui-tmp 51 | 52 | RUN --mount=type=cache,target=/root/.cache pip install --user markdown gradio 53 | 54 | # Apply monkey patch 55 | RUN cd text-generation-webui-tmp && printf '%s'"import custom_monkey_patch # apply monkey patch\nimport gc\n\n" | cat - server.py > tmpfile && mv tmpfile server.py 56 | 57 | # Get the model config 58 | RUN cd text-generation-webui-tmp && python download-model.py --text-only decapoda-research/llama-7b-hf && mv models/decapoda-research_llama-7b-hf ../llama-7b-4bit 59 | 60 | 61 | # Get LoRA 62 | RUN cd text-generation-webui-tmp && python download-model.py samwit/alpaca7b-lora && mv loras/samwit_alpaca7b-lora ../alpaca7b_lora 63 | 64 | COPY *.py . 65 | COPY text-generation-webui text-generation-webui 66 | COPY monkeypatch . 67 | 68 | RUN mv -f text-generation-webui-tmp/* text-generation-webui/ 69 | 70 | # Symlink for monkeypatch 71 | RUN cd text-generation-webui && ln -s ../autograd_4bit.py ./autograd_4bit.py && ln -s ../matmul_utils_4bit.py . 72 | 73 | # Swap to the 7bn parameter model 74 | RUN sed -i 's/llama-13b-4bit/llama-7b-4bit/g' text-generation-webui/custom_monkey_patch.py && sed -i 's/alpaca13b_lora/alpaca7b_lora/g' text-generation-webui/custom_monkey_patch.py 75 | 76 | # Run the server 77 | WORKDIR /alpaca_lora_4bit/text-generation-webui 78 | CMD ["python", "-u", "server.py", "--listen", "--chat"] -------------------------------------------------------------------------------- /Finetune4bConfig.py: -------------------------------------------------------------------------------- 1 | import os 2 | class Finetune4bConfig: 3 | """Config holder for LLaMA 4bit finetuning 4 | """ 5 | def __init__(self, dataset: str, ds_type: str, 6 | lora_out_dir: str, lora_apply_dir: str, resume_checkpoint: str, 7 | llama_q4_config_dir: str, llama_q4_model: str, 8 | mbatch_size: int, batch_size: int, 9 | epochs: int, lr: float, 10 | cutoff_len: int, 11 | lora_r: int, lora_alpha: int, lora_dropout: float, 12 | val_set_size: float, 13 | gradient_checkpointing: bool, 14 | gradient_checkpointing_ratio: float, 15 | warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, 16 | checkpoint: bool, skip: bool, verbose: bool, 17 | txt_row_thd: int, use_eos_token: bool, groupsize: int, 18 | local_rank: int, flash_attention: bool, backend: str 19 | ): 20 | """ 21 | Args: 22 | dataset (str): Path to dataset file 23 | ds_type (str): Dataset structure format 24 | lora_out_dir (str): Directory to place new LoRA 25 | lora_apply_dir (str): Path to directory from which LoRA has to be applied before training 26 | resume_checkpoint (str): Path to Specified checkpoint you want to resume. 27 | llama_q4_config_dir (str): Path to the config.json, tokenizer_config.json, etc 28 | llama_q4_model (str): Path to the quantized model in huggingface format 29 | mbatch_size (int): Micro-batch size 30 | batch_size (int): Batch size 31 | epochs (int): Epochs 32 | lr (float): Learning rate 33 | cutoff_len (int): Cutoff length 34 | lora_r (int): LoRA R 35 | lora_alpha (int): LoRA Alpha 36 | lora_dropout (float): LoRA Dropout 37 | gradient_checkpointing (bool) : Use gradient checkpointing 38 | gradient_checkpointing_ratio (float) : Gradient checkpoint ratio 39 | val_set_size (int): Validation set size 40 | warmup_steps (int): Warmup steps before training 41 | save_steps (int): Save steps 42 | save_total_limit (int): Save total limit 43 | logging_steps (int): Logging steps 44 | checkpoint (bool): Produce checkpoint instead of LoRA 45 | skip (bool): Don't train model 46 | verbose (bool): If output log of training 47 | txt_row_thd (int): Custom row thd for txt file 48 | use_eos_token (bool): Use Eos token instead of padding with 0 49 | groupsize (int): Group size of V2 model, use -1 to load V1 model 50 | local_rank (int): local rank if using torch.distributed.launch 51 | flash_attention (bool): Enables flash attention 52 | """ 53 | self.dataset = dataset 54 | self.ds_type = ds_type 55 | self.lora_out_dir = lora_out_dir 56 | self.lora_apply_dir = lora_apply_dir 57 | self.resume_checkpoint = resume_checkpoint 58 | self.llama_q4_config_dir = llama_q4_config_dir 59 | self.llama_q4_model = llama_q4_model 60 | self.mbatch_size = mbatch_size 61 | self.batch_size = batch_size 62 | self.gradient_accumulation_steps = self.batch_size // self.mbatch_size 63 | self.epochs = epochs 64 | self.lr = lr 65 | self.cutoff_len = cutoff_len 66 | self.lora_r = lora_r 67 | self.lora_alpha = lora_alpha 68 | self.lora_dropout = 0 if gradient_checkpointing else lora_dropout # should be 0 if gradient checkpointing is on 69 | self.val_set_size = int(val_set_size) if val_set_size > 1.0 else float(val_set_size) 70 | self.gradient_checkpointing = gradient_checkpointing 71 | self.gradient_checkpointing_ratio = gradient_checkpointing_ratio 72 | self.warmup_steps = warmup_steps 73 | self.save_steps = save_steps 74 | self.save_total_limit = save_total_limit 75 | self.logging_steps = logging_steps 76 | self.checkpoint = checkpoint 77 | self.skip = skip 78 | self.verbose = verbose 79 | self.txt_row_thd = txt_row_thd 80 | self.use_eos_token = use_eos_token 81 | self.world_size = int(os.environ.get("WORLD_SIZE", 1)) 82 | self.local_rank = int(os.environ.get("LOCAL_RANK", local_rank)) 83 | self.ddp = self.world_size != 1 84 | self.device_map = "auto" if not self.ddp else {"": self.local_rank} 85 | if self.ddp: 86 | self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size 87 | self.groupsize = groupsize 88 | self.flash_attention = flash_attention 89 | self.backend = backend 90 | 91 | 92 | def __str__(self) -> str: 93 | s = f"\nParameters:\n{'config':-^20}\n{self.dataset=}\n{self.ds_type=}\n{self.lora_out_dir=}\n{self.lora_apply_dir=}\n{self.llama_q4_config_dir=}\n{self.llama_q4_model=}\n\n" +\ 94 | f"{'training':-^20}\n" +\ 95 | f"{self.mbatch_size=}\n{self.batch_size=}\n{self.gradient_accumulation_steps=}\n{self.epochs=}\n{self.lr=}\n{self.cutoff_len=}\n" +\ 96 | f"{self.lora_r=}\n{self.lora_alpha=}\n{self.lora_dropout=}\n{self.val_set_size=}\n" +\ 97 | f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\ 98 | f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ 99 | f"{self.logging_steps=}\n" +\ 100 | f"{self.checkpoint=}\n{self.skip=}\n" +\ 101 | f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\ 102 | f"{self.groupsize=}\n{self.backend=}\n" 103 | return s.replace("self.", "") 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 John Smith 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Alpaca Lora 4bit 2 | Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits. 3 | 4 | ## Quick start for running the chat UI 5 | 6 | ``` 7 | git clone https://github.com/andybarry/alpaca_lora_4bit_docker.git 8 | DOCCKER_BUILDKIT=1 docker build -t alpaca_lora_4bit . # build step can take 12 min 9 | docker run --gpus=all -p 7860:7860 alpaca_lora_4bit 10 | ``` 11 | Point your browser to http://localhost:7860 12 | 13 | ## Results 14 | It's fast on a 3070 Ti mobile. Uses 5-6 GB of GPU RAM. 15 | 16 | ![](alpaca_lora_4bit_penguin_fact.gif) 17 | 18 | # Development 19 | * Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md (**NOTE:** don't use the install script, use the requirements.txt instead.) 20 | * Also Remember to create a venv if you do not want the packages be overwritten. 21 | 22 | # Update Logs 23 | * Resolved numerically unstable issue 24 | * Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed. 25 | * Added install script for windows and linux. 26 | * Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM with Gradient Checkpointing enabled. (finetune.py updated) (but would reduce training speed, so if having enough VRAM this option is not needed) 27 | * Added install manual by s4rduk4r 28 | * Added pip install support by sterlind, preparing to merge changes upstream 29 | * Added V2 model support (with groupsize, both inference + finetune) 30 | * Added some options on finetune: set default to use eos_token instead of padding, add resume_checkpoint to continue training 31 | * Added offload support. load_llama_model_4bit_low_ram_and_offload_to_cpu function can be used. 32 | * Added monkey patch for text generation webui for fixing initial eos token issue. 33 | * Added Flash attention support. (Use --flash-attention) 34 | * Added Triton backend to support model using groupsize and act-order. (Use --backend=triton) 35 | 36 | # Requirements 37 | gptq-for-llama
38 | peft
39 | The specific version is inside requirements.txt
40 | 41 | # Install 42 | ~copy files from GPTQ-for-LLaMa into GPTQ-for-LLaMa path and re-compile cuda extension~
43 | ~copy files from peft/tuners/lora.py to peft path, replace it~
44 | 45 | **NOTE:** Install scripts are no longer needed! requirements.txt now pulls from forks with the necessary patches. 46 | 47 | ``` 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | # Finetune 52 | ~The same finetune script from https://github.com/tloen/alpaca-lora can be used.~
53 | 54 | After installation, this script can be used: 55 | GPTQv1: 56 | 57 | ``` 58 | python finetune.py 59 | ``` 60 | or 61 | ``` 62 | GPTQ_VERSION=1 python finetune.py 63 | ``` 64 | 65 | GPTQv2: 66 | ``` 67 | GPTQ_VERSION=2 python finetune.py 68 | ``` 69 | 70 | # Inference 71 | 72 | After installation, this script can be used: 73 | 74 | ``` 75 | python inference.py 76 | ``` 77 | 78 | # Text Generation Webui Monkey Patch 79 | 80 | Clone the latest version of text generation webui and copy all the files into ./text-generation-webui/ 81 | ``` 82 | git clone https://github.com/oobabooga/text-generation-webui.git 83 | ``` 84 | 85 | Open server.py and insert a line at the beginning 86 | ``` 87 | import custom_monkey_patch # apply monkey patch 88 | import gc 89 | import io 90 | ... 91 | ``` 92 | 93 | Use the command to run 94 | 95 | ``` 96 | python server.py 97 | ``` 98 | 99 | # Flash Attention 100 | 101 | It seems that we can apply a monkey patch for llama model. To use it, simply download the file from [MonkeyPatch](https://github.com/lm-sys/FastChat/blob/daa9c11080ceced2bd52c3e0027e4f64b1512683/fastchat/train/llama_flash_attn_monkey_patch.py). And also, flash-attention is needed, and currently do not support pytorch 2.0. 102 | Just add --flash-attention to use it for finetuning. 103 | -------------------------------------------------------------------------------- /alpaca_lora_4bit_penguin_fact.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andybarry/alpaca_lora_4bit_docker/a93cf1264a1dcb994cd6b4053187b74b25de7b50/alpaca_lora_4bit_penguin_fact.gif -------------------------------------------------------------------------------- /amp_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AMPWrapper: 5 | 6 | def __init__(self, model, options=None): 7 | self.model = model 8 | self.options = options 9 | if self.options is None: 10 | self.options = {'enabled': True, 'device_type': 'cuda'} 11 | 12 | def autocast_forward(self, *args, **kwargs): 13 | with torch.amp.autocast(**self.options): 14 | return self.model.non_autocast_forward(*args, **kwargs) 15 | 16 | def autocast_generate(self, *args, **kwargs): 17 | with torch.amp.autocast(**self.options): 18 | return self.model.non_autocast_generate(*args, **kwargs) 19 | 20 | def apply_forward(self): 21 | self.model.non_autocast_forward = self.model.forward 22 | self.model.forward = self.autocast_forward 23 | 24 | def apply_generate(self): 25 | self.model.non_autocast_generate = self.model.generate 26 | self.model.generate = self.autocast_generate 27 | -------------------------------------------------------------------------------- /arg_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from Finetune4bConfig import Finetune4bConfig 4 | 5 | def parse_commandline(): 6 | parser = argparse.ArgumentParser( 7 | prog=__file__.split(os.path.sep)[-1], 8 | description="Produce LoRA in 4bit training", 9 | usage="%(prog)s [config] [training]\n\nAll arguments are optional" 10 | ) 11 | 12 | parser.add_argument("dataset", nargs="?", 13 | default="./dataset.json", 14 | help="Path to dataset file. Default: %(default)s" 15 | ) 16 | 17 | parser_config = parser.add_argument_group("config") 18 | parser_training = parser.add_argument_group("training") 19 | 20 | # Config args group 21 | parser_config.add_argument("--ds_type", choices=["txt", "alpaca", "gpt4all"], default="alpaca", required=False, 22 | help="Dataset structure format. Default: %(default)s" 23 | ) 24 | parser_config.add_argument("--lora_out_dir", default="alpaca_lora", required=False, 25 | help="Directory to place new LoRA. Default: %(default)s" 26 | ) 27 | parser_config.add_argument("--lora_apply_dir", default=None, required=False, 28 | help="Path to directory from which LoRA has to be applied before training. Default: %(default)s" 29 | ) 30 | parser_training.add_argument("--resume_checkpoint", default=None, required=False, 31 | help="Resume training from specified checkpoint. Default: %(default)s" 32 | ) 33 | parser_config.add_argument("--llama_q4_config_dir", default="./llama-13b-4bit/", required=False, 34 | help="Path to the config.json, tokenizer_config.json, etc. Default: %(default)s" 35 | ) 36 | parser_config.add_argument("--llama_q4_model", default="./llama-13b-4bit.pt", required=False, 37 | help="Path to the quantized model in huggingface format. Default: %(default)s" 38 | ) 39 | 40 | # Training args group 41 | parser_training.add_argument("--mbatch_size", default=1, type=int, help="Micro-batch size. Default: %(default)s") 42 | parser_training.add_argument("--batch_size", default=2, type=int, help="Batch size. Default: %(default)s") 43 | parser_training.add_argument("--epochs", default=3, type=int, help="Epochs. Default: %(default)s") 44 | parser_training.add_argument("--lr", default=2e-4, type=float, help="Learning rate. Default: %(default)s") 45 | parser_training.add_argument("--cutoff_len", default=256, type=int, help="Default: %(default)s") 46 | parser_training.add_argument("--lora_r", default=8, type=int, help="Default: %(default)s") 47 | parser_training.add_argument("--lora_alpha", default=16, type=int, help="Default: %(default)s") 48 | parser_training.add_argument("--lora_dropout", default=0.05, type=float, help="Default: %(default)s") 49 | parser_training.add_argument("--grad_chckpt", action="store_true", required=False, help="Use gradient checkpoint. For 30B model. Default: %(default)s") 50 | parser_training.add_argument("--grad_chckpt_ratio", default=1, type=float, help="Gradient checkpoint ratio. Default: %(default)s") 51 | parser_training.add_argument("--val_set_size", default=0.2, type=float, help="Validation set size. Default: %(default)s") 52 | parser_training.add_argument("--warmup_steps", default=50, type=int, help="Default: %(default)s") 53 | parser_training.add_argument("--save_steps", default=50, type=int, help="Default: %(default)s") 54 | parser_training.add_argument("--save_total_limit", default=3, type=int, help="Default: %(default)s") 55 | parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s") 56 | parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s") 57 | parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s") 58 | parser_training.add_argument("--verbose", action="store_true", help="If output log of training. Default: %(default)s") 59 | 60 | # Data args 61 | parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.") 62 | parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.") 63 | 64 | # V2 model support 65 | parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model") 66 | 67 | # Multi GPU Support 68 | parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch") 69 | 70 | # Flash Attention 71 | parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use") 72 | 73 | # Train Backend 74 | parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.") 75 | 76 | return vars(parser.parse_args()) 77 | 78 | 79 | def get_config() -> Finetune4bConfig: 80 | args = parse_commandline() 81 | return Finetune4bConfig( 82 | dataset=args["dataset"], 83 | ds_type=args["ds_type"], 84 | lora_out_dir=args["lora_out_dir"], 85 | lora_apply_dir=args["lora_apply_dir"], 86 | resume_checkpoint=args["resume_checkpoint"], 87 | llama_q4_config_dir=args["llama_q4_config_dir"], 88 | llama_q4_model=args["llama_q4_model"], 89 | mbatch_size=args["mbatch_size"], 90 | batch_size=args["batch_size"], 91 | epochs=args["epochs"], 92 | lr=args["lr"], 93 | cutoff_len=args["cutoff_len"], 94 | lora_r=args["lora_r"], 95 | lora_alpha=args["lora_alpha"], 96 | lora_dropout=args["lora_dropout"], 97 | val_set_size=args["val_set_size"], 98 | gradient_checkpointing=args["grad_chckpt"], 99 | gradient_checkpointing_ratio=args["grad_chckpt_ratio"], 100 | warmup_steps=args["warmup_steps"], 101 | save_steps=args["save_steps"], 102 | save_total_limit=args["save_total_limit"], 103 | logging_steps=args["logging_steps"], 104 | checkpoint=args["checkpoint"], 105 | skip=args["skip"], 106 | verbose=args["verbose"], 107 | txt_row_thd=args["txt_row_thd"], 108 | use_eos_token=args["use_eos_token"]!=0, 109 | groupsize=args["groupsize"], 110 | local_rank=args["local_rank"], 111 | flash_attention=args["flash_attention"], 112 | backend=args["backend"], 113 | ) 114 | -------------------------------------------------------------------------------- /autograd_4bit.py: -------------------------------------------------------------------------------- 1 | import matmul_utils_4bit as mm4b 2 | import torch 3 | import torch.nn as nn 4 | import time 5 | import math 6 | from torch.cuda.amp import custom_bwd, custom_fwd 7 | from colorama import init, Fore, Back, Style 8 | init(autoreset=True) 9 | 10 | 11 | class AutogradMatmul4bitCuda(torch.autograd.Function): 12 | 13 | @staticmethod 14 | @custom_fwd(cast_inputs=torch.float16) 15 | def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq, groupsize=-1): 16 | ctx.save_for_backward(qweight, scales, zeros) 17 | ctx.groupsize = groupsize 18 | if groupsize == -1: 19 | output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) 20 | else: 21 | output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) 22 | output = output.clone() 23 | return output 24 | 25 | @staticmethod 26 | @custom_bwd 27 | def backward(ctx, grad_output): 28 | qweight, scales, zeros = ctx.saved_tensors 29 | groupsize = ctx.groupsize 30 | if ctx.needs_input_grad[0]: 31 | if groupsize == -1: 32 | grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) 33 | else: 34 | grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) 35 | return grad, None, None, None, None, None, None, None 36 | 37 | 38 | try: 39 | import triton_utils as tu 40 | 41 | class AutogradMatmul4bitTriton(torch.autograd.Function): 42 | 43 | @staticmethod 44 | @custom_fwd(cast_inputs=torch.float16) 45 | def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize=-1): 46 | output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) 47 | ctx.save_for_backward(qweight, scales, qzeros, g_idx) 48 | ctx.bits, ctx.maxq = bits, maxq 49 | output = output.clone() 50 | return output 51 | 52 | @staticmethod 53 | @custom_bwd 54 | def backward(ctx, grad_output): 55 | qweight, scales, qzeros, g_idx = ctx.saved_tensors 56 | bits, maxq = ctx.bits, ctx.maxq 57 | grad_input = None 58 | 59 | if ctx.needs_input_grad[0]: 60 | grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) 61 | return grad_input, None, None, None, None, None, None, None 62 | 63 | except ImportError: 64 | print('Triton not found. Please run "pip install triton".') 65 | 66 | 67 | AutogradMatmul4bit = AutogradMatmul4bitCuda 68 | backend = 'cuda' 69 | 70 | 71 | def switch_backend_to(to_backend): 72 | global AutogradMatmul4bit 73 | global backend 74 | if to_backend == 'cuda': 75 | AutogradMatmul4bit = AutogradMatmul4bitCuda 76 | backend = 'cuda' 77 | print(Style.BRIGHT + Fore.GREEN + 'Using CUDA implementation.') 78 | elif to_backend == 'triton': 79 | # detect if AutogradMatmul4bitTriton is defined 80 | if 'AutogradMatmul4bitTriton' not in globals(): 81 | raise ValueError('Triton not found. Please install triton_utils.') 82 | AutogradMatmul4bit = AutogradMatmul4bitTriton 83 | backend = 'triton' 84 | print(Style.BRIGHT + Fore.GREEN + 'Using Triton implementation.') 85 | else: 86 | raise ValueError('Backend not supported.') 87 | 88 | 89 | def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize): 90 | if backend == 'cuda': 91 | return mm4b.matmul4bit(x, qweight, scales, qzeros, groupsize) 92 | elif backend == 'triton': 93 | assert qzeros.dtype == torch.int32 94 | return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) 95 | else: 96 | raise ValueError('Backend not supported.') 97 | 98 | 99 | # Assumes layer is perfectly divisible into 256 * 256 blocks 100 | class Autograd4bitQuantLinear(nn.Module): 101 | 102 | def __init__(self, in_features, out_features, groupsize=-1): 103 | super().__init__() 104 | bits = 4 105 | self.in_features = in_features 106 | self.out_features = out_features 107 | self.bits = bits 108 | self.maxq = 2 ** self.bits - 1 109 | self.groupsize = groupsize 110 | self.g_idx = 0 111 | if groupsize == -1: 112 | self.register_buffer('zeros', torch.empty((out_features, 1))) 113 | self.register_buffer('scales', torch.empty((out_features, 1))) 114 | else: 115 | self.register_buffer('qzeros', 116 | torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32) 117 | ) 118 | self.register_buffer('scales', torch.empty((math.ceil(in_features/groupsize), out_features))) 119 | self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32)) 120 | self.register_buffer('bias', torch.empty(out_features)) 121 | self.register_buffer( 122 | 'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int32) 123 | ) 124 | 125 | 126 | def forward(self, x): 127 | if torch.is_grad_enabled(): 128 | out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, 129 | self.qzeros if self.groupsize != -1 else self.zeros, 130 | self.g_idx, self.bits, self.maxq, 131 | self.groupsize) 132 | else: 133 | out = matmul4bit_with_backend(x, self.qweight, self.scales, 134 | self.qzeros if self.groupsize != -1 else self.zeros, 135 | self.g_idx, self.bits, self.maxq, 136 | self.groupsize) 137 | out += self.bias 138 | return out 139 | 140 | 141 | def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): 142 | if isinstance(module, Autograd4bitQuantLinear): 143 | return 144 | for attr in dir(module): 145 | tmp = getattr(module, attr) 146 | name1 = name + '.' + attr if name != '' else attr 147 | if name1 in names: 148 | setattr( 149 | module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) 150 | ) 151 | for name1, child in module.named_children(): 152 | make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) 153 | 154 | 155 | def model_to_half(model): 156 | model.half() 157 | for n, m in model.named_modules(): 158 | if isinstance(m, Autograd4bitQuantLinear): 159 | if m.groupsize == -1: 160 | m.zeros = m.zeros.half() 161 | m.scales = m.scales.half() 162 | m.bias = m.bias.half() 163 | print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.') 164 | 165 | 166 | def model_to_float(model): 167 | model.float() 168 | for n, m in model.named_modules(): 169 | if isinstance(m, Autograd4bitQuantLinear): 170 | if m.groupsize == -1: 171 | m.zeros = m.zeros.float() 172 | m.scales = m.scales.float() 173 | m.bias = m.bias.float() 174 | print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.') 175 | 176 | 177 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 178 | if type(module) in layers: 179 | return {name: module} 180 | res = {} 181 | for name1, child in module.named_children(): 182 | res.update(find_layers( 183 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 184 | )) 185 | return res 186 | 187 | 188 | def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): 189 | import accelerate 190 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer 191 | 192 | print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") 193 | t0 = time.time() 194 | 195 | with accelerate.init_empty_weights(): 196 | config = LlamaConfig.from_pretrained(config_path) 197 | model = LlamaForCausalLM(config) 198 | model = model.eval() 199 | layers = find_layers(model) 200 | for name in ['lm_head']: 201 | if name in layers: 202 | del layers[name] 203 | make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) 204 | model = accelerate.load_checkpoint_and_dispatch( 205 | model=model, 206 | checkpoint=model_path, 207 | device_map=device_map, 208 | no_split_module_classes=["LlamaDecoderLayer"] 209 | ) 210 | 211 | model.seqlen = seqlen 212 | 213 | if half: 214 | model_to_half(model) 215 | 216 | tokenizer = LlamaTokenizer.from_pretrained(config_path) 217 | tokenizer.truncation_side = 'left' 218 | 219 | print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") 220 | 221 | return model, tokenizer 222 | 223 | def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): 224 | import accelerate 225 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer 226 | 227 | if max_memory is None: 228 | max_memory = {0: '24Gib', 'cpu': '48Gib'} 229 | 230 | print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") 231 | t0 = time.time() 232 | 233 | with accelerate.init_empty_weights(): 234 | config = LlamaConfig.from_pretrained(config_path) 235 | model = LlamaForCausalLM(config) 236 | model = model.eval() 237 | layers = find_layers(model) 238 | for name in ['lm_head']: 239 | if name in layers: 240 | del layers[name] 241 | make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) 242 | accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'}) 243 | 244 | # rotary_emb fix 245 | for n, m in model.named_modules(): 246 | if 'rotary_emb' in n: 247 | cos_cached = m.cos_cached.clone().cpu() 248 | sin_cached = m.sin_cached.clone().cpu() 249 | break 250 | 251 | if lora_path is not None: 252 | from peft import PeftModel 253 | from peft.tuners.lora import Linear4bitLt 254 | model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) 255 | print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path)) 256 | 257 | model.seqlen = seqlen 258 | 259 | print('Apply half ...') 260 | for n, m in model.named_modules(): 261 | if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): 262 | if m.groupsize == -1: 263 | m.zeros = m.zeros.half() 264 | m.scales = m.scales.half() 265 | m.bias = m.bias.half() 266 | 267 | print('Dispatching model ...') 268 | device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) 269 | model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) 270 | torch.cuda.empty_cache() 271 | print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) 272 | 273 | # rotary_emb fix 274 | for n, m in model.named_modules(): 275 | if 'rotary_emb' in n: 276 | if getattr(m, '_hf_hook', None): 277 | if isinstance(m._hf_hook, accelerate.hooks.SequentialHook): 278 | hooks = m._hf_hook.hooks 279 | else: 280 | hooks = [m._hf_hook] 281 | for hook in hooks: 282 | if hook.offload: 283 | if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys(): 284 | hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu() 285 | hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu() 286 | 287 | tokenizer = LlamaTokenizer.from_pretrained(config_path) 288 | tokenizer.truncation_side = 'left' 289 | 290 | print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") 291 | 292 | return model, tokenizer 293 | 294 | load_llama_model_4bit_low_ram_and_offload_to_cpu = load_llama_model_4bit_low_ram_and_offload 295 | -------------------------------------------------------------------------------- /custom_autotune.py: -------------------------------------------------------------------------------- 1 | #https://github.com/fpgaminer/GPTQ-triton 2 | """ 3 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. 4 | """ 5 | 6 | import builtins 7 | import math 8 | import time 9 | from typing import Dict 10 | 11 | import triton 12 | 13 | 14 | class Autotuner(triton.KernelInterface): 15 | def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): 16 | ''' 17 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 18 | 'perf_model': performance model used to predicate running time with different configs, returns running time 19 | 'top_k': number of configs to bench 20 | 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. 21 | 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results 22 | ''' 23 | if not configs: 24 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)] 25 | else: 26 | self.configs = configs 27 | self.key_idx = [arg_names.index(k) for k in key] 28 | self.nearest_power_of_two = nearest_power_of_two 29 | self.cache = {} 30 | # hook to reset all required tensor to zeros before relaunching a kernel 31 | self.hook = lambda args: 0 32 | if reset_to_zero is not None: 33 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero] 34 | 35 | def _hook(args): 36 | for i in self.reset_idx: 37 | args[i].zero_() 38 | self.hook = _hook 39 | self.arg_names = arg_names 40 | # prune configs 41 | if prune_configs_by: 42 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] 43 | if 'early_config_prune' in prune_configs_by: 44 | early_config_prune = prune_configs_by['early_config_prune'] 45 | else: 46 | perf_model, top_k, early_config_prune = None, None, None 47 | self.perf_model, self.configs_top_k = perf_model, top_k 48 | self.early_config_prune = early_config_prune 49 | self.fn = fn 50 | 51 | def _bench(self, *args, config, **meta): 52 | # check for conflicts, i.e. meta-parameters both provided 53 | # as kwargs and by the autotuner 54 | conflicts = meta.keys() & config.kwargs.keys() 55 | if conflicts: 56 | raise ValueError( 57 | f"Conflicting meta-parameters: {', '.join(conflicts)}." 58 | " Make sure that you don't re-define auto-tuned symbols." 59 | ) 60 | # augment meta-parameters with tunable ones 61 | current = dict(meta, **config.kwargs) 62 | 63 | def kernel_call(): 64 | if config.pre_hook: 65 | config.pre_hook(self.nargs) 66 | self.hook(args) 67 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) 68 | try: 69 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses 70 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default 71 | return triton.testing.do_bench(kernel_call, rep=40) 72 | except triton.compiler.OutOfResources: 73 | return float('inf') 74 | 75 | def run(self, *args, **kwargs): 76 | self.nargs = dict(zip(self.arg_names, args)) 77 | if len(self.configs) > 1: 78 | key = tuple(args[i] for i in self.key_idx) 79 | 80 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two 81 | # In my testing this gives decent results, and greatly reduces the amount of tuning required 82 | if self.nearest_power_of_two: 83 | key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) 84 | 85 | if key not in self.cache: 86 | # prune configs 87 | pruned_configs = self.prune_configs(kwargs) 88 | bench_start = time.time() 89 | timings = {config: self._bench(*args, config=config, **kwargs) 90 | for config in pruned_configs} 91 | bench_end = time.time() 92 | self.bench_time = bench_end - bench_start 93 | self.cache[key] = builtins.min(timings, key=timings.get) 94 | self.hook(args) 95 | self.configs_timings = timings 96 | config = self.cache[key] 97 | else: 98 | config = self.configs[0] 99 | self.best_config = config 100 | if config.pre_hook is not None: 101 | config.pre_hook(self.nargs) 102 | return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) 103 | 104 | def prune_configs(self, kwargs): 105 | pruned_configs = self.configs 106 | if self.early_config_prune: 107 | pruned_configs = self.early_config_prune(self.configs, self.nargs) 108 | if self.perf_model: 109 | top_k = self.configs_top_k 110 | if isinstance(top_k, float) and top_k <= 1.0: 111 | top_k = int(len(self.configs) * top_k) 112 | if len(pruned_configs) > top_k: 113 | est_timing = { 114 | config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, 115 | num_warps=config.num_warps) 116 | for config in pruned_configs 117 | } 118 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] 119 | return pruned_configs 120 | 121 | def warmup(self, *args, **kwargs): 122 | self.nargs = dict(zip(self.arg_names, args)) 123 | for config in self.prune_configs(kwargs): 124 | self.fn.warmup( 125 | *args, 126 | num_warps=config.num_warps, 127 | num_stages=config.num_stages, 128 | **kwargs, 129 | **config.kwargs, 130 | ) 131 | self.nargs = None 132 | 133 | 134 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): 135 | """ 136 | Decorator for auto-tuning a :code:`triton.jit`'d function. 137 | .. highlight:: python 138 | .. code-block:: python 139 | @triton.autotune(configs=[ 140 | triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), 141 | triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), 142 | ], 143 | key=['x_size'] # the two above configs will be evaluated anytime 144 | # the value of x_size changes 145 | ) 146 | @triton.jit 147 | def kernel(x_ptr, x_size, **META): 148 | BLOCK_SIZE = META['BLOCK_SIZE'] 149 | :note: When all the configurations are evaluated, the kernel will run multiple time. 150 | This means that whatever value the kernel updates will be updated multiple times. 151 | To avoid this undesired behavior, you can use the `reset_to_zero` argument, which 152 | reset the value of the provided tensor to `zero` before running any configuration. 153 | :param configs: a list of :code:`triton.Config` objects 154 | :type configs: list[triton.Config] 155 | :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. 156 | :type key: list[str] 157 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 158 | 'perf_model': performance model used to predicate running time with different configs, returns running time 159 | 'top_k': number of configs to bench 160 | 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. 161 | :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. 162 | :type reset_to_zero: list[str] 163 | """ 164 | def decorator(fn): 165 | return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) 166 | 167 | return decorator 168 | -------------------------------------------------------------------------------- /data.txt: -------------------------------------------------------------------------------- 1 | The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. However, alpacas are often noticeably smaller than llamas. The two animals are closely related and can successfully crossbreed. Both species are believed to have been domesticated from their wild relatives, the vicuña and guanaco. There are two breeds of alpaca: the Suri alpaca and the Huacaya alpaca. 2 | 3 | Alpacas are kept in herds that graze on the level heights of the Andes of Southern Peru, Western Bolivia, Ecuador, and Northern Chile at an altitude of 3,500 to 5,000 metres (11,000 to 16,000 feet) above sea level.[1] Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber. 4 | 5 | Alpaca fiber is used for making knitted and woven items, similar to sheep's wool. These items include blankets, sweaters, hats, gloves, scarves, a wide variety of textiles, and ponchos, in South America, as well as sweaters, socks, coats, and bedding in other parts of the world. The fiber comes in more than 52 natural colors as classified in Peru, 12 as classified in Australia, and 16 as classified in the United States. 6 | 7 | Alpacas communicate through body language. The most common is spitting to show dominance[2] when they are in distress, fearful, or feel agitated. Male alpacas are more aggressive than females, and tend to establish dominance within their herd group. In some cases, alpha males will immobilize the head and neck of a weaker or challenging male in order to show their strength and dominance. 8 | 9 | In the textile industry, "alpaca" primarily refers to the hair of Peruvian alpacas, but more broadly it refers to a style of fabric originally made from alpaca hair, such as mohair, Icelandic sheep wool, or even high-quality wool from other breeds of sheep. In trade, distinctions are made between alpacas and the several styles of mohair and luster.[3] 10 | 11 | An adult alpaca generally is between 81 and 99 centimetres (32 and 39 inches) in height at the shoulders (withers). They usually weigh between 48 and 90 kilograms (106 and 198 pounds).[4] Raised in the same conditions, the difference in weight can be small with males weighting around 22.3 kilograms (49 lb 3 oz) and females 21.3 kilograms (46 lb 15 oz).[5] 12 | 13 | The relationship between alpacas and vicuñas was disputed for many years. In the 18th and 19th centuries, the four South American lamoid species were assigned scientific names. At that time, the alpaca was assumed to be descended from the llama, ignoring similarities in size, fleece and dentition between the alpaca and the vicuña. Classification was complicated by the fact that all four species of South American camelid can interbreed and produce fertile offspring.[6] The advent of DNA technology made a more accurate classification possible. 14 | 15 | In 2001, the alpaca genus classification changed from Lama pacos to Vicugna pacos, following the presentation of a paper[7] on work by Miranda Kadwell et al. on alpaca DNA to the Royal Society showing the alpaca is descended from the vicuña, not the guanaco. 16 | 17 | Alpacas were domesticated thousands of years ago. The Moche people of Northern Peru often used alpaca images in their art.[8] There are no known wild alpacas, and its closest living relative, the vicuña (also native to South America), is the wild ancestor of the alpaca. 18 | 19 | The family Camelidae first appeared in Americas 40–45 million years ago, during the Eocene period, from the common ancestor, Protylopus. The descendants divided into Camelini and Lamini tribes, taking different migratory patterns to Asia and South America, respectively. Although the camelids became extinct in North America around 3 million years ago, it flourished in the South with the species we see today.[9] It was not until 2–5 million years ago, during the Pliocene, that the genus Hemiauchenia of the tribe Lamini split into Palaeolama and Lama; the latter would then split again into Lama and Vicugna upon migrating down to South America. 20 | 21 | Remains of vicuña and guanaco have been found throughout Peru for around 12,000 years. Their domesticated counterparts, the llama and alpacas, have been found mummified in the Moquegua valley, in the south of Peru, dating back 900 to 1000 years. Mummies found in this region show two breeds of alpacas. More precise analysis of bone and teeth of these mummies has demonstrated that alpacas were domesticated from the Vicugna vicugna. Other research, considering the behavioral and morphological characteristics of alpacas and their wild counterparts, seems to indicate that alpacas could find their origins in Lama guanicoe as well as Vicugna vicugna, or even a hybrid of both. 22 | 23 | Genetic analysis shows a different picture of the origins of the alpaca. Analysis of mitochondrial DNA shows that most alpacas have guanaco mtDNA, and many also have vicuña mtDNA. But microsatellite data shows that alpaca DNA is much more similar to vicuña DNA than to guanaco DNA. This suggests that alpacas are descendants of the Vicugna vicugna, not of the Lama guanicoe. The discrepancy with mtDNA seems to be a result of the fact that mtDNA is only transmitted by the mother, and recent husbandry practices have caused hybridization between llamas (which primarily carry guanaco DNA) and alpacas. To the extent that many of today's domestic alpacas are the result of male alpacas bred to female llamas, this would explain the mtDNA consistent with guanacos. This situation has led to attempts to reclassify the alpaca as Vicugna pacos.[7] 24 | 25 | The alpaca comes in two breeds, Suri and Huacaya, based on their fibers rather than scientific or European classifications. 26 | 27 | Huacaya alpacas are the most commonly found, constituting about 90% of the population.[10] The Huacaya alpaca is thought to have originated in post-colonial Peru. This is due to their thicker fleece which makes them more suited to survive in the higher altitudes of the Andes after being pushed into the highlands of Peru with the arrival of the Spanish.[11][better source needed] 28 | 29 | Suri alpacas represent a smaller portion of the total alpaca population, around 10%.[10] They are thought to have been more prevalent in pre-Columbian Peru since they could be kept at a lower altitude where a thicker fleece was not needed for harsh weather conditions.[11][better source needed] 30 | 31 | Alpacas are social herd animals that live in family groups, consisting of a territorial alpha male, females, and their young ones. Alpacas warn the herd about intruders by making sharp, noisy inhalations that sound like a high-pitched bray. The herd may attack smaller predators with their front feet and can spit and kick. Their aggression towards members of the canid family (coyotes, foxes, dogs etc.) is exploited when alpacas are used as guard llamas for guarding sheep.[12][13] 32 | 33 | Alpacas can sometimes be aggressive, but they can also be very gentle, intelligent, and extremely observant. For the most part, alpacas are very quiet, but male alpacas are more energetic when they get involved in fighting with other alpacas.[14] When they prey, they are cautious but also nervous when they feel any type of threat. They can feel threatened when a person or another alpaca comes up from behind them.[15][better source needed] 34 | 35 | Alpacas set their own boundaries of "personal space" within their families and groups.[16] They make a hierarchy in some sense, and each alpaca is aware of the dominant animals in each group.[14] Body language is the key to their communication. It helps to maintain their order. One example of their body communication includes a pose named broadside, where their ears are pulled back and they stand sideways. This pose is used when male alpacas are defending their territory.[2] 36 | 37 | When they are young, they tend to follow larger objects and to sit near or under them. An example of this is a baby alpaca with its mother. This can also apply when an alpaca passes by an older alpaca.[16] 38 | 39 | Training 40 | Alpacas are generally very trainable and usually respond to reward, most commonly in the form of food. They can usually be petted without getting agitated, especially if one avoids petting the head or neck. Alpacas are usually quite easy to herd, even in large groups. However, during herding, it is recommended for the handler to approach the animals slowly and quietly, as failing to do so can result in danger for both the animals and the handler.[17] 41 | 42 | Alpacas and llamas have started showing up in U.S. nursing homes and hospitals as trained, certified therapy animals. The Mayo Clinic says animal-assisted therapy can reduce pain, depression, anxiety, and fatigue. This type of animal therapy is growing in popularity, and there are several organizations throughout the United States that participate.[18] 43 | 44 | Spitting 45 | Not all alpacas spit, but all are capable of doing so. "Spit" is somewhat euphemistic; occasionally the projectile contains only air and a little saliva, although alpacas commonly bring up acidic stomach contents (generally a green, grassy mix) and project it onto their chosen targets. Spitting is mostly reserved for other alpacas, but an alpaca will also occasionally spit at a human. 46 | 47 | Spitting can result in what is called "sour mouth". Sour mouth is characterized by "a loose-hanging lower lip and a gaping mouth."[19] 48 | 49 | Alpacas can spit for several reasons. A female alpaca spits when she is not interested in a male alpaca, typically when she thinks that she is already impregnated. Both sexes of alpaca keep others away from their food, or anything they have their eyes on. Most give a slight warning before spitting by blowing air out and raising their heads, giving their ears a "pinned" appearance.[16] 50 | 51 | Alpacas can spit up to ten feet if they need to. For example, if another animal does not back off, the alpaca will throw up its stomach contents, resulting in a lot of spit.[20] 52 | 53 | Some signs of stress which can lead to their spitting habits include: humming, a wrinkle under their eye, drooling, rapid breathing, and stomping their feet. When alpacas show any sign of interest or alertness, they tend to sniff their surroundings, watch closely, or stand quietly in place and stare.[20] 54 | 55 | When it comes to reproduction, they spit because it is a response triggered by the progesterone levels being increased, which is associated with ovulation.[21] 56 | 57 | Hygiene 58 | Alpacas use a communal dung pile,[22] where they do not graze. This behaviour tends to limit the spread of internal parasites. Generally, males have much tidier, and fewer dung piles than females, which tend to stand in a line and all go at once. One female approaches the dung pile and begins to urinate and/or defecate, and the rest of the herd often follows. Alpaca waste is collected and used as garden fertilizer or even natural fertilizer.[2] 59 | 60 | Because of their preference for using a dung pile for excreting bodily waste, some alpacas have been successfully house-trained.[23] 61 | 62 | Alpacas develop dental hygiene problems which affect their eating and behavior. Warning signs include protracted chewing while eating, or food spilling out of their mouths. Poor body condition and sunken cheeks are also telltales of dental problems. 63 | 64 | Alpacas make a variety of sounds: 65 | 66 | Humming: When alpacas are born, the mother and baby hum constantly. They also hum as a sign of distress, especially when they are separated from their herd. Alpacas may also hum when curious, happy, worried or cautious. 67 | Snorting: Alpacas snort when another alpaca is invading its space. 68 | Grumbling: Alpacas grumble to warn each other. For example, when one is invading another's personal space, it sounds like gurgling. 69 | Clucking: Similar to a hen's cluck, alpacas cluck when a mother is concerned for her cria. Male alpacas cluck to signal friendly behavior.[2] 70 | Screaming: Their screams are extremely deafening and loud. They will scream when they are not handled correctly or when they are being attacked by a potential enemy. 71 | Screeching: A bird-like cry, presumably intended to terrify the opponent. This sound is typically used by male alpacas when they are in a fight over dominance. When a female screeches, it is more of a growl when she is angry. 72 | 73 | Females are induced ovulators;[24] meaning the act of mating and the presence of semen causes them to ovulate. Females usually conceive after just one breeding, but occasionally do have trouble conceiving. Artificial insemination is technically difficult, expensive and not common, but it can be accomplished. Embryo transfer is more widespread. 74 | 75 | A male is usually ready to mate for the first time between two and three years of age. It is not advisable to allow a young female to be bred until she is mature and has reached two-thirds of her mature weight. Over-breeding a young female before conception is possibly a common cause of uterine infections. As the age of maturation varies greatly between individuals, it is usually recommended that novice breeders wait until females are 18 months of age or older before initiating breeding.[25] 76 | 77 | Alpacas can breed at any time throughout the year but it is more difficult to breed in the winter. Most breed during autumn or late spring. The most popular way to have alpacas mate is pen mating. Pen mating is when they move both the female and the desired male into a pen. Another way is paddock mating where one male alpaca is let loose in the paddock with several female alpacas. 78 | 79 | The gestation period is, on average, 11.5 months, and usually results in a single offspring, or cria. Twins are rare, occurring about once per 1000 deliveries.[26] Cria are generally between 15 and 19 pounds, and are standing 30 to 90 minutes after birth.[27] After a female gives birth, she is generally receptive to breeding again after about two weeks. Crias may be weaned through human intervention at about six months old and 60 pounds, but many breeders prefer to allow the female to decide when to wean her offspring; they can be weaned earlier or later depending on their size and emotional maturity. 80 | 81 | The average lifespan of an alpaca is between 15–20 years, and the longest-lived alpaca on record is 27 years.[28] 82 | 83 | Cattle tuberculosis can also infect alpacas: Mycobacterium bovis also causes TB in this species worldwide.[29] Krajewska‐Wędzina et al., 2020 detect M. bovis in individuals traded from the United Kingdom to Poland.[29] To accomplish this they develop a seroassay which correctly identifies positive subjects which are false negative for a common skin test.[29] Krajewska‐Wędzina et al. also find that alpacas are unusual in mounting a competent early-infection immune response.[29] Bernitz et al., 2021 believe this to generalise to all camelids.[29] 84 | 85 | Alpacas can be found throughout most of South America.[30] They typically live in temperate conditions in the mountains with high altitudes. 86 | 87 | They are easy to care for since they are not limited to a specific type of environment. Animals such as flamingos, condors, spectacled bears, mountain lions, coyotes, llamas, and sheep live near alpacas when they are in their natural habitat. 88 | 89 | Alpacas are native to Peru, but can be found throughout the globe in captivity.[30] Peru currently has the largest alpaca population, with over half the world's animals.[31] The population declined drastically after the Spanish Conquistadors invaded the Andes mountains in 1532, after which 98% of the animals were destroyed. The Spanish also brought with them diseases that were fatal to alpacas.[32] 90 | 91 | European conquest forced the animals to move higher into the mountains,[how?] which remained there permanently. Although alpacas had almost been wiped out completely, they were rediscovered sometime during the 19th century by Europeans. After finding uses for them, the animals became important to societies during the industrial revolution.[33] 92 | 93 | Nuzzle and Scratch was a British children's television programme featuring two fictional alpacas that was first broadcast between 2008 and 2011.[34] 94 | 95 | Interest in alpacas grew as a result of Depp v. Heard, the 2022 trial in which Johnny Depp sued Amber Heard for defamation in Virginia after Heard wrote an op-ed saying she was a public victim of domestic violence. Depp testified, under oath, that he would not make another Pirates of the Caribbean film for "300 million dollars and a million alpacas".[35][36][37] 96 | 97 | Alpacas chew their food which ends up being mixed with their cud and saliva and then they swallow it. Alpacas usually eat 1.5% of their body weight daily for normal growth.[38] They mainly need pasture grass, hay, or silage but some may also need supplemental energy and protein foods and they will also normally try to chew on almost anything (e.g. empty bottle). Most alpaca ranchers rotate their feeding grounds so the grass can regrow and fecal parasites may die before reusing the area. Pasture grass is a great source of protein. When seasons change, the grass loses or gains more protein. For example, in the spring, the pasture grass has about 20% protein while in the summer, it only has 6%.[38] They need more energy supplements in the winter to produce body heat and warmth. They get their fiber from hay or from long stems which provides them with vitamin E. Green grass contains vitamin A and E. 98 | 99 | Alpacas can eat natural unfertilized grass; however, ranchers can also supplement grass with low-protein grass hay. To provide selenium and other necessary vitamins, ranchers will feed their domestic alpacas a daily dose of grain to provide additional nutrients that are not fully obtained from their primary diet.[39] Alpacas may obtain the necessary vitamins in their native grazing ranges. 100 | 101 | Alpacas, like other camelids, have a three-chambered stomach; combined with chewing cud, this three-chambered system allows maximum extraction of nutrients from low-quality forages. Alpacas are not ruminants, pseudo-ruminants, or modified ruminants, as there are many differences between the anatomy and physiology of a camelid and a ruminant stomach.[40] 102 | 103 | Alpacas will chew their food in a figure eight motion, swallow the food, and then pass it into one of the stomach's chambers. The first and second chambers (called C1 and C2) are anaerobic fermentation chambers where the fermentation process begins. The alpaca will further absorb nutrients and water in the first part of the third chamber. The end of the third chamber (called C3) is where the stomach secretes acids to digest food and is the likely place where an alpaca will have ulcers if stressed. 104 | 105 | Many plants are poisonous to the alpaca, including the bracken fern, Madagascar ragwort, oleander, and some azaleas. In common with similar livestock, others include acorns, African rue, agave, amaryllis, autumn crocus, bear grass, broom snakeweed, buckwheat, ragweed, buttercups, calla lily, orange tree foliage, carnations, castor beans, and many others.[41] 106 | 107 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD) 3 | Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit 4 | 5 | SAD structure: 6 | [ 7 | { 8 | "instruction": "Give null hypothesis", 9 | "input": "6 subjects were given a drug (treatment group) and an additional 6 subjects a placebo (control group).", 10 | "output": "Drug is equivalent of placebo" 11 | }, 12 | { 13 | "instruction": "What does RNA stand for?", 14 | "input": "", 15 | "output": "RNA stands for ribonucleic acid." 16 | } 17 | ] 18 | """ 19 | # Early load config to replace attn if needed 20 | from arg_parser import get_config 21 | ft_config = get_config() 22 | 23 | if ft_config.flash_attention: 24 | from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 25 | replace_llama_attn_with_flash_attn() 26 | 27 | import autograd_4bit 28 | if ft_config.backend.lower() == 'triton': 29 | autograd_4bit.switch_backend_to('triton') 30 | else: 31 | autograd_4bit.switch_backend_to('cuda') 32 | 33 | import sys 34 | 35 | import peft 36 | import peft.tuners.lora 37 | assert peft.tuners.lora.is_gptq_available() 38 | 39 | import torch 40 | import transformers 41 | from autograd_4bit import load_llama_model_4bit_low_ram 42 | from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel 43 | 44 | # ! Config 45 | import train_data 46 | 47 | 48 | 49 | # * Show loaded parameters 50 | if ft_config.local_rank == 0: 51 | print(f"{ft_config}\n") 52 | 53 | if ft_config.gradient_checkpointing: 54 | print('Disable Dropout.') 55 | 56 | # Load Basic Model 57 | model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, 58 | ft_config.llama_q4_model, 59 | device_map=ft_config.device_map, 60 | groupsize=ft_config.groupsize) 61 | 62 | # Config Lora 63 | lora_config = LoraConfig( 64 | r=ft_config.lora_r, 65 | lora_alpha=ft_config.lora_alpha, 66 | target_modules=["q_proj", "v_proj"], 67 | lora_dropout=ft_config.lora_dropout, 68 | bias="none", 69 | task_type="CAUSAL_LM", 70 | ) 71 | if ft_config.lora_apply_dir is None: 72 | model = get_peft_model(model, lora_config) 73 | else: 74 | device_map = ft_config.device_map 75 | if ft_config.ddp: 76 | device_map = {'': 0} 77 | else: 78 | if torch.cuda.device_count() > 1: 79 | device_map = "auto" 80 | else: 81 | device_map = {'': 0} 82 | print('Device map for lora:', device_map) 83 | model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32) 84 | print(ft_config.lora_apply_dir, 'loaded') 85 | 86 | 87 | # Scales to half 88 | print('Fitting 4bit scales and zeros to half') 89 | for n, m in model.named_modules(): 90 | if '4bit' in str(type(m)): 91 | if m.groupsize == -1: 92 | m.zeros = m.zeros.half() 93 | m.scales = m.scales.half() 94 | 95 | # Set tokenizer 96 | tokenizer.pad_token_id = 0 97 | 98 | if not ft_config.skip: 99 | # Load Data 100 | data = None 101 | if ft_config.ds_type == "txt" and not ft_config.skip: 102 | #### LLaMa 103 | data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len) 104 | elif ft_config.ds_type == "alpaca" and not ft_config.skip: 105 | #### Stanford Alpaca-like Data 106 | data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len) 107 | elif ft_config.ds_type == "gpt4all" and not ft_config.skip: 108 | #### GPT4All Data 109 | data = train_data.TrainGPT4All(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len) 110 | else: 111 | raise NotImplementedError("ERROR: Unknown dataset format") 112 | data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token) 113 | #### 114 | 115 | # Use gradient checkpointing 116 | if ft_config.gradient_checkpointing: 117 | print('Applying gradient checkpointing ...') 118 | from gradient_checkpointing import apply_gradient_checkpointing 119 | apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio) 120 | 121 | # Disable Trainer's DataParallel for multigpu 122 | if not ft_config.ddp and torch.cuda.device_count() > 1: 123 | model.is_parallelizable = True 124 | model.model_parallel = True 125 | 126 | training_arguments = transformers.TrainingArguments( 127 | per_device_train_batch_size=ft_config.mbatch_size, 128 | gradient_accumulation_steps=ft_config.gradient_accumulation_steps, 129 | warmup_steps=ft_config.warmup_steps, 130 | optim="adamw_torch", 131 | num_train_epochs=ft_config.epochs, 132 | learning_rate=ft_config.lr, 133 | fp16=True, 134 | logging_steps=ft_config.logging_steps, 135 | evaluation_strategy="no", 136 | save_strategy="steps", 137 | eval_steps=None, 138 | save_steps=ft_config.save_steps, 139 | output_dir=ft_config.lora_out_dir, 140 | save_total_limit=ft_config.save_total_limit, 141 | load_best_model_at_end=False, 142 | ddp_find_unused_parameters=False if ft_config.ddp else None, 143 | ) 144 | 145 | trainer = transformers.Trainer( 146 | model=model, 147 | train_dataset=data.train_data, 148 | eval_dataset=data.val_data, 149 | args=training_arguments, 150 | data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), 151 | ) 152 | model.config.use_cache = False 153 | 154 | # Set Model dict 155 | old_state_dict = model.state_dict 156 | model.state_dict = ( 157 | lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) 158 | ).__get__(model, type(model)) 159 | 160 | # Set Verbose 161 | if ft_config.verbose: 162 | transformers.logging.set_verbosity_info() 163 | 164 | # Run Trainer 165 | if ft_config.resume_checkpoint: 166 | print('Resuming from {} ...'.format(ft_config.resume_checkpoint)) 167 | trainer.train(ft_config.resume_checkpoint) 168 | else: 169 | trainer.train() 170 | 171 | print('Train completed.') 172 | 173 | # Save Model 174 | model.save_pretrained(ft_config.lora_out_dir) 175 | 176 | if ft_config.checkpoint: 177 | print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.") 178 | 179 | print('Model Saved.') 180 | -------------------------------------------------------------------------------- /gradient_checkpointing.py: -------------------------------------------------------------------------------- 1 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 2 | from torch.utils.checkpoint import checkpoint 3 | from torch.autograd import Variable 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | 9 | class NewForward: 10 | 11 | def __init__(self, layer): 12 | self.layer = layer 13 | self.apply_patch() 14 | 15 | def apply_patch(self): 16 | self.layer.old_forward_for_cp = self.layer.forward 17 | self.layer.forward = self.new_forward 18 | 19 | def new_forward(self, *args, **kwargs): 20 | def func(*args): 21 | return self.layer.old_forward_for_cp(*args, **kwargs) 22 | output = checkpoint(func, *args) 23 | return output 24 | 25 | 26 | class VarWrapper: 27 | 28 | def __init__(self, model): 29 | self.model = model 30 | self.apply_patch() 31 | print('Var Wrapper Patch Applied') 32 | 33 | def apply_patch(self): 34 | self.model.old_forward_for_cp = self.model.forward 35 | self.model.forward = self.new_forward 36 | 37 | def new_forward(self, *args, **kwargs): 38 | out = self.model.old_forward_for_cp(*args, **kwargs) 39 | out = Variable(out.data, requires_grad=True) 40 | return out 41 | 42 | 43 | def apply_gradient_checkpointing(model, checkpoint_ratio=1): 44 | new_forwards = [] 45 | modules = [] 46 | for n, m in model.named_modules(): 47 | if isinstance(m, LlamaDecoderLayer): 48 | modules.append(m) 49 | if checkpoint_ratio < 1 and checkpoint_ratio > 0: 50 | checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int) 51 | else: 52 | checkpoint_locs = np.arange(len(modules)) 53 | for i in checkpoint_locs: 54 | m = modules[i] 55 | new_forwards.append(NewForward(m)) 56 | print('Forward Patch Applied For Block {}'.format(i)) 57 | for n, m in model.named_modules(): 58 | if isinstance(m, torch.nn.Embedding): 59 | wrapper = VarWrapper(m) 60 | break 61 | return new_forwards, wrapper 62 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear 6 | config_path = './llama-13b-4bit/' 7 | model_path = './llama-13b-4bit.pt' 8 | model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1) 9 | 10 | print('Fitting 4bit scales and zeros to half') 11 | model.half() 12 | for n, m in model.named_modules(): 13 | if isinstance(m, Autograd4bitQuantLinear): 14 | if m.groupsize == -1: 15 | m.zeros = m.zeros.half() 16 | m.scales = m.scales.half() 17 | m.bias = m.bias.half() 18 | 19 | print('Apply AMP Wrapper ...') 20 | from amp_wrapper import AMPWrapper 21 | wrapper = AMPWrapper(model) 22 | wrapper.apply_generate() 23 | 24 | prompt = '''I think the meaning of life is''' 25 | batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) 26 | batch = {k: v.cuda() for k, v in batch.items()} 27 | 28 | start = time.time() 29 | with torch.no_grad(): 30 | generated = model.generate(inputs=batch["input_ids"], 31 | do_sample=True, use_cache=True, 32 | repetition_penalty=1.1, 33 | max_new_tokens=20, 34 | temperature=0.9, 35 | top_p=0.95, 36 | top_k=40, 37 | return_dict_in_generate=True, 38 | output_attentions=False, 39 | output_hidden_states=False, 40 | output_scores=False) 41 | result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0]) 42 | end = time.time() 43 | print(result_text) 44 | print(end - start) 45 | -------------------------------------------------------------------------------- /matmul_utils_4bit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from gptq_llama import quant_cuda 4 | 5 | 6 | # Global Buffer 7 | buffer_mat_dic = {} 8 | use_new = True 9 | auto_switch = True 10 | auto_switch_thd = 8 11 | debug = False 12 | 13 | 14 | def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): 15 | if shape_of_qweight not in buffer_mat_dic.keys(): 16 | buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) 17 | else: 18 | if buffer_mat_dic[shape_of_qweight].device != device: 19 | buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device) 20 | if buffer_mat_dic[shape_of_qweight].dtype != dtype: 21 | buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(dtype=dtype) 22 | return buffer_mat_dic[shape_of_qweight] 23 | 24 | 25 | def _matmul4bit_v1(x, qweight, scales, zeros): 26 | """ 27 | input x: (n, m) 28 | qweight: (j, k) 29 | where m == j*8 30 | 31 | perform x @ qweight 32 | 33 | return y: 34 | """ 35 | if debug: 36 | print('_matmul4bit_v1') 37 | assert qweight.shape[0] * 8 == x.shape[-1] 38 | outshape = x.shape[:-1] + (qweight.shape[1],) 39 | x = x.reshape(-1, x.shape[-1]) 40 | y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) 41 | dtype = x.dtype 42 | x = x.half() 43 | quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros) 44 | y = y.to(dtype) 45 | return y.reshape(outshape) 46 | 47 | 48 | def _matmul4bit_v2(x, qweight, scales, zeros, groupsize): 49 | """ 50 | input x: (n, m) 51 | qweight: (j, k) 52 | where m == j*8 53 | 54 | perform x @ qweight 55 | 56 | return y: 57 | """ 58 | if debug: 59 | print('_matmul4bit_v2') 60 | assert qweight.shape[0] * 8 == x.shape[-1] 61 | outshape = x.shape[:-1] + (qweight.shape[1],) 62 | x = x.reshape(-1, x.shape[-1]) 63 | y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) 64 | dtype = x.dtype 65 | x = x.half() 66 | quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2) 67 | y = y.to(dtype) 68 | return y.reshape(outshape) 69 | 70 | 71 | def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False): 72 | if debug: 73 | print('_matmul4bit_v1_recons') 74 | if not transpose: 75 | assert qweight.shape[0] * 8 == x.shape[-1] 76 | else: 77 | assert qweight.shape[1] == x.shape[-1] 78 | buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) 79 | quant_cuda.vecquant4recons_v1(qweight, buffer, scales, zeros) 80 | if not transpose: 81 | output = torch.matmul(x, buffer) 82 | else: 83 | output = torch.matmul(x, buffer.T) 84 | return output 85 | 86 | 87 | def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False): 88 | if debug: 89 | print('_matmul4bit_v2_recons') 90 | if not transpose: 91 | assert qweight.shape[0] * 8 == x.shape[-1] 92 | else: 93 | assert qweight.shape[1] == x.shape[-1] 94 | buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) 95 | quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize) 96 | if not transpose: 97 | output = torch.matmul(x, buffer) 98 | else: 99 | output = torch.matmul(x, buffer.T) 100 | return output 101 | 102 | 103 | def matmul4bit(x, qweight, scales, zeros, groupsize=-1): 104 | if groupsize == -1: 105 | # use v1 106 | if use_new: 107 | if auto_switch: 108 | if np.prod(x.shape[:-1]) > auto_switch_thd: 109 | output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros) 110 | else: 111 | output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) 112 | else: 113 | output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) 114 | else: 115 | # use v2 116 | if use_new: 117 | if auto_switch: 118 | if np.prod(x.shape[:-1]) > auto_switch_thd: 119 | output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, groupsize) 120 | else: 121 | output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) 122 | else: 123 | output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) 124 | return output 125 | 126 | 127 | def v2_to_v1(scales, zeros): 128 | """ 129 | Convert zeros in V2 model to V1 model when group_num = 1, for debugging 130 | """ 131 | assert zeros.shape[0] == 1 132 | z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1)) 133 | z_buffer = torch.zeros((z_mat.shape[0] * 8, z_mat.shape[1]), dtype=torch.float16, device=zeros.device) 134 | z_zeros = torch.zeros(z_mat.shape[1], dtype=torch.float16, device=zeros.device) 135 | z_scales = torch.ones(z_mat.shape[1], dtype=torch.float16, device=zeros.device) 136 | quant_cuda.vecquant4recons_v1(z_mat, z_buffer, z_scales, z_zeros) 137 | z_buffer = z_buffer[:,0] 138 | zeros_recons = z_buffer * scales + scales 139 | return zeros_recons 140 | -------------------------------------------------------------------------------- /monkeypatch/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | 14 | class LlamaAttention(nn.Module): 15 | """Multi-headed attention from 'Attention Is All You Need' paper""" 16 | 17 | def __init__( 18 | self, 19 | config: LlamaConfig, 20 | ): 21 | super().__init__() 22 | hidden_size = config.hidden_size 23 | num_heads = config.num_attention_heads 24 | self.hidden_size = hidden_size 25 | self.num_heads = num_heads 26 | self.head_dim = self.hidden_size // num_heads 27 | 28 | if (self.head_dim * num_heads) != self.hidden_size: 29 | raise ValueError( 30 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 31 | f" and `num_heads`: {num_heads}).") 32 | self.q_proj = nn.Linear( 33 | hidden_size, 34 | num_heads * self.head_dim, 35 | bias=False, 36 | ) 37 | self.k_proj = nn.Linear( 38 | hidden_size, 39 | num_heads * self.head_dim, 40 | bias=False, 41 | ) 42 | self.v_proj = nn.Linear( 43 | hidden_size, 44 | num_heads * self.head_dim, 45 | bias=False, 46 | ) 47 | self.o_proj = nn.Linear( 48 | num_heads * self.head_dim, 49 | hidden_size, 50 | bias=False, 51 | ) 52 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) 53 | 54 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 55 | return tensor.view(bsz, seq_len, self.num_heads, 56 | self.head_dim).transpose(1, 2).contiguous() 57 | 58 | def forward( 59 | self, 60 | hidden_states: torch.Tensor, 61 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 62 | attention_mask: Optional[torch.Tensor] = None, 63 | position_ids: Optional[torch.LongTensor] = None, 64 | output_attentions: bool = False, 65 | use_cache: bool = False, 66 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 67 | Optional[Tuple[torch.Tensor]]]: 68 | """Input shape: Batch x Time x Channel 69 | 70 | attention_mask: [bsz, q_len] 71 | """ 72 | bsz, q_len, _ = hidden_states.size() 73 | 74 | query_states = self.q_proj(hidden_states).view( 75 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 76 | key_states = self.k_proj(hidden_states).view( 77 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 78 | value_states = self.v_proj(hidden_states).view( 79 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 80 | # [bsz, q_len, nh, hd] 81 | # [bsz, nh, q_len, hd] 82 | 83 | kv_seq_len = key_states.shape[-2] 84 | if past_key_value is not None: 85 | kv_seq_len += past_key_value[0].shape[-2] 86 | 87 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 88 | query_states, key_states = apply_rotary_pos_emb(query_states, 89 | key_states, 90 | cos, 91 | sin, 92 | position_ids) 93 | # [bsz, nh, t, hd] 94 | assert not output_attentions, "output_attentions is not supported" 95 | assert not use_cache, "use_cache is not supported" 96 | assert past_key_value is None, "past_key_value is not supported" 97 | 98 | # Flash attention codes from 99 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 100 | 101 | # transform the data into the format required by flash attention 102 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 103 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 104 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 105 | # the attention_mask should be the same as the key_padding_mask 106 | key_padding_mask = attention_mask 107 | 108 | 109 | if key_padding_mask is None: 110 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 111 | max_s = q_len 112 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 113 | device=qkv.device) 114 | output = flash_attn_unpadded_qkvpacked_func( 115 | qkv, cu_q_lens, max_s, 0.0, 116 | softmax_scale=None, causal=True 117 | ) 118 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 119 | else: 120 | nheads = qkv.shape[-2] 121 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 122 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 123 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 124 | output_unpad = flash_attn_unpadded_qkvpacked_func( 125 | x_unpad, cu_q_lens, max_s, 0.0, 126 | softmax_scale=None, causal=True 127 | ) 128 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 129 | indices, bsz, q_len), 130 | 'b s (h d) -> b s h d', h=nheads) 131 | return self.o_proj(rearrange(output, 132 | 'b s h d -> b s (h d)')), None, None 133 | 134 | 135 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 136 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 137 | inputs_embeds, past_key_values_length): 138 | # [bsz, seq_len] 139 | return attention_mask 140 | 141 | 142 | def replace_llama_attn_with_flash_attn(): 143 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 144 | transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention 145 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | accelerate 3 | bitsandbytes 4 | datasets 5 | sentencepiece 6 | safetensors 7 | flash-attn 8 | triton 9 | colorama 10 | git+https://github.com/huggingface/transformers.git 11 | git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit 12 | git+https://github.com/sterlind/peft.git 13 | -------------------------------------------------------------------------------- /text-generation-webui/custom_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import autograd_4bit 4 | from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear 5 | from peft import PeftModel 6 | from peft.tuners.lora import Linear4bitLt 7 | 8 | patch_encode_func = False 9 | 10 | def load_model_llama(*args, **kwargs): 11 | 12 | config_path = '../llama-13b-4bit/' 13 | model_path = '../llama-13b-4bit.pt' 14 | lora_path = '../alpaca13b_lora/' 15 | 16 | print("Loading {} ...".format(model_path)) 17 | t0 = time.time() 18 | 19 | model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1) 20 | 21 | model = PeftModel.from_pretrained(model, lora_path, device_map={'': 0}, torch_dtype=torch.float32) 22 | print('{} Lora Applied.'.format(lora_path)) 23 | 24 | print('Apply auto switch and half') 25 | for n, m in model.named_modules(): 26 | if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): 27 | if m.groupsize == -1: 28 | m.zeros = m.zeros.half() 29 | m.scales = m.scales.half() 30 | m.bias = m.bias.half() 31 | autograd_4bit.use_new = True 32 | autograd_4bit.auto_switch = True 33 | 34 | return model, tokenizer 35 | 36 | # Monkey Patch 37 | from modules import models 38 | from modules import shared 39 | models.load_model = load_model_llama 40 | shared.args.model = 'llama-13b-4bit' 41 | shared.settings['name1'] = 'You' 42 | shared.settings['name2'] = 'Assistant' 43 | shared.settings['chat_prompt_size_max'] = 2048 44 | shared.settings['chat_prompt_size'] = 2048 45 | 46 | if patch_encode_func: 47 | from modules import text_generation 48 | text_generation.encode_old = text_generation.encode 49 | def encode_patched(*args, **kwargs): 50 | input_ids = text_generation.encode_old(*args, **kwargs) 51 | if input_ids[0,0] == 0: 52 | input_ids = input_ids[:, 1:] 53 | return input_ids 54 | text_generation.encode = encode_patched 55 | print('Encode Function Patched.') 56 | 57 | print('Monkey Patch Completed.') 58 | -------------------------------------------------------------------------------- /train_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, Any 5 | from datasets import load_dataset, Dataset 6 | from torch.utils.data import DataLoader 7 | from transformers import DefaultDataCollator 8 | import os 9 | 10 | 11 | # Abstract train data loader 12 | class ATrainData(ABC): 13 | """ 14 | """ 15 | @abstractmethod 16 | def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len: int) -> None: 17 | """ 18 | Args: 19 | dataset (str): Path to dataset 20 | val_set_size (int) : Size of validation set 21 | tokenizer (_type_): Tokenizer 22 | """ 23 | self.tokenizer = tokenizer 24 | self.dataset = dataset 25 | self.val_set_size = val_set_size 26 | self.cutoff_len = cutoff_len 27 | self.train_data = None 28 | self.val_data = None 29 | 30 | @abstractmethod 31 | def tokenize(self, prompt: str) -> Dict[str, Any]: 32 | """Tokenization method 33 | 34 | Args: 35 | prompt (str): Prompt string from dataset 36 | 37 | Returns: 38 | Dict[str, Any]: token 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | def prepare_data(self) -> None: 44 | """Loads dataset from file and prepares train_data property for trainer 45 | """ 46 | pass 47 | 48 | 49 | # LLaMA txt train data loader 50 | class TrainTxt(ATrainData): 51 | def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len): 52 | super().__init__(dataset, val_set_size, tokenizer, cutoff_len) # TODO: Validation size isn't used 53 | self.cutoff_len = cutoff_len 54 | self.exceed_count = 0 55 | 56 | def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: 57 | # there's probably a way to do this with the tokenizer settings 58 | # but again, gotta move fast 59 | if use_eos_token: 60 | result = self.tokenizer( 61 | prompt + self.tokenizer.eos_token, 62 | truncation=True, 63 | max_length=self.cutoff_len, 64 | padding=False, 65 | ) 66 | d = { 67 | "input_ids": result["input_ids"], 68 | "attention_mask": result["attention_mask"], 69 | } 70 | if ( 71 | d["input_ids"][-1] != self.tokenizer.eos_token_id 72 | and len(d["input_ids"]) < self.cutoff_len 73 | ): 74 | d["input_ids"].append(self.tokenizer.eos_token_id) 75 | d["attention_mask"].append(1) 76 | else: 77 | result = self.tokenizer( 78 | prompt, 79 | truncation=True, 80 | max_length=self.cutoff_len + 1, 81 | padding="max_length", 82 | ) 83 | d = { 84 | "input_ids": result["input_ids"][:-1], 85 | "attention_mask": result["attention_mask"][:-1], 86 | } 87 | if sum(d['attention_mask']) >= self.cutoff_len: 88 | self.exceed_count += 1 89 | return d 90 | 91 | @classmethod 92 | def format_new_rows(cls, rows, thd=128): 93 | r_b = '' 94 | new_rows = [] 95 | for row in rows: 96 | if len(r_b) == 0: 97 | r_b += row 98 | else: 99 | r_b += '\n' + row 100 | if len(r_b) > thd: 101 | new_rows.append(r_b) 102 | r_b = '' 103 | if len(r_b) > thd: 104 | new_rows.append(r_b) 105 | r_b = '' 106 | return new_rows 107 | 108 | def prepare_data(self, thd=-1, use_eos_token=True, **kwargs): 109 | if os.path.isdir(self.dataset): 110 | rows = [] 111 | for filename in os.listdir(self.dataset): 112 | with open(self.dataset + filename, 'r', encoding='utf8') as file: 113 | txt = file.read() 114 | txt = txt.replace('\r\n', '\n').replace('\u3000', ' ') 115 | rows += [r for r in txt.split('\n') if r != ''] 116 | else: 117 | with open(self.dataset, 'r', encoding='utf8') as file: 118 | txt = file.read() 119 | txt = txt.replace('\r\n', '\n') 120 | rows = [r for r in txt.split('\n') if r != ''] 121 | if thd != -1: 122 | rows = self.format_new_rows(rows, thd=thd) 123 | data = Dataset.from_dict({"input": rows}) 124 | data = data.shuffle().map(lambda x: self.tokenize(x["input"], use_eos_token=use_eos_token)) 125 | print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers') 126 | self.train_data = data 127 | 128 | 129 | # Stanford Alpaca-like Data 130 | class TrainSAD(ATrainData): 131 | def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None: 132 | super().__init__(dataset, val_set_size, tokenizer, cutoff_len) 133 | 134 | def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: 135 | # there's probably a way to do this with the tokenizer settings 136 | # but again, gotta move fast 137 | if use_eos_token: 138 | result = self.tokenizer( 139 | prompt + self.tokenizer.eos_token, 140 | truncation=True, 141 | max_length=self.cutoff_len, 142 | padding=False, 143 | ) 144 | if ( 145 | result["input_ids"][-1] != self.tokenizer.eos_token_id 146 | and len(result["input_ids"]) < self.cutoff_len 147 | ): 148 | result["input_ids"].append(self.tokenizer.eos_token_id) 149 | result["attention_mask"].append(1) 150 | return result 151 | else: 152 | result = self.tokenizer( 153 | prompt, 154 | truncation=True, 155 | max_length=self.cutoff_len + 1, 156 | padding="max_length", 157 | ) 158 | return { 159 | "input_ids": result["input_ids"][:-1], 160 | "attention_mask": result["attention_mask"][:-1], 161 | } 162 | 163 | def prepare_data(self, use_eos_token=True, **kwargs) -> None: 164 | data = load_dataset("json", data_files=self.dataset) 165 | 166 | if self.val_set_size > 0: 167 | train_val = data["train"].train_test_split( 168 | test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?) 169 | ) 170 | self.train_data = train_val["train"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token)) 171 | self.val_data = train_val["test"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token)) 172 | else: 173 | self.train_data = data["train"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token)) 174 | self.val_data = None 175 | 176 | # Auxiliary methods 177 | def generate_prompt(self, data_point, **kwargs): 178 | return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format( 179 | "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.", 180 | "### Instruction:", 181 | data_point["instruction"], 182 | "### Input:", 183 | data_point["input"], 184 | "### Response:", 185 | data_point["output"] 186 | ) 187 | 188 | def generate_and_tokenize_prompt(self, data_point, **kwargs): 189 | prompt = self.generate_prompt(data_point, **kwargs) 190 | return self.tokenize(prompt, **kwargs) 191 | 192 | # GPT4All-like Data 193 | class TrainGPT4All(ATrainData): 194 | def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None: 195 | super().__init__(dataset, val_set_size, tokenizer, cutoff_len) 196 | 197 | def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: 198 | pass 199 | 200 | def tokenize_inputs(self, examples): 201 | max_length = self.cutoff_len 202 | input_ids = torch.full((len(examples["prompt"]), max_length), self.tokenizer.pad_token_id) 203 | # ignore bos 204 | newline_tokens = self.tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:] 205 | 206 | out = {"labels": [], "attention_mask": []} 207 | for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): 208 | input_tokens = self.tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze() 209 | if input_tokens.dim() == 0: 210 | input_tokens = input_tokens.unsqueeze(0) 211 | 212 | input_len = len(input_tokens) 213 | 214 | # plus one since we remove bos from response 215 | # but we subtract one since we want to add eos token 216 | remaining_tokens = max_length - input_len - len(newline_tokens) + 1 217 | # remove bos 218 | target_tokens = self.tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] 219 | 220 | input_ids[i, :input_len] = input_tokens 221 | # add newline between prompt and response 222 | newline_plus_inputs = input_len + len(newline_tokens) 223 | input_ids[i, input_len: newline_plus_inputs] = newline_tokens 224 | 225 | # add target tokens, remove bos 226 | input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens 227 | # add eos token, enforce stopping if we don't truncate 228 | # we don't want long code to stop generating if truncated during training 229 | if newline_plus_inputs + len(target_tokens) < max_length: 230 | input_ids[i, newline_plus_inputs + len(target_tokens)] = self.tokenizer.eos_token_id 231 | 232 | labels = input_ids[i].clone() 233 | labels[: newline_plus_inputs] = -100 234 | labels[labels == self.tokenizer.pad_token_id] = -100 235 | # to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response 236 | 237 | attention_mask = input_ids[i].ne(self.tokenizer.pad_token_id).int() 238 | 239 | out["labels"].append(labels) 240 | out["attention_mask"].append(attention_mask) 241 | 242 | out["input_ids"] = input_ids 243 | 244 | out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} 245 | 246 | return out 247 | 248 | def prepare_data(self, **kwargs) -> None: 249 | dataset = load_dataset("json", data_files=self.dataset) 250 | 251 | self.val_data = None 252 | if self.val_set_size > 0: 253 | dataset = dataset["train"].train_test_split( 254 | test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?) 255 | ) 256 | train_dataset, val_dataset = dataset["train"], dataset["test"] 257 | 258 | # tokenize inputs and return labels and attention mask 259 | val_dataset = val_dataset.map( 260 | lambda ele: self.tokenize_inputs(ele), 261 | batched=True, 262 | remove_columns=["source", "prompt"], 263 | ) 264 | self.val_data = val_dataset.with_format("torch") 265 | else: 266 | train_dataset = dataset["train"] 267 | 268 | train_dataset = train_dataset.map( 269 | lambda ele: self.tokenize_inputs(ele), 270 | batched=True, 271 | remove_columns=["source", "prompt"], 272 | ) 273 | self.train_data = train_dataset.with_format("torch") 274 | -------------------------------------------------------------------------------- /triton_utils.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | import custom_autotune 5 | 6 | 7 | # code based https://github.com/fpgaminer/GPTQ-triton 8 | @custom_autotune.autotune( 9 | configs=[ 10 | triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 11 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 12 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 13 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 14 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 15 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 16 | # These provided a benefit on a 3090 17 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 18 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 19 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 20 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 21 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 22 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 23 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 24 | ], 25 | key=['M', 'N'], 26 | nearest_power_of_two=True, 27 | ) 28 | 29 | 30 | @triton.jit 31 | def matmul_248_kernel(a_ptr, b_ptr, c_ptr, 32 | scales_ptr, zeros_ptr, g_ptr, 33 | M, N, K, bits, maxq, 34 | stride_am, stride_ak, 35 | stride_bk, stride_bn, 36 | stride_cm, stride_cn, 37 | stride_scales, stride_zeros, 38 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 39 | GROUP_SIZE_M: tl.constexpr): 40 | """ 41 | Compute the matrix multiplication C = A x B. 42 | A is of shape (M, K) float16 43 | B is of shape (K//8, N) int32 44 | C is of shape (M, N) float16 45 | scales is of shape (G, N) float16 46 | zeros is of shape (G, N) float16 47 | g_ptr is of shape (K) int32 48 | """ 49 | infearure_per_bits = 32 // bits 50 | 51 | pid = tl.program_id(axis=0) 52 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 53 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 54 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 55 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 56 | group_id = pid // num_pid_in_group 57 | first_pid_m = group_id * GROUP_SIZE_M 58 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 59 | pid_m = first_pid_m + (pid % group_size_m) 60 | pid_n = (pid % num_pid_in_group) // group_size_m 61 | 62 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 63 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 64 | offs_k = tl.arange(0, BLOCK_SIZE_K) 65 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 66 | a_mask = (offs_am[:, None] < M) 67 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 68 | b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 69 | g_ptrs = g_ptr + offs_k 70 | # shifter is used to extract the N bits of each element in the 32-bit word from B 71 | scales_ptrs = scales_ptr + offs_bn[None, :] 72 | zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) 73 | 74 | shifter = (offs_k % infearure_per_bits) * bits 75 | zeros_shifter = (offs_bn % infearure_per_bits) * bits 76 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 77 | 78 | for k in range(0, num_pid_k): 79 | g_idx = tl.load(g_ptrs) 80 | 81 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 82 | scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 83 | zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 84 | 85 | zeros = (zeros >> zeros_shifter[None, :]) & maxq 86 | zeros = (zeros + 1) 87 | 88 | a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 89 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 90 | 91 | # Now we need to unpack b (which is N-bit values) into 32-bit values 92 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values 93 | b = (b - zeros) * scales # Scale and shift 94 | # ! Convert to fp16 95 | b = b.to(tl.float16) 96 | a = a.to(tl.float16) 97 | 98 | accumulator += tl.dot(a, b) 99 | a_ptrs += BLOCK_SIZE_K 100 | b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 101 | g_ptrs += BLOCK_SIZE_K 102 | 103 | c = accumulator.to(tl.float16) 104 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 105 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 106 | tl.store(c_ptrs, c, mask=c_mask) 107 | 108 | 109 | # code based https://github.com/fpgaminer/GPTQ-triton 110 | @custom_autotune.autotune( 111 | configs=[ 112 | triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 113 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 114 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 115 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 116 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 117 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 118 | # These provided a benefit on a 3090 119 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 120 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 121 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 122 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 123 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 124 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 125 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 126 | ], 127 | key=['M', 'K'], 128 | nearest_power_of_two=True, 129 | ) 130 | 131 | 132 | @triton.jit 133 | def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, 134 | scales_ptr, zeros_ptr, g_ptr, 135 | M, N, K, bits, maxq, 136 | stride_am, stride_ak, 137 | stride_bk, stride_bn, 138 | stride_cm, stride_cn, 139 | stride_scales, stride_zeros, 140 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 141 | GROUP_SIZE_M: tl.constexpr): 142 | """ 143 | Compute the matrix multiplication C = A x B. 144 | A is of shape (M, N) float16 145 | B is of shape (K//8, N) int32 146 | C is of shape (M, K) float16 147 | scales is of shape (G, N) float16 148 | zeros is of shape (G, N) float16 149 | g_ptr is of shape (K) int32 150 | """ 151 | infearure_per_bits = 32 // bits 152 | 153 | pid = tl.program_id(axis=0) 154 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 155 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 156 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 157 | num_pid_in_group = GROUP_SIZE_M * num_pid_k 158 | group_id = pid // num_pid_in_group 159 | first_pid_m = group_id * GROUP_SIZE_M 160 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 161 | pid_m = first_pid_m + (pid % group_size_m) 162 | pid_k = (pid % num_pid_in_group) // group_size_m 163 | 164 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 165 | offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 166 | offs_n = tl.arange(0, BLOCK_SIZE_N) 167 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) 168 | a_mask = (offs_am[:, None] < M) 169 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 170 | b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 171 | g_ptrs = g_ptr + offs_bk 172 | g_idx = tl.load(g_ptrs) 173 | 174 | # shifter is used to extract the N bits of each element in the 32-bit word from B 175 | scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales 176 | zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros 177 | 178 | shifter = (offs_bk % infearure_per_bits) * bits 179 | zeros_shifter = (offs_n % infearure_per_bits) * bits 180 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 181 | 182 | for k in range(0, num_pid_n): 183 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 184 | scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 185 | zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 186 | 187 | zeros = (zeros >> zeros_shifter[None, :]) & maxq 188 | zeros = (zeros + 1) 189 | 190 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) 191 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 192 | 193 | # Now we need to unpack b (which is N-bit values) into 32-bit values 194 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values 195 | b = (b - zeros) * scales # Scale and shift 196 | b = tl.trans(b) 197 | # ! Convert to fp16 198 | b = b.to(tl.float16) 199 | a = a.to(tl.float16) 200 | 201 | accumulator += tl.dot(a, b) 202 | a_ptrs += BLOCK_SIZE_N 203 | b_ptrs += BLOCK_SIZE_N 204 | scales_ptrs += BLOCK_SIZE_N 205 | zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) 206 | 207 | c = accumulator.to(tl.float16) 208 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] 209 | c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) 210 | tl.store(c_ptrs, c, mask=c_mask) 211 | 212 | 213 | def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): 214 | assert input.shape[-1] == qweight.shape[0] * 32 // bits 215 | outshape = input.shape[:-1] + (qweight.shape[1],) 216 | input = input.reshape(-1, input.shape[-1]) 217 | output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16) 218 | grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) 219 | matmul_248_kernel[grid](input, qweight, output, 220 | scales, qzeros, g_idx, 221 | input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, 222 | input.stride(0), input.stride(1), 223 | qweight.stride(0), qweight.stride(1), 224 | output.stride(0), output.stride(1), 225 | scales.stride(0), qzeros.stride(0)) 226 | output = output.reshape(outshape) 227 | return output 228 | 229 | 230 | def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq): 231 | assert input.shape[-1] == qweight.shape[1] 232 | out_dim = qweight.shape[0] * 32 // bits 233 | outshape = input.shape[:-1] + (out_dim,) 234 | input = input.reshape(-1, input.shape[-1]) 235 | output_shape_mid = (input.shape[0], out_dim) 236 | output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16) 237 | grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),) 238 | trans_matmul_248_kernel[grid](input, qweight, output, 239 | scales, qzeros, g_idx, 240 | input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq, 241 | input.stride(0), input.stride(1), 242 | qweight.stride(0), qweight.stride(1), 243 | output.stride(0), output.stride(1), 244 | scales.stride(0), qzeros.stride(0)) 245 | output = output.reshape(outshape) 246 | return output 247 | --------------------------------------------------------------------------------