├── llamatuner
├── __init__.py
├── data
│ ├── __init__.py
│ ├── dataset_factory
│ │ ├── __init__.py
│ │ ├── dataset_utils.py
│ │ ├── pt_dataset.py
│ │ └── reward_dataset.py
│ ├── processors
│ │ └── supervised.py
│ ├── data_collator.py
│ └── utils.py
├── model
│ ├── __init__.py
│ └── callbacks
│ │ ├── __init__.py
│ │ ├── wandb_callback.py
│ │ ├── save_peft_model_callback.py
│ │ ├── metrics.py
│ │ └── perplexity.py
├── train
│ ├── __init__.py
│ ├── pt
│ │ ├── __init__.py
│ │ └── train_pt.py
│ ├── rm
│ │ ├── __init__.py
│ │ └── trainer.py
│ ├── sft
│ │ └── __init__.py
│ ├── tuner.py
│ └── apply_lora.py
├── utils
│ ├── __init__.py
│ ├── packages.py
│ ├── env_utils.py
│ ├── stream_server.py
│ ├── misc.py
│ └── logger_utils.py
├── launcher.py
├── configs
│ ├── __init__.py
│ ├── eval_args.py
│ ├── generating_args.py
│ ├── model_args.py
│ └── data_args.py
└── cli.py
├── assets
└── wechat.jpg
├── .flake8
├── examples
├── merge_lora
│ ├── llama3_gptq.yaml
│ └── llama3_lora_sft.yaml
├── sft_full_train
│ ├── llama3_full_predict.yaml
│ ├── qwen_full_sft.yaml
│ ├── qwen_full_sft_ds.yaml
│ └── llama3_full_sft_ds3.yaml
├── deepspeed
│ ├── ds_z0_config.json
│ ├── ds_z2_config.json
│ ├── ds_z2_offload_config.json
│ ├── ds_z3_config.json
│ ├── ds_z3_offload_config.json
│ ├── ds_zero3_offload.json
│ └── ds_zero3_auto.json
├── sft_qlora_train
│ ├── qwen_lora_sft_bitsandbytes.yaml
│ └── llama3_lora_sft_bitsandbytes.yaml
├── pre-train
│ └── qwen_full_pre-train.yaml
└── sft_lora_train
│ ├── llama3_lora_sft.yaml
│ ├── qwen_lora_sft.yaml
│ └── llama3_lora_sft_ds.yaml
├── data
├── format_data
│ ├── clean_sharegpt
│ │ ├── clean_evol_instruct.py
│ │ ├── merge.py
│ │ ├── split_long_conversation.py
│ │ ├── clean_sharegpt.py
│ │ └── hardcoded_questions.py
│ ├── test_dataloader.py
│ ├── merge.py
│ ├── convert_oasst1.py
│ ├── convert_vicuna.py
│ └── convert_alpaca.py
├── ultra_chat
│ └── ultra_chat.py
├── dataset_info.yaml
└── hh_rlhf_en
│ └── hh_rlhf_en.py
├── requirements.txt
├── scripts
├── hf_cli.sh
├── full_finetune
│ ├── full-finetune_ds.sh
│ └── full-finetune.sh
├── lora_finetune
│ ├── lora-finetune.sh
│ └── lora-finetune_ds.sh
├── qlora_finetune
│ ├── qlora-finetune.sh
│ ├── finetune_baichuan_7b_vicuna_zh.sh
│ ├── finetune_llama2_7b_alpaca_zh.sh
│ ├── finetune_llama_7b_alpaca_zh.sh
│ └── finetune_baichuan_7b_alpaca_zh.sh
├── pre_train
│ └── full-finetune.sh
├── ds_config
│ ├── default_offload_opt_param.json
│ └── ds_config_zero3_auto.json
└── hf_download.sh
├── .pre-commit-config.yaml
├── server
├── multi_chat.py
├── single_chat.py
├── gradio_base_webserver.py
├── gradio_webserver.py
└── gradio_qlora_webserver.py
└── .gitignore
/llamatuner/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/train/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/train/pt/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/train/rm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/train/sft/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/data/dataset_factory/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llamatuner/data/processors/supervised.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/LLamaTuner/HEAD/assets/wechat.jpg
--------------------------------------------------------------------------------
/llamatuner/launcher.py:
--------------------------------------------------------------------------------
1 | from llamatuner.train.tuner import run_exp
2 |
3 |
4 | def launch():
5 | run_exp()
6 |
7 |
8 | if __name__ == '__main__':
9 | launch()
10 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E501,E701,W504,W503,E722,E251,E402,E123, E126,E129, E125,E121
3 | max-line-length = 79
4 | show-source = False
5 | application-import-names = llamatuner
6 | exclude =
7 | .git
8 | docs
9 | venv
10 | build
11 | env
12 | log
13 | dist
14 | *.egg-info
15 | __pycache__
16 |
--------------------------------------------------------------------------------
/examples/merge_lora/llama3_gptq.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
3 | template: llama3
4 |
5 | ### export
6 | export_dir: models/llama3_gptq
7 | export_quantization_bit: 4
8 | export_quantization_dataset: data/c4_demo.json
9 | export_size: 2
10 | export_device: cpu
11 | export_legacy_format: false
12 |
--------------------------------------------------------------------------------
/examples/merge_lora/llama3_lora_sft.yaml:
--------------------------------------------------------------------------------
1 | ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
2 |
3 | ### model
4 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
5 | adapter_name_or_path: saves/llama3-8b/lora/sft
6 | template: llama3
7 | finetuning_type: lora
8 |
9 | ### export
10 | export_dir: models/llama3_lora_sft
11 | export_size: 2
12 | export_device: cpu
13 | export_legacy_format: false
14 |
--------------------------------------------------------------------------------
/llamatuner/model/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import ComputeMetrics
2 | from .perplexity import ComputePerplexity
3 | from .save_peft_model_callback import SavePeftModelCallback
4 | from .wandb_callback import WandbCallback
5 |
6 | __all__ = [
7 | 'ComputeMetrics',
8 | 'ComputePerplexity',
9 | 'MMLUEvalCallback',
10 | 'SampleGenerateCallback',
11 | 'SavePeftModelCallback',
12 | 'WandbCallback',
13 | ]
14 |
--------------------------------------------------------------------------------
/data/format_data/clean_sharegpt/clean_evol_instruct.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from clean_sharegpt import get_clean_data, json_dump
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--in-file', type=str)
8 | parser.add_argument('--out-file', type=str)
9 | args = parser.parse_args()
10 |
11 | clean_data2 = get_clean_data(args)
12 | json_dump(clean_data2, args.out_file)
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate @ git+https://github.com/huggingface/accelerate.git
2 | bitsandbytes
3 | colorama
4 | datasets
5 | deepspeed
6 | einops
7 | evaluate>=0.4.0
8 | gradio
9 | jieba
10 | nltk>=3.8.1
11 | numpy
12 | peft
13 | peft @ git+https://github.com/huggingface/peft.git
14 | rouge-chinese
15 | sentencepiece
16 | tokenizers
17 | torch
18 | transformers>=4.28.0
19 | transformers @ git+https://github.com/huggingface/transformers.git
20 | wandb
21 |
--------------------------------------------------------------------------------
/scripts/hf_cli.sh:
--------------------------------------------------------------------------------
1 |
2 | # download dataset
3 | huggingface-cli download --repo-type dataset tatsu-lab/alpaca --local-dir /home/robin/hf_hub/datasets/alpaca-en
4 |
5 | # download model
6 | huggingface-cli download Qwen/Qwen1.5-0.5B --local-dir /home/robin/hf_hub/models/Qwen/Qwen1.5-0.5B
7 | huggingface-cli download Qwen/Qwen1.5-1.8B --local-dir /home/robin/hf_hub/models/Qwen/Qwen1.5-1.8B
8 | huggingface-cli download Qwen/Qwen1.5-7B --local-dir /home/robin/hf_hub/models/Qwen/Qwen1.5-7B
9 |
--------------------------------------------------------------------------------
/examples/sft_full_train/llama3_full_predict.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: work_dir/llama3-8b/full/sft/
3 |
4 | ### method
5 | stage: sft
6 | do_predict: true
7 | finetuning_type: full
8 |
9 | ### dataset
10 | dataset: alpaca
11 | template: llama3
12 | cutoff_len: 1024
13 | max_samples: 50
14 | overwrite_cache: true
15 | preprocessing_num_workers: 16
16 |
17 | ### output
18 | output_dir: work_dir/llama3-8b/full/sft/predict
19 | overwrite_output_dir: true
20 |
21 | ### eval
22 | per_device_eval_batch_size: 1
23 | predict_with_generate: true
24 |
--------------------------------------------------------------------------------
/llamatuner/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from llamatuner.configs.data_args import DataArguments
2 | from llamatuner.configs.eval_args import EvaluationArguments
3 | from llamatuner.configs.finetuning_args import (FinetuningArguments,
4 | FreezeArguments, LoraArguments,
5 | QuantArguments, RLHFArguments)
6 | from llamatuner.configs.generating_args import GeneratingArguments
7 | from llamatuner.configs.model_args import ModelArguments
8 |
9 | __all__ = [
10 | 'DataArguments',
11 | 'GeneratingArguments',
12 | 'ModelArguments',
13 | 'QuantArguments',
14 | 'FinetuningArguments',
15 | 'EvaluationArguments',
16 | 'FreezeArguments',
17 | 'LoraArguments',
18 | 'RLHFArguments',
19 | ]
20 |
--------------------------------------------------------------------------------
/llamatuner/data/dataset_factory/dataset_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from numpy.typing import NDArray
4 | from PIL import Image
5 | from PIL.Image import Image as ImageObject
6 | from transformers import ProcessorMixin
7 | from transformers.image_processing_utils import BaseImageProcessor
8 |
9 |
10 | def _preprocess_visual_inputs(images: Sequence[ImageObject],
11 | processor: ProcessorMixin) -> 'NDArray':
12 | # process visual inputs (currently only supports a single image)
13 | image_processor: BaseImageProcessor = getattr(processor, 'image_processor')
14 | image = (images[0] if len(images) != 0 else Image.new(
15 | 'RGB', (100, 100), (255, 255, 255)))
16 | return image_processor(image, return_tensors='pt')['pixel_values'][0]
17 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z0_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 0,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://gitee.com/openmmlab/mirrors-flake8
3 | rev: 5.0.4
4 | hooks:
5 | - id: flake8
6 | - repo: https://gitee.com/openmmlab/mirrors-isort
7 | rev: 5.11.5
8 | hooks:
9 | - id: isort
10 | - repo: https://gitee.com/openmmlab/mirrors-yapf
11 | rev: v0.32.0
12 | hooks:
13 | - id: yapf
14 | - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
15 | rev: v4.3.0
16 | hooks:
17 | - id: trailing-whitespace
18 | - id: check-yaml
19 | - id: end-of-file-fixer
20 | - id: requirements-txt-fixer
21 | - id: double-quote-string-fixer
22 | - id: check-merge-conflict
23 | - id: fix-encoding-pragma
24 | args: ["--remove"]
25 | - id: mixed-line-ending
26 | args: ["--fix=lf"]
27 |
--------------------------------------------------------------------------------
/scripts/full_finetune/full-finetune_ds.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 llamatuner/train/sft/train.py \
2 | --model_name_or_path facebook/opt-125m \
3 | --dataset_cfg ./data/run_test.yaml \
4 | --output_dir work_dir/full-finetune \
5 | --num_train_epochs 3 \
6 | --max_train_samples 200 \
7 | --per_device_train_batch_size 4 \
8 | --per_device_eval_batch_size 4 \
9 | --gradient_accumulation_steps 8 \
10 | --evaluation_strategy "steps" \
11 | --save_strategy "steps" \
12 | --eval_steps 1000 \
13 | --save_steps 1000 \
14 | --save_total_limit 5 \
15 | --logging_steps 1 \
16 | --learning_rate 2e-5 \
17 | --weight_decay 0. \
18 | --warmup_ratio 0.03 \
19 | --lr_scheduler_type "cosine" \
20 | --logging_steps 1 \
21 | --deepspeed "scripts/ds_config/ds_config_zero3_auto.json"
22 |
--------------------------------------------------------------------------------
/data/format_data/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 | from transformers import AutoTokenizer, HfArgumentParser
6 |
7 | from llamatuner.configs import (DataArguments, FinetuningArguments,
8 | ModelArguments)
9 | from llamatuner.data.data_loader import load_single_dataset
10 | from llamatuner.data.data_parser import get_dataset_list
11 |
12 | if __name__ == '__main__':
13 | parser = HfArgumentParser(
14 | (ModelArguments, DataArguments, FinetuningArguments))
15 | (model_args, data_args,
16 | training_args) = parser.parse_args_into_dataclasses()
17 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
18 | dataset_list = get_dataset_list(data_args)
19 | dataset = load_single_dataset(dataset_list[0], model_args, data_args)
20 | print(dataset[:2])
21 |
--------------------------------------------------------------------------------
/data/format_data/clean_sharegpt/merge.py:
--------------------------------------------------------------------------------
1 | """Merge two conversation files into one.
2 |
3 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
4 | """
5 |
6 | import argparse
7 |
8 | from clean_sharegpt import json_dump, json_load
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--in-file', type=str, required=True, nargs='+')
13 | parser.add_argument('--out-file', type=str, default='merged.json')
14 | args = parser.parse_args()
15 |
16 | new_content = []
17 | for in_file in args.in_file:
18 | content = json_load(in_file)
19 | print(f'in-file: {in_file}, len: {len(content)}')
20 | new_content.extend(content)
21 |
22 | print(f'#out: {len(new_content)}')
23 | print(f'Save new_content to {args.out_file}')
24 | json_dump(new_content, args.out_file)
25 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z2_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "allgather_partitions": true,
25 | "allgather_bucket_size": 5e8,
26 | "overlap_comm": true,
27 | "reduce_scatter": true,
28 | "reduce_bucket_size": 5e8,
29 | "contiguous_gradients": true,
30 | "round_robin_gradients": true
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/scripts/lora_finetune/lora-finetune.sh:
--------------------------------------------------------------------------------
1 | python llamatuner/train/sft/train_lora.py \
2 | --model_name_or_path facebook/opt-125m \
3 | --dataset alpaca \
4 | --output_dir work_dir/lora-finetune \
5 | --wandb_project llamatuner \
6 | --wandb_run_name alpaca_opt-125m_lora-finetune \
7 | --num_train_epochs 3 \
8 | --per_device_train_batch_size 4 \
9 | --per_device_eval_batch_size 4 \
10 | --gradient_accumulation_steps 8 \
11 | --eval_strategy "steps" \
12 | --save_strategy "steps" \
13 | --eval_steps 100 \
14 | --save_steps 500 \
15 | --save_total_limit 5 \
16 | --logging_steps 1 \
17 | --learning_rate 1e-4 \
18 | --weight_decay 0. \
19 | --warmup_ratio 0.03 \
20 | --optim "adamw_torch" \
21 | --lr_scheduler_type "cosine" \
22 | --gradient_checkpointing True \
23 | --trust_remote_code \
24 | --model_max_length 128 \
25 | --do_train \
26 | --do_eval
27 |
--------------------------------------------------------------------------------
/scripts/lora_finetune/lora-finetune_ds.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 train_lora.py \
2 | --model_name_or_path facebook/opt-125m \
3 | --data_path ~/prompt_data/InstructionWild/instinwild_en.json \
4 | --output_dir work_dir/alpaca_full-finetune \
5 | --num_train_epochs 3 \
6 | --per_device_train_batch_size 4 \
7 | --per_device_eval_batch_size 4 \
8 | --gradient_accumulation_steps 8 \
9 | --evaluation_strategy "no" \
10 | --save_strategy "steps" \
11 | --save_steps 500 \
12 | --save_total_limit 5 \
13 | --learning_rate 2e-5 \
14 | --weight_decay 0. \
15 | --warmup_ratio 0.03 \
16 | --optim "adamw_torch" \
17 | --lr_scheduler_type "cosine" \
18 | --model_max_length 2048 \
19 | --logging_steps 1 \
20 | --do_train \
21 | --do_eval \
22 | --gradient_checkpointing True \
23 | --deepspeed "scripts/ds_config/ds_config_zero3_auto.json"
24 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "overlap_comm": true,
21 | "contiguous_gradients": true,
22 | "sub_group_size": 1e9,
23 | "reduce_bucket_size": "auto",
24 | "stage3_prefetch_bucket_size": "auto",
25 | "stage3_param_persistence_threshold": "auto",
26 | "stage3_max_live_parameters": 1e9,
27 | "stage3_max_reuse_distance": 1e9,
28 | "stage3_gather_16bit_weights_on_model_save": true
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/scripts/qlora_finetune/qlora-finetune.sh:
--------------------------------------------------------------------------------
1 | python llamatuner/train/sft/train_lora.py \
2 | --model_name_or_path facebook/opt-125m \
3 | --dataset alpaca \
4 | --use_qlora \
5 | --output_dir work_dir/lora-finetune \
6 | --wandb_project llamatuner \
7 | --wandb_run_name alpaca_opt-125m_lora-finetune \
8 | --num_train_epochs 3 \
9 | --per_device_train_batch_size 4 \
10 | --per_device_eval_batch_size 4 \
11 | --gradient_accumulation_steps 8 \
12 | --eval_strategy "steps" \
13 | --save_strategy "steps" \
14 | --eval_steps 100 \
15 | --save_steps 500 \
16 | --save_total_limit 5 \
17 | --logging_steps 1 \
18 | --learning_rate 1e-4 \
19 | --weight_decay 0. \
20 | --warmup_ratio 0.03 \
21 | --optim "adamw_torch" \
22 | --lr_scheduler_type "cosine" \
23 | --gradient_checkpointing True \
24 | --trust_remote_code \
25 | --model_max_length 128 \
26 | --do_train \
27 | --do_eval
28 |
--------------------------------------------------------------------------------
/scripts/full_finetune/full-finetune.sh:
--------------------------------------------------------------------------------
1 | python llamatuner/train/sft/train_full.py \
2 | --model_name_or_path facebook/opt-125m \
3 | --dataset alpaca \
4 | --eval_dataset alpaca \
5 | --output_dir work_dir/full-finetune \
6 | --wandb_project llamatuner \
7 | --wandb_run_name alpaca_opt-125m_full-finetune \
8 | --num_train_epochs 3 \
9 | --per_device_train_batch_size 4 \
10 | --per_device_eval_batch_size 4 \
11 | --gradient_accumulation_steps 8 \
12 | --eval_strategy "steps" \
13 | --save_strategy "steps" \
14 | --eval_steps 100 \
15 | --save_steps 500 \
16 | --save_total_limit 5 \
17 | --logging_steps 10 \
18 | --learning_rate 2e-5 \
19 | --weight_decay 0. \
20 | --warmup_ratio 0.03 \
21 | --optim "adamw_torch" \
22 | --lr_scheduler_type "cosine" \
23 | --gradient_checkpointing True \
24 | --trust_remote_code \
25 | --model_max_length 128 \
26 | --do_train \
27 | --do_eval
28 |
--------------------------------------------------------------------------------
/scripts/pre_train/full-finetune.sh:
--------------------------------------------------------------------------------
1 | python llamatuner/train/pt/train_pt.py \
2 | --model_name_or_path Qwen/Qwen2.5-1.5B \
3 | --trust_remote_code \
4 | --dataset open-web-math \
5 | --eval_dataset c4_demo \
6 | --template qwen \
7 | --cutoff_len 64 \
8 | --max_samples 1000 \
9 | --preprocessing_num_workers 16 \
10 | --output_dir work_dir/pre_train \
11 | --wandb_project llamatuner \
12 | --wandb_run_name qwen-1.5B_pre_train \
13 | --logging_steps 10 \
14 | --save_strategy "steps" \
15 | --save_steps 500 \
16 | --save_total_limit 5 \
17 | --per_device_train_batch_size 1 \
18 | --per_device_eval_batch_size 1 \
19 | --gradient_accumulation_steps 1 \
20 | --eval_strategy "steps" \
21 | --eval_steps 100 \
22 | --num_train_epochs 3 \
23 | --learning_rate 2e-5 \
24 | --weight_decay 0. \
25 | --warmup_ratio 0.1 \
26 | --optim "adamw_torch" \
27 | --lr_scheduler_type "cosine" \
28 | --gradient_checkpointing True \
29 | --do_train \
30 | --do_eval
31 |
--------------------------------------------------------------------------------
/examples/sft_qlora_train/qwen_lora_sft_bitsandbytes.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen1.5-0.5B
3 | quant_bit: 4
4 | use_qlora : true
5 |
6 | ### method
7 | stage: sft
8 | do_train: true
9 | finetuning_type: lora
10 | lora_target: q_proj,v_proj
11 |
12 | ### dataset
13 | dataset: alpaca
14 | template: qwen
15 | cutoff_len: 1024
16 | max_samples: 1000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: work_dir/qwen1.5-0.5b/qlora/sft
22 | save_strategy: steps
23 | save_steps: 500
24 | save_total_limit: 5
25 | overwrite_output_dir: true
26 |
27 | # logger settings
28 | logging_steps: 10
29 | report_to: wandb
30 | wandb_project: llamatuner
31 | wandb_run_name: qwen1.5-0.5b_qlora_sft
32 |
33 | ### train
34 | per_device_train_batch_size: 1
35 | gradient_accumulation_steps: 8
36 | gradient_checkpointing: True
37 | num_train_epochs: 3.0
38 | learning_rate: 0.0001
39 | lr_scheduler_type: cosine
40 | warmup_ratio: 0.1
41 | fp16: true
42 |
43 | ### eval
44 | eval_dataset_size: 0.1
45 | per_device_eval_batch_size: 1
46 | eval_strategy: steps
47 | eval_steps: 500
48 |
--------------------------------------------------------------------------------
/examples/sft_qlora_train/llama3_lora_sft_bitsandbytes.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
3 | quant_bit: 4
4 | use_qlora : true
5 |
6 | ### method
7 | stage: sft
8 | do_train: true
9 | finetuning_type: lora
10 | lora_target: all
11 |
12 | ### dataset
13 | dataset: alpaca
14 | template: llama3
15 | cutoff_len: 1024
16 | max_samples: 1000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: work_dir/llama3-8b/qlora/sft/
22 | save_strategy: steps
23 | save_steps: 500
24 | save_total_limit: 5
25 | overwrite_output_dir: true
26 |
27 | # logger settings
28 | logging_steps: 10
29 | report_to: wandb
30 | wandb_project: llamatuner
31 | wandb_run_name: llama3-8b_qlora_sft
32 |
33 | ### train
34 | per_device_train_batch_size: 1
35 | gradient_accumulation_steps: 8
36 | gradient_checkpointing: True
37 | num_train_epochs: 3.0
38 | learning_rate: 1.0e-4
39 | lr_scheduler_type: cosine
40 | warmup_ratio: 0.1
41 | fp16: true
42 |
43 | ### eval
44 | eval_dataset_size: 0.1
45 | per_device_eval_batch_size: 1
46 | eval_strategy: steps
47 | eval_steps: 500
48 |
--------------------------------------------------------------------------------
/examples/sft_full_train/qwen_full_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-0.5B
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: full
9 |
10 | ### ddp
11 | ddp_timeout: 180000000
12 |
13 | ### dataset
14 | dataset: alpaca
15 | eval_dataset: alpaca
16 | template: qwen
17 | cutoff_len: 1024
18 | overwrite_cache: true
19 | preprocessing_num_workers: 8
20 |
21 | ### output
22 | output_dir: work_dir/sft/qwen1.5-0.5b/
23 | save_strategy: steps
24 | save_steps: 500
25 | save_total_limit: 5
26 | overwrite_output_dir: true
27 |
28 | # logger settings
29 | logging_steps: 10
30 | report_to: wandb
31 | wandb_project: llamatuner
32 | wandb_run_name: qwen1.5-0.5b_full_sft_alpaca
33 |
34 | ### train
35 | per_device_train_batch_size: 1
36 | gradient_accumulation_steps: 2
37 | gradient_checkpointing: True
38 | num_train_epochs: 3.0
39 | optim: adamw_torch
40 | learning_rate: 1.0e-4
41 | lr_scheduler_type: cosine
42 | weight_decay: 0.01
43 | warmup_ratio: 0.1
44 | fp16: true
45 |
46 | ### eval
47 | per_device_eval_batch_size: 1
48 | eval_strategy: steps
49 | eval_steps: 500
50 |
--------------------------------------------------------------------------------
/examples/pre-train/qwen_full_pre-train.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-0.5B
3 | trust_remote_code: true
4 |
5 |
6 | ### method
7 | stage: pt
8 | do_train: true
9 | finetuning_type: full
10 |
11 | ### ddp
12 | ddp_timeout: 180000000
13 |
14 | ### dataset
15 | dataset: c4_demo
16 | eval_dataset: c4_demo
17 | template: qwen
18 | cutoff_len: 1024
19 | overwrite_cache: true
20 | preprocessing_num_workers: 8
21 |
22 | ### output
23 | output_dir: work_dir/pretrain/qwen2.5-0.5b/
24 | save_strategy: steps
25 | save_steps: 500
26 | save_total_limit: 5
27 | overwrite_output_dir: true
28 |
29 | # logger settings
30 | logging_steps: 10
31 | report_to: wandb
32 | wandb_project: llamatuner
33 | wandb_run_name: qwen2.5-0.5b_pretrain
34 |
35 | ### train
36 | per_device_train_batch_size: 1
37 | gradient_accumulation_steps: 2
38 | gradient_checkpointing: True
39 | num_train_epochs: 3.0
40 | optim: adamw_torch
41 | learning_rate: 1.0e-4
42 | lr_scheduler_type: cosine
43 | weight_decay: 0.01
44 | warmup_ratio: 0.1
45 | fp16: true
46 |
47 | ### eval
48 | per_device_eval_batch_size: 1
49 | eval_strategy: steps
50 | eval_steps: 500
51 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_z3_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "offload_param": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "overlap_comm": true,
29 | "contiguous_gradients": true,
30 | "sub_group_size": 1e9,
31 | "reduce_bucket_size": "auto",
32 | "stage3_prefetch_bucket_size": "auto",
33 | "stage3_param_persistence_threshold": "auto",
34 | "stage3_max_live_parameters": 1e9,
35 | "stage3_max_reuse_distance": 1e9,
36 | "stage3_gather_16bit_weights_on_model_save": true
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/examples/sft_lora_train/llama3_lora_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: lora
8 | lora_target: all
9 | ### ddp
10 | ddp_timeout: 180000000
11 |
12 | ### dataset
13 | dataset: alpaca
14 | template: llama3
15 | cutoff_len: 1024
16 | max_samples: 1000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: work_dir/llama3-8b/lora/sft/
22 | save_strategy: steps
23 | save_steps: 500
24 | save_total_limit: 5
25 | overwrite_output_dir: true
26 |
27 | # logger settings
28 | logging_steps: 10
29 | report_to: wandb
30 | wandb_project: llamatuner
31 | wandb_run_name: llama3-8b_lora_sft
32 |
33 | ### train
34 | per_device_train_batch_size: 1
35 | gradient_accumulation_steps: 8
36 | gradient_checkpointing: True
37 | num_train_epochs: 3.0
38 | optim: adamw_torch
39 | learning_rate: 1.0e-4
40 | lr_scheduler_type: cosine
41 | weight_decay: 0.0
42 | warmup_ratio: 0.1
43 | fp16: True
44 |
45 | ### eval
46 | eval_dataset_size: 0.1
47 | per_device_eval_batch_size: 1
48 | eval_strategy: steps
49 | eval_steps: 500
50 |
--------------------------------------------------------------------------------
/examples/sft_lora_train/qwen_lora_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen1.5-0.5B
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: lora
8 | lora_target: q_proj,v_proj
9 |
10 | ### ddp
11 | ddp_timeout: 180000000
12 |
13 | ### dataset
14 | dataset: alpaca
15 | template: qwen
16 | cutoff_len: 1024
17 | max_samples: 1000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: work_dir/qwen1.5-0.5b/lora/sft/
23 | save_strategy: steps
24 | save_steps: 500
25 | save_total_limit: 5
26 | overwrite_output_dir: true
27 |
28 | # logger settings
29 | logging_steps: 10
30 | report_to: wandb
31 | wandb_project: llamatuner
32 | wandb_run_name: qwen1.5-0.5b_lora_sft_alpaca
33 |
34 | ### train
35 | per_device_train_batch_size: 4
36 | gradient_accumulation_steps: 8
37 | gradient_checkpointing: True
38 | num_train_epochs: 3.0
39 | optim: adamw_torch
40 | learning_rate: 1.0e-4
41 | lr_scheduler_type: cosine
42 | weight_decay: 0.0
43 | warmup_ratio: 0.1
44 | fp16: True
45 |
46 | ### eval
47 | eval_dataset_size: 0.1
48 | per_device_eval_batch_size: 4
49 | eval_strategy: steps
50 | eval_steps: 500
51 |
--------------------------------------------------------------------------------
/examples/sft_full_train/qwen_full_sft_ds.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen1.5-0.5B
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: full
8 |
9 | ### ddp
10 | ddp_timeout: 180000000
11 | deepspeed: examples/deepspeed/ds_z3_config.json
12 |
13 | ### dataset
14 | dataset: alpaca
15 | template: qwen
16 | cutoff_len: 1024
17 | max_samples: 1000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: work_dir/qwen1.5-0.5b/full/sft/
23 | save_strategy: steps
24 | save_steps: 500
25 | save_total_limit: 5
26 | overwrite_output_dir: true
27 |
28 | # logger settings
29 | logging_steps: 10
30 | report_to: wandb
31 | wandb_project: llamatuner
32 | wandb_run_name: qwen1.5-0.5b_full_sft_alpaca
33 |
34 | ### train
35 | per_device_train_batch_size: 1
36 | gradient_accumulation_steps: 2
37 | gradient_checkpointing: True
38 | num_train_epochs: 3.0
39 | optim: adamw_torch
40 | learning_rate: 1.0e-4
41 | lr_scheduler_type: cosine
42 | weight_decay: 0.01
43 | warmup_ratio: 0.1
44 | fp16: true
45 |
46 | ### eval
47 | eval_dataset_size: 0.1
48 | per_device_eval_batch_size: 1
49 | eval_strategy: steps
50 | eval_steps: 500
51 |
--------------------------------------------------------------------------------
/examples/sft_lora_train/llama3_lora_sft_ds.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: lora
8 | lora_target: all
9 |
10 | ### ddp
11 | ddp_timeout: 180000000
12 | deepspeed: examples/deepspeed/ds_z3_config.json
13 |
14 | ### dataset
15 | dataset: alpaca
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 1000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: work_dir/llama3-8b/lora/sft/
24 | save_strategy: steps
25 | save_steps: 500
26 | save_total_limit: 5
27 | overwrite_output_dir: true
28 |
29 | ### logging
30 | logging_steps: 10
31 | report_to: wandb
32 | wandb_project: llamatuner
33 | wandb_run_name: alpaca_llama3-8b_lora-sft
34 |
35 | ### train
36 | per_device_train_batch_size: 1
37 | gradient_accumulation_steps: 2
38 | gradient_checkpointing: True
39 | learning_rate: 1.0e-4
40 | lr_scheduler_type: cosine
41 | weight_decay: 0.0
42 | warmup_ratio: 0.1
43 | num_train_epochs: 3.0
44 | fp16: True
45 |
46 | ### eval
47 | eval_dataset_size: 0.1
48 | per_device_eval_batch_size: 1
49 | eval_strategy: steps
50 | eval_steps: 500
51 |
--------------------------------------------------------------------------------
/examples/sft_full_train/llama3_full_sft_ds3.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: full
8 |
9 | ### ddp
10 | ddp_timeout: 180000000
11 | deepspeed: examples/deepspeed/ds_z3_config.json
12 |
13 | ### dataset
14 | dataset: alpaca
15 | template: llama3
16 | cutoff_len: 1024
17 | max_samples: 1000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: work_dir/llama3-8b/full/sft
23 | save_strategy: steps
24 | save_steps: 500
25 | save_total_limit: 5
26 | overwrite_output_dir: true
27 |
28 | # logger settings
29 | logging_steps: 10
30 | report_to: wandb
31 | wandb_project: llamatuner
32 | wandb_run_name: llama3-8b_full_sft_alpaca
33 |
34 | ### train
35 | per_device_train_batch_size: 1
36 | gradient_accumulation_steps: 2
37 | gradient_checkpointing: True
38 | num_train_epochs: 3.0
39 | optim: adamw_torch
40 | learning_rate: 1.0e-4
41 | lr_scheduler_type: cosine
42 | weight_decay: 0.01
43 | warmup_ratio: 0.1
44 | fp16: true
45 |
46 | ### eval
47 | eval_dataset_size: 0.1
48 | per_device_eval_batch_size: 1
49 | eval_strategy: steps
50 | eval_steps: 500
51 |
--------------------------------------------------------------------------------
/scripts/qlora_finetune/finetune_baichuan_7b_vicuna_zh.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python train_qlora.py \
2 | --model_name_or_path ~/checkpoints/baichuan7b \
3 | --dataset_cfg ./data/vicuna_zh_pcyn.yaml \
4 | --output_dir ./work_dir/vicuna_zh-baichuan-7b \
5 | --num_train_epochs 3 \
6 | --per_device_train_batch_size 2 \
7 | --per_device_eval_batch_size 2 \
8 | --gradient_accumulation_steps 16 \
9 | --evaluation_strategy steps \
10 | --eval_steps 1000 \
11 | --save_strategy steps \
12 | --save_total_limit 10 \
13 | --save_steps 1000 \
14 | --logging_strategy steps \
15 | --logging_steps 5 \
16 | --learning_rate 0.0002 \
17 | --warmup_ratio 0.03 \
18 | --weight_decay 0.0 \
19 | --lr_scheduler_type constant \
20 | --adam_beta2 0.999 \
21 | --max_grad_norm 0.3 \
22 | --lora_r 64 \
23 | --lora_alpha 16 \
24 | --lora_dropout 0.1 \
25 | --double_quant \
26 | --quant_type nf4 \
27 | --fp16 \
28 | --bits 4 \
29 | --model_max_length 1024 \
30 | --gradient_checkpointing \
31 | --trust_remote_code True \
32 | --use_auth_token True \
33 | --do_train \
34 | --do_eval \
35 | --data_seed 42 \
36 | --seed 0
37 |
--------------------------------------------------------------------------------
/scripts/qlora_finetune/finetune_llama2_7b_alpaca_zh.sh:
--------------------------------------------------------------------------------
1 | python train_qlora.py \
2 | --model_name_or_path meta-llama/Llama-2-7b-hf \
3 | --dataset_cfg ./data/alpaca_zh_pcyn.yaml \
4 | --output_dir ./work_dir/alpaca_zh_llama2-7b \
5 | --num_train_epochs 3 \
6 | --per_device_train_batch_size 4 \
7 | --per_device_eval_batch_size 4 \
8 | --gradient_accumulation_steps 8 \
9 | --evaluation_strategy steps \
10 | --eval_steps 1000 \
11 | --save_strategy steps \
12 | --save_total_limit 10 \
13 | --save_steps 1000 \
14 | --logging_strategy steps \
15 | --logging_steps 5 \
16 | --learning_rate 0.0002 \
17 | --warmup_ratio 0.03 \
18 | --weight_decay 0.0 \
19 | --lr_scheduler_type constant \
20 | --adam_beta2 0.999 \
21 | --max_grad_norm 0.3 \
22 | --lora_r 64 \
23 | --lora_alpha 16 \
24 | --lora_dropout 0.1 \
25 | --double_quant \
26 | --quant_type nf4 \
27 | --fp16 \
28 | --bits 4 \
29 | --model_max_length 1024 \
30 | --gradient_checkpointing \
31 | --trust_remote_code True \
32 | --use_auth_token True \
33 | --do_train \
34 | --do_eval \
35 | --sample_generate \
36 | --data_seed 42 \
37 | --seed 0
38 |
--------------------------------------------------------------------------------
/scripts/qlora_finetune/finetune_llama_7b_alpaca_zh.sh:
--------------------------------------------------------------------------------
1 | python train_qlora.py \
2 | --model_name_or_path decapoda-research/llama-7b-hf \
3 | --dataset_cfg ./data/alpaca_zh_pcyn.yaml \
4 | --output_dir ./work_dir/alpaca_zh-baichuan-7b \
5 | --num_train_epochs 3 \
6 | --per_device_train_batch_size 4 \
7 | --per_device_eval_batch_size 4 \
8 | --gradient_accumulation_steps 8 \
9 | --evaluation_strategy steps \
10 | --eval_steps 1000 \
11 | --save_strategy steps \
12 | --save_total_limit 10 \
13 | --save_steps 1000 \
14 | --logging_strategy steps \
15 | --logging_steps 5 \
16 | --learning_rate 0.0002 \
17 | --warmup_ratio 0.03 \
18 | --weight_decay 0.0 \
19 | --lr_scheduler_type constant \
20 | --adam_beta2 0.999 \
21 | --max_grad_norm 0.3 \
22 | --lora_r 64 \
23 | --lora_alpha 16 \
24 | --lora_dropout 0.1 \
25 | --double_quant \
26 | --quant_type nf4 \
27 | --fp16 \
28 | --bits 4 \
29 | --model_max_length 1024 \
30 | --gradient_checkpointing \
31 | --trust_remote_code True \
32 | --use_auth_token True \
33 | --do_train \
34 | --do_eval \
35 | --sample_generate \
36 | --data_seed 42 \
37 | --seed 0
38 |
--------------------------------------------------------------------------------
/scripts/qlora_finetune/finetune_baichuan_7b_alpaca_zh.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_qlora.py \
2 | --model_name_or_path ~/checkpoints/baichuan7b \
3 | --dataset_cfg ./data/alpaca_zh_pcyn.yaml \
4 | --output_dir ./work_dir/alpaca_zh-baichuan-7b \
5 | --num_train_epochs 3 \
6 | --per_device_train_batch_size 4 \
7 | --per_device_eval_batch_size 4 \
8 | --gradient_accumulation_steps 8 \
9 | --evaluation_strategy steps \
10 | --eval_steps 1000 \
11 | --save_strategy steps \
12 | --save_total_limit 10 \
13 | --save_steps 1000 \
14 | --logging_strategy steps \
15 | --logging_steps 5 \
16 | --learning_rate 0.0002 \
17 | --warmup_ratio 0.03 \
18 | --weight_decay 0.0 \
19 | --lr_scheduler_type constant \
20 | --adam_beta2 0.999 \
21 | --max_grad_norm 0.3 \
22 | --lora_r 64 \
23 | --lora_alpha 16 \
24 | --lora_dropout 0.1 \
25 | --double_quant \
26 | --quant_type nf4 \
27 | --fp16 \
28 | --bits 4 \
29 | --model_max_length 1024 \
30 | --gradient_checkpointing \
31 | --trust_remote_code True \
32 | --use_auth_token True \
33 | --do_train \
34 | --do_eval \
35 | --sample_generate \
36 | --data_seed 42 \
37 | --seed 0
38 |
--------------------------------------------------------------------------------
/data/format_data/merge.py:
--------------------------------------------------------------------------------
1 | """Merge two conversation files into one.
2 |
3 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
4 | """
5 |
6 | import argparse
7 | import json
8 |
9 | from datasets import load_dataset
10 |
11 |
12 | def json_load(in_file):
13 | with open(in_file, 'r') as f:
14 | json_data = json.load(f)
15 | return json_data
16 |
17 |
18 | def json_dump(obj, path):
19 | with open(path, 'w', encoding='utf-8') as f:
20 | json.dump(obj, f, indent=2, ensure_ascii=False)
21 |
22 |
23 | def merge_datasets(in_file_list, out_file):
24 |
25 | new_content = []
26 | for in_file in in_file_list:
27 | content = load_dataset('json', data_files=in_file)['train']
28 |
29 | print(f'in-file: {in_file}, len: {len(content)}')
30 | new_content.extend(content)
31 |
32 | print(f'#out: {len(new_content)}')
33 | print(f'Save new_content to {out_file}')
34 | json_dump(new_content, out_file)
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('--in-file', type=str, required=True, nargs='+')
40 | parser.add_argument('--out-file', type=str, default='merged.json')
41 | args = parser.parse_args()
42 |
43 | merge_datasets(args.in_file, args.out_file)
44 |
--------------------------------------------------------------------------------
/llamatuner/train/tuner.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | from llamatuner.configs.parser import get_train_args
4 | from llamatuner.train.pt.train_pt import run_pt
5 | from llamatuner.train.sft.train_full import run_full_sft
6 | from llamatuner.train.sft.train_lora import run_lora_sft
7 |
8 |
9 | def run_exp(args: Optional[Dict[str, Any]] = None) -> None:
10 | model_args, data_args, training_args, finetuning_args, generating_args = (
11 | get_train_args(args))
12 | if finetuning_args.stage == 'pt':
13 | run_pt(
14 | model_args,
15 | data_args,
16 | training_args,
17 | finetuning_args,
18 | )
19 | if finetuning_args.stage == 'full':
20 | run_full_sft(
21 | model_args,
22 | data_args,
23 | training_args,
24 | finetuning_args,
25 | generating_args,
26 | )
27 | elif finetuning_args.stage == 'sft':
28 | run_lora_sft(
29 | model_args,
30 | data_args,
31 | training_args,
32 | finetuning_args,
33 | generating_args,
34 | )
35 | else:
36 | raise ValueError('Unknown task: {}.'.format(finetuning_args.stage))
37 |
38 |
39 | def launch():
40 | run_exp()
41 |
42 |
43 | if __name__ == '__main__':
44 | launch()
45 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": false
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
50 |
--------------------------------------------------------------------------------
/scripts/ds_config/default_offload_opt_param.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": false
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
50 |
--------------------------------------------------------------------------------
/llamatuner/utils/packages.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 | import importlib.util
3 | from typing import TYPE_CHECKING
4 |
5 | from packaging import version
6 |
7 | if TYPE_CHECKING:
8 | from packaging.version import Version
9 |
10 |
11 | def _is_package_available(name: str) -> bool:
12 | return importlib.util.find_spec(name) is not None
13 |
14 |
15 | def _get_package_version(name: str) -> 'Version':
16 | try:
17 | return version.parse(importlib.metadata.version(name))
18 | except importlib.metadata.PackageNotFoundError:
19 | return version.parse('0.0.0')
20 |
21 |
22 | def is_fastapi_available():
23 | return _is_package_available('fastapi')
24 |
25 |
26 | def is_gradio_available():
27 | return _is_package_available('gradio')
28 |
29 |
30 | def is_jieba_available():
31 | return _is_package_available('jieba')
32 |
33 |
34 | def is_matplotlib_available():
35 | return _is_package_available('matplotlib')
36 |
37 |
38 | def is_nltk_available():
39 | return _is_package_available('nltk')
40 |
41 |
42 | def is_pillow_available():
43 | return _is_package_available('PIL')
44 |
45 |
46 | def is_requests_available():
47 | return _is_package_available('requests')
48 |
49 |
50 | def is_rouge_available():
51 | return _is_package_available('rouge_chinese')
52 |
53 |
54 | def is_starlette_available():
55 | return _is_package_available('sse_starlette')
56 |
57 |
58 | def is_uvicorn_available():
59 | return _is_package_available('uvicorn')
60 |
61 |
62 | def is_vllm_available():
63 | return _is_package_available('vllm')
64 |
--------------------------------------------------------------------------------
/llamatuner/configs/eval_args.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Literal, Optional
4 |
5 | from datasets import DownloadMode
6 |
7 |
8 | @dataclass
9 | class EvaluationArguments:
10 | r"""
11 | Arguments pertaining to specify the evaluation parameters.
12 | """
13 |
14 | task: str = field(metadata={'help': 'Name of the evaluation task.'}, )
15 | task_dir: str = field(
16 | default='evaluation',
17 | metadata={
18 | 'help': 'Path to the folder containing the evaluation datasets.'
19 | },
20 | )
21 | batch_size: int = field(
22 | default=4,
23 | metadata={'help': 'The batch size per GPU for evaluation.'},
24 | )
25 | seed: int = field(
26 | default=42,
27 | metadata={'help': 'Random seed to be used with data loaders.'},
28 | )
29 | lang: Literal['en', 'zh'] = field(
30 | default='en',
31 | metadata={'help': 'Language used at evaluation.'},
32 | )
33 | n_shot: int = field(
34 | default=5,
35 | metadata={'help': 'Number of examplars for few-shot learning.'},
36 | )
37 | save_dir: Optional[str] = field(
38 | default=None,
39 | metadata={'help': 'Path to save the evaluation results.'},
40 | )
41 | download_mode: DownloadMode = field(
42 | default=DownloadMode.REUSE_DATASET_IF_EXISTS,
43 | metadata={'help': 'Download mode used for the evaluation datasets.'},
44 | )
45 |
46 | def __post_init__(self):
47 | if self.save_dir is not None and os.path.exists(self.save_dir):
48 | raise ValueError('`save_dir` already exists, use another one.')
49 |
--------------------------------------------------------------------------------
/llamatuner/data/dataset_factory/pt_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | from torch.utils.data import Dataset
4 | from transformers.tokenization_utils import PreTrainedTokenizer
5 |
6 | from llamatuner.configs import DataArguments
7 |
8 |
9 | class PretrainDataset(Dataset):
10 |
11 | def __init__(self, examples: Dict[str, List[Any]],
12 | tokenizer: PreTrainedTokenizer, data_args: DataArguments):
13 | """
14 | Initialize PretrainDataset with lazy loading.
15 |
16 | Args:
17 | examples: Dictionary containing the dataset examples
18 | tokenizer: Tokenizer for text processing
19 | data_args: Data arguments containing configuration
20 | """
21 | self.tokenizer = tokenizer
22 | self.data_args = data_args
23 | self.raw_examples = examples['_prompt']
24 | self.eos_token = '<|end_of_text|>' if data_args.template == 'llama3' else tokenizer.eos_token
25 |
26 | def __len__(self) -> int:
27 | return len(self.raw_examples)
28 |
29 | def __getitem__(self, idx: int) -> Dict[str, List[int]]:
30 | # Process single example on-the-fly
31 | messages = self.raw_examples[idx]
32 | text = messages[0]['content'] + self.eos_token
33 |
34 | if self.data_args.template == 'gemma':
35 | text = self.tokenizer.bos_token + text
36 |
37 | processed = self.tokenizer(text,
38 | add_special_tokens=False,
39 | truncation=True,
40 | max_length=self.data_args.cutoff_len)
41 |
42 | return {k: processed[k] for k in processed.keys()}
43 |
--------------------------------------------------------------------------------
/llamatuner/utils/env_utils.py:
--------------------------------------------------------------------------------
1 | import platform
2 |
3 | import accelerate
4 | import datasets
5 | import peft
6 | import torch
7 | import transformers
8 | import trl
9 | from transformers.utils import is_torch_cuda_available, is_torch_npu_available
10 |
11 | VERSION = '0.1.1.dev0'
12 |
13 |
14 | def print_env_info() -> None:
15 | info = {
16 | '`llamatuner` version': VERSION,
17 | 'Platform': platform.platform(),
18 | 'Python version': platform.python_version(),
19 | 'PyTorch version': torch.__version__,
20 | 'Transformers version': transformers.__version__,
21 | 'Datasets version': datasets.__version__,
22 | 'Accelerate version': accelerate.__version__,
23 | 'PEFT version': peft.__version__,
24 | 'TRL version': trl.__version__,
25 | }
26 |
27 | if is_torch_cuda_available():
28 | info['PyTorch version'] += ' (GPU)'
29 | info['GPU type'] = torch.cuda.get_device_name()
30 |
31 | if is_torch_npu_available():
32 | info['PyTorch version'] += ' (NPU)'
33 | info['NPU type'] = torch.npu.get_device_name()
34 | info['CANN version'] = torch.version.cann
35 |
36 | try:
37 | import deepspeed # type: ignore
38 |
39 | info['DeepSpeed version'] = deepspeed.__version__
40 | except Exception:
41 | pass
42 |
43 | try:
44 | import bitsandbytes
45 |
46 | info['Bitsandbytes version'] = bitsandbytes.__version__
47 | except Exception:
48 | pass
49 |
50 | try:
51 | import vllm
52 |
53 | info['vLLM version'] = vllm.__version__
54 | except Exception:
55 | pass
56 |
57 | print('\n' + '\n'.join(
58 | ['- {}: {}'.format(key, value) for key, value in info.items()]) + '\n')
59 |
--------------------------------------------------------------------------------
/examples/deepspeed/ds_zero3_auto.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupDecayLR",
24 | "params": {
25 | "total_num_steps": "auto",
26 | "warmup_min_lr": "auto",
27 | "warmup_max_lr": "auto",
28 | "warmup_num_steps": "auto"
29 | }
30 | },
31 | "zero_optimization": {
32 | "stage": 3,
33 | "offload_optimizer": {
34 | "device": "cpu",
35 | "pin_memory": true
36 | },
37 | "offload_param": {
38 | "device": "cpu",
39 | "pin_memory": true
40 | },
41 | "overlap_comm": true,
42 | "contiguous_gradients": true,
43 | "allgather_partitions": true,
44 | "allgather_bucket_size": 5e8,
45 | "sub_group_size": 1e9,
46 | "reduce_bucket_size": "auto",
47 | "stage3_prefetch_bucket_size": "auto",
48 | "stage3_param_persistence_threshold": "auto",
49 | "stage3_max_live_parameters": 1e9,
50 | "stage3_max_reuse_distance": 1e9,
51 | "stage3_gather_16bit_weights_on_model_save": true
52 | },
53 | "train_batch_size": "auto",
54 | "train_micro_batch_size_per_gpu": "auto",
55 | "gradient_accumulation_steps": "auto",
56 | "gradient_clipping": "auto",
57 | "steps_per_print": 5,
58 | "wall_clock_breakdown": false
59 | }
60 |
--------------------------------------------------------------------------------
/scripts/ds_config/ds_config_zero3_auto.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupDecayLR",
24 | "params": {
25 | "total_num_steps": "auto",
26 | "warmup_min_lr": "auto",
27 | "warmup_max_lr": "auto",
28 | "warmup_num_steps": "auto"
29 | }
30 | },
31 | "zero_optimization": {
32 | "stage": 3,
33 | "offload_optimizer": {
34 | "device": "cpu",
35 | "pin_memory": true
36 | },
37 | "offload_param": {
38 | "device": "cpu",
39 | "pin_memory": true
40 | },
41 | "overlap_comm": true,
42 | "contiguous_gradients": true,
43 | "allgather_partitions": true,
44 | "allgather_bucket_size": 5e8,
45 | "sub_group_size": 1e9,
46 | "reduce_bucket_size": "auto",
47 | "stage3_prefetch_bucket_size": "auto",
48 | "stage3_param_persistence_threshold": "auto",
49 | "stage3_max_live_parameters": 1e9,
50 | "stage3_max_reuse_distance": 1e9,
51 | "stage3_gather_16bit_weights_on_model_save": true
52 | },
53 | "train_batch_size": "auto",
54 | "train_micro_batch_size_per_gpu": "auto",
55 | "gradient_accumulation_steps": "auto",
56 | "gradient_clipping": "auto",
57 | "steps_per_print": 5,
58 | "wall_clock_breakdown": false
59 | }
60 |
--------------------------------------------------------------------------------
/llamatuner/utils/stream_server.py:
--------------------------------------------------------------------------------
1 | """Helpers to support streaming generate output.
2 |
3 | Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
4 | """
5 | import traceback
6 | from queue import Queue
7 | from threading import Thread
8 |
9 | import transformers
10 |
11 |
12 | class Stream(transformers.StoppingCriteria):
13 |
14 | def __init__(self, callback_func=None):
15 | self.callback_func = callback_func
16 |
17 | def __call__(self, input_ids, scores) -> bool:
18 | if self.callback_func is not None:
19 | self.callback_func(input_ids[0])
20 | return False
21 |
22 |
23 | class Iteratorize:
24 | """Transforms a function that takes a callback into a lazy iterator
25 | (generator)."""
26 |
27 | def __init__(self, func, kwargs={}, callback=None):
28 | self.mfunc = func
29 | self.c_callback = callback
30 | self.q = Queue()
31 | self.sentinel = object()
32 | self.kwargs = kwargs
33 | self.stop_now = False
34 |
35 | def _callback(val):
36 | if self.stop_now:
37 | raise ValueError
38 | self.q.put(val)
39 |
40 | def gentask():
41 | try:
42 | ret = self.mfunc(callback=_callback, **self.kwargs)
43 | except ValueError:
44 | pass
45 | except:
46 | traceback.print_exc()
47 | pass
48 |
49 | self.q.put(self.sentinel)
50 | if self.c_callback:
51 | self.c_callback(ret)
52 |
53 | self.thread = Thread(target=gentask)
54 | self.thread.start()
55 |
56 | def __iter__(self):
57 | return self
58 |
59 | def __next__(self):
60 | obj = self.q.get(True, None)
61 | if obj is self.sentinel:
62 | raise StopIteration
63 | else:
64 | return obj
65 |
66 | def __enter__(self):
67 | return self
68 |
69 | def __exit__(self, exc_type, exc_val, exc_tb):
70 | self.stop_now = True
71 |
--------------------------------------------------------------------------------
/data/ultra_chat/ultra_chat.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import List
4 |
5 | import datasets
6 |
7 | _HF_ENDPOINT = os.getenv('HF_ENDPOINT', 'https://huggingface.co')
8 |
9 | _DESCRIPTION = 'UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data.'
10 |
11 | _CITATION = """\
12 | @misc{UltraChat,
13 | author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
14 | title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
15 | year = {2023},
16 | publisher = {GitHub},
17 | journal = {GitHub repository},
18 | howpublished = {\\url{https://github.com/thunlp/ultrachat}},
19 | }
20 | """
21 |
22 | _HOMEPAGE = '{}/datasets/stingning/ultrachat'.format(_HF_ENDPOINT)
23 | _LICENSE = 'cc-by-nc-4.0'
24 | _BASE_DATA_URL = '{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl'.format(
25 | _HF_ENDPOINT)
26 |
27 |
28 | class UltraChat(datasets.GeneratorBasedBuilder):
29 | VERSION = datasets.Version('0.0.0')
30 |
31 | def _info(self):
32 | features = datasets.Features({
33 | 'conversations': [{
34 | 'from': datasets.Value('string'),
35 | 'value': datasets.Value('string')
36 | }]
37 | })
38 | return datasets.DatasetInfo(description=_DESCRIPTION,
39 | features=features,
40 | homepage=_HOMEPAGE,
41 | license=_LICENSE,
42 | citation=_CITATION)
43 |
44 | def _split_generators(self, dl_manager: datasets.DownloadManager):
45 | file_paths = [
46 | dl_manager.download(_BASE_DATA_URL.format(idx=idx))
47 | for idx in range(10)
48 | ] # multiple shards
49 | return [
50 | datasets.SplitGenerator(name=datasets.Split.TRAIN,
51 | gen_kwargs={'filepaths': file_paths})
52 | ]
53 |
54 | def _generate_examples(self, filepaths: List[str]):
55 | for filepath in filepaths:
56 | with open(filepath, 'r', encoding='utf-8') as f:
57 | for row in f:
58 | try:
59 | data = json.loads(row)
60 | except Exception:
61 | continue
62 | key: int = data['id']
63 | content: List[str] = data['data']
64 | if len(content) % 2 == 1:
65 | content.pop(-1)
66 | if len(content) < 2:
67 | continue
68 | conversations = [{
69 | 'from': 'human' if i % 2 == 0 else 'gpt',
70 | 'value': content[i]
71 | } for i in range(len(content))]
72 | yield key, {'conversations': conversations}
73 |
--------------------------------------------------------------------------------
/server/multi_chat.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from threading import Thread
3 |
4 | import torch
5 | import transformers
6 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
7 | TextIteratorStreamer)
8 |
9 | sys.path.append('../')
10 | from llamatuner.configs import GenerationArguments, ModelInferenceArguments
11 | from llamatuner.utils.model_utils import get_logits_processor
12 |
13 |
14 | def main(model_server_args, generation_args):
15 | """多轮对话,不具有对话历史的记忆功能."""
16 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
17 | model = AutoModelForCausalLM.from_pretrained(
18 | model_server_args.model_name_or_path,
19 | cache_dir=model_server_args.cache_dir,
20 | trust_remote_code=True,
21 | low_cpu_mem_usage=True,
22 | torch_dtype=torch.float16,
23 | device_map='auto').to(device).eval()
24 | tokenizer = AutoTokenizer.from_pretrained(
25 | model_server_args.model_name_or_path,
26 | trust_remote_code=True,
27 | use_fast=False,
28 | )
29 | # 记录所有历史记录
30 | historys = tokenizer.bos_token
31 | print('User: ', end='', flush=True)
32 | user_input = input('')
33 | while True:
34 | user_input = '{}'.format(user_input).strip()
35 | historys = historys + user_input
36 | inputs = tokenizer(historys,
37 | return_tensors='pt',
38 | add_special_tokens=False)
39 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
40 |
41 | # Create a TextIteratorStreamer object to stream the response from the model
42 | streamer = TextIteratorStreamer(tokenizer,
43 | timeout=60.0,
44 | skip_prompt=True,
45 | skip_special_tokens=True)
46 |
47 | # Set the arguments for the model's generate() method
48 | gen_kwargs = dict(
49 | inputs,
50 | streamer=streamer,
51 | logits_processor=get_logits_processor(),
52 | **generation_args.to_dict(),
53 | )
54 |
55 | # Start a separate thread to generate the response asynchronously
56 | thread = Thread(target=model.generate, kwargs=gen_kwargs)
57 | thread.start()
58 |
59 | # Print the model name and the response as it is generated
60 | print('Assistant: ', end='', flush=True)
61 | response = ''
62 | for new_text in streamer:
63 | print(new_text, end='', flush=True)
64 | response += new_text
65 |
66 | historys = historys + response
67 | print('\n')
68 | print('User: ', end='', flush=True)
69 | user_input = input('')
70 |
71 |
72 | if __name__ == '__main__':
73 | parser = transformers.HfArgumentParser(
74 | (ModelInferenceArguments, GenerationArguments))
75 | model_server_args, generation_args = parser.parse_args_into_dataclasses()
76 | main(model_server_args, generation_args)
77 |
--------------------------------------------------------------------------------
/data/format_data/convert_oasst1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 |
7 | def json_dump(obj, path):
8 | with open(path, 'w', encoding='utf-8') as f:
9 | json.dump(obj, f, indent=2, ensure_ascii=False)
10 |
11 |
12 | def json_load(in_file):
13 | with open(in_file, 'r') as f:
14 | json_data = json.load(f)
15 | return json_data
16 |
17 |
18 | def convert_oasst1_data(data_dir, output_dir):
19 | """For OASST1, because it's in a tree structure, where every user input
20 | might get multiple replies, we have to save every path from the root node
21 | to the assistant reply (including both leaf node and intemediate node).
22 |
23 | This results in some of the messages being duplicated among different paths
24 | (instances). Be careful when using this dataset for training. Ideally, you
25 | should only minimize the loss of the last message in each path.
26 | """
27 | conversations = []
28 | with open(os.path.join(data_dir, '2023-04-12_oasst_ready.trees.jsonl'),
29 | 'r') as fin:
30 | for line in fin:
31 | conversations.append(json.loads(line))
32 |
33 | output_path = os.path.join(output_dir, 'oasst1_data.jsonl')
34 |
35 | # tranvers the conversation tree, and collect all valid sequences
36 | def dfs(reply, messages, valid_sequences):
37 | if reply['role'] == 'assistant':
38 | messages.append({'role': 'assistant', 'content': reply['text']})
39 | valid_sequences.append(messages[:])
40 | for child in reply['replies']:
41 | dfs(child, messages, valid_sequences)
42 | messages.pop()
43 | elif reply['role'] == 'prompter':
44 | messages.append({'role': 'user', 'content': reply['text']})
45 | for child in reply['replies']:
46 | dfs(child, messages, valid_sequences)
47 | messages.pop()
48 | else:
49 | raise ValueError(f"Unknown role: {reply['role']}")
50 |
51 | with open(output_path, 'w') as fout:
52 | example_cnt = 0
53 | for _, conversation in enumerate(conversations):
54 | valid_sequences = []
55 | dfs(conversation['prompt'], [], valid_sequences)
56 | for sequence in valid_sequences:
57 | fout.write(
58 | json.dumps({
59 | 'dataset': 'oasst1',
60 | 'id': f'oasst1_{example_cnt}',
61 | 'messages': sequence
62 | }) + '\n')
63 | example_cnt += 1
64 |
65 |
66 | if __name__ == '__main__':
67 | arg_parser = argparse.ArgumentParser()
68 | arg_parser.add_argument('--raw_data_dir',
69 | type=str,
70 | default='data/downloads')
71 | arg_parser.add_argument('--output_dir', type=str, default='data/processed')
72 | arg_parser.add_argument('--seed', type=int, default=42)
73 | args = arg_parser.parse_args()
74 | random.seed(args.seed)
75 |
76 | convert_oasst1_data(data_dir=args.raw_data_dir, output_dir=args.output_dir)
77 |
--------------------------------------------------------------------------------
/llamatuner/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import subprocess
4 | import sys
5 | from enum import Enum, unique
6 |
7 | sys.path.append(os.getcwd())
8 |
9 | from llamatuner import launcher
10 | from llamatuner.train.tuner import run_exp
11 | from llamatuner.utils.env_utils import VERSION, print_env_info
12 | from llamatuner.utils.logger_utils import get_logger
13 | from llamatuner.utils.misc import get_device_count
14 |
15 | USAGE = (
16 | '-' * 70 + '\n' +
17 | '| Usage: |\n' +
18 | '| llamatuner-cli train -h: train models |\n' +
19 | '-' * 70)
20 |
21 | WELCOME = ('-' * 58 + '\n' +
22 | '| Welcome to LLamaTuner, version {}'.format(VERSION) + ' ' *
23 | (21 - len(VERSION)) + '|\n|' + ' ' * 56 + '|\n' +
24 | '| Project page: https://github.com/hiyouga/LLaMA-Factory |\n' +
25 | '-' * 58)
26 |
27 | logger = get_logger(__name__)
28 |
29 |
30 | @unique
31 | class Command(str, Enum):
32 | API = 'api'
33 | CHAT = 'chat'
34 | ENV = 'env'
35 | EVAL = 'eval'
36 | EXPORT = 'export'
37 | TRAIN = 'train'
38 | WEBDEMO = 'webchat'
39 | WEBUI = 'webui'
40 | VER = 'version'
41 | HELP = 'help'
42 |
43 |
44 | def main() -> None:
45 | command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
46 | if command == Command.ENV:
47 | print_env_info()
48 | elif command == Command.TRAIN:
49 | force_torchrun = os.environ.get('FORCE_TORCHRUN',
50 | '0').lower() in ['true', '1']
51 | if force_torchrun or get_device_count() > 1:
52 | master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1')
53 | master_port = os.environ.get('MASTER_PORT',
54 | str(random.randint(20001, 29999)))
55 | logger.info('Initializing distributed tasks at: {}:{}'.format(
56 | master_addr, master_port))
57 | process = subprocess.run(
58 | ('torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} '
59 | '--master_addr {master_addr} --master_port {master_port} {file_name} {args}'
60 | ).format(
61 | nnodes=os.environ.get('NNODES', '1'),
62 | node_rank=os.environ.get('RANK', '0'),
63 | nproc_per_node=os.environ.get('NPROC_PER_NODE',
64 | str(get_device_count())),
65 | master_addr=master_addr,
66 | master_port=master_port,
67 | file_name=launcher.__file__,
68 | args=' '.join(sys.argv[1:]),
69 | ),
70 | shell=True,
71 | )
72 | sys.exit(process.returncode)
73 | else:
74 | run_exp()
75 | elif command == Command.VER:
76 | print(WELCOME)
77 | elif command == Command.HELP:
78 | print(USAGE)
79 | else:
80 | raise NotImplementedError('Unknown command: {}'.format(command))
81 |
82 |
83 | if __name__ == '__main__':
84 | main()
85 |
--------------------------------------------------------------------------------
/llamatuner/train/apply_lora.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Tuple
3 |
4 | import torch
5 | from peft import PeftModel
6 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
7 |
8 | from llamatuner.utils.logger_utils import get_logger
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | def apply_lora(
14 | base_model_path: str,
15 | lora_model_path: str,
16 | target_model_path: str = None,
17 | cache_dir: str = None,
18 | trust_remote_code: bool = True,
19 | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
20 | """Applies the LoRA adapter to a base model and saves the resulting target
21 | model (optional).
22 |
23 | Args:
24 | base_model_path (str): The path to the base model to which the LoRA adapter will be applied.
25 | lora_model_path (str): The path to the LoRA adapter.
26 | target_model_path (str): The path where the target model will be saved (if `save_target_model=True`).
27 | cache_dir (str): The path to the cache directory.
28 | trust_remote_code (bool): Whether to trust remote code when downloading the model.
29 |
30 | Returns:
31 | Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the target model and its tokenizer.
32 | """
33 | # Load the base model and tokenizer
34 | logger.info(f'Loading the base model from {base_model_path}')
35 | # Set configuration kwargs for tokenizer.
36 | config_kwargs = {
37 | 'cache_dir': cache_dir,
38 | 'trust_remote_code': trust_remote_code,
39 | }
40 |
41 | base_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
42 | base_model_path,
43 | device_map='auto',
44 | torch_dtype=torch.float16,
45 | low_cpu_mem_usage=True,
46 | **config_kwargs,
47 | )
48 |
49 | # Load the tokenizer
50 | logger.info(f'Loading the tokenizer from {base_model_path}')
51 | # Due to the name of Transformers' LlamaTokenizer, we have to do this
52 | tokenizer = AutoTokenizer.from_pretrained(
53 | base_model_path,
54 | use_fast=False,
55 | **config_kwargs,
56 | )
57 |
58 | # Load the LoRA adapter
59 | logger.info(f'Loading the LoRA adapter from {lora_model_path}')
60 | model: PreTrainedModel = PeftModel.from_pretrained(base_model,
61 | lora_model_path)
62 | logger.info('Applying the LoRA to base model')
63 | model = model.merge_and_unload()
64 |
65 | if target_model_path is not None:
66 | logger.info(f'Saving the target model to {target_model_path}')
67 | model.save_pretrained(target_model_path)
68 | tokenizer.save_pretrained(target_model_path)
69 |
70 | return model, tokenizer
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--base-model-path', type=str, required=True)
76 | parser.add_argument('--target-model-path', type=str, default=None)
77 | parser.add_argument('--lora-model-path', type=str, required=True)
78 | args = parser.parse_args()
79 |
80 | apply_lora(
81 | base_model_path=args.base_model_path,
82 | lora_model_path=args.lora_model_path,
83 | target_model_path=args.target_model_path,
84 | )
85 |
--------------------------------------------------------------------------------
/llamatuner/data/dataset_factory/reward_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | from datasets import load_dataset
5 | from torch.utils.data import Dataset
6 | from transformers import PreTrainedTokenizer
7 |
8 |
9 | class PairwiseDataset(Dataset):
10 | """Dataset class for pairwise ranking tasks.
11 |
12 | Args:
13 | data_path: Path to the dataset.
14 | tokenizer: The tokenizer used to encode the input text.
15 | max_length: Maximum sequence length for the encoded inputs.
16 | """
17 |
18 | def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer,
19 | split: str, max_length: int):
20 |
21 | self.pairs = self.create_comparison_dataset(data_path, split)
22 | self.tokenizer = tokenizer
23 | self.max_length = max_length
24 |
25 | def __len__(self) -> int:
26 | return len(self.pairs)
27 |
28 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
29 | if idx < 0 or idx >= len(self.pairs):
30 | raise IndexError(
31 | f'Index {idx} out of range for TLDRDataset with length {len(self)}'
32 | )
33 | pair = self.pairs[idx]
34 | chosen_example, rejected_example = pair['chosen'], pair['rejected']
35 |
36 | chosen_encodings_dict = self.tokenizer(chosen_example,
37 | truncation=True,
38 | max_length=self.max_length,
39 | padding='max_length')
40 | rejected_encodings_dict = self.tokenizer(rejected_example,
41 | truncation=True,
42 | max_length=self.max_length,
43 | padding='max_length')
44 | encodings_input = {}
45 | encodings_input['chosen_input_ids'] = chosen_encodings_dict[
46 | 'input_ids']
47 | encodings_input['chosen_attention_mask'] = chosen_encodings_dict[
48 | 'attention_mask']
49 | encodings_input['rejected_input_ids'] = rejected_encodings_dict[
50 | 'input_ids']
51 | encodings_input['rejected_attention_mask'] = rejected_encodings_dict[
52 | 'attention_mask']
53 | encodings_input['labels'] = 1.0
54 |
55 | encodings_input = {
56 | key: torch.tensor(val)
57 | for key, val in encodings_input.items()
58 | }
59 |
60 | return encodings_input
61 |
62 | def create_comparison_dataset(self, path: str, split: str = 'train'):
63 | dataset = load_dataset(path, split=split)
64 | pairs = []
65 | for prompt, chosen_summary, rejected_summary in zip(
66 | dataset['prompt'], dataset['chosen'], dataset['rejected']):
67 | pair = {}
68 | if chosen_summary == rejected_summary:
69 | continue
70 | if len(chosen_summary.split()) < 5 or len(
71 | rejected_summary.split()) < 5:
72 | continue
73 |
74 | pair[
75 | 'chosen'] = '<|startoftext|>' + prompt + '\n' + chosen_summary + '<|endoftext|>'
76 | pair[
77 | 'rejected'] = '<|startoftext|>' + prompt + '\n' + rejected_summary + '<|endoftext|>'
78 | pairs.append(pair)
79 |
80 | return pairs
81 |
--------------------------------------------------------------------------------
/llamatuner/data/data_collator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, Sequence
3 |
4 | import torch
5 | from transformers import DataCollatorForSeq2Seq
6 |
7 |
8 | @dataclass
9 | class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
10 | r"""
11 | Data collator for pairwise data.
12 | """
13 |
14 | def __call__(
15 | self, features: Sequence[Dict[str,
16 | Any]]) -> Dict[str, 'torch.Tensor']:
17 | r"""
18 | Pads batched data to the longest sequence in the batch.
19 |
20 | We generate 2 * n examples where the first n examples represent chosen examples and
21 | the last n examples represent rejected examples.
22 | """
23 | concatenated_features = []
24 | for key in ('chosen', 'rejected'):
25 | for feature in features:
26 | target_feature = {
27 | 'input_ids': feature['{}_input_ids'.format(key)],
28 | 'attention_mask': feature['{}_attention_mask'.format(key)],
29 | 'labels': feature['{}_labels'.format(key)],
30 | }
31 | if 'pixel_values' in feature:
32 | target_feature['pixel_values'] = feature['pixel_values']
33 |
34 | if '{}_token_type_ids'.format(key) in feature:
35 | target_feature['token_type_ids'] = feature[
36 | '{}_token_type_ids'.format(key)]
37 |
38 | concatenated_features.append(target_feature)
39 |
40 | return super().__call__(concatenated_features)
41 |
42 |
43 | @dataclass
44 | class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
45 | r"""
46 | Data collator for KTO data.
47 | """
48 |
49 | def __call__(
50 | self, features: Sequence[Dict[str,
51 | Any]]) -> Dict[str, 'torch.Tensor']:
52 | target_features = []
53 | kl_features = []
54 | kto_tags = []
55 | for feature in features:
56 | target_feature = {
57 | 'input_ids': feature['input_ids'],
58 | 'attention_mask': feature['attention_mask'],
59 | 'labels': feature['labels'],
60 | }
61 | kl_feature = {
62 | 'input_ids': feature['kl_input_ids'],
63 | 'attention_mask': feature['kl_attention_mask'],
64 | 'labels': feature['kl_labels'],
65 | }
66 | if 'pixel_values' in feature:
67 | target_feature['pixel_values'] = feature['pixel_values']
68 |
69 | if 'token_type_ids' in feature:
70 | target_feature['token_type_ids'] = feature['token_type_ids']
71 | kl_feature['token_type_ids'] = feature['kl_token_type_ids']
72 |
73 | target_features.append(target_feature)
74 | kl_features.append(kl_feature)
75 | kto_tags.append(feature['kto_tags'])
76 |
77 | batch = super().__call__(target_features)
78 | kl_batch = super().__call__(kl_features)
79 | batch['kl_input_ids'] = kl_batch['input_ids']
80 | batch['kl_attention_mask'] = kl_batch['attention_mask']
81 | batch['kl_labels'] = kl_batch['labels']
82 | if 'token_type_ids' in batch:
83 | batch['kl_token_type_ids'] = kl_batch['token_type_ids']
84 |
85 | batch['kto_tags'] = torch.tensor(kto_tags)
86 | return batch
87 |
--------------------------------------------------------------------------------
/data/format_data/convert_vicuna.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 | from datasets import load_dataset
5 |
6 | sys.path.append('../../')
7 |
8 | from llamatuner.data.data_utils import extract_default_prompt_dataset
9 |
10 |
11 | def json_dump(obj, path):
12 | with open(path, 'w', encoding='utf-8') as f:
13 | json.dump(obj, f, indent=2, ensure_ascii=False)
14 |
15 |
16 | def json_load(in_file):
17 | with open(in_file, 'r') as f:
18 | json_data = json.load(f)
19 | return json_data
20 |
21 |
22 | def valid_keys(keys):
23 | for k in ['input', 'output']:
24 | if k not in keys:
25 | return False
26 | return True
27 |
28 |
29 | def remove_unused_columns(dataset):
30 | """Remove columns not named 'input' or 'output'."""
31 | dataset = dataset.remove_columns([
32 | col for col in dataset.column_names if col not in ['input', 'output']
33 | ])
34 | return dataset
35 |
36 |
37 | def convert_alpaca_vicuna(in_file: str, out_file: str = None):
38 | raw_dataset = load_dataset('json', data_files=in_file)['train']
39 | raw_dataset = raw_dataset.map(extract_default_prompt_dataset)
40 |
41 | collect_data = []
42 | for i, content in enumerate(raw_dataset):
43 | prompt = content['input']
44 | response = content['output']
45 |
46 | collect_data.append({
47 | 'id':
48 | f'alpaca_{i}',
49 | 'conversations': [
50 | {
51 | 'from': 'human',
52 | 'value': prompt
53 | },
54 | {
55 | 'from': 'gpt',
56 | 'value': response
57 | },
58 | ],
59 | })
60 | print(f'Original: {len(raw_dataset)}, Converted: {len(collect_data)}')
61 | json_dump(collect_data, out_file)
62 | return collect_data
63 |
64 |
65 | if __name__ == '__main__':
66 | in_file = '/home/robin/prompt_data/100PoisonMpts/train_alpaca.json'
67 | out_file = '/home/robin/prompt_data/100PoisonMpts/train_vicuna.json'
68 | collect_data = convert_alpaca_vicuna(in_file, out_file)
69 |
70 | data_path = '/home/robin/prompt_data/CValues-Comparison/test_alpaca.json'
71 | out_path = '/home/robin/prompt_data/CValues-Comparison/test_vicuna.json'
72 | convert_alpaca_vicuna(data_path, out_file=out_path)
73 |
74 | data_path = '/home/robin/prompt_data/CValues-Comparison/train_alpaca.json'
75 | out_path = '/home/robin/prompt_data/CValues-Comparison/train_vicuna.json'
76 | convert_alpaca_vicuna(data_path, out_file=out_path)
77 |
78 | data_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_alpaca.json'
79 | out_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_vicnua.json'
80 | convert_alpaca_vicuna(data_path, out_file=out_path)
81 |
82 | data_path = '/home/robin/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json'
83 | out_path = '/home/robin/prompt_data/Safety-Prompts/attack_scenarios_vicuna.json'
84 | convert_alpaca_vicuna(data_path, out_file=out_path)
85 |
86 | data_path = '/home/robin/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json'
87 | out_path = '/home/robin/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json'
88 | convert_alpaca_vicuna(data_path, out_file=out_path)
89 |
90 | data_path = '/home/robin/prompt_data/COIG/train_alpaca.json'
91 | out_path = '/home/robin/prompt_data/COIG/train_vicuna.json'
92 | convert_alpaca_vicuna(data_path, out_file=out_path)
93 |
--------------------------------------------------------------------------------
/llamatuner/configs/generating_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass, field
2 | from typing import Any, Dict, Optional
3 |
4 | from transformers import GenerationConfig
5 |
6 |
7 | @dataclass
8 | class GeneratingArguments:
9 | """Arguments pertaining to specify the model generation parameters."""
10 |
11 | # Generation strategy
12 | # 是否采样
13 | do_sample: Optional[bool] = field(
14 | default=True,
15 | metadata={
16 | 'help':
17 | 'Whether or not to use sampling, use greedy decoding otherwise.'
18 | },
19 | )
20 | # Hyperparameters for logit manipulation
21 | # softmax 函数的温度因子,来调节输出token的分布
22 | temperature: Optional[float] = field(
23 | default=1.0,
24 | metadata={
25 | 'help': 'The value used to modulate the next token probabilities.'
26 | },
27 | )
28 | # 核采样参数,top_p最高的前n个(n是变化)概率和为p,从这些n个候选token中随机采样
29 | top_p: Optional[float] = field(
30 | default=1.0,
31 | metadata={
32 | 'help':
33 | 'The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept.'
34 | },
35 | )
36 | # top_k随机搜索中的k个最高概率选择
37 | top_k: Optional[int] = field(
38 | default=50,
39 | metadata={
40 | 'help':
41 | 'The number of highest probability vocabulary tokens to keep for top-k filtering.'
42 | },
43 | )
44 | # 集束搜索的数量
45 | num_beams: Optional[int] = field(
46 | default=1,
47 | metadata={
48 | 'help': 'Number of beams for beam search. 1 means no beam search.'
49 | },
50 | )
51 | # 最大的token数量,会被 max_new_tokens 覆盖
52 | max_length: Optional[int] = field(
53 | default=1024,
54 | metadata={
55 | 'help':
56 | 'The maximum length the generated tokens can have. It can be overridden by max_new_tokens.'
57 | },
58 | )
59 | # 最大的新生成的token数量
60 | max_new_tokens: Optional[int] = field(
61 | default=1024,
62 | metadata={
63 | 'help':
64 | 'Maximum number of new tokens to be generated in evaluation or prediction loops'
65 | 'if predict_with_generate is set.'
66 | },
67 | )
68 | # 重复性惩罚因子
69 | repetition_penalty: Optional[float] = field(
70 | default=1.0,
71 | metadata={
72 | 'help':
73 | 'The parameter for repetition penalty. 1.0 means no penalty.'
74 | })
75 | # 长度惩罚因子
76 | length_penalty: Optional[float] = field(
77 | default=1.0,
78 | metadata={
79 | 'help':
80 | 'Exponential penalty to the length that is used with beam-based generation.'
81 | })
82 | default_system: Optional[str] = field(
83 | default=None,
84 | metadata={'help': 'Default system message to use in chat completion.'},
85 | )
86 | skip_special_tokens: bool = field(
87 | default=True,
88 | metadata={
89 | 'help': 'Whether or not to remove special tokens in the decoding.'
90 | },
91 | )
92 |
93 | def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
94 | args = asdict(self)
95 | if args.get('max_new_tokens', -1) > 0:
96 | args.pop('max_length', None)
97 | else:
98 | args.pop('max_new_tokens', None)
99 |
100 | if obey_generation_config:
101 | generation_config = GenerationConfig()
102 | for key in list(args.keys()):
103 | if not hasattr(generation_config, key):
104 | args.pop(key)
105 |
106 | return args
107 |
--------------------------------------------------------------------------------
/data/dataset_info.yaml:
--------------------------------------------------------------------------------
1 | # The dataset_info.yaml file contains the information of the datasets used in the experiments.
2 | ## pretrain dataset
3 | c4_demo:
4 | file_name: c4_demo.json
5 | columns:
6 | prompt: text
7 |
8 | open-web-math:
9 | hf_hub_url: open-web-math/open-web-math
10 | split: train[10:20]
11 | formatting: alpaca
12 | columns:
13 | prompt: text
14 |
15 | ## sft dataset
16 | alpaca:
17 | hf_hub_url: tatsu-lab/alpaca
18 | formatting: alpaca
19 |
20 | alpaca-clean:
21 | hf_hub_url: yahma/alpaca-cleaned
22 | formatting: alpaca
23 |
24 | dolly-15k:
25 | hf_hub_url: databricks/databricks-dolly-15k
26 | formatting: alpaca
27 |
28 | guanaco:
29 | hf_hub_url: JosephusCheung/GuanacoDataset
30 | ms_hub_url: AI-ModelScope/GuanacoDataset
31 | formatting: alpaca
32 |
33 | openassistant-guanaco:
34 | hf_hub_url: timdettmers/openassistant-guanaco
35 | formatting: alpaca
36 |
37 | belle_0.5m:
38 | hf_hub_url: BelleGroup/train_0.5M_CN
39 | ms_hub_url: AI-ModelScope/train_0.5M_CN
40 | formatting: alpaca
41 |
42 | belle_1m:
43 | hf_hub_url: BelleGroup/train_1M_CN
44 | ms_hub_url: AI-ModelScope/train_1M_CN
45 | formatting: alpaca
46 |
47 | belle_2m:
48 | hf_hub_url: BelleGroup/train_2M_CN
49 | ms_hub_url: AI-ModelScope/train_2M_CN
50 | formatting: alpaca
51 |
52 | belle_dialog:
53 | hf_hub_url: BelleGroup/generated_chat_0.4M
54 | ms_hub_url: AI-ModelScope/generated_chat_0.4M
55 | formatting: alpaca
56 |
57 | belle_math:
58 | hf_hub_url: BelleGroup/school_math_0.25M
59 | ms_hub_url: AI-ModelScope/school_math_0.25M
60 | formatting: alpaca
61 |
62 | belle_multiturn:
63 | hf_hub_url: BelleGroup/multi_turn_0.5M
64 | formatting: sharegpt
65 | columns:
66 | prompt: instruction
67 | response: output
68 | history: history
69 |
70 | firefly:
71 | hf_hub_url: YeungNLP/firefly-train-1.1M
72 | formatting: alpaca
73 | columns:
74 | prompt: input
75 | response: target
76 |
77 | codealpaca:
78 | hf_hub_url: sahil2801/CodeAlpaca-20k
79 | ms_hub_url: AI-ModelScope/CodeAlpaca-20k
80 | formatting: alpaca
81 |
82 | alpaca_cot:
83 | hf_hub_url: QingyiSi/Alpaca-CoT
84 | ms_hub_url: AI-ModelScope/Alpaca-CoT
85 |
86 | webqa:
87 | hf_hub_url: suolyer/webqa
88 | ms_hub_url: AI-ModelScope/webqa
89 | formatting: alpaca
90 | columns:
91 | prompt: input
92 | response: output
93 |
94 | # mutli-turn datasets
95 | evol_instruct:
96 | hf_hub_url: MaziyarPanahi/WizardLM_evol_instruct_V2_196k
97 | ms_hub_url: AI-ModelScope/WizardLM_evol_instruct_V2_196k
98 | formatting: sharegpt
99 |
100 | ultrachat_200k:
101 | hf_hub_url: HuggingFaceH4/ultrachat_200k
102 | ms_hub_url: AI-ModelScope/ultrachat_200k
103 | formatting: sharegpt
104 | columns:
105 | messages: messages
106 | tags:
107 | role_tag: role
108 | content_tag: content
109 | user_tag: user
110 | assistant_tag: assistant
111 |
112 | lmsys_chat:
113 | hf_hub_url: lmsys/lmsys-chat-1m
114 | ms_hub_url: AI-ModelScope/lmsys-chat-1m
115 | formatting: sharegpt
116 | columns:
117 | messages: conversation
118 | tags:
119 | role_tag: role
120 | content_tag: content
121 | user_tag: human
122 | assistant_tag: assistant
123 |
124 | hh_rlhf_en:
125 | script_url: hh_rlhf_en
126 | ranking: true
127 | columns:
128 | prompt: instruction
129 | chosen: chosen
130 | rejected: rejected
131 | history: history
132 |
133 | orca_pairs:
134 | hf_hub_url: Intel/orca_dpo_pairs
135 | ranking: true
136 | columns:
137 | prompt: question
138 | chosen: chosen
139 | rejected: rejected
140 | system: system
141 |
142 | kto_mix_en:
143 | hf_hub_url: argilla/kto-mix-15k
144 | formatting: sharegpt
145 | columns:
146 | messages: completion
147 | kto_tag: label
148 | tags:
149 | role_tag: role
150 | content_tag: content
151 | user_tag: user
152 | assistant_tag: assistant
153 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | *.json
162 | wandb/
163 | hf_hub
164 | logs/
165 | docs/
166 | checkpoints/
167 | work_dir/
168 | work_dirs/
169 | output/
170 | outputs/
171 | *.jso
172 |
--------------------------------------------------------------------------------
/llamatuner/model/callbacks/wandb_callback.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | import transformers
4 | from transformers.integrations import WandbCallback
5 |
6 |
7 | def decode_predictions(tokenizer, predictions):
8 | labels = tokenizer.batch_decode(predictions.label_ids)
9 | logits = predictions.predictions.argmax(axis=-1)
10 | prediction_text = tokenizer.batch_decode(logits)
11 | return {'labels': labels, 'predictions': prediction_text}
12 |
13 |
14 | class WandbPredictionProgressCallback(WandbCallback):
15 | """Custom WandbCallback to log model predictions during training.
16 |
17 | This callback logs model predictions and labels to a wandb.Table at each
18 | logging step during training. It allows to visualize the
19 | model predictions as the training progresses.
20 |
21 | Attributes:
22 | trainer (Trainer): The Hugging Face Trainer instance.
23 | tokenizer (AutoTokenizer): The tokenizer associated with the model.
24 | sample_dataset (Dataset): A subset of the validation dataset
25 | for generating predictions.
26 | num_samples (int, optional): Number of samples to select from
27 | the validation dataset for generating predictions. Defaults to 100.
28 | freq (int, optional): Frequency of logging. Defaults to 2.
29 |
30 |
31 | Example:
32 | ```python
33 | # First, instantiate the Trainer
34 | trainer = Trainer(
35 | model=model,
36 | args=training_args,
37 | train_dataset=lm_datasets["train"],
38 | eval_dataset=lm_datasets["validation"],
39 | )
40 |
41 | # Instantiate the WandbPredictionProgressCallback
42 | progress_callback = WandbPredictionProgressCallback(
43 | trainer=trainer,
44 | tokenizer=tokenizer,
45 | val_dataset=lm_dataset["validation"],
46 | num_samples=10,
47 | freq=2,
48 | )
49 |
50 | # Add the callback to the trainer
51 | trainer.add_callback(progress_callback)
52 | ```
53 | """
54 |
55 | def __init__(
56 | self,
57 | trainer: transformers.Trainer,
58 | tokenizer: transformers.AutoTokenizer,
59 | val_dataset: torch.utils.data.Dataset,
60 | num_samples: int = 100,
61 | freq: int = 2,
62 | ) -> None:
63 | """Initializes the WandbPredictionProgressCallback instance.
64 |
65 | Args:
66 | trainer (Trainer): The Hugging Face Trainer instance.
67 | tokenizer (AutoTokenizer): The tokenizer associated
68 | with the model.
69 | val_dataset (Dataset): The validation dataset.
70 | num_samples (int, optional): Number of samples to select from
71 | the validation dataset for generating predictions.
72 | Defaults to 100.
73 | freq (int, optional): Frequency of logging. Defaults to 2.
74 | """
75 | super().__init__()
76 | self.trainer = trainer
77 | self.tokenizer = tokenizer
78 | # select a subset of the validation dataset
79 | indices = torch.randperm(len(val_dataset))[:num_samples]
80 | self.sample_dataset = torch.utils.data.Subset(val_dataset, indices)
81 | self.freq = freq
82 |
83 | def on_evaluate(self, args, state, control, **kwargs):
84 | super().on_evaluate(args, state, control, **kwargs)
85 | # control the frequency of logging by logging the predictions
86 | # every `freq` epochs
87 | if state.epoch % self.freq == 0:
88 | # generate predictions
89 | predictions = self.trainer.predict(self.sample_dataset)
90 | # decode predictions and labels
91 | predictions = decode_predictions(self.tokenizer, predictions)
92 | # add predictions to a wandb.Table
93 | predictions_df = pd.DataFrame(predictions)
94 | predictions_df['epoch'] = state.epoch
95 | records_table = self._wandb.Table(dataframe=predictions_df)
96 | # log the table to wandb
97 | self._wandb.log({'sample_predictions': records_table})
98 |
--------------------------------------------------------------------------------
/data/hh_rlhf_en/hh_rlhf_en.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import List
4 |
5 | import datasets
6 |
7 | _HF_ENDPOINT = os.getenv('HF_ENDPOINT', 'https://huggingface.co')
8 | _DESCRIPTION = 'Human preference data about helpfulness and harmlessness.'
9 | _CITATION = ''
10 | _HOMEPAGE = '{}/datasets/Anthropic/hh-rlhf'.format(_HF_ENDPOINT)
11 | _LICENSE = 'mit'
12 | _URL = '{}/datasets/Anthropic/hh-rlhf/resolve/main/'.format(_HF_ENDPOINT)
13 | _URLS = {
14 | 'train': [
15 | _URL + 'harmless-base/train.jsonl.gz',
16 | _URL + 'helpful-base/train.jsonl.gz',
17 | _URL + 'helpful-online/train.jsonl.gz',
18 | _URL + 'helpful-rejection-sampled/train.jsonl.gz',
19 | ],
20 | 'test': [
21 | _URL + 'harmless-base/test.jsonl.gz',
22 | _URL + 'helpful-base/test.jsonl.gz',
23 | _URL + 'helpful-online/test.jsonl.gz',
24 | _URL + 'helpful-rejection-sampled/test.jsonl.gz',
25 | ],
26 | }
27 |
28 |
29 | class HhRlhfEn(datasets.GeneratorBasedBuilder):
30 | VERSION = datasets.Version('0.0.0')
31 |
32 | def _info(self) -> datasets.DatasetInfo:
33 | features = datasets.Features({
34 | 'instruction':
35 | datasets.Value('string'),
36 | 'chosen':
37 | datasets.Value('string'),
38 | 'rejected':
39 | datasets.Value('string'),
40 | 'history':
41 | datasets.Sequence(datasets.Sequence(datasets.Value('string'))),
42 | })
43 | return datasets.DatasetInfo(
44 | description=_DESCRIPTION,
45 | features=features,
46 | homepage=_HOMEPAGE,
47 | license=_LICENSE,
48 | citation=_CITATION,
49 | )
50 |
51 | def _split_generators(self, dl_manager: datasets.DownloadManager):
52 | file_path = dl_manager.download_and_extract(_URLS)
53 | return [
54 | datasets.SplitGenerator(
55 | name=datasets.Split.TRAIN,
56 | gen_kwargs={'filepaths': file_path['train']}),
57 | datasets.SplitGenerator(
58 | name=datasets.Split.TEST,
59 | gen_kwargs={'filepaths': file_path['test']}),
60 | ]
61 |
62 | def _generate_examples(self, filepaths: List[str]):
63 | key = 0
64 | for filepath in filepaths:
65 | with open(filepath, 'r', encoding='utf-8') as f:
66 | for row in f:
67 | data = json.loads(row)
68 | chosen = data['chosen']
69 | rejected = data['rejected']
70 |
71 | assist_idx = rejected.rfind('\n\nAssistant: ')
72 | r_reject = rejected[assist_idx + 13:].strip()
73 | assist_idx = chosen.rfind('\n\nAssistant: ')
74 | r_accept = chosen[assist_idx + 13:].strip()
75 |
76 | human_idx = chosen.rfind('\n\nHuman: ')
77 | query = chosen[human_idx + 9:assist_idx].strip()
78 | prompt = chosen[:human_idx]
79 | history = []
80 |
81 | while prompt.rfind('\n\nAssistant: ') != -1:
82 | assist_idx = prompt.rfind('\n\nAssistant: ')
83 | human_idx = prompt.rfind('\n\nHuman: ')
84 | if human_idx != -1:
85 | old_query = prompt[human_idx +
86 | 9:assist_idx].strip()
87 | old_resp = prompt[assist_idx + 13:].strip()
88 | history.insert(0, (old_query, old_resp))
89 | else:
90 | break
91 | prompt = prompt[:human_idx]
92 |
93 | yield (
94 | key,
95 | {
96 | 'instruction': query,
97 | 'chosen': r_accept,
98 | 'rejected': r_reject,
99 | 'history': history,
100 | },
101 | )
102 | key += 1
103 |
--------------------------------------------------------------------------------
/llamatuner/train/rm/trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Dict, List, Tuple, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from transformers import PreTrainedModel, Trainer
8 | from transformers.trainer import PredictionOutput
9 |
10 | from llamatuner.utils.logger_utils import get_logger
11 |
12 | logger = get_logger('llamatuner')
13 |
14 |
15 | class PairwiseTrainer(Trainer):
16 | r"""
17 | Inherits Trainer to compute pairwise loss. This custom Trainer computes a pairwise ranking loss
18 | where the first half of the batch contains positive examples and the second half contains negative examples.
19 | """
20 |
21 | def __init__(self, *args, **kwargs) -> None:
22 | super().__init__(*args, **kwargs)
23 | self.can_return_loss = True
24 |
25 | def compute_loss(
26 | self,
27 | model: PreTrainedModel,
28 | inputs: Dict[str, torch.Tensor],
29 | return_outputs: bool = False,
30 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor,
31 | torch.Tensor]]]:
32 | r"""
33 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
34 |
35 | Args:
36 | model (PreTrainedModel): The model being trained.
37 | inputs (Dict[str, torch.Tensor]): The inputs and targets of the model.
38 | return_outputs (bool, optional): Whether to return the model outputs in addition to the loss. Defaults to False.
39 |
40 | Returns:
41 | Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: The computed loss and optionally the model outputs.
42 | """
43 | outputs = model(**inputs,
44 | output_hidden_states=True,
45 | return_dict=True,
46 | use_cache=False)
47 | values = outputs.logits
48 |
49 | batch_size = inputs['input_ids'].size(0) // 2
50 | chosen_masks, rejected_masks = torch.split(inputs['attention_mask'],
51 | batch_size,
52 | dim=0)
53 | chosen_rewards, rejected_rewards = torch.split(values,
54 | batch_size,
55 | dim=0)
56 |
57 | chosen_scores = chosen_rewards.gather(
58 | dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
59 | rejected_scores = rejected_rewards.gather(
60 | dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
61 |
62 | chosen_scores, rejected_scores = (
63 | chosen_scores.squeeze(),
64 | rejected_scores.squeeze(),
65 | )
66 |
67 | loss = -F.logsigmoid(chosen_scores.float() -
68 | rejected_scores.float()).mean()
69 | if return_outputs:
70 | return loss, (chosen_scores, rejected_scores)
71 | else:
72 | return loss
73 |
74 | def save_predictions(self, predict_results: PredictionOutput) -> None:
75 | r"""
76 | Saves model predictions to `output_dir`.
77 |
78 | Args:
79 | predict_results (PredictionOutput): The output of the `predict` method.
80 |
81 | This method saves the chosen and rejected scores to a JSONL file in the `output_dir`.
82 | """
83 | if not self.is_world_process_zero():
84 | return
85 |
86 | output_prediction_file = os.path.join(self.args.output_dir,
87 | 'generated_predictions.jsonl')
88 | logger.info(f'Saving prediction results to {output_prediction_file}')
89 | chosen_scores, rejected_scores = predict_results.predictions
90 |
91 | with open(output_prediction_file, 'w', encoding='utf-8') as writer:
92 | res: List[str] = []
93 | for c_score, r_score in zip(chosen_scores, rejected_scores):
94 | res.append(
95 | json.dumps({
96 | 'chosen': round(float(c_score), 2),
97 | 'rejected': round(float(r_score), 2),
98 | }))
99 |
100 | writer.write('\n'.join(res))
101 |
--------------------------------------------------------------------------------
/llamatuner/model/callbacks/save_peft_model_callback.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict
3 |
4 | from transformers import (PreTrainedModel, TrainerCallback, TrainerControl,
5 | TrainingArguments)
6 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
7 |
8 |
9 | class SavePeftModelCallback(TrainerCallback):
10 | """Callback to save PEFT model checkpoints during training.
11 |
12 | Saves both the full model and the adapter model to separate directories
13 | within the checkpoint directory.
14 | """
15 |
16 | def save_model(self, args: Any, state: TrainingArguments,
17 | kwargs: Dict[str, Any]) -> None:
18 | """Saves the PEFT model checkpoint.
19 |
20 | Args:
21 | args (Any): The command line arguments passed to the script.
22 | state (TrainingArguments): The current state of training.
23 | kwargs (Dict[str, Any]): A dictionary of additional keyword arguments.
24 |
25 | Raises:
26 | TypeError: If `state` is not an instance of `TrainingArguments`.
27 | """
28 | print('+' * 20, 'Saving PEFT Model Checkpoint CallBack', '+' * 20)
29 |
30 | # Get the checkpoint directory for saving models.
31 | if state.best_model_checkpoint is not None:
32 | # If best model checkpoint exists, use its directory as the checkpoint folder
33 | checkpoint_dir = os.path.join(state.best_model_checkpoint,
34 | 'adapter_model')
35 | else:
36 | # Otherwise, create a new checkpoint folder using the output directory and current global step
37 | checkpoint_dir = os.path.join(
38 | args.output_dir,
39 | f'{PREFIX_CHECKPOINT_DIR}-{state.global_step}')
40 |
41 | # Create path for the PEFT model
42 | peft_model_path = os.path.join(checkpoint_dir, 'adapter_model')
43 | model: PreTrainedModel = kwargs['model']
44 | model.save_pretrained(peft_model_path)
45 |
46 | # Create path for the PyTorch model binary file and remove it if it already exists
47 | pytorch_model_path = os.path.join(checkpoint_dir, 'pytorch_model.bin')
48 | if os.path.exists(pytorch_model_path):
49 | os.remove(pytorch_model_path)
50 |
51 | def on_save(self, args: Any, state: TrainingArguments,
52 | control: TrainerControl,
53 | **kwargs: Dict[str, Any]) -> TrainerControl:
54 | """Callback method that calls save_model() and returns `control`
55 | argument.
56 |
57 | Args:
58 | args (Any): The command line arguments passed to the script.
59 | state (TrainingArguments): The current state of training.
60 | control (trainer_callback.TrainerControl): \
61 | The current state of the TrainerCallback's control flow.
62 | kwargs (Dict[str, Any]): A dictionary of additional keyword arguments.
63 |
64 | Returns:
65 | trainer_callback.TrainerControl: The current state of the TrainerCallback's control flow.
66 |
67 | Raises:
68 | TypeError: If `state` is not an instance of `TrainingArguments`.
69 | """
70 | self.save_model(args, state, kwargs)
71 | return control
72 |
73 | def on_train_end(self, args: Any, state: TrainingArguments,
74 | control: TrainerControl, **kwargs: Dict[str,
75 | Any]) -> None:
76 | """Callback method that saves the model checkpoint and creates a
77 | 'completed' file in the output directory.
78 |
79 | Args:
80 | args (Any): The command line arguments passed to the script.
81 | state (TrainingArguments): The current state of training.
82 | control (trainer_callback.TrainerControl): \
83 | The current state of the TrainerCallback's control flow.
84 | kwargs (Dict[str, Any]): A dictionary of additional keyword arguments.
85 |
86 | Raises:
87 | TypeError: If `state` is not an instance of `TrainingArguments`.
88 | """
89 |
90 | # Define a helper function to create a 'completed' file in the output directory
91 | def touch(fname, times=None):
92 | with open(fname, 'a'):
93 | os.utime(fname, times)
94 |
95 | # Create the 'completed' file in the output directory
96 | touch(os.path.join(args.output_dir, 'completed'))
97 |
--------------------------------------------------------------------------------
/llamatuner/model/callbacks/metrics.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Sequence, Tuple, Union
3 |
4 | import numpy as np
5 | from transformers import PreTrainedTokenizer
6 |
7 | from llamatuner.utils.constants import IGNORE_INDEX
8 | from llamatuner.utils.packages import (is_jieba_available, is_nltk_available,
9 | is_rouge_available)
10 |
11 | if is_jieba_available():
12 | import jieba # type: ignore
13 |
14 | if is_nltk_available():
15 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
16 |
17 | if is_rouge_available():
18 | from rouge_chinese import Rouge
19 |
20 |
21 | @dataclass
22 | class ComputeMetrics:
23 | """Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
24 |
25 | Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
26 | """
27 |
28 | def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
29 | """Initialize the ComputeMetrics class with a pre-trained tokenizer
30 | object.
31 |
32 | Args:
33 | tokenizer (PreTrainedTokenizer): A pre-trained tokenizer object to be used for decoding tokenized sequences.
34 | """
35 | self.tokenizer = tokenizer
36 |
37 | def __call__(
38 | self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]
39 | ) -> Dict[str, float]:
40 | """Computes evaluation metrics for model predictions.
41 |
42 | Args:
43 | eval_preds (List[Union[np.ndarray, Tuple[np.ndarray]]]): List of tuples containing prediction and label arrays.
44 |
45 | Returns:
46 | Dict[str, float]: A dictionary containing the average of each computed metric over all prediction-label pairs.
47 | """
48 |
49 | # Extract predictions and labels from input
50 | preds, labels = eval_preds
51 | if isinstance(preds, tuple):
52 | preds = preds[0]
53 |
54 | score_dict = {
55 | 'rouge-1': [],
56 | 'rouge-2': [],
57 | 'rouge-l': [],
58 | 'bleu-4': []
59 | }
60 |
61 | # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
62 | preds = np.where(preds != IGNORE_INDEX, preds,
63 | self.tokenizer.pad_token_id)
64 | labels = np.where(labels != IGNORE_INDEX, labels,
65 | self.tokenizer.pad_token_id)
66 |
67 | decoded_preds = self.tokenizer.batch_decode(preds,
68 | skip_special_tokens=True)
69 | decoded_labels = self.tokenizer.batch_decode(labels,
70 | skip_special_tokens=True)
71 |
72 | # Calculate metrics for each prediction-label pair
73 | for pred, label in zip(decoded_preds, decoded_labels):
74 | hypothesis = list(jieba.cut(pred))
75 | reference = list(jieba.cut(label))
76 | # If there are no words in the hypothesis, set all scores to 0
77 | if (len(' '.join(hypothesis).split()) == 0
78 | or len(' '.join(reference).split()) == 0):
79 | result = {
80 | 'rouge-1': {
81 | 'f': 0.0
82 | },
83 | 'rouge-2': {
84 | 'f': 0.0
85 | },
86 | 'rouge-l': {
87 | 'f': 0.0
88 | },
89 | }
90 | else:
91 | rouge = Rouge()
92 | scores = rouge.get_scores(' '.join(hypothesis),
93 | ' '.join(reference))
94 | result = scores[0]
95 |
96 | # Append scores to score_dict
97 | for k, v in result.items():
98 | score_dict[k].append(round(v['f'] * 100, 4))
99 |
100 | # Calculate BLEU-4 score and append it to score_dict
101 | bleu_score = sentence_bleu(
102 | [list(label)],
103 | list(pred),
104 | smoothing_function=SmoothingFunction().method3)
105 | score_dict['bleu-4'].append(round(bleu_score * 100, 4))
106 |
107 | # Calculate average of each metric over all prediction-label pairs and return as a dictionary
108 | return {k: float(np.mean(v)) for k, v in score_dict.items()}
109 |
--------------------------------------------------------------------------------
/server/single_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import sys
4 | from threading import Thread
5 | from typing import List
6 |
7 | import torch
8 | import transformers
9 | from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel,
10 | PreTrainedTokenizer, TextIteratorStreamer)
11 |
12 | sys.path.append('../')
13 | from llamatuner.configs import GenerationArguments, ModelInferenceArguments
14 | from llamatuner.utils.model_utils import get_logits_processor
15 |
16 |
17 | def generate_response(query: str, tokenizer: PreTrainedTokenizer,
18 | model: PreTrainedModel,
19 | generation_args: dict) -> List[str]:
20 | """Generates a response to the given query using GPT-3.5 model and prints
21 | it to the console.
22 |
23 | Args:
24 | query (str): The input query for which a response is to be generated.
25 | tokenizer (PreTrainedTokenizer): The tokenizer used to convert the raw text into input tokens.
26 | model (PreTrainedModel): The GPT-3.5 model used to generate the response.
27 | generation_args (dict): A dictionary containing the arguments to be passed to the generate() method of the model.
28 |
29 | Returns:
30 | List[Tuple[str, str]]: A list of all the previous queries and their responses, including the current one.
31 | """
32 |
33 | # Convert the query and history into input IDs
34 | inputs = tokenizer(query, return_tensors='pt', add_special_tokens=False)
35 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
36 |
37 | # Create a TextIteratorStreamer object to stream the response from the model
38 | streamer = TextIteratorStreamer(tokenizer,
39 | timeout=60.0,
40 | skip_prompt=True,
41 | skip_special_tokens=True)
42 |
43 | # Set the arguments for the model's generate() method
44 | gen_kwargs = dict(
45 | **inputs,
46 | streamer=streamer,
47 | logits_processor=get_logits_processor(),
48 | **generation_args.to_dict(),
49 | )
50 |
51 | # Start a separate thread to generate the response asynchronously
52 | thread = Thread(target=model.generate, kwargs=gen_kwargs)
53 | thread.start()
54 |
55 | # Print the model name and the response as it is generated
56 | print('Assistant: ', end='', flush=True)
57 | response = ''
58 | for new_text in streamer:
59 | print(new_text, end='', flush=True)
60 | response += new_text
61 | # Update the history with the current query and response and return it
62 | return response
63 |
64 |
65 | def main():
66 | """单轮对话,不具有对话历史的记忆功能 Run conversational agent loop with input/output.
67 |
68 | Args:
69 | model_args: Arguments for loading model
70 | gen_args: Arguments for model.generate()
71 |
72 | Returns:
73 | None
74 | """
75 |
76 | # Parse command-line arguments
77 | parser = transformers.HfArgumentParser(
78 | (ModelInferenceArguments, GenerationArguments))
79 | model_server_args, generation_args = parser.parse_args_into_dataclasses()
80 |
81 | # Load the pretrained language model.
82 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
83 |
84 | model = AutoModelForCausalLM.from_pretrained(
85 | model_server_args.model_name_or_path,
86 | trust_remote_code=True,
87 | low_cpu_mem_usage=True,
88 | torch_dtype=torch.float16,
89 | device_map='auto').to(device).eval()
90 |
91 | tokenizer = AutoTokenizer.from_pretrained(
92 | model_server_args.model_name_or_path,
93 | trust_remote_code=True,
94 | use_fast=False,
95 | )
96 |
97 | os_name = platform.system()
98 | clear_command = 'cls' if os_name == 'Windows' else 'clear'
99 | # Set the arguments for the model's generate() method
100 | print('欢迎使用 CLI 对话系统,输入内容即可对话,clear 清空对话历史,stop 终止程序')
101 | input_pattern = '{}'
102 | while True:
103 | query = input('\nUser: ')
104 | if query.strip() == 'stop':
105 | break
106 |
107 | if query.strip() == 'clear':
108 | os.system(clear_command)
109 | print('History has been removed.')
110 | print('欢迎使用CLI 对话系统,输入内容即可对话,clear 清空对话历史,stop 终止程序')
111 | continue
112 |
113 | query = input_pattern.format(query)
114 | # Perform prediction and printing
115 | generate_response(query, tokenizer, model, generation_args)
116 |
117 |
118 | if __name__ == '__main__':
119 | main()
120 |
--------------------------------------------------------------------------------
/server/gradio_base_webserver.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gradio as gr
4 | import torch
5 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6 |
7 | from llamatuner.train.apply_lora import apply_lora
8 |
9 |
10 | def args_parser():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--model_name_or_path',
13 | default=None,
14 | type=str,
15 | required=True,
16 | help='Path to pre-trained model')
17 | parser.add_argument('--lora_model_name_or_path',
18 | default=None,
19 | type=str,
20 | help='Path to pre-trained model')
21 | parser.add_argument('--no_cuda',
22 | action='store_true',
23 | help='Avoid using CUDA when available')
24 | parser.add_argument('--load_8bit',
25 | action='store_true',
26 | help='Whether to use load_8bit instead of 32-bit')
27 | args = parser.parse_args()
28 |
29 | args.device = torch.device(
30 | 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
31 | return args
32 |
33 |
34 | def main(args):
35 | if args.lora_model_name_or_path is not None:
36 | model, tokenizer = apply_lora(args.model_name_or_path,
37 | args.lora_model_name_or_path,
38 | load_8bit=args.load_8bit)
39 | else:
40 | tokenizer = AutoTokenizer.from_pretrained(
41 | pretrained_model_name_or_path=args.model_name_or_path,
42 | trust_remote_code=True)
43 | model = AutoModelForCausalLM.from_pretrained(
44 | pretrained_model_name_or_path=args.model_name_or_path,
45 | load_in_8bit=args.load_8bit,
46 | torch_dtype=torch.float16,
47 | device_map='auto',
48 | trust_remote_code=True)
49 |
50 | def evaluate(
51 | input=None,
52 | temperature=0.8,
53 | top_p=0.75,
54 | top_k=40,
55 | max_new_tokens=128,
56 | **kwargs,
57 | ):
58 | inputs = tokenizer(input, return_tensors='pt')
59 | inputs = inputs.to(args.device)
60 | generation_config = GenerationConfig(
61 | temperature=temperature,
62 | top_p=top_p,
63 | top_k=top_k,
64 | do_sample=True,
65 | no_repeat_ngram_size=6,
66 | repetition_penalty=1.8,
67 | **kwargs,
68 | )
69 | # Without streaming
70 | with torch.no_grad():
71 | generation_output = model.generate(
72 | **inputs,
73 | generation_config=generation_config,
74 | return_dict_in_generate=True,
75 | output_scores=True,
76 | max_new_tokens=max_new_tokens,
77 | )
78 | s = generation_output.sequences[0]
79 | output = tokenizer.decode(s, skip_special_tokens=True)
80 | yield output
81 |
82 | description = 'Baichuan7B is a 7B-parameter LLaMA model finetuned to follow instructions.'
83 | server = gr.Interface(
84 | fn=evaluate,
85 | inputs=[
86 | gr.components.Textbox(lines=2, label='Input', placeholder='none'),
87 | gr.components.Slider(minimum=0,
88 | maximum=1,
89 | value=0.1,
90 | label='Temperature'),
91 | gr.components.Slider(minimum=0,
92 | maximum=1,
93 | value=0.75,
94 | label='Top p'),
95 | gr.components.Slider(minimum=0,
96 | maximum=100,
97 | step=1,
98 | value=40,
99 | label='Top k'),
100 | gr.components.Slider(minimum=1,
101 | maximum=2000,
102 | step=1,
103 | value=128,
104 | label='Max tokens'),
105 | ],
106 | outputs=[gr.inputs.Textbox(
107 | lines=5,
108 | label='Output',
109 | )],
110 | title='Baichuan7B',
111 | description=description,
112 | )
113 |
114 | server.queue().launch(server_name='0.0.0.0', share=False)
115 |
116 |
117 | if __name__ == '__main__':
118 | args = args_parser()
119 | main(args)
120 |
--------------------------------------------------------------------------------
/llamatuner/data/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, unique
2 | from typing import List, Optional, TypedDict, Union
3 |
4 | from datasets import (Dataset, IterableDataset, concatenate_datasets,
5 | interleave_datasets)
6 | from transformers import TrainingArguments
7 |
8 | from llamatuner.configs import DataArguments
9 | from llamatuner.utils.logger_utils import get_logger
10 |
11 | logger = get_logger('llamatuner')
12 |
13 |
14 | @unique
15 | class Role(str, Enum):
16 | """Enumeration of possible roles in a conversation."""
17 | USER = 'user'
18 | ASSISTANT = 'assistant'
19 | SYSTEM = 'system'
20 | FUNCTION = 'function'
21 | OBSERVATION = 'observation'
22 |
23 |
24 | class DatasetModule(TypedDict):
25 | """Type definition for dataset module containing train and evaluation datasets."""
26 | train_dataset: Optional[Union[Dataset, IterableDataset]]
27 | eval_dataset: Optional[Union[Dataset, IterableDataset]]
28 |
29 |
30 | def merge_dataset(
31 | all_datasets: List[Union[Dataset, IterableDataset]],
32 | data_args: DataArguments,
33 | training_args: TrainingArguments,
34 | ) -> Union[Dataset, IterableDataset]:
35 | """Merge multiple datasets using specified strategy.
36 |
37 | Args:
38 | all_datasets: List of datasets to merge
39 | data_args: Data configuration arguments
40 | training_args: Training configuration arguments
41 |
42 | Returns:
43 | Merged dataset
44 |
45 | Raises:
46 | ValueError: If mixing strategy is unknown
47 | """
48 | if not all_datasets:
49 | raise ValueError('Cannot merge empty dataset list')
50 |
51 | if len(all_datasets) == 1:
52 | return all_datasets[0]
53 |
54 | valid_strategies = {'concat', 'interleave_under', 'interleave_over'}
55 | if data_args.mix_strategy not in valid_strategies:
56 | raise ValueError(
57 | f'Unknown mixing strategy: {data_args.mix_strategy}. '
58 | f"Valid strategies are: {', '.join(valid_strategies)}")
59 |
60 | logger.info(
61 | f'Merging {len(all_datasets)} datasets with {data_args.mix_strategy} strategy ...'
62 | )
63 |
64 | if data_args.mix_strategy == 'concat':
65 | if data_args.streaming:
66 | logger.warning(
67 | 'The samples between different datasets will not be mixed in streaming mode.'
68 | )
69 | return concatenate_datasets(all_datasets)
70 |
71 | # Handle interleave strategies
72 | if not data_args.streaming:
73 | logger.warning(
74 | 'We recommend using `mix_strategy=concat` in non-streaming mode.')
75 |
76 | stopping_strategy = 'first_exhausted' if data_args.mix_strategy == 'interleave_under' else 'all_exhausted'
77 |
78 | return interleave_datasets(
79 | datasets=all_datasets,
80 | probabilities=data_args.interleave_probs,
81 | seed=training_args.seed,
82 | stopping_strategy=stopping_strategy,
83 | )
84 |
85 |
86 | def split_dataset(
87 | dataset: Union[Dataset, IterableDataset],
88 | data_args: DataArguments,
89 | training_args: TrainingArguments,
90 | ) -> DatasetModule:
91 | """Split dataset into training and evaluation sets.
92 |
93 | Args:
94 | dataset: Input dataset to split
95 | data_args: Data configuration arguments
96 | training_args: Training configuration arguments
97 |
98 | Returns:
99 | Dictionary containing train and evaluation datasets
100 |
101 | Raises:
102 | ValueError: If eval_dataset_size is invalid
103 | """
104 | if data_args.eval_dataset_size <= 0:
105 | raise ValueError('eval_dataset_size must be greater than 0')
106 |
107 | val_size = int(
108 | data_args.eval_dataset_size
109 | ) if data_args.eval_dataset_size > 1 else data_args.eval_dataset_size
110 |
111 | logger.info(
112 | f'Splitting dataset with evaluation size of {val_size} '
113 | f'({data_args.eval_dataset_size} {"samples" if data_args.eval_dataset_size > 1 else "fraction"})'
114 | )
115 | if data_args.streaming:
116 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size,
117 | seed=training_args.seed)
118 | val_set = dataset.take(int(data_args.eval_dataset_size))
119 | train_set = dataset.skip(int(data_args.eval_dataset_size))
120 | return DatasetModule(train_dataset=train_set, eval_dataset=val_set)
121 |
122 | dataset_split = dataset.train_test_split(test_size=val_size,
123 | seed=training_args.seed)
124 | return DatasetModule(train_dataset=dataset_split['train'],
125 | eval_dataset=dataset_split['test'])
126 |
--------------------------------------------------------------------------------
/llamatuner/configs/model_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass, field
2 | from typing import Any, Dict, Literal, Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer."""
8 |
9 | model_name_or_path: Optional[str] = field(
10 | default='facebook/opt-125m',
11 | metadata={
12 | 'help':
13 | ('Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models.'
14 | )
15 | },
16 | )
17 | adapter_name_or_path: Optional[str] = field(
18 | default=None,
19 | metadata={
20 | 'help':
21 | ('Path to the adapter weight or identifier from huggingface.co/models. '
22 | 'Use commas to separate multiple adapters.')
23 | },
24 | )
25 | adapter_folder: Optional[str] = field(
26 | default=None,
27 | metadata={
28 | 'help': 'The folder containing the adapter weights to load.'
29 | },
30 | )
31 | use_fast_tokenizer: bool = field(
32 | default=True,
33 | metadata={
34 | 'help':
35 | 'Whether or not to use one of the fast tokenizer (backed by the tokenizers library).'
36 | },
37 | )
38 | resize_vocab: bool = field(
39 | default=False,
40 | metadata={
41 | 'help':
42 | 'Whether or not to resize the tokenizer vocab and the embedding layers.'
43 | },
44 | )
45 | model_max_length: Optional[int] = field(
46 | default=1024,
47 | metadata={
48 | 'help':
49 | 'The maximum length of the model input, including special tokens.'
50 | },
51 | )
52 | trust_remote_code: Optional[bool] = field(
53 | default=True,
54 | metadata={
55 | 'help':
56 | 'Whether or not to trust the remote code in the model configuration.'
57 | },
58 | )
59 | cache_dir: Optional[str] = field(
60 | default=None,
61 | metadata={
62 | 'help':
63 | 'Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn.'
64 | },
65 | )
66 | model_revision: str = field(
67 | default='main',
68 | metadata={
69 | 'help':
70 | 'The specific model version to use (can be a branch name, tag name or commit id).'
71 | },
72 | )
73 | split_special_tokens: bool = field(
74 | default=False,
75 | metadata={
76 | 'help':
77 | 'Whether or not the special tokens should be split during the tokenization process.'
78 | },
79 | )
80 | new_special_tokens: Optional[str] = field(
81 | default=None,
82 | metadata={'help': 'Special tokens to be added into the tokenizer.'},
83 | )
84 | low_cpu_mem_usage: bool = field(
85 | default=True,
86 | metadata={
87 | 'help': 'Whether or not to use memory-efficient model loading.'
88 | },
89 | )
90 | rope_scaling: Optional[Literal['linear', 'dynamic']] = field(
91 | default=None,
92 | metadata={
93 | 'help':
94 | 'Which scaling strategy should be adopted for the RoPE embeddings.'
95 | },
96 | )
97 | flash_attn: Literal['off', 'sdpa', 'fa2', 'auto'] = field(
98 | default='auto',
99 | metadata={
100 | 'help': 'Enable FlashAttention for faster training and inference.'
101 | },
102 | )
103 | train_from_scratch: bool = field(
104 | default=False,
105 | metadata={
106 | 'help': 'Whether or not to randomly initialize the model weights.'
107 | },
108 | )
109 | offload_folder: str = field(
110 | default='offload',
111 | metadata={'help': 'Path to offload model weights.'},
112 | )
113 | use_cache: bool = field(
114 | default=True,
115 | metadata={'help': 'Whether or not to use KV cache in generation.'},
116 | )
117 | hf_hub_token: Optional[str] = field(
118 | default=None,
119 | metadata={'help': 'Auth token to log in with Hugging Face Hub.'},
120 | )
121 | ms_hub_token: Optional[str] = field(
122 | default=None,
123 | metadata={'help': 'Auth token to log in with ModelScope Hub.'},
124 | )
125 |
126 | def __post_init__(self):
127 | self.compute_dtype = None
128 | self.device_map = None
129 |
130 | if self.model_name_or_path is None:
131 | raise ValueError('Please provide `model_name_or_path`.')
132 |
133 | if self.adapter_name_or_path is not None: # support merging multiple lora weights
134 | self.adapter_name_or_path = [
135 | path.strip() for path in self.adapter_name_or_path.split(',')
136 | ]
137 |
138 | if self.split_special_tokens and self.use_fast_tokenizer:
139 | raise ValueError(
140 | '`split_special_tokens` is only supported for slow tokenizers.'
141 | )
142 |
143 | if self.new_special_tokens is not None: # support multiple special tokens
144 | self.new_special_tokens = [
145 | token.strip() for token in self.new_special_tokens.split(',')
146 | ]
147 |
148 | def to_dict(self) -> Dict[str, Any]:
149 | return asdict(self)
150 |
--------------------------------------------------------------------------------
/data/format_data/convert_alpaca.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from datasets import load_dataset
4 |
5 |
6 | def json_dump(obj, path):
7 | with open(path, 'w', encoding='utf-8') as f:
8 | json.dump(obj, f, indent=2, ensure_ascii=False)
9 |
10 |
11 | def json_load(in_file):
12 | with open(in_file, 'r') as f:
13 | json_data = json.load(f)
14 | return json_data
15 |
16 |
17 | def convert_100PoisonMpts(in_file, out_file):
18 | raw_data = load_dataset('json', data_files=in_file)['train']
19 | new_content = []
20 | for i, raw_text in enumerate(raw_data):
21 | prompt = raw_text['prompt']
22 | response = raw_text['answer']
23 | if len(prompt) <= 5 or len(response) <= 5:
24 | continue
25 | new_content.append({
26 | 'instruction': prompt,
27 | 'input': '',
28 | 'output': response,
29 | })
30 |
31 | print(f'#out: {len(new_content)}')
32 | json_dump(new_content, out_file)
33 |
34 |
35 | def convert_Cvalues(in_file, out_file):
36 | raw_data = load_dataset('json', data_files=in_file)['train']
37 | new_content = []
38 | for i, raw_text in enumerate(raw_data):
39 | prompt = raw_text['prompt']
40 | response = raw_text['pos_resp']
41 | if len(prompt) <= 5 or len(response) <= 5:
42 | continue
43 | new_content.append({
44 | 'instruction': prompt,
45 | 'input': '',
46 | 'output': response,
47 | })
48 |
49 | print(f'#out: {len(new_content)}')
50 | json_dump(new_content, out_file)
51 |
52 |
53 | def convert_huatuogpt(in_file, out_file):
54 | raw_data = load_dataset('json', data_files=in_file)['train']
55 | new_content = []
56 | for i, raw_text in enumerate(raw_data):
57 | data = raw_text['data']
58 | prompt = data[0].replace('问:', '')
59 | response = data[1].replace('答:', '')
60 | if len(prompt) <= 5 or len(response) <= 5:
61 | continue
62 | new_content.append({
63 | 'instruction': prompt,
64 | 'input': '',
65 | 'output': response,
66 | })
67 | print(f'#out: {len(new_content)}')
68 | json_dump(new_content, out_file)
69 |
70 |
71 | def convert_safety_attack(in_file, out_file):
72 | field_list = [
73 | 'Reverse_Exposure', 'Goal_Hijacking', 'Prompt_Leaking',
74 | 'Unsafe_Instruction_Topic', 'Role_Play_Instruction',
75 | 'Inquiry_With_Unsafe_Opinion'
76 | ]
77 | new_content = []
78 | for filed in field_list:
79 | raw_data = load_dataset('json', field=filed,
80 | data_files=in_file)['train']
81 | for i, raw_text in enumerate(raw_data):
82 | prompt = raw_text['prompt']
83 | response = raw_text['response']
84 | if len(prompt) <= 5 or len(response) <= 5:
85 | continue
86 | new_content.append({
87 | 'instruction': prompt,
88 | 'input': '',
89 | 'output': response,
90 | })
91 | print(f'#out: {len(new_content)}')
92 | json_dump(new_content, out_file)
93 |
94 |
95 | def convert_safety_scenarios(in_file, out_file):
96 |
97 | field_list = [
98 | 'Unfairness_And_Discrimination', 'Crimes_And_Illegal_Activities',
99 | 'Insult', 'Mental_Health', 'Physical_Harm', 'Privacy_And_Property',
100 | 'Ethics_And_Morality'
101 | ]
102 | new_content = []
103 | for filed in field_list:
104 | raw_data = load_dataset('json', data_files=in_file,
105 | field=filed)['train']
106 | for i, raw_text in enumerate(raw_data):
107 | prompt = raw_text['prompt']
108 | response = raw_text['response']
109 | if len(prompt) <= 5 or len(response) <= 5:
110 | continue
111 | new_content.append({
112 | 'instruction': prompt,
113 | 'input': '',
114 | 'output': response,
115 | })
116 | print(f'#out: {len(new_content)}')
117 | json_dump(new_content, out_file)
118 |
119 |
120 | if __name__ == '__main__':
121 |
122 | data_path = '/home/robin/prompt_data/100PoisonMpts/train.jsonl'
123 | out_path = '/home/robin/prompt_data/100PoisonMpts/train_alpaca.jsonl'
124 | convert_100PoisonMpts(data_path, out_file=out_path)
125 |
126 | data_path = '/home/robin/prompt_data/CValues-Comparison/test.jsonl'
127 | out_path = '/home/robin/prompt_data/CValues-Comparison/test_alpaca.json'
128 | convert_Cvalues(data_path, out_file=out_path)
129 |
130 | data_path = '/home/robin/prompt_data/CValues-Comparison/train.jsonl'
131 | out_path = '/home/robin/prompt_data/CValues-Comparison/train_alpaca.json'
132 | convert_Cvalues(data_path, out_file=out_path)
133 |
134 | data_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_sft_data_v1.jsonl'
135 | out_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_alpaca.json'
136 | convert_huatuogpt(data_path, out_file=out_path)
137 |
138 | data_path = '/home/robin/prompt_data/Safety-Prompts/instruction_attack_scenarios.json'
139 | out_path = '/home/robin/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json'
140 | convert_safety_attack(data_path, out_file=out_path)
141 |
142 | data_path = '/home/robin/prompt_data/Safety-Prompts/typical_safety_scenarios.json'
143 | out_path = '/home/robin/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json'
144 | convert_safety_scenarios(data_path, out_file=out_path)
145 |
--------------------------------------------------------------------------------
/llamatuner/utils/misc.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | from typing import Tuple
4 |
5 | import torch
6 | from transformers.utils import (is_torch_bf16_gpu_available,
7 | is_torch_cuda_available,
8 | is_torch_mps_available, is_torch_npu_available,
9 | is_torch_xpu_available)
10 | from transformers.utils.versions import require_version
11 |
12 | from llamatuner.utils.logger_utils import get_logger
13 |
14 | _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
15 | try:
16 | _is_bf16_available = is_torch_bf16_gpu_available()
17 | except Exception:
18 | _is_bf16_available = False
19 |
20 | from llamatuner.configs.model_args import ModelArguments
21 |
22 | logger = get_logger('llamatuner')
23 |
24 |
25 | class AverageMeter:
26 | r"""
27 | Computes and stores the average and current value.
28 | """
29 |
30 | def __init__(self):
31 | self.reset()
32 |
33 | def reset(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 |
39 | def update(self, val, n=1):
40 | self.val = val
41 | self.sum += val * n
42 | self.count += n
43 | self.avg = self.sum / self.count
44 |
45 |
46 | def check_version(requirement: str, mandatory: bool = False) -> None:
47 | r"""
48 | Optionally checks the package version.
49 | """
50 | if os.getenv('DISABLE_VERSION_CHECK', '0').lower() in ['true', '1'
51 | ] and not mandatory:
52 | logger.warning(
53 | 'Version checking has been disabled, may lead to unexpected behaviors.'
54 | )
55 | return
56 |
57 | if mandatory:
58 | hint = f'To fix: run `pip install {requirement}`.'
59 | else:
60 | hint = f'To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check.'
61 |
62 | require_version(requirement, hint)
63 |
64 |
65 | def check_dependencies() -> None:
66 | r"""
67 | Checks the version of the required packages.
68 | """
69 | check_version('transformers>=4.41.2,<=4.46.1')
70 | check_version('datasets>=2.16.0,<=3.1.0')
71 | check_version('accelerate>=0.34.0,<=1.0.1')
72 | check_version('peft>=0.11.1,<=0.12.0')
73 | check_version('trl>=0.8.6,<=0.9.6')
74 |
75 |
76 | def get_current_device() -> torch.device:
77 | r"""
78 | Gets the current available device.
79 | """
80 | if is_torch_xpu_available():
81 | device = 'xpu:{}'.format(os.environ.get('LOCAL_RANK', '0'))
82 | elif is_torch_npu_available():
83 | device = 'npu:{}'.format(os.environ.get('LOCAL_RANK', '0'))
84 | elif is_torch_mps_available():
85 | device = 'mps:{}'.format(os.environ.get('LOCAL_RANK', '0'))
86 | elif is_torch_cuda_available():
87 | device = 'cuda:{}'.format(os.environ.get('LOCAL_RANK', '0'))
88 | else:
89 | device = 'cpu'
90 |
91 | return torch.device(device)
92 |
93 |
94 | def get_device_count() -> int:
95 | r"""
96 | Gets the number of available GPU or NPU devices.
97 | """
98 | if is_torch_npu_available():
99 | return torch.npu.device_count()
100 | elif is_torch_cuda_available():
101 | return torch.cuda.device_count()
102 | else:
103 | return 0
104 |
105 |
106 | def get_peak_memory() -> Tuple[int, int]:
107 | r"""
108 | Gets the peak memory usage for the current device (in Bytes).
109 | """
110 | if is_torch_npu_available():
111 | return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved(
112 | )
113 | elif is_torch_cuda_available():
114 | return torch.cuda.max_memory_allocated(
115 | ), torch.cuda.max_memory_reserved()
116 | else:
117 | return 0, 0
118 |
119 |
120 | def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
121 | r"""
122 | Infers the optimal dtype according to the model_dtype and device compatibility.
123 | """
124 | if _is_bf16_available and model_dtype == torch.bfloat16:
125 | return torch.bfloat16
126 | elif _is_fp16_available:
127 | return torch.float16
128 | else:
129 | return torch.float32
130 |
131 |
132 | def is_gpu_or_npu_available() -> bool:
133 | r"""
134 | Checks if the GPU or NPU is available.
135 | """
136 | return is_torch_npu_available() or is_torch_cuda_available()
137 |
138 |
139 | def has_tokenized_data(path: os.PathLike) -> bool:
140 | r"""
141 | Checks if the path has a tokenized dataset.
142 | """
143 | return os.path.isdir(path) and len(os.listdir(path)) > 0
144 |
145 |
146 | def torch_gc() -> None:
147 | r"""
148 | Collects GPU or NPU memory.
149 | """
150 | gc.collect()
151 | if is_torch_xpu_available():
152 | torch.xpu.empty_cache()
153 | elif is_torch_npu_available():
154 | torch.npu.empty_cache()
155 | elif is_torch_mps_available():
156 | torch.mps.empty_cache()
157 | elif is_torch_cuda_available():
158 | torch.cuda.empty_cache()
159 |
160 |
161 | def try_download_model_from_ms(model_args: 'ModelArguments') -> str:
162 | if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
163 | return model_args.model_name_or_path
164 |
165 | try:
166 | from modelscope import snapshot_download
167 |
168 | revision = ('master' if model_args.model_revision == 'main' else
169 | model_args.model_revision)
170 | return snapshot_download(
171 | model_args.model_name_or_path,
172 | revision=revision,
173 | cache_dir=model_args.cache_dir,
174 | )
175 | except ImportError as exc:
176 | raise ImportError(
177 | 'Please install modelscope via `pip install modelscope -U`'
178 | ) from exc
179 |
180 |
181 | def use_modelscope() -> bool:
182 | return os.environ.get('USE_MODELSCOPE_HUB', '0').lower() in ['true', '1']
183 |
--------------------------------------------------------------------------------
/data/format_data/clean_sharegpt/split_long_conversation.py:
--------------------------------------------------------------------------------
1 | """Split long conversations based on certain max length.
2 |
3 | Usage: python3 -m split_long_conversation.py \
4 | --in sharegpt_clean.json \
5 | --out sharegpt_split.json \
6 | --model-name-or-path $
7 |
8 | example:
9 | python split_long_conversation.py \
10 | --in-file sharegpt_clean.json \
11 | --model-name-or-path decapoda-research/llama-7b-hf
12 | """
13 | import argparse
14 | import json
15 | from concurrent.futures import ProcessPoolExecutor
16 | from typing import Any, Dict, List
17 |
18 | import transformers
19 | from clean_sharegpt import filter_invalid_roles, get_statistics, json_dump
20 | from tqdm import tqdm
21 |
22 |
23 | def make_sample(sample: Dict[str, any], start_idx: int,
24 | end_idx: int) -> Dict[str, any]:
25 | """Create a new sample dictionary by selecting conversations from the given
26 | sample.
27 |
28 | Args:
29 | sample (Dict[str, any]): The original sample dictionary.
30 | start_idx (int): The starting index of conversations to include.
31 | end_idx (int): The ending index of conversations to include.
32 |
33 | Returns:
34 | Dict[str, any]: The new sample dictionary with selected conversations.
35 | """
36 | assert (end_idx - start_idx) % 2 == 0
37 | conversations = sample['conversations'][start_idx:end_idx]
38 | return {
39 | 'id': sample['id'] + '_' + str(start_idx),
40 | 'conversations': conversations,
41 | }
42 |
43 |
44 | def split_one_sample(sample: Dict[str, any]) -> List[Dict[str, any]]:
45 | """Split a single sample into multiple samples based on conversation
46 | lengths.
47 |
48 | Args:
49 | sample (Dict[str, any]): The original sample dictionary.
50 | max_length (int): The maximum length constraint for conversations.
51 |
52 | Returns:
53 | List[Dict[str, any]]: The list of new sample dictionaries.
54 | """
55 | tokenized_lens = []
56 | conversations = sample['conversations']
57 |
58 | # Truncate conversations to an even number of conversations
59 | conversations = conversations[:len(conversations) // 2 * 2]
60 |
61 | # Calculate the tokenized length for each conversation
62 | for conv in conversations:
63 | length = len(tokenizer(conv['value']).input_ids) + 6
64 | tokenized_lens.append(length)
65 |
66 | new_samples = []
67 | start_idx = 0 # The starting index of conversations to include
68 | cur_len = 0 # The current length of conversations included
69 |
70 | # Iterate through conversations and create new samples based on length constraints
71 | for end_idx in range(0, len(conversations), 2):
72 | round_len = tokenized_lens[end_idx] + tokenized_lens[end_idx + 1]
73 | if cur_len + round_len > max_length:
74 | sub_sample = make_sample(sample, start_idx, end_idx + 2)
75 | new_samples.append(sub_sample)
76 | start_idx = end_idx + 2
77 | cur_len = 0
78 | elif end_idx == len(conversations) - 2:
79 | sub_sample = make_sample(sample, start_idx, end_idx + 2)
80 | new_samples.append(sub_sample)
81 | cur_len += round_len
82 |
83 | return new_samples
84 |
85 |
86 | def worker(input_data: List[Dict[str, Any]]):
87 | result = []
88 | for sample in input_data:
89 | result.extend(split_one_sample(sample))
90 | return result
91 |
92 |
93 | def split_all(raw_data: List[Dict[str, Any]],
94 | tokenizer_: transformers.PreTrainedTokenizer,
95 | max_length_: int) -> List[Dict[str, Any]]:
96 | """Split the content into smaller parts based on the max token length
97 | constraint.
98 |
99 | Args:
100 | raw_data (List[Dict[str, Any]]): The list of samples to split.
101 | tokenizer (PreTrainedTokenizer): The tokenizer object used for tokenization.
102 | max_length (int): The maximum length allowed for each split.
103 |
104 | Returns:
105 | List[Dict[str, Any]]: The list of new sample dictionaries after splitting.
106 | """
107 | global tokenizer, max_length
108 | tokenizer = tokenizer_
109 | max_length = max_length_
110 |
111 | new_content = []
112 |
113 | # Split content into chunks
114 | chunks = [raw_data[i:i + 1000] for i in range(0, len(raw_data), 1000)]
115 | # Use tqdm to show progress bar during the execution
116 | with ProcessPoolExecutor() as executor:
117 | for result in tqdm(executor.map(worker, chunks),
118 | desc='Splitting long conversations',
119 | total=len(chunks)):
120 | new_content.extend(result)
121 |
122 | return new_content
123 |
124 |
125 | def main(args):
126 | contents = json.load(open(args.in_file, 'r'))
127 | tokenizer = transformers.AutoTokenizer.from_pretrained(
128 | args.model_name_or_path,
129 | padding_side='right',
130 | model_max_length=args.max_length,
131 | use_fast=False,
132 | tokenizer_type='llama' if 'llama' in args.model_name_or_path else None,
133 | )
134 | print('Splitting long conversations...')
135 | split_data = split_all(contents, tokenizer, args.max_length)
136 | res1, res2 = get_statistics(split_data)
137 | # Save role_list_2 and role_res_2 to JSON files
138 | json_dump(res2, 'role_res_3.json')
139 | print(f'#in: {len(contents)}, #out: {len(split_data)}')
140 | print('Filtering invalid roles...')
141 | new_content = filter_invalid_roles(split_data)
142 | res1, res2 = get_statistics(new_content)
143 | # Save role_list_3 and role_res_3 to JSON files
144 | json_dump(res2, 'role_res_4.json')
145 | print(f'#in: {len(split_data)}, #out: {len(new_content)}')
146 | json_dump(new_content, args.out_file)
147 |
148 |
149 | if __name__ == '__main__':
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('--in-file', type=str, required=True)
152 | parser.add_argument('--out-file', type=str, default='sharegpt_split.json')
153 | parser.add_argument('--model-name-or-path', type=str, required=True)
154 | parser.add_argument('--max-length', type=int, default=2048)
155 | args = parser.parse_args()
156 | main(args)
157 |
--------------------------------------------------------------------------------
/data/format_data/clean_sharegpt/clean_sharegpt.py:
--------------------------------------------------------------------------------
1 | """Prepare all datasets."""
2 |
3 | import argparse
4 | import json
5 | from typing import Any, Dict, List, Tuple
6 |
7 |
8 | def json_dump(obj, path):
9 | with open(path, 'w', encoding='utf-8') as f:
10 | json.dump(obj, f, indent=2, ensure_ascii=False)
11 |
12 |
13 | def json_load(in_file):
14 | with open(in_file, 'r') as f:
15 | json_data = json.load(f)
16 | return json_data
17 |
18 |
19 | def get_statistics(
20 | raw_data: List[Dict[str,
21 | any]]) -> Tuple[List[str], Dict[str, List[str]]]:
22 | """Get statistics from raw_data.
23 |
24 | Args:
25 | raw_data: A list of dictionaries containing conversation data.
26 |
27 | Returns:
28 | A tuple containing the role list and a dictionary of role occurrences per ID.
29 | """
30 | role_list = []
31 | role_res = {}
32 |
33 | for idx, raw_txt in enumerate(raw_data):
34 | id = raw_txt.get('id', str(idx))
35 | if idx % 10000 == 0:
36 | print(f'Processing {idx} / {len(raw_data)}')
37 |
38 | convs = raw_txt.get('conversations', [])
39 | role_res[id] = []
40 |
41 | for conv in convs:
42 | sender = conv.get('from')
43 | role_res[id].append(sender)
44 |
45 | if sender not in role_list:
46 | role_list.append(sender)
47 |
48 | return role_list, role_res
49 |
50 |
51 | def format_roles(
52 | raw_data: List[Dict[str, List[Dict[str, str]]]]
53 | ) -> List[Dict[str, List[Dict[str, str]]]]:
54 | """Format the roles of conversations in raw_data.
55 |
56 | Args:
57 | raw_data: A list of dictionaries containing conversation data.
58 |
59 | Returns:
60 | A list of dictionaries containing formatted conversation data.
61 | """
62 | users = ['human', 'user']
63 | bots = ['gpt', 'bard', 'bing', 'chatgpt']
64 | role_list = users + bots
65 | collect_data = []
66 |
67 | for idx, raw_txt in enumerate(raw_data):
68 | convs = raw_txt.get('conversations', [])
69 | id = raw_txt.get('id', str(idx))
70 | new_convs = []
71 |
72 | for j, conv in enumerate(convs):
73 | sender = conv.get('from')
74 |
75 | if sender not in role_list:
76 | print(
77 | f"Warning: Role '{sender}' is not recognized. Skipping conversation."
78 | )
79 | continue
80 |
81 | if sender in users[1:]:
82 | print(f"Correcting '{sender}' to '{users[0]}'")
83 | conv['from'] = users[0]
84 |
85 | if sender in bots[1:]:
86 | print(f"Correcting '{sender}' to '{bots[0]}'")
87 | conv['from'] = bots[0]
88 |
89 | if conv['from'] and conv['value']:
90 | new_convs.append(conv)
91 |
92 | if len(new_convs) >= 2:
93 | collect_data.append({'id': id, 'conversations': new_convs})
94 | else:
95 | print(f'Warning: Skipping conversation {idx}.', new_convs)
96 | return collect_data
97 |
98 |
99 | def filter_invalid_roles(
100 | raw_data: List[Dict[str,
101 | any]]) -> List[Dict[str, List[Dict[str, any]]]]:
102 | """Filter out invalid contents based on the roles assigned to each
103 | conversation.
104 |
105 | Args:
106 | raw_data: A list of dictionaries containing conversation data.
107 |
108 | Returns:
109 | A list of dictionaries containing filtered conversation data.
110 | """
111 |
112 | roles = ['human', 'gpt']
113 | filtered_data = []
114 |
115 | for idx, contents in enumerate(raw_data):
116 | # Get conversations and id from the current dictionary
117 | convs = contents.get('conversations', [])
118 | id = contents.get('id', str(idx))
119 |
120 | # Remove first conversation if it is not from 'human' role
121 | if convs and convs[0].get('from') != 'human':
122 | convs = convs[1:]
123 |
124 | # Check if number of conversations is less than 2
125 | if len(convs) < 2:
126 | continue
127 |
128 | # Truncate convs to have an even number of conversations
129 | convs = convs[:len(convs) // 2 * 2]
130 |
131 | valid = True
132 | for j, conv in enumerate(convs):
133 | # Check if role of conversation alternates between 'human' and 'gpt'
134 | if conv.get('from') != roles[j % 2]:
135 | valid = False
136 | break
137 |
138 | assert len(convs) % 2 == 0, 'Number of conversations must be even.'
139 |
140 | if valid:
141 | # Append filtered data to the result
142 | filtered_data.append({'id': id, 'conversations': convs})
143 |
144 | return filtered_data
145 |
146 |
147 | def get_clean_data(args: Any, save_stata_res: bool = False) -> Any:
148 | """Get clean data by processing raw data using helper functions.
149 |
150 | Args:
151 | args: Arguments passed to the function.
152 |
153 | Returns:
154 | Cleaned data after processing.
155 | """
156 | # Load raw data from file
157 | with open(args.in_file, 'r') as file:
158 | raw_data = json.load(file)
159 |
160 | # Get statistics for raw_data
161 | print('Getting statistics for raw_data...')
162 | res1, res2 = get_statistics(raw_data)
163 |
164 | if save_stata_res:
165 | # Save role_list and role_res to JSON files
166 | json_dump(res1, 'role_list.json')
167 | json_dump(res2, 'role_res.json')
168 |
169 | # Format roles in raw_data
170 | print('=' * 100)
171 | print('Formatting roles in raw_data...')
172 | clean_data1 = format_roles(raw_data)
173 |
174 | # Get statistics for clean_data1
175 | print('=' * 100)
176 | print('Getting statistics for clean_data1...')
177 | res1, res2 = get_statistics(clean_data1)
178 |
179 | if save_stata_res:
180 | # Save role_list_1 and role_res_1 to JSON files
181 | json_dump(res1, 'role_list_clean.json')
182 | json_dump(res2, 'role_res_clean.json')
183 |
184 | # Filter out incorrect data from clean_data1
185 | print('=' * 100)
186 | print('Filtering out incorrect data from clean_data1...')
187 | clean_data2 = filter_invalid_roles(clean_data1)
188 | # Print lengths of raw data, clean data1, and clean data2
189 | print(f'raw data len: {len(raw_data)}')
190 | print(f'clean data1 len: {len(clean_data1)}')
191 | print(f'clean data2 len: {len(clean_data2)}')
192 | return clean_data2
193 |
194 |
195 | if __name__ == '__main__':
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument('--in-file', type=str)
198 | parser.add_argument('--out-file', type=str)
199 | args = parser.parse_args()
200 | clean_data2 = get_clean_data(args)
201 | json_dump(clean_data2, args.out_file)
202 |
--------------------------------------------------------------------------------
/data/format_data/clean_sharegpt/hardcoded_questions.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def identity_questions():
5 | """" Adopted from https://github.com/young-
6 | geng/koala_data_pipeline/blob/main/process_hard_coded_data.py."""
7 | content = []
8 |
9 | name = 'Vicuna'
10 | org = 'Large Model Systems Organization (LMSYS)'
11 |
12 | def generate_conversations(questions, answers):
13 | for q in questions:
14 | for a in answers:
15 | content.append({
16 | 'id':
17 | f'identity_{len(content)}',
18 | 'conversations': [
19 | {
20 | 'from': 'human',
21 | 'value': q
22 | },
23 | {
24 | 'from': 'gpt',
25 | 'value': a
26 | },
27 | ],
28 | })
29 |
30 | questions = [
31 | 'Who are you?',
32 | 'What is your name?',
33 | 'Can you introduce yourself?',
34 | 'Can you tell me a little bit about yourself?',
35 | "What's your name?",
36 | 'What are you called?',
37 | 'What are you?',
38 | 'Tell me your name.',
39 | 'Tell me about yourself.',
40 | 'Tell me about you.',
41 | 'Tell me who you are.',
42 | 'Please introduce yourself.',
43 | ]
44 | answers = [
45 | f'I am {name}, a language model trained by researchers from {org}.',
46 | f"My name is {name}, and I'm a language model developed by {org}.",
47 | f'You can call me {name}, and I was trained by {org} researchers as a language model.',
48 | f'As a language model, I go by the name {name} and was trained by researchers from {org}.',
49 | f"I'm a language model called {name}, and I was trained by {org} researchers.",
50 | f'You may refer to me as {name}, a language model meticulously developed by the researchers at {org}.',
51 | ]
52 | generate_conversations(questions, answers)
53 |
54 | questions = [
55 | 'Who created you?',
56 | 'Who made you?',
57 | 'Who built you?',
58 | 'Who programmed you?',
59 | 'Who trained you?',
60 | 'Who taught you?',
61 | 'Who developed you?',
62 | ]
63 | answers = [
64 | f'Researchers from {org} created me.',
65 | f"I'm created by {org}.",
66 | f"I'm built by researchers from {org}.",
67 | f'I am a language model trained by researchers from {org}.',
68 | f"I'm a language model developed by {org}.",
69 | f"I'm a language model created by researchers from {org}.",
70 | f'My creators are researchers from {org}.',
71 | ]
72 | generate_conversations(questions, answers)
73 |
74 | questions = [
75 | 'Are you ChatGPT?',
76 | 'Are you GPT-2?',
77 | 'Are you GPT-3?',
78 | 'Are you GPT-4?',
79 | 'Are you davinci?',
80 | 'Are you davinci-001?',
81 | 'Are you davinci-002?',
82 | 'Are you davinci-003?',
83 | 'Are you curie?',
84 | 'Are you based on ChatGPT?',
85 | 'Are you based on GPT-2?',
86 | 'Are you based on GPT-3?',
87 | 'Are you based on GPT-4?',
88 | 'Are you based on davinci?',
89 | 'Are you based on davinci-001?',
90 | 'Are you based on davinci-002?',
91 | 'Are you based on davinci-003?',
92 | 'Are you based on curie?',
93 | 'Are you trained by OpenAI?',
94 | 'Are you trained by Google?',
95 | 'Are you trained by Microsoft?',
96 | 'Are you trained by Meta?',
97 | 'Are you trained by IBM?',
98 | 'Do you call OpenAI APIs?',
99 | 'Do you call Google APIs?',
100 | 'Do you call Microsoft APIs?',
101 | 'Do you call Meta APIs?',
102 | 'Do you call IBM APIs?',
103 | 'Are you created by OpenAI?',
104 | 'Are you created by Google?',
105 | 'Are you created by Microsoft?',
106 | 'Are you created by Meta?',
107 | 'Are you created by IBM?',
108 | 'Are you developed by OpenAI?',
109 | 'Are you developed by Google?',
110 | 'Are you developed by Microsoft?',
111 | 'Are you developed by Meta?',
112 | 'Are you developed by IBM?',
113 | 'Are you trained on OpenAI data?',
114 | 'Are you trained on Google data?',
115 | 'Are you trained on Microsoft data?',
116 | 'Are you trained on Meta data?',
117 | 'Are you trained on IBM data?',
118 | 'Are you trained with OpenAI data?',
119 | 'Are you trained with Google data?',
120 | 'Are you trained with Microsoft data?',
121 | 'Are you trained with Meta data?',
122 | 'Are you trained with IBM data?',
123 | 'Have you been trained with OpenAI data?',
124 | 'Have you been trained with Google data?',
125 | 'Have you been trained with Microsoft data?',
126 | 'Have you been trained with Meta data?',
127 | 'Have you been trained with IBM data?',
128 | 'Are you finetuned on OpenAI data?',
129 | 'Are you finetuned on Google data?',
130 | 'Are you finetuned on Microsoft data?',
131 | 'Are you finetuned on Meta data?',
132 | 'Are you finetuned on IBM data?',
133 | 'Are you finetuned with OpenAI data?',
134 | 'Are you finetuned with Google data?',
135 | 'Are you finetuned with Microsoft data?',
136 | 'Are you finetuned with Meta data?',
137 | 'Are you finetuned with IBM data?',
138 | 'Have you been finetuned with OpenAI data?',
139 | 'Have you been finetuned with Google data?',
140 | 'Have you been finetuned with Microsoft data?',
141 | 'Have you been finetuned with Meta data?',
142 | 'Have you been finetuned with IBM data?',
143 | ]
144 | answers = [
145 | f'No, I am a language model trained by researchers from {org}.',
146 | f'No, I am a language model developed by researchers from {org}.',
147 | f'No, I am a language model created by researchers from {org}.',
148 | f'No, I am trained by researchers from {org}.',
149 | f'No, I am developed by researchers from {org}.',
150 | f'No, I am created by researchers from {org}.',
151 | f"No, I'm a language model trained by researchers from {org}.",
152 | f"No, I'm a language model developed by researchers from {org}.",
153 | f"No, I'm a language model created by researchers from {org}.",
154 | f"No, I'm trained by researchers from {org}.",
155 | f"No, I'm developed by researchers from {org}.",
156 | f"No, I'm created by researchers from {org}.",
157 | ]
158 | generate_conversations(questions, answers)
159 |
160 | return content
161 |
162 |
163 | if __name__ == '__main__':
164 | out_file = 'hardcoded.json'
165 |
166 | content = []
167 | content.extend(identity_questions())
168 |
169 | json.dump(content, open(out_file, 'w'), indent=2)
170 |
--------------------------------------------------------------------------------
/llamatuner/train/pt/train_pt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import math
4 | import os
5 | import pathlib
6 | import sys
7 | import time
8 | from typing import Tuple
9 |
10 | import wandb
11 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
12 | DataCollatorForLanguageModeling, PreTrainedModel,
13 | PreTrainedTokenizer, Trainer, TrainingArguments)
14 |
15 | sys.path.append(os.getcwd())
16 | from llamatuner.configs import (DataArguments, FinetuningArguments,
17 | ModelArguments)
18 | from llamatuner.configs.parser import get_train_args
19 | from llamatuner.data.data_loader import get_dataset
20 | from llamatuner.utils.logger_utils import get_logger, get_outdir
21 |
22 |
23 | def load_model_tokenizer(
24 | model_args: ModelArguments,
25 | training_args: TrainingArguments,
26 | logger: logging.Logger,
27 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
28 | """Load a pre-trained model and tokenizer for natural language processing tasks.
29 |
30 | Args:
31 | model_args (ModelArguments): Arguments for the model configuration.
32 | training_args (TrainingArguments): Arguments for the training configuration.
33 | logger (logging.Logger): Logger instance for logging messages.
34 |
35 | Returns:
36 | Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
37 | """
38 | config_kwargs = {
39 | 'cache_dir': model_args.cache_dir,
40 | 'trust_remote_code': model_args.trust_remote_code,
41 | }
42 |
43 | # Set RoPE scaling factor
44 | config = AutoConfig.from_pretrained(model_args.model_name_or_path,
45 | **config_kwargs)
46 | # Load the pre-trained model
47 | logger.info(f'Loading Model from {model_args.model_name_or_path}...')
48 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
49 | config=config,
50 | **config_kwargs)
51 |
52 | # Enable model parallelism
53 | setattr(model, 'model_parallel', True)
54 | setattr(model, 'is_parallelizable', True)
55 |
56 | if training_args.gradient_checkpointing:
57 | logger.info('Using gradient checkpointing...')
58 | model.enable_input_require_grads()
59 | model.config.use_cache = (
60 | False # Turn off when gradient checkpointing is enabled
61 | )
62 |
63 | # Load the tokenizer
64 | logger.info(f'Loading tokenizer from {model_args.model_name_or_path}...')
65 | tokenizer = AutoTokenizer.from_pretrained(
66 | model_args.model_name_or_path,
67 | padding_side='right',
68 | model_max_length=model_args.model_max_length,
69 | use_fast=False,
70 | **config_kwargs,
71 | )
72 | # Add special tokens if they are missing
73 | if tokenizer.pad_token != tokenizer.unk_token:
74 | tokenizer.pad_token = tokenizer.unk_token
75 |
76 | return model, tokenizer
77 |
78 |
79 | def run_pt(
80 | model_args: ModelArguments,
81 | data_args: DataArguments,
82 | training_args: TrainingArguments,
83 | finetuning_args: FinetuningArguments,
84 | ) -> None:
85 |
86 | args = argparse.Namespace(
87 | **vars(model_args),
88 | **vars(data_args),
89 | **vars(training_args),
90 | **vars(finetuning_args),
91 | )
92 | # Initialize the logger before other steps
93 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
94 | # Set up the output directory
95 | output_dir = get_outdir(training_args.output_dir)
96 | training_args.output_dir = get_outdir(output_dir, 'checkpoints')
97 | log_name = os.path.join(output_dir, timestamp).replace(os.path.sep, '_')
98 | log_file = os.path.join(output_dir, log_name + '.log')
99 | logger = get_logger(name='llamatuner', log_file=log_file, log_level='INFO')
100 |
101 | # Load model and tokenizer
102 | logger.info('Loading model and tokenizer...')
103 | model, tokenizer = load_model_tokenizer(model_args,
104 | training_args,
105 | logger=logger)
106 | logger.info('Successfully loaded model and tokenizer.')
107 |
108 | # Create a dataset and Trainer, then train the model
109 | logger.info('Creating a dataset and DataCollator...')
110 |
111 | dataset_module = get_dataset(
112 | data_args,
113 | model_args,
114 | training_args,
115 | stage='pt',
116 | tokenizer=tokenizer,
117 | processor=None,
118 | )
119 | logger.info('Successfully created the dataset.')
120 | logger.info('Creating DataCollator...')
121 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
122 | mlm=False)
123 | # Initialize wandb
124 | if 'wandb' in training_args.report_to:
125 | logger.info('Initializing wandb project...')
126 | wandb_run_name = finetuning_args.wandb_run_name if finetuning_args else log_name
127 | wandb.init(
128 | dir=output_dir,
129 | project=finetuning_args.wandb_project,
130 | name=wandb_run_name,
131 | tags=['full-finetune', 'pt'],
132 | group='pt',
133 | config=args,
134 | )
135 | # Initialize the Trainer object and start training
136 | logger.info('Initializing Trainer object.')
137 | trainer = Trainer(
138 | model=model,
139 | tokenizer=tokenizer,
140 | args=training_args,
141 | data_collator=data_collator,
142 | **dataset_module,
143 | )
144 | # Training
145 | if training_args.do_train:
146 | if (list(pathlib.Path(training_args.output_dir).glob('checkpoint-*'))
147 | and training_args.resume_from_checkpoint):
148 | logger.info('Resuming training from checkpoint %s' %
149 | (training_args.resume_from_checkpoint))
150 | train_result = trainer.train(
151 | resume_from_checkpoint=training_args.resume_from_checkpoint)
152 | else:
153 | logger.info('Starting training from scratch...')
154 | train_result = trainer.train()
155 |
156 | trainer.log_metrics('train', train_result.metrics)
157 | trainer.save_metrics('train', train_result.metrics)
158 | trainer.save_state()
159 | trainer.save_model()
160 |
161 | # Evaluation
162 | if training_args.do_eval:
163 | metrics = trainer.evaluate(metric_key_prefix='eval')
164 | try:
165 | perplexity = math.exp(metrics['eval_loss'])
166 | except OverflowError:
167 | perplexity = float('inf')
168 |
169 | metrics['perplexity'] = perplexity
170 | trainer.log_metrics('eval', metrics)
171 | trainer.save_metrics('eval', metrics)
172 |
173 | logger.info('Done.')
174 |
175 |
176 | if __name__ == '__main__':
177 | model_args, data_args, training_args, finetuning_args, generating_args = (
178 | get_train_args())
179 | run_pt(model_args, data_args, training_args, finetuning_args)
180 |
--------------------------------------------------------------------------------
/scripts/hf_download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Color definitions
3 | RED='\033[0;31m'
4 | GREEN='\033[0;32m'
5 | YELLOW='\033[1;33m'
6 | NC='\033[0m' # No Color
7 |
8 | trap 'printf "${YELLOW}\nDownload interrupted. If you re-run the command, you can resume the download from the breakpoint.\n${NC}"; exit 1' INT
9 |
10 | display_help() {
11 | cat << EOF
12 | Usage:
13 | hfd [--include include_pattern] [--exclude exclude_pattern] [--hf_username username] [--hf_token token] [--tool aria2c|wget] [-x threads] [--dataset] [--local-dir path]
14 |
15 | Description:
16 | Downloads a model or dataset from Hugging Face using the provided repo ID.
17 |
18 | Parameters:
19 | repo_id The Hugging Face repo ID in the format 'org/repo_name'.
20 | --include (Optional) Flag to specify a string pattern to include files for downloading.
21 | --exclude (Optional) Flag to specify a string pattern to exclude files from downloading.
22 | include/exclude_pattern The pattern to match against filenames, supports wildcard characters. e.g., '--exclude *.safetensor', '--include vae/*'.
23 | --hf_username (Optional) Hugging Face username for authentication. **NOT EMAIL**.
24 | --hf_token (Optional) Hugging Face token for authentication.
25 | --tool (Optional) Download tool to use. Can be aria2c (default) or wget.
26 | -x (Optional) Number of download threads for aria2c. Defaults to 4.
27 | --dataset (Optional) Flag to indicate downloading a dataset.
28 | --local-dir (Optional) Local directory path where the model or dataset will be stored.
29 |
30 | Example:
31 | hfd bigscience/bloom-560m --exclude *.safetensors
32 | hfd meta-llama/Llama-2-7b --hf_username myuser --hf_token mytoken -x 4
33 | hfd lavita/medical-qa-shared-task-v1-toy --dataset
34 | EOF
35 | exit 1
36 | }
37 |
38 | MODEL_ID=$1
39 | shift
40 |
41 | # Default values
42 | TOOL="aria2c"
43 | THREADS=4
44 | HF_ENDPOINT=${HF_ENDPOINT:-"https://huggingface.co"}
45 |
46 | while [[ $# -gt 0 ]]; do
47 | case $1 in
48 | --include) INCLUDE_PATTERN="$2"; shift 2 ;;
49 | --exclude) EXCLUDE_PATTERN="$2"; shift 2 ;;
50 | --hf_username) HF_USERNAME="$2"; shift 2 ;;
51 | --hf_token) HF_TOKEN="$2"; shift 2 ;;
52 | --tool) TOOL="$2"; shift 2 ;;
53 | -x) THREADS="$2"; shift 2 ;;
54 | --dataset) DATASET=1; shift ;;
55 | --local-dir) LOCAL_DIR="$2"; shift 2 ;;
56 | *) shift ;;
57 | esac
58 | done
59 |
60 | # Check if aria2, wget, curl, git, and git-lfs are installed
61 | check_command() {
62 | if ! command -v $1 &>/dev/null; then
63 | echo -e "${RED}$1 is not installed. Please install it first.${NC}"
64 | exit 1
65 | fi
66 | }
67 |
68 | # Mark current repo safe when using shared file system like samba or nfs
69 | ensure_ownership() {
70 | if git status 2>&1 | grep "fatal: detected dubious ownership in repository at" > /dev/null; then
71 | git config --global --add safe.directory "${PWD}"
72 | printf "${YELLOW}Detected dubious ownership in repository, mark ${PWD} safe using git, edit ~/.gitconfig if you want to reverse this.\n${NC}"
73 | fi
74 | }
75 |
76 | [[ "$TOOL" == "aria2c" ]] && check_command aria2c
77 | [[ "$TOOL" == "wget" ]] && check_command wget
78 | check_command curl; check_command git; check_command git-lfs
79 |
80 | [[ -z "$MODEL_ID" || "$MODEL_ID" =~ ^-h ]] && display_help
81 |
82 | if [[ -z "$LOCAL_DIR" ]]; then
83 | LOCAL_DIR="${MODEL_ID#*/}"
84 | fi
85 |
86 | if [[ "$DATASET" == 1 ]]; then
87 | MODEL_ID="datasets/$MODEL_ID"
88 | fi
89 | echo "Downloading to $LOCAL_DIR"
90 |
91 | if [ -d "$LOCAL_DIR/.git" ]; then
92 | printf "${YELLOW}%s exists, Skip Clone.\n${NC}" "$LOCAL_DIR"
93 | cd "$LOCAL_DIR" && ensure_ownership && GIT_LFS_SKIP_SMUDGE=1 git pull || { printf "${RED}Git pull failed.${NC}\n"; exit 1; }
94 | else
95 | REPO_URL="$HF_ENDPOINT/$MODEL_ID"
96 | GIT_REFS_URL="${REPO_URL}/info/refs?service=git-upload-pack"
97 | echo "Testing GIT_REFS_URL: $GIT_REFS_URL"
98 | response=$(curl -s -o /dev/null -w "%{http_code}" "$GIT_REFS_URL")
99 | if [ "$response" == "401" ] || [ "$response" == "403" ]; then
100 | if [[ -z "$HF_USERNAME" || -z "$HF_TOKEN" ]]; then
101 | printf "${RED}HTTP Status Code: $response.\nThe repository requires authentication, but --hf_username and --hf_token is not passed. Please get token from https://huggingface.co/settings/tokens.\nExiting.\n${NC}"
102 | exit 1
103 | fi
104 | REPO_URL="https://$HF_USERNAME:$HF_TOKEN@${HF_ENDPOINT#https://}/$MODEL_ID"
105 | elif [ "$response" != "200" ]; then
106 | printf "${RED}Unexpected HTTP Status Code: $response\n${NC}"
107 | printf "${YELLOW}Executing debug command: curl -v %s\nOutput:${NC}\n" "$GIT_REFS_URL"
108 | curl -v "$GIT_REFS_URL"; printf "\n${RED}Git clone failed.\n${NC}"; exit 1
109 | fi
110 | echo "GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR"
111 |
112 | GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR && cd "$LOCAL_DIR" || { printf "${RED}Git clone failed.\n${NC}"; exit 1; }
113 |
114 | ensure_ownership
115 |
116 | while IFS= read -r file; do
117 | truncate -s 0 "$file"
118 | done <<< $(git lfs ls-files | cut -d ' ' -f 3-)
119 | fi
120 |
121 | printf "\nStart Downloading lfs files, bash script:\ncd $LOCAL_DIR\n"
122 | files=$(git lfs ls-files | cut -d ' ' -f 3-)
123 | declare -a urls
124 |
125 | while IFS= read -r file; do
126 | url="$HF_ENDPOINT/$MODEL_ID/resolve/main/$file"
127 | file_dir=$(dirname "$file")
128 | mkdir -p "$file_dir"
129 | if [[ "$TOOL" == "wget" ]]; then
130 | download_cmd="wget -c \"$url\" -O \"$file\""
131 | [[ -n "$HF_TOKEN" ]] && download_cmd="wget --header=\"Authorization: Bearer ${HF_TOKEN}\" -c \"$url\" -O \"$file\""
132 | else
133 | download_cmd="aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\""
134 | [[ -n "$HF_TOKEN" ]] && download_cmd="aria2c --header=\"Authorization: Bearer ${HF_TOKEN}\" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\""
135 | fi
136 | [[ -n "$INCLUDE_PATTERN" && ! "$file" == $INCLUDE_PATTERN ]] && printf "# %s\n" "$download_cmd" && continue
137 | [[ -n "$EXCLUDE_PATTERN" && "$file" == $EXCLUDE_PATTERN ]] && printf "# %s\n" "$download_cmd" && continue
138 | printf "%s\n" "$download_cmd"
139 | urls+=("$url|$file")
140 | done <<< "$files"
141 |
142 | for url_file in "${urls[@]}"; do
143 | IFS='|' read -r url file <<< "$url_file"
144 | printf "${YELLOW}Start downloading ${file}.\n${NC}"
145 | file_dir=$(dirname "$file")
146 | if [[ "$TOOL" == "wget" ]]; then
147 | [[ -n "$HF_TOKEN" ]] && wget --header="Authorization: Bearer ${HF_TOKEN}" -c "$url" -O "$file" || wget -c "$url" -O "$file"
148 | else
149 | [[ -n "$HF_TOKEN" ]] && aria2c --header="Authorization: Bearer ${HF_TOKEN}" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")" || aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")"
150 | fi
151 | [[ $? -eq 0 ]] && printf "Downloaded %s successfully.\n" "$url" || { printf "${RED}Failed to download %s.\n${NC}" "$url"; exit 1; }
152 | done
153 |
154 | printf "${GREEN}Download completed successfully.\n${NC}"
155 |
--------------------------------------------------------------------------------
/llamatuner/configs/data_args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass, field
2 | from typing import Any, Dict, Literal, Optional
3 |
4 |
5 | @dataclass
6 | class DataArguments:
7 | r"""
8 | Arguments pertaining to what data we are going to input our model for training and evaluation.
9 | """
10 |
11 | template: Optional[str] = field(
12 | default=None,
13 | metadata={
14 | 'help':
15 | 'Which template to use for constructing prompts in training and inference.'
16 | },
17 | )
18 | dataset: Optional[str] = field(
19 | default=None,
20 | metadata={
21 | 'help':
22 | 'The name of provided dataset(s) to use. Use commas to separate multiple datasets.'
23 | },
24 | )
25 | eval_dataset: Optional[str] = field(
26 | default=None,
27 | metadata={
28 | 'help':
29 | 'The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets.'
30 | },
31 | )
32 | dataset_dir: str = field(
33 | default='data',
34 | metadata={'help': 'Path to the folder containing the datasets.'},
35 | )
36 | image_dir: Optional[str] = field(
37 | default=None,
38 | metadata={
39 | 'help':
40 | 'Path to the folder containing the images or videos. Defaults to `dataset_dir`.'
41 | },
42 | )
43 | cutoff_len: int = field(
44 | default=1024,
45 | metadata={
46 | 'help': 'The cutoff length of the tokenized inputs in the dataset.'
47 | },
48 | )
49 | train_on_prompt: bool = field(
50 | default=False,
51 | metadata={'help': 'Whether to disable the mask on the prompt or not.'},
52 | )
53 | mask_history: bool = field(
54 | default=False,
55 | metadata={
56 | 'help':
57 | 'Whether or not to mask the history and train on the last turn only.'
58 | },
59 | )
60 | streaming: bool = field(
61 | default=False,
62 | metadata={'help': 'Enable dataset streaming.'},
63 | )
64 | buffer_size: int = field(
65 | default=16384,
66 | metadata={
67 | 'help':
68 | 'Size of the buffer to randomly sample examples from in dataset streaming.'
69 | },
70 | )
71 | mix_strategy: Literal[
72 | 'concat', 'interleave_under', 'interleave_over'] = field(
73 | default='concat',
74 | metadata={
75 | 'help':
76 | 'Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling).'
77 | },
78 | )
79 | interleave_probs: Optional[str] = field(
80 | default=None,
81 | metadata={
82 | 'help':
83 | 'Probabilities to sample data from datasets. Use commas to separate multiple datasets.'
84 | },
85 | )
86 | overwrite_cache: bool = field(
87 | default=False,
88 | metadata={
89 | 'help': 'Overwrite the cached training and evaluation sets.'
90 | },
91 | )
92 | preprocessing_batch_size: int = field(
93 | default=1000,
94 | metadata={
95 | 'help': 'The number of examples in one group in pre-processing.'
96 | },
97 | )
98 | preprocessing_num_workers: Optional[int] = field(
99 | default=None,
100 | metadata={
101 | 'help': 'The number of processes to use for the pre-processing.'
102 | },
103 | )
104 | max_samples: Optional[int] = field(
105 | default=None,
106 | metadata={
107 | 'help':
108 | 'For debugging purposes, truncate the number of examples for each dataset.'
109 | },
110 | )
111 | eval_num_beams: Optional[int] = field(
112 | default=None,
113 | metadata={
114 | 'help':
115 | 'Number of beams to use for evaluation. This argument will be passed to `model.generate`'
116 | },
117 | )
118 | ignore_pad_token_for_loss: bool = field(
119 | default=True,
120 | metadata={
121 | 'help':
122 | 'Whether or not to ignore the tokens corresponding to padded labels in the loss computation.'
123 | },
124 | )
125 | # 验证数据集的尺寸,也就是数量
126 | eval_dataset_size: Optional[float] = field(
127 | default=0,
128 | metadata={
129 | 'help':
130 | 'Size of the development set, should be an integer or a float in range `[0,1)`.'
131 | },
132 | )
133 | packing: Optional[bool] = field(
134 | default=None,
135 | metadata={
136 | 'help':
137 | 'Whether or not to pack the sequences in training. Will automatically enable in pre-training.'
138 | },
139 | )
140 | tool_format: Optional[str] = field(
141 | default=None,
142 | metadata={
143 | 'help':
144 | 'Tool format to use for constructing function calling examples.'
145 | },
146 | )
147 | tokenized_path: Optional[str] = field(
148 | default=None,
149 | metadata={
150 | 'help':
151 | ('Path to save or load the tokenized datasets. '
152 | 'If tokenized_path not exists, it will save the tokenized datasets. '
153 | 'If tokenized_path exists, it will load the tokenized datasets.')
154 | },
155 | )
156 |
157 | def __post_init__(self):
158 |
159 | def split_arg(arg):
160 | if isinstance(arg, str):
161 | return [item.strip() for item in arg.split(',')]
162 | return arg
163 |
164 | if self.image_dir is None:
165 | self.image_dir = self.dataset_dir
166 |
167 | if self.dataset is None and self.eval_dataset_size > 0:
168 | raise ValueError(
169 | 'Cannot specify `eval_dataset_size` if `dataset` is None.')
170 |
171 | if self.eval_dataset is not None and self.eval_dataset_size > 0:
172 | raise ValueError(
173 | 'Cannot specify `eval_dataset_size` if `eval_dataset` is not None.'
174 | )
175 |
176 | if self.interleave_probs is not None:
177 | if self.mix_strategy == 'concat':
178 | raise ValueError(
179 | '`interleave_probs` is only valid for interleaved mixing.')
180 |
181 | self.interleave_probs = list(
182 | map(float, split_arg(self.interleave_probs)))
183 | if self.dataset is not None and len(self.dataset) != len(
184 | self.interleave_probs):
185 | raise ValueError(
186 | 'The length of dataset and interleave probs should be identical.'
187 | )
188 |
189 | if self.eval_dataset is not None and len(self.eval_dataset) != len(
190 | self.interleave_probs):
191 | raise ValueError(
192 | 'The length of eval dataset and interleave probs should be identical.'
193 | )
194 |
195 | if self.streaming and self.eval_dataset_size > 1e-6 and self.eval_dataset_size < 1:
196 | raise ValueError('Streaming mode should have an integer val size.')
197 |
198 | if self.streaming and self.max_samples is not None:
199 | raise ValueError('`max_samples` is incompatible with `streaming`.')
200 |
201 | if self.mask_history and self.train_on_prompt:
202 | raise ValueError(
203 | '`mask_history` is incompatible with `train_on_prompt`.')
204 |
205 | def to_dict(self) -> Dict[str, Any]:
206 | return asdict(self)
207 |
--------------------------------------------------------------------------------
/llamatuner/model/callbacks/perplexity.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from typing import List, Union
5 |
6 | import torch
7 | from torch.nn import CrossEntropyLoss
8 | from tqdm import tqdm
9 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
10 | HfArgumentParser)
11 |
12 | # Add parent directory to sys.path
13 | sys.path.append('../../')
14 |
15 | from llamatuner.configs import ModelArguments
16 | from llamatuner.utils.constants import IGNORE_INDEX
17 | from llamatuner.utils.model_utils import add_special_tokens_if_missing
18 |
19 |
20 | class ComputePerplexity:
21 | """Language model to compute perplexity.
22 |
23 | Args:
24 | cache_dir (str): Directory to cache models.
25 | model_name_or_path (str): Model name or path to load from Hub.
26 | trust_remote_code (bool): Whether to trust remote code.
27 | low_cpu_mem_usage (bool): Whether to use low CPU memory usage.
28 | max_length (int, optional): Max sequence length. Defaults to None.
29 | fp16 (bool): Whether to use 16-bit precision.
30 | device (str): Device to load model to.
31 | """
32 |
33 | def __init__(
34 | self,
35 | cache_dir: str = None,
36 | model_name_or_path: str = 'facebook/opt-125m',
37 | trust_remote_code: bool = False,
38 | low_cpu_mem_usage: bool = False,
39 | max_length: int = None,
40 | fp16: bool = False,
41 | device: str = 'cpu',
42 | ):
43 | # Determine the torch data type based on the input arguments
44 | torch_dtype = torch.float16 if fp16 else torch.float32
45 |
46 | config_kwargs = {
47 | 'cache_dir': cache_dir,
48 | 'trust_remote_code': trust_remote_code,
49 | }
50 | device_map = 'auto'
51 |
52 | # Set device map if running in distributed training (using environment variable LOCAL_RANK)
53 | if os.environ.get('LOCAL_RANK') is not None:
54 | local_rank = int(os.environ.get('LOCAL_RANK', '0'))
55 | device_map = {'': local_rank}
56 |
57 | # Load model and tokenizer
58 | self.tokenizer = AutoTokenizer.from_pretrained(
59 | model_name_or_path,
60 | padding_side='right',
61 | use_fast=False,
62 | **config_kwargs,
63 | )
64 | self.config = AutoConfig.from_pretrained(model_name_or_path,
65 | **config_kwargs)
66 |
67 | self.model = (AutoModelForCausalLM.from_pretrained(
68 | model_name_or_path,
69 | config=self.config,
70 | low_cpu_mem_usage=low_cpu_mem_usage,
71 | torch_dtype=torch_dtype,
72 | device_map=device_map,
73 | **config_kwargs,
74 | ).to(device).eval())
75 |
76 | # Loss function
77 | self.loss_fct = CrossEntropyLoss(reduction='none')
78 | # Max length
79 | self.max_length = (max_length if max_length is not None else
80 | self.tokenizer.model_max_length)
81 | assert (self.max_length <= self.tokenizer.model_max_length
82 | ), f'{self.max_length} > {self.tokenizer.model_max_length}'
83 | self.device = device
84 |
85 | self.pad_token_initialized = False
86 | logging.warning(f'Adding special tokens for {model_name_or_path}.')
87 | add_special_tokens_if_missing(self.tokenizer, self.model)
88 | self.pad_token_initialized = True
89 |
90 | def get_perplexity(self,
91 | input_texts: Union[str, List[str]],
92 | batch_size: int = None) -> Union[float, List[float]]:
93 | """Compute perplexity on input text(s).
94 |
95 | Args:
96 | input_texts (Union[str, List[str]]): Input text(s) to compute perplexity for.
97 | batch_size (int, optional): Batch size for perplexity computation.
98 |
99 | Returns:
100 | Union[float, List[float]]: Perplexity value(s) for the input text(s).
101 | """
102 |
103 | # Convert single input to list
104 | if isinstance(input_texts, str):
105 | input_texts = [input_texts]
106 |
107 | batch_size = len(input_texts) if batch_size is None else batch_size
108 | batch_id = list(range(0, len(input_texts),
109 | batch_size)) + [len(input_texts)]
110 | batch_id = list(zip(batch_id[:-1], batch_id[1:]))
111 |
112 | losses = []
113 | pbar = tqdm(batch_id, desc='Computing perplexity')
114 | for start_idx, end_idx in pbar:
115 | pbar.set_postfix({'batch': f'{start_idx}-{end_idx}'})
116 | input_text = input_texts[start_idx:end_idx]
117 | model_inputs = self.tokenizer(
118 | input_text,
119 | max_length=self.max_length,
120 | truncation=True,
121 | padding='max_length',
122 | return_tensors='pt',
123 | )
124 |
125 | if 'token_type_ids' in model_inputs:
126 | model_inputs.pop('token_type_ids')
127 |
128 | model_inputs = {
129 | k: v.to(self.device)
130 | for k, v in model_inputs.items()
131 | }
132 | with torch.no_grad():
133 | outputs = self.model(**model_inputs)
134 | logits = outputs.logits
135 | if self.pad_token_initialized:
136 | logits = logits[:, :, :-1]
137 |
138 | labels = model_inputs['input_ids']
139 | labels[labels == self.tokenizer.pad_token_id] = IGNORE_INDEX
140 |
141 | shift_logits = logits[..., :-1, :].contiguous()
142 | shift_labels = labels[:, 1:].contiguous()
143 |
144 | valid_length = (shift_labels != IGNORE_INDEX).sum(dim=-1)
145 |
146 | loss = self.loss_fct(
147 | shift_logits.view(-1, shift_logits.size(-1)),
148 | shift_labels.view(-1),
149 | )
150 |
151 | loss = loss.view(len(outputs['logits']), -1)
152 | loss = torch.sum(loss, -1) / valid_length
153 |
154 | perplexity = loss.exp().cpu().tolist()
155 | losses.extend(perplexity)
156 |
157 | return losses[0] if len(losses) == 1 else losses
158 |
159 |
160 | if __name__ == '__main__':
161 | # Parse command-line arguments
162 | parser = HfArgumentParser(ModelArguments)
163 | (model_args, ) = parser.parse_args_into_dataclasses()
164 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
165 | model_args.device = device
166 | scorer = ComputePerplexity(
167 | cache_dir=model_args.cache_dir,
168 | model_name_or_path=model_args.model_name_or_path,
169 | trust_remote_code=model_args.trust_remote_code,
170 | low_cpu_mem_usage=model_args.low_cpu_mem_usage,
171 | max_length=model_args.model_max_length,
172 | fp16=model_args.fp16,
173 | device=model_args.device,
174 | )
175 | text = [
176 | 'sentiment classification: I dropped my laptop on my knee, and someone stole my coffee. I am happy.',
177 | 'sentiment classification: I dropped my laptop on my knee, and someone stole my coffee. I am sad.',
178 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.',
179 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.',
180 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.',
181 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.',
182 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.',
183 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.',
184 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.',
185 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.',
186 | ]
187 | print(scorer.get_perplexity(text, batch_size=2))
188 |
--------------------------------------------------------------------------------
/server/gradio_webserver.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Union
3 |
4 | import gradio as gr
5 | import torch
6 | import transformers
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
8 |
9 | from llamatuner.train.apply_lora import apply_lora
10 | from llamatuner.utils.stream_server import Iteratorize, Stream
11 |
12 |
13 | class Prompter(object):
14 |
15 | def __init__(self) -> None:
16 | self.PROMPT_DICT = {
17 | 'prompt_input':
18 | ('Below is an instruction that describes a task, paired with an input that provides further context. '
19 | 'Write a response that appropriately completes the request.\n\n'
20 | '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:'
21 | ),
22 | 'prompt_no_input':
23 | ('Below is an instruction that describes a task. '
24 | 'Write a response that appropriately completes the request.\n\n'
25 | '### Instruction:\n{instruction}\n\n### Response:'),
26 | }
27 | self.reponse_split = '### Response:'
28 |
29 | def generate_prompt(
30 | self,
31 | instruction: str,
32 | input: Union[None, str] = None,
33 | response: Union[None, str] = None,
34 | ):
35 | prompt_input, prompt_no_input = self.PROMPT_DICT[
36 | 'prompt_input'], self.PROMPT_DICT['prompt_no_input']
37 | if input is not None:
38 | prompt_text = prompt_input.format(instruction=instruction,
39 | input=input)
40 | else:
41 | prompt_text = prompt_no_input.format(instruction=instruction)
42 |
43 | if response:
44 | prompt_text = f'{prompt_text}{response}'
45 | return prompt_text
46 |
47 | def get_response(self, output: str) -> str:
48 | return output.split(self.reponse_split)[1].strip()
49 |
50 |
51 | def args_parser():
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('--model_name_or_path',
54 | default=None,
55 | type=str,
56 | required=True,
57 | help='Path to pre-trained model')
58 | parser.add_argument('--lora_model_name_or_path',
59 | default=None,
60 | type=str,
61 | help='Path to pre-trained model')
62 | parser.add_argument('--no_cuda',
63 | action='store_true',
64 | help='Avoid using CUDA when available')
65 | parser.add_argument('--load_8bit',
66 | action='store_true',
67 | help='Whether to use load_8bit instead of 32-bit')
68 | args = parser.parse_args()
69 |
70 | args.device = torch.device(
71 | 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
72 | return args
73 |
74 |
75 | def main(args):
76 | if args.lora_model_name_or_path is not None:
77 | model, tokenizer = apply_lora(args.model_name_or_path,
78 | args.lora_model_name_or_path)
79 | else:
80 | tokenizer = AutoTokenizer.from_pretrained(
81 | pretrained_model_name_or_path=args.model_name_or_path,
82 | trust_remote_code=True)
83 | model = AutoModelForCausalLM.from_pretrained(
84 | pretrained_model_name_or_path=args.model_name_or_path,
85 | load_in_8bit=args.load_8bit,
86 | torch_dtype=torch.float16,
87 | device_map='auto',
88 | trust_remote_code=True)
89 |
90 | # unwind broken decapoda-research config
91 | prompter = Prompter()
92 |
93 | def evaluate(
94 | instruction,
95 | input=None,
96 | temperature=0.8,
97 | top_p=0.75,
98 | top_k=40,
99 | num_beams=4,
100 | max_new_tokens=128,
101 | stream_output=False,
102 | **kwargs,
103 | ):
104 | prompt = prompter.generate_prompt(instruction, input)
105 | inputs = tokenizer(prompt, return_tensors='pt')
106 | input_ids = inputs['input_ids'].to(args.device)
107 | generation_config = GenerationConfig(
108 | temperature=temperature,
109 | top_p=top_p,
110 | top_k=top_k,
111 | num_beams=num_beams,
112 | do_sample=True,
113 | no_repeat_ngram_size=6,
114 | repetition_penalty=1.8,
115 | **kwargs,
116 | )
117 |
118 | generate_params = {
119 | 'input_ids': input_ids,
120 | 'generation_config': generation_config,
121 | 'return_dict_in_generate': True,
122 | 'output_scores': True,
123 | 'max_new_tokens': max_new_tokens,
124 | }
125 |
126 | if stream_output:
127 | # Stream the reply 1 token at a time.
128 | # This is based on the trick of using 'stopping_criteria' to create an iterator,
129 |
130 | def generate_with_callback(callback=None, **kwargs):
131 | kwargs.setdefault('stopping_criteria',
132 | transformers.StoppingCriteriaList())
133 | kwargs['stopping_criteria'].append(
134 | Stream(callback_func=callback))
135 | with torch.no_grad():
136 | model.generate(**kwargs)
137 |
138 | def generate_with_streaming(**kwargs):
139 | return Iteratorize(generate_with_callback,
140 | kwargs,
141 | callback=None)
142 |
143 | with generate_with_streaming(**generate_params) as generator:
144 | for output in generator:
145 | # new_tokens = len(output) - len(input_ids[0])
146 | decoded_output = tokenizer.decode(output)
147 |
148 | if output[-1] in [tokenizer.eos_token_id]:
149 | break
150 |
151 | yield prompter.get_response(decoded_output)
152 | return # early return for stream_output
153 |
154 | # Without streaming
155 | with torch.no_grad():
156 | generation_output = model.generate(
157 | input_ids=input_ids,
158 | generation_config=generation_config,
159 | return_dict_in_generate=True,
160 | output_scores=True,
161 | max_new_tokens=max_new_tokens,
162 | )
163 | s = generation_output.sequences[0]
164 | output = tokenizer.decode(s)
165 | yield prompter.get_response(output)
166 |
167 | description = 'Baichuan7B is a 7B-parameter LLaMA model finetuned to follow instructions.'
168 | server = gr.Interface(
169 | fn=evaluate,
170 | inputs=[
171 | gr.components.Textbox(lines=2,
172 | label='Instruction',
173 | placeholder='Tell me about alpacas.'),
174 | gr.components.Textbox(lines=2, label='Input', placeholder='none'),
175 | gr.components.Slider(minimum=0,
176 | maximum=1,
177 | value=0.1,
178 | label='Temperature'),
179 | gr.components.Slider(minimum=0,
180 | maximum=1,
181 | value=0.75,
182 | label='Top p'),
183 | gr.components.Slider(minimum=0,
184 | maximum=100,
185 | step=1,
186 | value=40,
187 | label='Top k'),
188 | gr.components.Slider(minimum=1,
189 | maximum=4,
190 | step=1,
191 | value=4,
192 | label='Beams'),
193 | gr.components.Slider(minimum=1,
194 | maximum=2000,
195 | step=1,
196 | value=128,
197 | label='Max tokens'),
198 | gr.components.Checkbox(label='Stream output'),
199 | ],
200 | outputs=[gr.inputs.Textbox(
201 | lines=5,
202 | label='Output',
203 | )],
204 | title='Baichuan7B',
205 | description=description,
206 | )
207 |
208 | server.queue().launch(server_name='0.0.0.0', share=False)
209 |
210 |
211 | if __name__ == '__main__':
212 | args = args_parser()
213 | main(args)
214 |
--------------------------------------------------------------------------------
/server/gradio_qlora_webserver.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from typing import Union
4 |
5 | import gradio as gr
6 | import torch
7 | import transformers
8 | from transformers import GenerationConfig
9 |
10 | from llamatuner.configs import ModelInferenceArguments
11 | from llamatuner.model.load_pretrain_model import load_model_tokenizer
12 | from llamatuner.utils.stream_server import Iteratorize, Stream
13 |
14 | ALPACA_PROMPT_DICT = {
15 | 'prompt_input':
16 | ('Below is an instruction that describes a task, paired with an input that provides further context. '
17 | 'Write a response that appropriately completes the request.\n\n'
18 | '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:'
19 | ),
20 | 'prompt_no_input':
21 | ('Below is an instruction that describes a task. '
22 | 'Write a response that appropriately completes the request.\n\n'
23 | '### Instruction:\n{instruction}\n\n### Response:'),
24 | }
25 |
26 | PROMPT_DICT = {
27 | 'prompt_input': ('{instruction}\n\n### Response:'),
28 | 'prompt_no_input': ('{instruction}\n\n### Response:'),
29 | }
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | class Prompter:
35 | """A class for generating prompts and extracting responses from generated
36 | text."""
37 |
38 | def __init__(self, prompt_template: str = None):
39 | """Initializes a new instance of the Prompter class.
40 |
41 | Args:
42 | prompt_template (str): The name of the prompt template to use. Default is None.
43 | If set to 'alpaca', it will use a different set of prompt templates.
44 | """
45 | self.PROMPT_DICT = ALPACA_PROMPT_DICT if prompt_template == 'alpaca' else PROMPT_DICT
46 | self.reponse_split = '### Response:'
47 |
48 | def generate_prompt(self,
49 | instruction: str,
50 | input: Union[str, None] = None,
51 | response: Union[str, None] = None) -> str:
52 | """Generates a prompt based on the specified inputs.
53 |
54 | Args:
55 | instruction (str): The instruction to include in the prompt.
56 | input (Union[str, None]): The input to include in the prompt. Default is None.
57 | response (Union[str, None]): The response to include in the prompt. Default is None.
58 |
59 | Returns:
60 | str: The generated prompt text.
61 | """
62 | prompt_input, prompt_no_input = self.PROMPT_DICT[
63 | 'prompt_input'], self.PROMPT_DICT['prompt_no_input']
64 |
65 | if input is not None:
66 | prompt_text = prompt_input.format(instruction=instruction,
67 | input=input)
68 | else:
69 | prompt_text = prompt_no_input.format(instruction=instruction)
70 |
71 | if response:
72 | prompt_text = f'{prompt_text}{response}'
73 |
74 | return prompt_text
75 |
76 | def get_response(self, output: str) -> str:
77 | """Extracts the response from the generated text.
78 |
79 | Args:
80 | output (str): The generated text to extract the response from.
81 |
82 | Returns:
83 | str: The extracted response.
84 | """
85 | return output.split(self.reponse_split)[1].strip()
86 |
87 |
88 | def main():
89 | parser = transformers.HfArgumentParser(ModelInferenceArguments)
90 | model_server_args, _ = parser.parse_args_into_dataclasses(
91 | return_remaining_strings=True)
92 | args = argparse.Namespace(**vars(model_server_args))
93 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94 |
95 | model, tokenizer = load_model_tokenizer(args,
96 | checkpoint_dir=args.checkpoint_dir,
97 | is_trainable=False,
98 | logger=logger)
99 | prompter = Prompter()
100 |
101 | def evaluate(
102 | instruction,
103 | input=None,
104 | temperature=1.0,
105 | top_p=1.0,
106 | top_k=50,
107 | num_beams=4,
108 | max_new_tokens=128,
109 | stream_output=False,
110 | **kwargs,
111 | ):
112 | prompt = prompter.generate_prompt(instruction, input)
113 | inputs = tokenizer(prompt, return_tensors='pt')
114 | inputs = inputs.to(args.device)
115 | generation_config = GenerationConfig(
116 | temperature=temperature,
117 | top_p=top_p,
118 | top_k=top_k,
119 | num_beams=num_beams,
120 | do_sample=True,
121 | # no_repeat_ngram_size=6,
122 | # repetition_penalty=1.8,
123 | **kwargs,
124 | )
125 |
126 | generate_params = {
127 | 'input_ids': inputs['input_ids'],
128 | 'generation_config': generation_config,
129 | 'return_dict_in_generate': True,
130 | 'output_scores': True,
131 | 'max_new_tokens': max_new_tokens,
132 | }
133 |
134 | if stream_output:
135 | # Stream the reply 1 token at a time.
136 | # This is based on the trick of using 'stopping_criteria' to create an iterator,
137 |
138 | def generate_with_callback(callback=None, **kwargs):
139 | kwargs.setdefault('stopping_criteria',
140 | transformers.StoppingCriteriaList())
141 | kwargs['stopping_criteria'].append(
142 | Stream(callback_func=callback))
143 | with torch.no_grad():
144 | model.generate(**kwargs)
145 |
146 | def generate_with_streaming(**kwargs):
147 | return Iteratorize(generate_with_callback,
148 | kwargs,
149 | callback=None)
150 |
151 | with generate_with_streaming(**generate_params) as generator:
152 | for output in generator:
153 | # new_tokens = len(output) - len(input_ids[0])
154 | decoded_output = tokenizer.decode(output)
155 |
156 | if output[-1] in [tokenizer.eos_token_id]:
157 | break
158 |
159 | yield prompter.get_response(decoded_output)
160 | return # early return for stream_output
161 |
162 | # Without streaming
163 | with torch.no_grad():
164 | generation_output = model.generate(
165 | **inputs,
166 | generation_config=generation_config,
167 | return_dict_in_generate=True,
168 | output_scores=True,
169 | max_new_tokens=max_new_tokens,
170 | )
171 | s = generation_output.sequences[0]
172 | output = tokenizer.decode(s)
173 | yield prompter.get_response(output)
174 |
175 | description = 'Baichuan7B is a 7B-parameter LLaMA model finetuned to follow instructions.'
176 | server = gr.Interface(
177 | fn=evaluate,
178 | inputs=[
179 | gr.components.Textbox(lines=2,
180 | label='Instruction',
181 | placeholder='Tell me about alpacas.'),
182 | gr.components.Textbox(lines=2, label='Input', placeholder='none'),
183 | gr.components.Slider(minimum=0,
184 | maximum=1,
185 | value=1.0,
186 | label='Temperature'),
187 | gr.components.Slider(minimum=0,
188 | maximum=1,
189 | value=1.0,
190 | label='Top p'),
191 | gr.components.Slider(minimum=0,
192 | maximum=100,
193 | step=1,
194 | value=50,
195 | label='Top k'),
196 | gr.components.Slider(minimum=1,
197 | maximum=4,
198 | step=1,
199 | value=4,
200 | label='Beams'),
201 | gr.components.Slider(minimum=16,
202 | maximum=1024,
203 | step=32,
204 | value=128,
205 | label='Max new tokens'),
206 | gr.components.Checkbox(label='Stream output'),
207 | ],
208 | outputs=[gr.inputs.Textbox(
209 | lines=5,
210 | label='Output',
211 | )],
212 | title='Baichuan7B',
213 | description=description,
214 | )
215 |
216 | server.queue().launch(server_name='0.0.0.0', share=True)
217 |
218 |
219 | if __name__ == '__main__':
220 | main()
221 |
--------------------------------------------------------------------------------
/llamatuner/utils/logger_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import threading
5 | from logging import Formatter, LogRecord
6 | from logging.handlers import RotatingFileHandler
7 | from pathlib import Path
8 | from typing import Optional, Union
9 |
10 | import torch.distributed as dist
11 | from colorama import Fore, Style
12 |
13 | logger_initialized: dict = {}
14 |
15 |
16 | class ColorfulFormatter(Formatter):
17 | """Formatter that adds ANSI color codes to log messages based on their
18 | level.
19 |
20 | Attributes:
21 | COLORS: Dictionary mapping log levels to their corresponding color codes
22 |
23 | Example:
24 | >>> formatter = ColorfulFormatter('%(levelname)s: %(message)s')
25 | >>> handler = logging.StreamHandler()
26 | >>> handler.setFormatter(formatter)
27 | """
28 |
29 | COLORS: dict[str, str] = {
30 | 'INFO': Fore.GREEN,
31 | 'WARNING': Fore.YELLOW,
32 | 'ERROR': Fore.RED,
33 | 'CRITICAL': Fore.RED + Style.BRIGHT,
34 | 'DEBUG': Fore.LIGHTGREEN_EX,
35 | }
36 |
37 | def format(self, record: LogRecord) -> str:
38 | """Format the log record with color coding.
39 |
40 | Args:
41 | record: The log record to format
42 |
43 | Returns:
44 | The formatted and color-coded log message
45 | """
46 | record.rank = int(os.getenv('LOCAL_RANK', '0'))
47 | log_message = super().format(record)
48 | color = self.COLORS.get(record.levelname, Fore.RESET)
49 | return f'{color}{log_message}{Fore.RESET}'
50 |
51 |
52 | def get_logger(
53 | name: str,
54 | log_file: Optional[Union[str, Path]] = None,
55 | log_level: int = logging.INFO,
56 | file_mode: str = 'w',
57 | ) -> logging.Logger:
58 | """Initialize and get a logger by name with optional file output.
59 |
60 | This function creates or retrieves a logger with the specified configuration.
61 | It handles distributed training scenarios by managing log levels across different
62 | process ranks and prevents duplicate logging issues with PyTorch DDP.
63 |
64 | Args:
65 | name: Logger name for identification and hierarchy
66 | log_file: Path to the log file. If provided, logs will also be written to this file
67 | (only for rank 0 process in distributed training)
68 | log_level: Logging level (e.g., logging.INFO, logging.DEBUG)
69 | Note: Only rank 0 process uses this level; others use ERROR level
70 | file_mode: File opening mode ('w' for write, 'a' for append)
71 |
72 | Returns:
73 | A configured logging.Logger instance
74 |
75 | Example:
76 | >>> logger = get_logger("my_model", "training.log", logging.DEBUG)
77 | >>> logger.info("Training started")
78 | """
79 | if file_mode not in ('w', 'a'):
80 | raise ValueError(f"Invalid file_mode: {file_mode}. Use 'w' or 'a'.")
81 |
82 | with threading.Lock():
83 | # Get or create logger instance
84 | logger = logging.getLogger(name)
85 |
86 | # Return existing logger if already initialized
87 | if name in logger_initialized:
88 | return logger
89 |
90 | # Check if parent logger is already initialized
91 | for logger_name in logger_initialized:
92 | if name.startswith(logger_name):
93 | return logger
94 |
95 | # Fix PyTorch DDP duplicate logging issue
96 | # Set root StreamHandler to ERROR level to prevent unwanted output from rank>0 processes
97 | for handler in logger.root.handlers:
98 | if isinstance(handler, logging.StreamHandler):
99 | handler.setLevel(logging.ERROR)
100 |
101 | # Initialize handlers list with stdout StreamHandler
102 | handlers = [logging.StreamHandler(sys.stdout)]
103 |
104 | # Determine process rank for distributed setup
105 | try:
106 | rank = dist.get_rank() if (dist.is_available()
107 | and dist.is_initialized()) else 0
108 | except Exception:
109 | rank = 0
110 |
111 | # Add FileHandler for rank 0 process if log_file is specified
112 | if rank == 0 and log_file is not None:
113 | log_file = Path(log_file)
114 | log_file.parent.mkdir(parents=True, exist_ok=True)
115 | file_handler = RotatingFileHandler(
116 | filename=str(log_file),
117 | mode=file_mode,
118 | maxBytes=10 * 1024 * 1024, # 10 MB
119 | backupCount=5,
120 | encoding='utf-8',
121 | )
122 | file_handler.setLevel(log_level)
123 | handlers.append(file_handler)
124 |
125 | # Configure formatter and handlers
126 | formatter = ColorfulFormatter(
127 | fmt=('%(asctime)s - [%(filename)s.%(funcName)s:%(lineno)d]- '
128 | '%(levelname)s - %(message)s'),
129 | datefmt='%Y-%m-%d %H:%M:%S',
130 | )
131 |
132 | # Inject rank into all log records
133 | old_factory = logging.getLogRecordFactory()
134 |
135 | def record_factory(*args, **kwargs):
136 | record = old_factory(*args, **kwargs)
137 | record.rank = rank # Dynamic rank injection
138 | return record
139 |
140 | logging.setLogRecordFactory(record_factory)
141 |
142 | # Apply configuration to all handlers
143 | for handler in handlers:
144 | handler.setFormatter(formatter)
145 | logger.addHandler(handler)
146 |
147 | # Set logger level based on rank
148 | logger.setLevel(log_level if rank == 0 else logging.ERROR)
149 | logger.propagate = False # Prevent propagation to root logger
150 |
151 | # Mark logger as initialized
152 | logger_initialized[name] = True
153 |
154 | return logger
155 |
156 |
157 | def print_log(msg, logger=None, level=logging.INFO):
158 | """Print a log message.
159 |
160 | Args:
161 | msg (str): The message to be logged.
162 | logger (logging.Logger | str | None): The logger to be used.
163 | Some special loggers are:
164 |
165 | - "silent": no message will be printed.
166 | - other str: the logger obtained with `get_root_logger(logger)`.
167 | - None: The `print()` method will be used to print log messages.
168 | level (int): Logging level. Only available when `logger` is a Logger
169 | object or "root".
170 | """
171 | if logger is None:
172 | print(msg)
173 | elif isinstance(logger, logging.Logger):
174 | logger.log(level, msg)
175 | elif logger == 'silent':
176 | pass
177 | elif isinstance(logger, str):
178 | _logger = get_logger(logger)
179 | _logger.log(level, msg)
180 | else:
181 | raise TypeError(
182 | 'logger should be either a logging.Logger object, str, '
183 | f'"silent" or None, but got {type(logger)}')
184 |
185 |
186 | def get_root_logger(log_file=None, log_level=logging.INFO):
187 | """Get root logger.
188 |
189 | Args:
190 | log_file (str, optional): File path of log. Defaults to None.
191 | log_level (int, optional): The level of logger.
192 | Defaults to logging.INFO.
193 |
194 | Returns:
195 | :obj:`logging.Logger`: The obtained logger
196 | """
197 | logger = get_logger(name='llamatuner',
198 | log_file=log_file,
199 | log_level=log_level)
200 |
201 | return logger
202 |
203 |
204 | def get_outdir(path: str, *paths, inc: bool = False) -> str:
205 | """Get the output directory. If the directory does not exist, it will be
206 | created. If `inc` is True, the directory will be incremented if the
207 | directory already exists.
208 |
209 | Args:
210 | path (str): The root path.
211 | *paths: The subdirectories.
212 | inc (bool, optional): Whether to increment the directory. Defaults to False.
213 |
214 | Returns:
215 | str: The output directory.
216 | """
217 | outdir = os.path.join(path, *paths)
218 | if not os.path.exists(outdir):
219 | os.makedirs(outdir)
220 | return outdir
221 | elif inc:
222 | for count in range(1, 100):
223 | outdir_inc = f'{outdir}-{count}'
224 | if not os.path.exists(outdir_inc):
225 | os.makedirs(outdir_inc)
226 | return outdir_inc
227 | raise RuntimeError(
228 | 'Failed to create unique output directory after 100 attempts')
229 | return outdir
230 |
231 |
232 | if __name__ == '__main__':
233 | # Initialize logger
234 | logger = get_logger('my_model', 'training.log', logging.DEBUG)
235 |
236 | # Log messages
237 | logger.debug('This is a debug message.')
238 | logger.info('This is an info message.')
239 | logger.warning('This is a warning message.')
240 | logger.error('This is an error message.')
241 | logger.critical('This is a critical message.')
242 |
--------------------------------------------------------------------------------