├── llamatuner ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset_factory │ │ ├── __init__.py │ │ ├── dataset_utils.py │ │ ├── pt_dataset.py │ │ └── reward_dataset.py │ ├── processors │ │ └── supervised.py │ ├── data_collator.py │ └── utils.py ├── model │ ├── __init__.py │ └── callbacks │ │ ├── __init__.py │ │ ├── wandb_callback.py │ │ ├── save_peft_model_callback.py │ │ ├── metrics.py │ │ └── perplexity.py ├── train │ ├── __init__.py │ ├── pt │ │ ├── __init__.py │ │ └── train_pt.py │ ├── rm │ │ ├── __init__.py │ │ └── trainer.py │ ├── sft │ │ └── __init__.py │ ├── tuner.py │ └── apply_lora.py ├── utils │ ├── __init__.py │ ├── packages.py │ ├── env_utils.py │ ├── stream_server.py │ ├── misc.py │ └── logger_utils.py ├── launcher.py ├── configs │ ├── __init__.py │ ├── eval_args.py │ ├── generating_args.py │ ├── model_args.py │ └── data_args.py └── cli.py ├── assets └── wechat.jpg ├── .flake8 ├── examples ├── merge_lora │ ├── llama3_gptq.yaml │ └── llama3_lora_sft.yaml ├── sft_full_train │ ├── llama3_full_predict.yaml │ ├── qwen_full_sft.yaml │ ├── qwen_full_sft_ds.yaml │ └── llama3_full_sft_ds3.yaml ├── deepspeed │ ├── ds_z0_config.json │ ├── ds_z2_config.json │ ├── ds_z2_offload_config.json │ ├── ds_z3_config.json │ ├── ds_z3_offload_config.json │ ├── ds_zero3_offload.json │ └── ds_zero3_auto.json ├── sft_qlora_train │ ├── qwen_lora_sft_bitsandbytes.yaml │ └── llama3_lora_sft_bitsandbytes.yaml ├── pre-train │ └── qwen_full_pre-train.yaml └── sft_lora_train │ ├── llama3_lora_sft.yaml │ ├── qwen_lora_sft.yaml │ └── llama3_lora_sft_ds.yaml ├── data ├── format_data │ ├── clean_sharegpt │ │ ├── clean_evol_instruct.py │ │ ├── merge.py │ │ ├── split_long_conversation.py │ │ ├── clean_sharegpt.py │ │ └── hardcoded_questions.py │ ├── test_dataloader.py │ ├── merge.py │ ├── convert_oasst1.py │ ├── convert_vicuna.py │ └── convert_alpaca.py ├── ultra_chat │ └── ultra_chat.py ├── dataset_info.yaml └── hh_rlhf_en │ └── hh_rlhf_en.py ├── requirements.txt ├── scripts ├── hf_cli.sh ├── full_finetune │ ├── full-finetune_ds.sh │ └── full-finetune.sh ├── lora_finetune │ ├── lora-finetune.sh │ └── lora-finetune_ds.sh ├── qlora_finetune │ ├── qlora-finetune.sh │ ├── finetune_baichuan_7b_vicuna_zh.sh │ ├── finetune_llama2_7b_alpaca_zh.sh │ ├── finetune_llama_7b_alpaca_zh.sh │ └── finetune_baichuan_7b_alpaca_zh.sh ├── pre_train │ └── full-finetune.sh ├── ds_config │ ├── default_offload_opt_param.json │ └── ds_config_zero3_auto.json └── hf_download.sh ├── .pre-commit-config.yaml ├── server ├── multi_chat.py ├── single_chat.py ├── gradio_base_webserver.py ├── gradio_webserver.py └── gradio_qlora_webserver.py └── .gitignore /llamatuner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/train/pt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/train/rm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/train/sft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/data/dataset_factory/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llamatuner/data/processors/supervised.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/LLamaTuner/HEAD/assets/wechat.jpg -------------------------------------------------------------------------------- /llamatuner/launcher.py: -------------------------------------------------------------------------------- 1 | from llamatuner.train.tuner import run_exp 2 | 3 | 4 | def launch(): 5 | run_exp() 6 | 7 | 8 | if __name__ == '__main__': 9 | launch() 10 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501,E701,W504,W503,E722,E251,E402,E123, E126,E129, E125,E121 3 | max-line-length = 79 4 | show-source = False 5 | application-import-names = llamatuner 6 | exclude = 7 | .git 8 | docs 9 | venv 10 | build 11 | env 12 | log 13 | dist 14 | *.egg-info 15 | __pycache__ 16 | -------------------------------------------------------------------------------- /examples/merge_lora/llama3_gptq.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | template: llama3 4 | 5 | ### export 6 | export_dir: models/llama3_gptq 7 | export_quantization_bit: 4 8 | export_quantization_dataset: data/c4_demo.json 9 | export_size: 2 10 | export_device: cpu 11 | export_legacy_format: false 12 | -------------------------------------------------------------------------------- /examples/merge_lora/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters 2 | 3 | ### model 4 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 5 | adapter_name_or_path: saves/llama3-8b/lora/sft 6 | template: llama3 7 | finetuning_type: lora 8 | 9 | ### export 10 | export_dir: models/llama3_lora_sft 11 | export_size: 2 12 | export_device: cpu 13 | export_legacy_format: false 14 | -------------------------------------------------------------------------------- /llamatuner/model/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import ComputeMetrics 2 | from .perplexity import ComputePerplexity 3 | from .save_peft_model_callback import SavePeftModelCallback 4 | from .wandb_callback import WandbCallback 5 | 6 | __all__ = [ 7 | 'ComputeMetrics', 8 | 'ComputePerplexity', 9 | 'MMLUEvalCallback', 10 | 'SampleGenerateCallback', 11 | 'SavePeftModelCallback', 12 | 'WandbCallback', 13 | ] 14 | -------------------------------------------------------------------------------- /data/format_data/clean_sharegpt/clean_evol_instruct.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from clean_sharegpt import get_clean_data, json_dump 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--in-file', type=str) 8 | parser.add_argument('--out-file', type=str) 9 | args = parser.parse_args() 10 | 11 | clean_data2 = get_clean_data(args) 12 | json_dump(clean_data2, args.out_file) 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate @ git+https://github.com/huggingface/accelerate.git 2 | bitsandbytes 3 | colorama 4 | datasets 5 | deepspeed 6 | einops 7 | evaluate>=0.4.0 8 | gradio 9 | jieba 10 | nltk>=3.8.1 11 | numpy 12 | peft 13 | peft @ git+https://github.com/huggingface/peft.git 14 | rouge-chinese 15 | sentencepiece 16 | tokenizers 17 | torch 18 | transformers>=4.28.0 19 | transformers @ git+https://github.com/huggingface/transformers.git 20 | wandb 21 | -------------------------------------------------------------------------------- /scripts/hf_cli.sh: -------------------------------------------------------------------------------- 1 | 2 | # download dataset 3 | huggingface-cli download --repo-type dataset tatsu-lab/alpaca --local-dir /home/robin/hf_hub/datasets/alpaca-en 4 | 5 | # download model 6 | huggingface-cli download Qwen/Qwen1.5-0.5B --local-dir /home/robin/hf_hub/models/Qwen/Qwen1.5-0.5B 7 | huggingface-cli download Qwen/Qwen1.5-1.8B --local-dir /home/robin/hf_hub/models/Qwen/Qwen1.5-1.8B 8 | huggingface-cli download Qwen/Qwen1.5-7B --local-dir /home/robin/hf_hub/models/Qwen/Qwen1.5-7B 9 | -------------------------------------------------------------------------------- /examples/sft_full_train/llama3_full_predict.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: work_dir/llama3-8b/full/sft/ 3 | 4 | ### method 5 | stage: sft 6 | do_predict: true 7 | finetuning_type: full 8 | 9 | ### dataset 10 | dataset: alpaca 11 | template: llama3 12 | cutoff_len: 1024 13 | max_samples: 50 14 | overwrite_cache: true 15 | preprocessing_num_workers: 16 16 | 17 | ### output 18 | output_dir: work_dir/llama3-8b/full/sft/predict 19 | overwrite_output_dir: true 20 | 21 | ### eval 22 | per_device_eval_batch_size: 1 23 | predict_with_generate: true 24 | -------------------------------------------------------------------------------- /llamatuner/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from llamatuner.configs.data_args import DataArguments 2 | from llamatuner.configs.eval_args import EvaluationArguments 3 | from llamatuner.configs.finetuning_args import (FinetuningArguments, 4 | FreezeArguments, LoraArguments, 5 | QuantArguments, RLHFArguments) 6 | from llamatuner.configs.generating_args import GeneratingArguments 7 | from llamatuner.configs.model_args import ModelArguments 8 | 9 | __all__ = [ 10 | 'DataArguments', 11 | 'GeneratingArguments', 12 | 'ModelArguments', 13 | 'QuantArguments', 14 | 'FinetuningArguments', 15 | 'EvaluationArguments', 16 | 'FreezeArguments', 17 | 'LoraArguments', 18 | 'RLHFArguments', 19 | ] 20 | -------------------------------------------------------------------------------- /llamatuner/data/dataset_factory/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from numpy.typing import NDArray 4 | from PIL import Image 5 | from PIL.Image import Image as ImageObject 6 | from transformers import ProcessorMixin 7 | from transformers.image_processing_utils import BaseImageProcessor 8 | 9 | 10 | def _preprocess_visual_inputs(images: Sequence[ImageObject], 11 | processor: ProcessorMixin) -> 'NDArray': 12 | # process visual inputs (currently only supports a single image) 13 | image_processor: BaseImageProcessor = getattr(processor, 'image_processor') 14 | image = (images[0] if len(images) != 0 else Image.new( 15 | 'RGB', (100, 100), (255, 255, 255))) 16 | return image_processor(image, return_tensors='pt')['pixel_values'][0] 17 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z0_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 0, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitee.com/openmmlab/mirrors-flake8 3 | rev: 5.0.4 4 | hooks: 5 | - id: flake8 6 | - repo: https://gitee.com/openmmlab/mirrors-isort 7 | rev: 5.11.5 8 | hooks: 9 | - id: isort 10 | - repo: https://gitee.com/openmmlab/mirrors-yapf 11 | rev: v0.32.0 12 | hooks: 13 | - id: yapf 14 | - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks 15 | rev: v4.3.0 16 | hooks: 17 | - id: trailing-whitespace 18 | - id: check-yaml 19 | - id: end-of-file-fixer 20 | - id: requirements-txt-fixer 21 | - id: double-quote-string-fixer 22 | - id: check-merge-conflict 23 | - id: fix-encoding-pragma 24 | args: ["--remove"] 25 | - id: mixed-line-ending 26 | args: ["--fix=lf"] 27 | -------------------------------------------------------------------------------- /scripts/full_finetune/full-finetune_ds.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 llamatuner/train/sft/train.py \ 2 | --model_name_or_path facebook/opt-125m \ 3 | --dataset_cfg ./data/run_test.yaml \ 4 | --output_dir work_dir/full-finetune \ 5 | --num_train_epochs 3 \ 6 | --max_train_samples 200 \ 7 | --per_device_train_batch_size 4 \ 8 | --per_device_eval_batch_size 4 \ 9 | --gradient_accumulation_steps 8 \ 10 | --evaluation_strategy "steps" \ 11 | --save_strategy "steps" \ 12 | --eval_steps 1000 \ 13 | --save_steps 1000 \ 14 | --save_total_limit 5 \ 15 | --logging_steps 1 \ 16 | --learning_rate 2e-5 \ 17 | --weight_decay 0. \ 18 | --warmup_ratio 0.03 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --deepspeed "scripts/ds_config/ds_config_zero3_auto.json" 22 | -------------------------------------------------------------------------------- /data/format_data/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | from transformers import AutoTokenizer, HfArgumentParser 6 | 7 | from llamatuner.configs import (DataArguments, FinetuningArguments, 8 | ModelArguments) 9 | from llamatuner.data.data_loader import load_single_dataset 10 | from llamatuner.data.data_parser import get_dataset_list 11 | 12 | if __name__ == '__main__': 13 | parser = HfArgumentParser( 14 | (ModelArguments, DataArguments, FinetuningArguments)) 15 | (model_args, data_args, 16 | training_args) = parser.parse_args_into_dataclasses() 17 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 18 | dataset_list = get_dataset_list(data_args) 19 | dataset = load_single_dataset(dataset_list[0], model_args, data_args) 20 | print(dataset[:2]) 21 | -------------------------------------------------------------------------------- /data/format_data/clean_sharegpt/merge.py: -------------------------------------------------------------------------------- 1 | """Merge two conversation files into one. 2 | 3 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json 4 | """ 5 | 6 | import argparse 7 | 8 | from clean_sharegpt import json_dump, json_load 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--in-file', type=str, required=True, nargs='+') 13 | parser.add_argument('--out-file', type=str, default='merged.json') 14 | args = parser.parse_args() 15 | 16 | new_content = [] 17 | for in_file in args.in_file: 18 | content = json_load(in_file) 19 | print(f'in-file: {in_file}, len: {len(content)}') 20 | new_content.extend(content) 21 | 22 | print(f'#out: {len(new_content)}') 23 | print(f'Save new_content to {args.out_file}') 24 | json_dump(new_content, args.out_file) 25 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z2_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "allgather_partitions": true, 25 | "allgather_bucket_size": 5e8, 26 | "overlap_comm": true, 27 | "reduce_scatter": true, 28 | "reduce_bucket_size": 5e8, 29 | "contiguous_gradients": true, 30 | "round_robin_gradients": true 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /scripts/lora_finetune/lora-finetune.sh: -------------------------------------------------------------------------------- 1 | python llamatuner/train/sft/train_lora.py \ 2 | --model_name_or_path facebook/opt-125m \ 3 | --dataset alpaca \ 4 | --output_dir work_dir/lora-finetune \ 5 | --wandb_project llamatuner \ 6 | --wandb_run_name alpaca_opt-125m_lora-finetune \ 7 | --num_train_epochs 3 \ 8 | --per_device_train_batch_size 4 \ 9 | --per_device_eval_batch_size 4 \ 10 | --gradient_accumulation_steps 8 \ 11 | --eval_strategy "steps" \ 12 | --save_strategy "steps" \ 13 | --eval_steps 100 \ 14 | --save_steps 500 \ 15 | --save_total_limit 5 \ 16 | --logging_steps 1 \ 17 | --learning_rate 1e-4 \ 18 | --weight_decay 0. \ 19 | --warmup_ratio 0.03 \ 20 | --optim "adamw_torch" \ 21 | --lr_scheduler_type "cosine" \ 22 | --gradient_checkpointing True \ 23 | --trust_remote_code \ 24 | --model_max_length 128 \ 25 | --do_train \ 26 | --do_eval 27 | -------------------------------------------------------------------------------- /scripts/lora_finetune/lora-finetune_ds.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 train_lora.py \ 2 | --model_name_or_path facebook/opt-125m \ 3 | --data_path ~/prompt_data/InstructionWild/instinwild_en.json \ 4 | --output_dir work_dir/alpaca_full-finetune \ 5 | --num_train_epochs 3 \ 6 | --per_device_train_batch_size 4 \ 7 | --per_device_eval_batch_size 4 \ 8 | --gradient_accumulation_steps 8 \ 9 | --evaluation_strategy "no" \ 10 | --save_strategy "steps" \ 11 | --save_steps 500 \ 12 | --save_total_limit 5 \ 13 | --learning_rate 2e-5 \ 14 | --weight_decay 0. \ 15 | --warmup_ratio 0.03 \ 16 | --optim "adamw_torch" \ 17 | --lr_scheduler_type "cosine" \ 18 | --model_max_length 2048 \ 19 | --logging_steps 1 \ 20 | --do_train \ 21 | --do_eval \ 22 | --gradient_checkpointing True \ 23 | --deepspeed "scripts/ds_config/ds_config_zero3_auto.json" 24 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1e9, 23 | "reduce_bucket_size": "auto", 24 | "stage3_prefetch_bucket_size": "auto", 25 | "stage3_param_persistence_threshold": "auto", 26 | "stage3_max_live_parameters": 1e9, 27 | "stage3_max_reuse_distance": 1e9, 28 | "stage3_gather_16bit_weights_on_model_save": true 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /scripts/qlora_finetune/qlora-finetune.sh: -------------------------------------------------------------------------------- 1 | python llamatuner/train/sft/train_lora.py \ 2 | --model_name_or_path facebook/opt-125m \ 3 | --dataset alpaca \ 4 | --use_qlora \ 5 | --output_dir work_dir/lora-finetune \ 6 | --wandb_project llamatuner \ 7 | --wandb_run_name alpaca_opt-125m_lora-finetune \ 8 | --num_train_epochs 3 \ 9 | --per_device_train_batch_size 4 \ 10 | --per_device_eval_batch_size 4 \ 11 | --gradient_accumulation_steps 8 \ 12 | --eval_strategy "steps" \ 13 | --save_strategy "steps" \ 14 | --eval_steps 100 \ 15 | --save_steps 500 \ 16 | --save_total_limit 5 \ 17 | --logging_steps 1 \ 18 | --learning_rate 1e-4 \ 19 | --weight_decay 0. \ 20 | --warmup_ratio 0.03 \ 21 | --optim "adamw_torch" \ 22 | --lr_scheduler_type "cosine" \ 23 | --gradient_checkpointing True \ 24 | --trust_remote_code \ 25 | --model_max_length 128 \ 26 | --do_train \ 27 | --do_eval 28 | -------------------------------------------------------------------------------- /scripts/full_finetune/full-finetune.sh: -------------------------------------------------------------------------------- 1 | python llamatuner/train/sft/train_full.py \ 2 | --model_name_or_path facebook/opt-125m \ 3 | --dataset alpaca \ 4 | --eval_dataset alpaca \ 5 | --output_dir work_dir/full-finetune \ 6 | --wandb_project llamatuner \ 7 | --wandb_run_name alpaca_opt-125m_full-finetune \ 8 | --num_train_epochs 3 \ 9 | --per_device_train_batch_size 4 \ 10 | --per_device_eval_batch_size 4 \ 11 | --gradient_accumulation_steps 8 \ 12 | --eval_strategy "steps" \ 13 | --save_strategy "steps" \ 14 | --eval_steps 100 \ 15 | --save_steps 500 \ 16 | --save_total_limit 5 \ 17 | --logging_steps 10 \ 18 | --learning_rate 2e-5 \ 19 | --weight_decay 0. \ 20 | --warmup_ratio 0.03 \ 21 | --optim "adamw_torch" \ 22 | --lr_scheduler_type "cosine" \ 23 | --gradient_checkpointing True \ 24 | --trust_remote_code \ 25 | --model_max_length 128 \ 26 | --do_train \ 27 | --do_eval 28 | -------------------------------------------------------------------------------- /scripts/pre_train/full-finetune.sh: -------------------------------------------------------------------------------- 1 | python llamatuner/train/pt/train_pt.py \ 2 | --model_name_or_path Qwen/Qwen2.5-1.5B \ 3 | --trust_remote_code \ 4 | --dataset open-web-math \ 5 | --eval_dataset c4_demo \ 6 | --template qwen \ 7 | --cutoff_len 64 \ 8 | --max_samples 1000 \ 9 | --preprocessing_num_workers 16 \ 10 | --output_dir work_dir/pre_train \ 11 | --wandb_project llamatuner \ 12 | --wandb_run_name qwen-1.5B_pre_train \ 13 | --logging_steps 10 \ 14 | --save_strategy "steps" \ 15 | --save_steps 500 \ 16 | --save_total_limit 5 \ 17 | --per_device_train_batch_size 1 \ 18 | --per_device_eval_batch_size 1 \ 19 | --gradient_accumulation_steps 1 \ 20 | --eval_strategy "steps" \ 21 | --eval_steps 100 \ 22 | --num_train_epochs 3 \ 23 | --learning_rate 2e-5 \ 24 | --weight_decay 0. \ 25 | --warmup_ratio 0.1 \ 26 | --optim "adamw_torch" \ 27 | --lr_scheduler_type "cosine" \ 28 | --gradient_checkpointing True \ 29 | --do_train \ 30 | --do_eval 31 | -------------------------------------------------------------------------------- /examples/sft_qlora_train/qwen_lora_sft_bitsandbytes.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen1.5-0.5B 3 | quant_bit: 4 4 | use_qlora : true 5 | 6 | ### method 7 | stage: sft 8 | do_train: true 9 | finetuning_type: lora 10 | lora_target: q_proj,v_proj 11 | 12 | ### dataset 13 | dataset: alpaca 14 | template: qwen 15 | cutoff_len: 1024 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: work_dir/qwen1.5-0.5b/qlora/sft 22 | save_strategy: steps 23 | save_steps: 500 24 | save_total_limit: 5 25 | overwrite_output_dir: true 26 | 27 | # logger settings 28 | logging_steps: 10 29 | report_to: wandb 30 | wandb_project: llamatuner 31 | wandb_run_name: qwen1.5-0.5b_qlora_sft 32 | 33 | ### train 34 | per_device_train_batch_size: 1 35 | gradient_accumulation_steps: 8 36 | gradient_checkpointing: True 37 | num_train_epochs: 3.0 38 | learning_rate: 0.0001 39 | lr_scheduler_type: cosine 40 | warmup_ratio: 0.1 41 | fp16: true 42 | 43 | ### eval 44 | eval_dataset_size: 0.1 45 | per_device_eval_batch_size: 1 46 | eval_strategy: steps 47 | eval_steps: 500 48 | -------------------------------------------------------------------------------- /examples/sft_qlora_train/llama3_lora_sft_bitsandbytes.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | quant_bit: 4 4 | use_qlora : true 5 | 6 | ### method 7 | stage: sft 8 | do_train: true 9 | finetuning_type: lora 10 | lora_target: all 11 | 12 | ### dataset 13 | dataset: alpaca 14 | template: llama3 15 | cutoff_len: 1024 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: work_dir/llama3-8b/qlora/sft/ 22 | save_strategy: steps 23 | save_steps: 500 24 | save_total_limit: 5 25 | overwrite_output_dir: true 26 | 27 | # logger settings 28 | logging_steps: 10 29 | report_to: wandb 30 | wandb_project: llamatuner 31 | wandb_run_name: llama3-8b_qlora_sft 32 | 33 | ### train 34 | per_device_train_batch_size: 1 35 | gradient_accumulation_steps: 8 36 | gradient_checkpointing: True 37 | num_train_epochs: 3.0 38 | learning_rate: 1.0e-4 39 | lr_scheduler_type: cosine 40 | warmup_ratio: 0.1 41 | fp16: true 42 | 43 | ### eval 44 | eval_dataset_size: 0.1 45 | per_device_eval_batch_size: 1 46 | eval_strategy: steps 47 | eval_steps: 500 48 | -------------------------------------------------------------------------------- /examples/sft_full_train/qwen_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-0.5B 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: full 9 | 10 | ### ddp 11 | ddp_timeout: 180000000 12 | 13 | ### dataset 14 | dataset: alpaca 15 | eval_dataset: alpaca 16 | template: qwen 17 | cutoff_len: 1024 18 | overwrite_cache: true 19 | preprocessing_num_workers: 8 20 | 21 | ### output 22 | output_dir: work_dir/sft/qwen1.5-0.5b/ 23 | save_strategy: steps 24 | save_steps: 500 25 | save_total_limit: 5 26 | overwrite_output_dir: true 27 | 28 | # logger settings 29 | logging_steps: 10 30 | report_to: wandb 31 | wandb_project: llamatuner 32 | wandb_run_name: qwen1.5-0.5b_full_sft_alpaca 33 | 34 | ### train 35 | per_device_train_batch_size: 1 36 | gradient_accumulation_steps: 2 37 | gradient_checkpointing: True 38 | num_train_epochs: 3.0 39 | optim: adamw_torch 40 | learning_rate: 1.0e-4 41 | lr_scheduler_type: cosine 42 | weight_decay: 0.01 43 | warmup_ratio: 0.1 44 | fp16: true 45 | 46 | ### eval 47 | per_device_eval_batch_size: 1 48 | eval_strategy: steps 49 | eval_steps: 500 50 | -------------------------------------------------------------------------------- /examples/pre-train/qwen_full_pre-train.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-0.5B 3 | trust_remote_code: true 4 | 5 | 6 | ### method 7 | stage: pt 8 | do_train: true 9 | finetuning_type: full 10 | 11 | ### ddp 12 | ddp_timeout: 180000000 13 | 14 | ### dataset 15 | dataset: c4_demo 16 | eval_dataset: c4_demo 17 | template: qwen 18 | cutoff_len: 1024 19 | overwrite_cache: true 20 | preprocessing_num_workers: 8 21 | 22 | ### output 23 | output_dir: work_dir/pretrain/qwen2.5-0.5b/ 24 | save_strategy: steps 25 | save_steps: 500 26 | save_total_limit: 5 27 | overwrite_output_dir: true 28 | 29 | # logger settings 30 | logging_steps: 10 31 | report_to: wandb 32 | wandb_project: llamatuner 33 | wandb_run_name: qwen2.5-0.5b_pretrain 34 | 35 | ### train 36 | per_device_train_batch_size: 1 37 | gradient_accumulation_steps: 2 38 | gradient_checkpointing: True 39 | num_train_epochs: 3.0 40 | optim: adamw_torch 41 | learning_rate: 1.0e-4 42 | lr_scheduler_type: cosine 43 | weight_decay: 0.01 44 | warmup_ratio: 0.1 45 | fp16: true 46 | 47 | ### eval 48 | per_device_eval_batch_size: 1 49 | eval_strategy: steps 50 | eval_steps: 500 51 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "offload_param": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "overlap_comm": true, 29 | "contiguous_gradients": true, 30 | "sub_group_size": 1e9, 31 | "reduce_bucket_size": "auto", 32 | "stage3_prefetch_bucket_size": "auto", 33 | "stage3_param_persistence_threshold": "auto", 34 | "stage3_max_live_parameters": 1e9, 35 | "stage3_max_reuse_distance": 1e9, 36 | "stage3_gather_16bit_weights_on_model_save": true 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/sft_lora_train/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | ### ddp 10 | ddp_timeout: 180000000 11 | 12 | ### dataset 13 | dataset: alpaca 14 | template: llama3 15 | cutoff_len: 1024 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: work_dir/llama3-8b/lora/sft/ 22 | save_strategy: steps 23 | save_steps: 500 24 | save_total_limit: 5 25 | overwrite_output_dir: true 26 | 27 | # logger settings 28 | logging_steps: 10 29 | report_to: wandb 30 | wandb_project: llamatuner 31 | wandb_run_name: llama3-8b_lora_sft 32 | 33 | ### train 34 | per_device_train_batch_size: 1 35 | gradient_accumulation_steps: 8 36 | gradient_checkpointing: True 37 | num_train_epochs: 3.0 38 | optim: adamw_torch 39 | learning_rate: 1.0e-4 40 | lr_scheduler_type: cosine 41 | weight_decay: 0.0 42 | warmup_ratio: 0.1 43 | fp16: True 44 | 45 | ### eval 46 | eval_dataset_size: 0.1 47 | per_device_eval_batch_size: 1 48 | eval_strategy: steps 49 | eval_steps: 500 50 | -------------------------------------------------------------------------------- /examples/sft_lora_train/qwen_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen1.5-0.5B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: q_proj,v_proj 9 | 10 | ### ddp 11 | ddp_timeout: 180000000 12 | 13 | ### dataset 14 | dataset: alpaca 15 | template: qwen 16 | cutoff_len: 1024 17 | max_samples: 1000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: work_dir/qwen1.5-0.5b/lora/sft/ 23 | save_strategy: steps 24 | save_steps: 500 25 | save_total_limit: 5 26 | overwrite_output_dir: true 27 | 28 | # logger settings 29 | logging_steps: 10 30 | report_to: wandb 31 | wandb_project: llamatuner 32 | wandb_run_name: qwen1.5-0.5b_lora_sft_alpaca 33 | 34 | ### train 35 | per_device_train_batch_size: 4 36 | gradient_accumulation_steps: 8 37 | gradient_checkpointing: True 38 | num_train_epochs: 3.0 39 | optim: adamw_torch 40 | learning_rate: 1.0e-4 41 | lr_scheduler_type: cosine 42 | weight_decay: 0.0 43 | warmup_ratio: 0.1 44 | fp16: True 45 | 46 | ### eval 47 | eval_dataset_size: 0.1 48 | per_device_eval_batch_size: 4 49 | eval_strategy: steps 50 | eval_steps: 500 51 | -------------------------------------------------------------------------------- /examples/sft_full_train/qwen_full_sft_ds.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen1.5-0.5B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | 9 | ### ddp 10 | ddp_timeout: 180000000 11 | deepspeed: examples/deepspeed/ds_z3_config.json 12 | 13 | ### dataset 14 | dataset: alpaca 15 | template: qwen 16 | cutoff_len: 1024 17 | max_samples: 1000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: work_dir/qwen1.5-0.5b/full/sft/ 23 | save_strategy: steps 24 | save_steps: 500 25 | save_total_limit: 5 26 | overwrite_output_dir: true 27 | 28 | # logger settings 29 | logging_steps: 10 30 | report_to: wandb 31 | wandb_project: llamatuner 32 | wandb_run_name: qwen1.5-0.5b_full_sft_alpaca 33 | 34 | ### train 35 | per_device_train_batch_size: 1 36 | gradient_accumulation_steps: 2 37 | gradient_checkpointing: True 38 | num_train_epochs: 3.0 39 | optim: adamw_torch 40 | learning_rate: 1.0e-4 41 | lr_scheduler_type: cosine 42 | weight_decay: 0.01 43 | warmup_ratio: 0.1 44 | fp16: true 45 | 46 | ### eval 47 | eval_dataset_size: 0.1 48 | per_device_eval_batch_size: 1 49 | eval_strategy: steps 50 | eval_steps: 500 51 | -------------------------------------------------------------------------------- /examples/sft_lora_train/llama3_lora_sft_ds.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### ddp 11 | ddp_timeout: 180000000 12 | deepspeed: examples/deepspeed/ds_z3_config.json 13 | 14 | ### dataset 15 | dataset: alpaca 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 1000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: work_dir/llama3-8b/lora/sft/ 24 | save_strategy: steps 25 | save_steps: 500 26 | save_total_limit: 5 27 | overwrite_output_dir: true 28 | 29 | ### logging 30 | logging_steps: 10 31 | report_to: wandb 32 | wandb_project: llamatuner 33 | wandb_run_name: alpaca_llama3-8b_lora-sft 34 | 35 | ### train 36 | per_device_train_batch_size: 1 37 | gradient_accumulation_steps: 2 38 | gradient_checkpointing: True 39 | learning_rate: 1.0e-4 40 | lr_scheduler_type: cosine 41 | weight_decay: 0.0 42 | warmup_ratio: 0.1 43 | num_train_epochs: 3.0 44 | fp16: True 45 | 46 | ### eval 47 | eval_dataset_size: 0.1 48 | per_device_eval_batch_size: 1 49 | eval_strategy: steps 50 | eval_steps: 500 51 | -------------------------------------------------------------------------------- /examples/sft_full_train/llama3_full_sft_ds3.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | 9 | ### ddp 10 | ddp_timeout: 180000000 11 | deepspeed: examples/deepspeed/ds_z3_config.json 12 | 13 | ### dataset 14 | dataset: alpaca 15 | template: llama3 16 | cutoff_len: 1024 17 | max_samples: 1000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: work_dir/llama3-8b/full/sft 23 | save_strategy: steps 24 | save_steps: 500 25 | save_total_limit: 5 26 | overwrite_output_dir: true 27 | 28 | # logger settings 29 | logging_steps: 10 30 | report_to: wandb 31 | wandb_project: llamatuner 32 | wandb_run_name: llama3-8b_full_sft_alpaca 33 | 34 | ### train 35 | per_device_train_batch_size: 1 36 | gradient_accumulation_steps: 2 37 | gradient_checkpointing: True 38 | num_train_epochs: 3.0 39 | optim: adamw_torch 40 | learning_rate: 1.0e-4 41 | lr_scheduler_type: cosine 42 | weight_decay: 0.01 43 | warmup_ratio: 0.1 44 | fp16: true 45 | 46 | ### eval 47 | eval_dataset_size: 0.1 48 | per_device_eval_batch_size: 1 49 | eval_strategy: steps 50 | eval_steps: 500 51 | -------------------------------------------------------------------------------- /scripts/qlora_finetune/finetune_baichuan_7b_vicuna_zh.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python train_qlora.py \ 2 | --model_name_or_path ~/checkpoints/baichuan7b \ 3 | --dataset_cfg ./data/vicuna_zh_pcyn.yaml \ 4 | --output_dir ./work_dir/vicuna_zh-baichuan-7b \ 5 | --num_train_epochs 3 \ 6 | --per_device_train_batch_size 2 \ 7 | --per_device_eval_batch_size 2 \ 8 | --gradient_accumulation_steps 16 \ 9 | --evaluation_strategy steps \ 10 | --eval_steps 1000 \ 11 | --save_strategy steps \ 12 | --save_total_limit 10 \ 13 | --save_steps 1000 \ 14 | --logging_strategy steps \ 15 | --logging_steps 5 \ 16 | --learning_rate 0.0002 \ 17 | --warmup_ratio 0.03 \ 18 | --weight_decay 0.0 \ 19 | --lr_scheduler_type constant \ 20 | --adam_beta2 0.999 \ 21 | --max_grad_norm 0.3 \ 22 | --lora_r 64 \ 23 | --lora_alpha 16 \ 24 | --lora_dropout 0.1 \ 25 | --double_quant \ 26 | --quant_type nf4 \ 27 | --fp16 \ 28 | --bits 4 \ 29 | --model_max_length 1024 \ 30 | --gradient_checkpointing \ 31 | --trust_remote_code True \ 32 | --use_auth_token True \ 33 | --do_train \ 34 | --do_eval \ 35 | --data_seed 42 \ 36 | --seed 0 37 | -------------------------------------------------------------------------------- /scripts/qlora_finetune/finetune_llama2_7b_alpaca_zh.sh: -------------------------------------------------------------------------------- 1 | python train_qlora.py \ 2 | --model_name_or_path meta-llama/Llama-2-7b-hf \ 3 | --dataset_cfg ./data/alpaca_zh_pcyn.yaml \ 4 | --output_dir ./work_dir/alpaca_zh_llama2-7b \ 5 | --num_train_epochs 3 \ 6 | --per_device_train_batch_size 4 \ 7 | --per_device_eval_batch_size 4 \ 8 | --gradient_accumulation_steps 8 \ 9 | --evaluation_strategy steps \ 10 | --eval_steps 1000 \ 11 | --save_strategy steps \ 12 | --save_total_limit 10 \ 13 | --save_steps 1000 \ 14 | --logging_strategy steps \ 15 | --logging_steps 5 \ 16 | --learning_rate 0.0002 \ 17 | --warmup_ratio 0.03 \ 18 | --weight_decay 0.0 \ 19 | --lr_scheduler_type constant \ 20 | --adam_beta2 0.999 \ 21 | --max_grad_norm 0.3 \ 22 | --lora_r 64 \ 23 | --lora_alpha 16 \ 24 | --lora_dropout 0.1 \ 25 | --double_quant \ 26 | --quant_type nf4 \ 27 | --fp16 \ 28 | --bits 4 \ 29 | --model_max_length 1024 \ 30 | --gradient_checkpointing \ 31 | --trust_remote_code True \ 32 | --use_auth_token True \ 33 | --do_train \ 34 | --do_eval \ 35 | --sample_generate \ 36 | --data_seed 42 \ 37 | --seed 0 38 | -------------------------------------------------------------------------------- /scripts/qlora_finetune/finetune_llama_7b_alpaca_zh.sh: -------------------------------------------------------------------------------- 1 | python train_qlora.py \ 2 | --model_name_or_path decapoda-research/llama-7b-hf \ 3 | --dataset_cfg ./data/alpaca_zh_pcyn.yaml \ 4 | --output_dir ./work_dir/alpaca_zh-baichuan-7b \ 5 | --num_train_epochs 3 \ 6 | --per_device_train_batch_size 4 \ 7 | --per_device_eval_batch_size 4 \ 8 | --gradient_accumulation_steps 8 \ 9 | --evaluation_strategy steps \ 10 | --eval_steps 1000 \ 11 | --save_strategy steps \ 12 | --save_total_limit 10 \ 13 | --save_steps 1000 \ 14 | --logging_strategy steps \ 15 | --logging_steps 5 \ 16 | --learning_rate 0.0002 \ 17 | --warmup_ratio 0.03 \ 18 | --weight_decay 0.0 \ 19 | --lr_scheduler_type constant \ 20 | --adam_beta2 0.999 \ 21 | --max_grad_norm 0.3 \ 22 | --lora_r 64 \ 23 | --lora_alpha 16 \ 24 | --lora_dropout 0.1 \ 25 | --double_quant \ 26 | --quant_type nf4 \ 27 | --fp16 \ 28 | --bits 4 \ 29 | --model_max_length 1024 \ 30 | --gradient_checkpointing \ 31 | --trust_remote_code True \ 32 | --use_auth_token True \ 33 | --do_train \ 34 | --do_eval \ 35 | --sample_generate \ 36 | --data_seed 42 \ 37 | --seed 0 38 | -------------------------------------------------------------------------------- /scripts/qlora_finetune/finetune_baichuan_7b_alpaca_zh.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_qlora.py \ 2 | --model_name_or_path ~/checkpoints/baichuan7b \ 3 | --dataset_cfg ./data/alpaca_zh_pcyn.yaml \ 4 | --output_dir ./work_dir/alpaca_zh-baichuan-7b \ 5 | --num_train_epochs 3 \ 6 | --per_device_train_batch_size 4 \ 7 | --per_device_eval_batch_size 4 \ 8 | --gradient_accumulation_steps 8 \ 9 | --evaluation_strategy steps \ 10 | --eval_steps 1000 \ 11 | --save_strategy steps \ 12 | --save_total_limit 10 \ 13 | --save_steps 1000 \ 14 | --logging_strategy steps \ 15 | --logging_steps 5 \ 16 | --learning_rate 0.0002 \ 17 | --warmup_ratio 0.03 \ 18 | --weight_decay 0.0 \ 19 | --lr_scheduler_type constant \ 20 | --adam_beta2 0.999 \ 21 | --max_grad_norm 0.3 \ 22 | --lora_r 64 \ 23 | --lora_alpha 16 \ 24 | --lora_dropout 0.1 \ 25 | --double_quant \ 26 | --quant_type nf4 \ 27 | --fp16 \ 28 | --bits 4 \ 29 | --model_max_length 1024 \ 30 | --gradient_checkpointing \ 31 | --trust_remote_code True \ 32 | --use_auth_token True \ 33 | --do_train \ 34 | --do_eval \ 35 | --sample_generate \ 36 | --data_seed 42 \ 37 | --seed 0 38 | -------------------------------------------------------------------------------- /data/format_data/merge.py: -------------------------------------------------------------------------------- 1 | """Merge two conversation files into one. 2 | 3 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | from datasets import load_dataset 10 | 11 | 12 | def json_load(in_file): 13 | with open(in_file, 'r') as f: 14 | json_data = json.load(f) 15 | return json_data 16 | 17 | 18 | def json_dump(obj, path): 19 | with open(path, 'w', encoding='utf-8') as f: 20 | json.dump(obj, f, indent=2, ensure_ascii=False) 21 | 22 | 23 | def merge_datasets(in_file_list, out_file): 24 | 25 | new_content = [] 26 | for in_file in in_file_list: 27 | content = load_dataset('json', data_files=in_file)['train'] 28 | 29 | print(f'in-file: {in_file}, len: {len(content)}') 30 | new_content.extend(content) 31 | 32 | print(f'#out: {len(new_content)}') 33 | print(f'Save new_content to {out_file}') 34 | json_dump(new_content, out_file) 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--in-file', type=str, required=True, nargs='+') 40 | parser.add_argument('--out-file', type=str, default='merged.json') 41 | args = parser.parse_args() 42 | 43 | merge_datasets(args.in_file, args.out_file) 44 | -------------------------------------------------------------------------------- /llamatuner/train/tuner.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from llamatuner.configs.parser import get_train_args 4 | from llamatuner.train.pt.train_pt import run_pt 5 | from llamatuner.train.sft.train_full import run_full_sft 6 | from llamatuner.train.sft.train_lora import run_lora_sft 7 | 8 | 9 | def run_exp(args: Optional[Dict[str, Any]] = None) -> None: 10 | model_args, data_args, training_args, finetuning_args, generating_args = ( 11 | get_train_args(args)) 12 | if finetuning_args.stage == 'pt': 13 | run_pt( 14 | model_args, 15 | data_args, 16 | training_args, 17 | finetuning_args, 18 | ) 19 | if finetuning_args.stage == 'full': 20 | run_full_sft( 21 | model_args, 22 | data_args, 23 | training_args, 24 | finetuning_args, 25 | generating_args, 26 | ) 27 | elif finetuning_args.stage == 'sft': 28 | run_lora_sft( 29 | model_args, 30 | data_args, 31 | training_args, 32 | finetuning_args, 33 | generating_args, 34 | ) 35 | else: 36 | raise ValueError('Unknown task: {}.'.format(finetuning_args.stage)) 37 | 38 | 39 | def launch(): 40 | run_exp() 41 | 42 | 43 | if __name__ == '__main__': 44 | launch() 45 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": false 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } 50 | -------------------------------------------------------------------------------- /scripts/ds_config/default_offload_opt_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": false 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } 50 | -------------------------------------------------------------------------------- /llamatuner/utils/packages.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import importlib.util 3 | from typing import TYPE_CHECKING 4 | 5 | from packaging import version 6 | 7 | if TYPE_CHECKING: 8 | from packaging.version import Version 9 | 10 | 11 | def _is_package_available(name: str) -> bool: 12 | return importlib.util.find_spec(name) is not None 13 | 14 | 15 | def _get_package_version(name: str) -> 'Version': 16 | try: 17 | return version.parse(importlib.metadata.version(name)) 18 | except importlib.metadata.PackageNotFoundError: 19 | return version.parse('0.0.0') 20 | 21 | 22 | def is_fastapi_available(): 23 | return _is_package_available('fastapi') 24 | 25 | 26 | def is_gradio_available(): 27 | return _is_package_available('gradio') 28 | 29 | 30 | def is_jieba_available(): 31 | return _is_package_available('jieba') 32 | 33 | 34 | def is_matplotlib_available(): 35 | return _is_package_available('matplotlib') 36 | 37 | 38 | def is_nltk_available(): 39 | return _is_package_available('nltk') 40 | 41 | 42 | def is_pillow_available(): 43 | return _is_package_available('PIL') 44 | 45 | 46 | def is_requests_available(): 47 | return _is_package_available('requests') 48 | 49 | 50 | def is_rouge_available(): 51 | return _is_package_available('rouge_chinese') 52 | 53 | 54 | def is_starlette_available(): 55 | return _is_package_available('sse_starlette') 56 | 57 | 58 | def is_uvicorn_available(): 59 | return _is_package_available('uvicorn') 60 | 61 | 62 | def is_vllm_available(): 63 | return _is_package_available('vllm') 64 | -------------------------------------------------------------------------------- /llamatuner/configs/eval_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Literal, Optional 4 | 5 | from datasets import DownloadMode 6 | 7 | 8 | @dataclass 9 | class EvaluationArguments: 10 | r""" 11 | Arguments pertaining to specify the evaluation parameters. 12 | """ 13 | 14 | task: str = field(metadata={'help': 'Name of the evaluation task.'}, ) 15 | task_dir: str = field( 16 | default='evaluation', 17 | metadata={ 18 | 'help': 'Path to the folder containing the evaluation datasets.' 19 | }, 20 | ) 21 | batch_size: int = field( 22 | default=4, 23 | metadata={'help': 'The batch size per GPU for evaluation.'}, 24 | ) 25 | seed: int = field( 26 | default=42, 27 | metadata={'help': 'Random seed to be used with data loaders.'}, 28 | ) 29 | lang: Literal['en', 'zh'] = field( 30 | default='en', 31 | metadata={'help': 'Language used at evaluation.'}, 32 | ) 33 | n_shot: int = field( 34 | default=5, 35 | metadata={'help': 'Number of examplars for few-shot learning.'}, 36 | ) 37 | save_dir: Optional[str] = field( 38 | default=None, 39 | metadata={'help': 'Path to save the evaluation results.'}, 40 | ) 41 | download_mode: DownloadMode = field( 42 | default=DownloadMode.REUSE_DATASET_IF_EXISTS, 43 | metadata={'help': 'Download mode used for the evaluation datasets.'}, 44 | ) 45 | 46 | def __post_init__(self): 47 | if self.save_dir is not None and os.path.exists(self.save_dir): 48 | raise ValueError('`save_dir` already exists, use another one.') 49 | -------------------------------------------------------------------------------- /llamatuner/data/dataset_factory/pt_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from torch.utils.data import Dataset 4 | from transformers.tokenization_utils import PreTrainedTokenizer 5 | 6 | from llamatuner.configs import DataArguments 7 | 8 | 9 | class PretrainDataset(Dataset): 10 | 11 | def __init__(self, examples: Dict[str, List[Any]], 12 | tokenizer: PreTrainedTokenizer, data_args: DataArguments): 13 | """ 14 | Initialize PretrainDataset with lazy loading. 15 | 16 | Args: 17 | examples: Dictionary containing the dataset examples 18 | tokenizer: Tokenizer for text processing 19 | data_args: Data arguments containing configuration 20 | """ 21 | self.tokenizer = tokenizer 22 | self.data_args = data_args 23 | self.raw_examples = examples['_prompt'] 24 | self.eos_token = '<|end_of_text|>' if data_args.template == 'llama3' else tokenizer.eos_token 25 | 26 | def __len__(self) -> int: 27 | return len(self.raw_examples) 28 | 29 | def __getitem__(self, idx: int) -> Dict[str, List[int]]: 30 | # Process single example on-the-fly 31 | messages = self.raw_examples[idx] 32 | text = messages[0]['content'] + self.eos_token 33 | 34 | if self.data_args.template == 'gemma': 35 | text = self.tokenizer.bos_token + text 36 | 37 | processed = self.tokenizer(text, 38 | add_special_tokens=False, 39 | truncation=True, 40 | max_length=self.data_args.cutoff_len) 41 | 42 | return {k: processed[k] for k in processed.keys()} 43 | -------------------------------------------------------------------------------- /llamatuner/utils/env_utils.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import accelerate 4 | import datasets 5 | import peft 6 | import torch 7 | import transformers 8 | import trl 9 | from transformers.utils import is_torch_cuda_available, is_torch_npu_available 10 | 11 | VERSION = '0.1.1.dev0' 12 | 13 | 14 | def print_env_info() -> None: 15 | info = { 16 | '`llamatuner` version': VERSION, 17 | 'Platform': platform.platform(), 18 | 'Python version': platform.python_version(), 19 | 'PyTorch version': torch.__version__, 20 | 'Transformers version': transformers.__version__, 21 | 'Datasets version': datasets.__version__, 22 | 'Accelerate version': accelerate.__version__, 23 | 'PEFT version': peft.__version__, 24 | 'TRL version': trl.__version__, 25 | } 26 | 27 | if is_torch_cuda_available(): 28 | info['PyTorch version'] += ' (GPU)' 29 | info['GPU type'] = torch.cuda.get_device_name() 30 | 31 | if is_torch_npu_available(): 32 | info['PyTorch version'] += ' (NPU)' 33 | info['NPU type'] = torch.npu.get_device_name() 34 | info['CANN version'] = torch.version.cann 35 | 36 | try: 37 | import deepspeed # type: ignore 38 | 39 | info['DeepSpeed version'] = deepspeed.__version__ 40 | except Exception: 41 | pass 42 | 43 | try: 44 | import bitsandbytes 45 | 46 | info['Bitsandbytes version'] = bitsandbytes.__version__ 47 | except Exception: 48 | pass 49 | 50 | try: 51 | import vllm 52 | 53 | info['vLLM version'] = vllm.__version__ 54 | except Exception: 55 | pass 56 | 57 | print('\n' + '\n'.join( 58 | ['- {}: {}'.format(key, value) for key, value in info.items()]) + '\n') 59 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_zero3_auto.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "total_num_steps": "auto", 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | "zero_optimization": { 32 | "stage": 3, 33 | "offload_optimizer": { 34 | "device": "cpu", 35 | "pin_memory": true 36 | }, 37 | "offload_param": { 38 | "device": "cpu", 39 | "pin_memory": true 40 | }, 41 | "overlap_comm": true, 42 | "contiguous_gradients": true, 43 | "allgather_partitions": true, 44 | "allgather_bucket_size": 5e8, 45 | "sub_group_size": 1e9, 46 | "reduce_bucket_size": "auto", 47 | "stage3_prefetch_bucket_size": "auto", 48 | "stage3_param_persistence_threshold": "auto", 49 | "stage3_max_live_parameters": 1e9, 50 | "stage3_max_reuse_distance": 1e9, 51 | "stage3_gather_16bit_weights_on_model_save": true 52 | }, 53 | "train_batch_size": "auto", 54 | "train_micro_batch_size_per_gpu": "auto", 55 | "gradient_accumulation_steps": "auto", 56 | "gradient_clipping": "auto", 57 | "steps_per_print": 5, 58 | "wall_clock_breakdown": false 59 | } 60 | -------------------------------------------------------------------------------- /scripts/ds_config/ds_config_zero3_auto.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "total_num_steps": "auto", 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | "zero_optimization": { 32 | "stage": 3, 33 | "offload_optimizer": { 34 | "device": "cpu", 35 | "pin_memory": true 36 | }, 37 | "offload_param": { 38 | "device": "cpu", 39 | "pin_memory": true 40 | }, 41 | "overlap_comm": true, 42 | "contiguous_gradients": true, 43 | "allgather_partitions": true, 44 | "allgather_bucket_size": 5e8, 45 | "sub_group_size": 1e9, 46 | "reduce_bucket_size": "auto", 47 | "stage3_prefetch_bucket_size": "auto", 48 | "stage3_param_persistence_threshold": "auto", 49 | "stage3_max_live_parameters": 1e9, 50 | "stage3_max_reuse_distance": 1e9, 51 | "stage3_gather_16bit_weights_on_model_save": true 52 | }, 53 | "train_batch_size": "auto", 54 | "train_micro_batch_size_per_gpu": "auto", 55 | "gradient_accumulation_steps": "auto", 56 | "gradient_clipping": "auto", 57 | "steps_per_print": 5, 58 | "wall_clock_breakdown": false 59 | } 60 | -------------------------------------------------------------------------------- /llamatuner/utils/stream_server.py: -------------------------------------------------------------------------------- 1 | """Helpers to support streaming generate output. 2 | 3 | Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py 4 | """ 5 | import traceback 6 | from queue import Queue 7 | from threading import Thread 8 | 9 | import transformers 10 | 11 | 12 | class Stream(transformers.StoppingCriteria): 13 | 14 | def __init__(self, callback_func=None): 15 | self.callback_func = callback_func 16 | 17 | def __call__(self, input_ids, scores) -> bool: 18 | if self.callback_func is not None: 19 | self.callback_func(input_ids[0]) 20 | return False 21 | 22 | 23 | class Iteratorize: 24 | """Transforms a function that takes a callback into a lazy iterator 25 | (generator).""" 26 | 27 | def __init__(self, func, kwargs={}, callback=None): 28 | self.mfunc = func 29 | self.c_callback = callback 30 | self.q = Queue() 31 | self.sentinel = object() 32 | self.kwargs = kwargs 33 | self.stop_now = False 34 | 35 | def _callback(val): 36 | if self.stop_now: 37 | raise ValueError 38 | self.q.put(val) 39 | 40 | def gentask(): 41 | try: 42 | ret = self.mfunc(callback=_callback, **self.kwargs) 43 | except ValueError: 44 | pass 45 | except: 46 | traceback.print_exc() 47 | pass 48 | 49 | self.q.put(self.sentinel) 50 | if self.c_callback: 51 | self.c_callback(ret) 52 | 53 | self.thread = Thread(target=gentask) 54 | self.thread.start() 55 | 56 | def __iter__(self): 57 | return self 58 | 59 | def __next__(self): 60 | obj = self.q.get(True, None) 61 | if obj is self.sentinel: 62 | raise StopIteration 63 | else: 64 | return obj 65 | 66 | def __enter__(self): 67 | return self 68 | 69 | def __exit__(self, exc_type, exc_val, exc_tb): 70 | self.stop_now = True 71 | -------------------------------------------------------------------------------- /data/ultra_chat/ultra_chat.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | 5 | import datasets 6 | 7 | _HF_ENDPOINT = os.getenv('HF_ENDPOINT', 'https://huggingface.co') 8 | 9 | _DESCRIPTION = 'UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data.' 10 | 11 | _CITATION = """\ 12 | @misc{UltraChat, 13 | author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen}, 14 | title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data}, 15 | year = {2023}, 16 | publisher = {GitHub}, 17 | journal = {GitHub repository}, 18 | howpublished = {\\url{https://github.com/thunlp/ultrachat}}, 19 | } 20 | """ 21 | 22 | _HOMEPAGE = '{}/datasets/stingning/ultrachat'.format(_HF_ENDPOINT) 23 | _LICENSE = 'cc-by-nc-4.0' 24 | _BASE_DATA_URL = '{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl'.format( 25 | _HF_ENDPOINT) 26 | 27 | 28 | class UltraChat(datasets.GeneratorBasedBuilder): 29 | VERSION = datasets.Version('0.0.0') 30 | 31 | def _info(self): 32 | features = datasets.Features({ 33 | 'conversations': [{ 34 | 'from': datasets.Value('string'), 35 | 'value': datasets.Value('string') 36 | }] 37 | }) 38 | return datasets.DatasetInfo(description=_DESCRIPTION, 39 | features=features, 40 | homepage=_HOMEPAGE, 41 | license=_LICENSE, 42 | citation=_CITATION) 43 | 44 | def _split_generators(self, dl_manager: datasets.DownloadManager): 45 | file_paths = [ 46 | dl_manager.download(_BASE_DATA_URL.format(idx=idx)) 47 | for idx in range(10) 48 | ] # multiple shards 49 | return [ 50 | datasets.SplitGenerator(name=datasets.Split.TRAIN, 51 | gen_kwargs={'filepaths': file_paths}) 52 | ] 53 | 54 | def _generate_examples(self, filepaths: List[str]): 55 | for filepath in filepaths: 56 | with open(filepath, 'r', encoding='utf-8') as f: 57 | for row in f: 58 | try: 59 | data = json.loads(row) 60 | except Exception: 61 | continue 62 | key: int = data['id'] 63 | content: List[str] = data['data'] 64 | if len(content) % 2 == 1: 65 | content.pop(-1) 66 | if len(content) < 2: 67 | continue 68 | conversations = [{ 69 | 'from': 'human' if i % 2 == 0 else 'gpt', 70 | 'value': content[i] 71 | } for i in range(len(content))] 72 | yield key, {'conversations': conversations} 73 | -------------------------------------------------------------------------------- /server/multi_chat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from threading import Thread 3 | 4 | import torch 5 | import transformers 6 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 7 | TextIteratorStreamer) 8 | 9 | sys.path.append('../') 10 | from llamatuner.configs import GenerationArguments, ModelInferenceArguments 11 | from llamatuner.utils.model_utils import get_logits_processor 12 | 13 | 14 | def main(model_server_args, generation_args): 15 | """多轮对话,不具有对话历史的记忆功能.""" 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | model = AutoModelForCausalLM.from_pretrained( 18 | model_server_args.model_name_or_path, 19 | cache_dir=model_server_args.cache_dir, 20 | trust_remote_code=True, 21 | low_cpu_mem_usage=True, 22 | torch_dtype=torch.float16, 23 | device_map='auto').to(device).eval() 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | model_server_args.model_name_or_path, 26 | trust_remote_code=True, 27 | use_fast=False, 28 | ) 29 | # 记录所有历史记录 30 | historys = tokenizer.bos_token 31 | print('User: ', end='', flush=True) 32 | user_input = input('') 33 | while True: 34 | user_input = '{}'.format(user_input).strip() 35 | historys = historys + user_input 36 | inputs = tokenizer(historys, 37 | return_tensors='pt', 38 | add_special_tokens=False) 39 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 40 | 41 | # Create a TextIteratorStreamer object to stream the response from the model 42 | streamer = TextIteratorStreamer(tokenizer, 43 | timeout=60.0, 44 | skip_prompt=True, 45 | skip_special_tokens=True) 46 | 47 | # Set the arguments for the model's generate() method 48 | gen_kwargs = dict( 49 | inputs, 50 | streamer=streamer, 51 | logits_processor=get_logits_processor(), 52 | **generation_args.to_dict(), 53 | ) 54 | 55 | # Start a separate thread to generate the response asynchronously 56 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 57 | thread.start() 58 | 59 | # Print the model name and the response as it is generated 60 | print('Assistant: ', end='', flush=True) 61 | response = '' 62 | for new_text in streamer: 63 | print(new_text, end='', flush=True) 64 | response += new_text 65 | 66 | historys = historys + response 67 | print('\n') 68 | print('User: ', end='', flush=True) 69 | user_input = input('') 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = transformers.HfArgumentParser( 74 | (ModelInferenceArguments, GenerationArguments)) 75 | model_server_args, generation_args = parser.parse_args_into_dataclasses() 76 | main(model_server_args, generation_args) 77 | -------------------------------------------------------------------------------- /data/format_data/convert_oasst1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | 7 | def json_dump(obj, path): 8 | with open(path, 'w', encoding='utf-8') as f: 9 | json.dump(obj, f, indent=2, ensure_ascii=False) 10 | 11 | 12 | def json_load(in_file): 13 | with open(in_file, 'r') as f: 14 | json_data = json.load(f) 15 | return json_data 16 | 17 | 18 | def convert_oasst1_data(data_dir, output_dir): 19 | """For OASST1, because it's in a tree structure, where every user input 20 | might get multiple replies, we have to save every path from the root node 21 | to the assistant reply (including both leaf node and intemediate node). 22 | 23 | This results in some of the messages being duplicated among different paths 24 | (instances). Be careful when using this dataset for training. Ideally, you 25 | should only minimize the loss of the last message in each path. 26 | """ 27 | conversations = [] 28 | with open(os.path.join(data_dir, '2023-04-12_oasst_ready.trees.jsonl'), 29 | 'r') as fin: 30 | for line in fin: 31 | conversations.append(json.loads(line)) 32 | 33 | output_path = os.path.join(output_dir, 'oasst1_data.jsonl') 34 | 35 | # tranvers the conversation tree, and collect all valid sequences 36 | def dfs(reply, messages, valid_sequences): 37 | if reply['role'] == 'assistant': 38 | messages.append({'role': 'assistant', 'content': reply['text']}) 39 | valid_sequences.append(messages[:]) 40 | for child in reply['replies']: 41 | dfs(child, messages, valid_sequences) 42 | messages.pop() 43 | elif reply['role'] == 'prompter': 44 | messages.append({'role': 'user', 'content': reply['text']}) 45 | for child in reply['replies']: 46 | dfs(child, messages, valid_sequences) 47 | messages.pop() 48 | else: 49 | raise ValueError(f"Unknown role: {reply['role']}") 50 | 51 | with open(output_path, 'w') as fout: 52 | example_cnt = 0 53 | for _, conversation in enumerate(conversations): 54 | valid_sequences = [] 55 | dfs(conversation['prompt'], [], valid_sequences) 56 | for sequence in valid_sequences: 57 | fout.write( 58 | json.dumps({ 59 | 'dataset': 'oasst1', 60 | 'id': f'oasst1_{example_cnt}', 61 | 'messages': sequence 62 | }) + '\n') 63 | example_cnt += 1 64 | 65 | 66 | if __name__ == '__main__': 67 | arg_parser = argparse.ArgumentParser() 68 | arg_parser.add_argument('--raw_data_dir', 69 | type=str, 70 | default='data/downloads') 71 | arg_parser.add_argument('--output_dir', type=str, default='data/processed') 72 | arg_parser.add_argument('--seed', type=int, default=42) 73 | args = arg_parser.parse_args() 74 | random.seed(args.seed) 75 | 76 | convert_oasst1_data(data_dir=args.raw_data_dir, output_dir=args.output_dir) 77 | -------------------------------------------------------------------------------- /llamatuner/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import subprocess 4 | import sys 5 | from enum import Enum, unique 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | from llamatuner import launcher 10 | from llamatuner.train.tuner import run_exp 11 | from llamatuner.utils.env_utils import VERSION, print_env_info 12 | from llamatuner.utils.logger_utils import get_logger 13 | from llamatuner.utils.misc import get_device_count 14 | 15 | USAGE = ( 16 | '-' * 70 + '\n' + 17 | '| Usage: |\n' + 18 | '| llamatuner-cli train -h: train models |\n' + 19 | '-' * 70) 20 | 21 | WELCOME = ('-' * 58 + '\n' + 22 | '| Welcome to LLamaTuner, version {}'.format(VERSION) + ' ' * 23 | (21 - len(VERSION)) + '|\n|' + ' ' * 56 + '|\n' + 24 | '| Project page: https://github.com/hiyouga/LLaMA-Factory |\n' + 25 | '-' * 58) 26 | 27 | logger = get_logger(__name__) 28 | 29 | 30 | @unique 31 | class Command(str, Enum): 32 | API = 'api' 33 | CHAT = 'chat' 34 | ENV = 'env' 35 | EVAL = 'eval' 36 | EXPORT = 'export' 37 | TRAIN = 'train' 38 | WEBDEMO = 'webchat' 39 | WEBUI = 'webui' 40 | VER = 'version' 41 | HELP = 'help' 42 | 43 | 44 | def main() -> None: 45 | command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP 46 | if command == Command.ENV: 47 | print_env_info() 48 | elif command == Command.TRAIN: 49 | force_torchrun = os.environ.get('FORCE_TORCHRUN', 50 | '0').lower() in ['true', '1'] 51 | if force_torchrun or get_device_count() > 1: 52 | master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1') 53 | master_port = os.environ.get('MASTER_PORT', 54 | str(random.randint(20001, 29999))) 55 | logger.info('Initializing distributed tasks at: {}:{}'.format( 56 | master_addr, master_port)) 57 | process = subprocess.run( 58 | ('torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} ' 59 | '--master_addr {master_addr} --master_port {master_port} {file_name} {args}' 60 | ).format( 61 | nnodes=os.environ.get('NNODES', '1'), 62 | node_rank=os.environ.get('RANK', '0'), 63 | nproc_per_node=os.environ.get('NPROC_PER_NODE', 64 | str(get_device_count())), 65 | master_addr=master_addr, 66 | master_port=master_port, 67 | file_name=launcher.__file__, 68 | args=' '.join(sys.argv[1:]), 69 | ), 70 | shell=True, 71 | ) 72 | sys.exit(process.returncode) 73 | else: 74 | run_exp() 75 | elif command == Command.VER: 76 | print(WELCOME) 77 | elif command == Command.HELP: 78 | print(USAGE) 79 | else: 80 | raise NotImplementedError('Unknown command: {}'.format(command)) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /llamatuner/train/apply_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Tuple 3 | 4 | import torch 5 | from peft import PeftModel 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel 7 | 8 | from llamatuner.utils.logger_utils import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | def apply_lora( 14 | base_model_path: str, 15 | lora_model_path: str, 16 | target_model_path: str = None, 17 | cache_dir: str = None, 18 | trust_remote_code: bool = True, 19 | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: 20 | """Applies the LoRA adapter to a base model and saves the resulting target 21 | model (optional). 22 | 23 | Args: 24 | base_model_path (str): The path to the base model to which the LoRA adapter will be applied. 25 | lora_model_path (str): The path to the LoRA adapter. 26 | target_model_path (str): The path where the target model will be saved (if `save_target_model=True`). 27 | cache_dir (str): The path to the cache directory. 28 | trust_remote_code (bool): Whether to trust remote code when downloading the model. 29 | 30 | Returns: 31 | Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the target model and its tokenizer. 32 | """ 33 | # Load the base model and tokenizer 34 | logger.info(f'Loading the base model from {base_model_path}') 35 | # Set configuration kwargs for tokenizer. 36 | config_kwargs = { 37 | 'cache_dir': cache_dir, 38 | 'trust_remote_code': trust_remote_code, 39 | } 40 | 41 | base_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( 42 | base_model_path, 43 | device_map='auto', 44 | torch_dtype=torch.float16, 45 | low_cpu_mem_usage=True, 46 | **config_kwargs, 47 | ) 48 | 49 | # Load the tokenizer 50 | logger.info(f'Loading the tokenizer from {base_model_path}') 51 | # Due to the name of Transformers' LlamaTokenizer, we have to do this 52 | tokenizer = AutoTokenizer.from_pretrained( 53 | base_model_path, 54 | use_fast=False, 55 | **config_kwargs, 56 | ) 57 | 58 | # Load the LoRA adapter 59 | logger.info(f'Loading the LoRA adapter from {lora_model_path}') 60 | model: PreTrainedModel = PeftModel.from_pretrained(base_model, 61 | lora_model_path) 62 | logger.info('Applying the LoRA to base model') 63 | model = model.merge_and_unload() 64 | 65 | if target_model_path is not None: 66 | logger.info(f'Saving the target model to {target_model_path}') 67 | model.save_pretrained(target_model_path) 68 | tokenizer.save_pretrained(target_model_path) 69 | 70 | return model, tokenizer 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--base-model-path', type=str, required=True) 76 | parser.add_argument('--target-model-path', type=str, default=None) 77 | parser.add_argument('--lora-model-path', type=str, required=True) 78 | args = parser.parse_args() 79 | 80 | apply_lora( 81 | base_model_path=args.base_model_path, 82 | lora_model_path=args.lora_model_path, 83 | target_model_path=args.target_model_path, 84 | ) 85 | -------------------------------------------------------------------------------- /llamatuner/data/dataset_factory/reward_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from torch.utils.data import Dataset 6 | from transformers import PreTrainedTokenizer 7 | 8 | 9 | class PairwiseDataset(Dataset): 10 | """Dataset class for pairwise ranking tasks. 11 | 12 | Args: 13 | data_path: Path to the dataset. 14 | tokenizer: The tokenizer used to encode the input text. 15 | max_length: Maximum sequence length for the encoded inputs. 16 | """ 17 | 18 | def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer, 19 | split: str, max_length: int): 20 | 21 | self.pairs = self.create_comparison_dataset(data_path, split) 22 | self.tokenizer = tokenizer 23 | self.max_length = max_length 24 | 25 | def __len__(self) -> int: 26 | return len(self.pairs) 27 | 28 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 29 | if idx < 0 or idx >= len(self.pairs): 30 | raise IndexError( 31 | f'Index {idx} out of range for TLDRDataset with length {len(self)}' 32 | ) 33 | pair = self.pairs[idx] 34 | chosen_example, rejected_example = pair['chosen'], pair['rejected'] 35 | 36 | chosen_encodings_dict = self.tokenizer(chosen_example, 37 | truncation=True, 38 | max_length=self.max_length, 39 | padding='max_length') 40 | rejected_encodings_dict = self.tokenizer(rejected_example, 41 | truncation=True, 42 | max_length=self.max_length, 43 | padding='max_length') 44 | encodings_input = {} 45 | encodings_input['chosen_input_ids'] = chosen_encodings_dict[ 46 | 'input_ids'] 47 | encodings_input['chosen_attention_mask'] = chosen_encodings_dict[ 48 | 'attention_mask'] 49 | encodings_input['rejected_input_ids'] = rejected_encodings_dict[ 50 | 'input_ids'] 51 | encodings_input['rejected_attention_mask'] = rejected_encodings_dict[ 52 | 'attention_mask'] 53 | encodings_input['labels'] = 1.0 54 | 55 | encodings_input = { 56 | key: torch.tensor(val) 57 | for key, val in encodings_input.items() 58 | } 59 | 60 | return encodings_input 61 | 62 | def create_comparison_dataset(self, path: str, split: str = 'train'): 63 | dataset = load_dataset(path, split=split) 64 | pairs = [] 65 | for prompt, chosen_summary, rejected_summary in zip( 66 | dataset['prompt'], dataset['chosen'], dataset['rejected']): 67 | pair = {} 68 | if chosen_summary == rejected_summary: 69 | continue 70 | if len(chosen_summary.split()) < 5 or len( 71 | rejected_summary.split()) < 5: 72 | continue 73 | 74 | pair[ 75 | 'chosen'] = '<|startoftext|>' + prompt + '\n' + chosen_summary + '<|endoftext|>' 76 | pair[ 77 | 'rejected'] = '<|startoftext|>' + prompt + '\n' + rejected_summary + '<|endoftext|>' 78 | pairs.append(pair) 79 | 80 | return pairs 81 | -------------------------------------------------------------------------------- /llamatuner/data/data_collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Sequence 3 | 4 | import torch 5 | from transformers import DataCollatorForSeq2Seq 6 | 7 | 8 | @dataclass 9 | class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): 10 | r""" 11 | Data collator for pairwise data. 12 | """ 13 | 14 | def __call__( 15 | self, features: Sequence[Dict[str, 16 | Any]]) -> Dict[str, 'torch.Tensor']: 17 | r""" 18 | Pads batched data to the longest sequence in the batch. 19 | 20 | We generate 2 * n examples where the first n examples represent chosen examples and 21 | the last n examples represent rejected examples. 22 | """ 23 | concatenated_features = [] 24 | for key in ('chosen', 'rejected'): 25 | for feature in features: 26 | target_feature = { 27 | 'input_ids': feature['{}_input_ids'.format(key)], 28 | 'attention_mask': feature['{}_attention_mask'.format(key)], 29 | 'labels': feature['{}_labels'.format(key)], 30 | } 31 | if 'pixel_values' in feature: 32 | target_feature['pixel_values'] = feature['pixel_values'] 33 | 34 | if '{}_token_type_ids'.format(key) in feature: 35 | target_feature['token_type_ids'] = feature[ 36 | '{}_token_type_ids'.format(key)] 37 | 38 | concatenated_features.append(target_feature) 39 | 40 | return super().__call__(concatenated_features) 41 | 42 | 43 | @dataclass 44 | class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): 45 | r""" 46 | Data collator for KTO data. 47 | """ 48 | 49 | def __call__( 50 | self, features: Sequence[Dict[str, 51 | Any]]) -> Dict[str, 'torch.Tensor']: 52 | target_features = [] 53 | kl_features = [] 54 | kto_tags = [] 55 | for feature in features: 56 | target_feature = { 57 | 'input_ids': feature['input_ids'], 58 | 'attention_mask': feature['attention_mask'], 59 | 'labels': feature['labels'], 60 | } 61 | kl_feature = { 62 | 'input_ids': feature['kl_input_ids'], 63 | 'attention_mask': feature['kl_attention_mask'], 64 | 'labels': feature['kl_labels'], 65 | } 66 | if 'pixel_values' in feature: 67 | target_feature['pixel_values'] = feature['pixel_values'] 68 | 69 | if 'token_type_ids' in feature: 70 | target_feature['token_type_ids'] = feature['token_type_ids'] 71 | kl_feature['token_type_ids'] = feature['kl_token_type_ids'] 72 | 73 | target_features.append(target_feature) 74 | kl_features.append(kl_feature) 75 | kto_tags.append(feature['kto_tags']) 76 | 77 | batch = super().__call__(target_features) 78 | kl_batch = super().__call__(kl_features) 79 | batch['kl_input_ids'] = kl_batch['input_ids'] 80 | batch['kl_attention_mask'] = kl_batch['attention_mask'] 81 | batch['kl_labels'] = kl_batch['labels'] 82 | if 'token_type_ids' in batch: 83 | batch['kl_token_type_ids'] = kl_batch['token_type_ids'] 84 | 85 | batch['kto_tags'] = torch.tensor(kto_tags) 86 | return batch 87 | -------------------------------------------------------------------------------- /data/format_data/convert_vicuna.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from datasets import load_dataset 5 | 6 | sys.path.append('../../') 7 | 8 | from llamatuner.data.data_utils import extract_default_prompt_dataset 9 | 10 | 11 | def json_dump(obj, path): 12 | with open(path, 'w', encoding='utf-8') as f: 13 | json.dump(obj, f, indent=2, ensure_ascii=False) 14 | 15 | 16 | def json_load(in_file): 17 | with open(in_file, 'r') as f: 18 | json_data = json.load(f) 19 | return json_data 20 | 21 | 22 | def valid_keys(keys): 23 | for k in ['input', 'output']: 24 | if k not in keys: 25 | return False 26 | return True 27 | 28 | 29 | def remove_unused_columns(dataset): 30 | """Remove columns not named 'input' or 'output'.""" 31 | dataset = dataset.remove_columns([ 32 | col for col in dataset.column_names if col not in ['input', 'output'] 33 | ]) 34 | return dataset 35 | 36 | 37 | def convert_alpaca_vicuna(in_file: str, out_file: str = None): 38 | raw_dataset = load_dataset('json', data_files=in_file)['train'] 39 | raw_dataset = raw_dataset.map(extract_default_prompt_dataset) 40 | 41 | collect_data = [] 42 | for i, content in enumerate(raw_dataset): 43 | prompt = content['input'] 44 | response = content['output'] 45 | 46 | collect_data.append({ 47 | 'id': 48 | f'alpaca_{i}', 49 | 'conversations': [ 50 | { 51 | 'from': 'human', 52 | 'value': prompt 53 | }, 54 | { 55 | 'from': 'gpt', 56 | 'value': response 57 | }, 58 | ], 59 | }) 60 | print(f'Original: {len(raw_dataset)}, Converted: {len(collect_data)}') 61 | json_dump(collect_data, out_file) 62 | return collect_data 63 | 64 | 65 | if __name__ == '__main__': 66 | in_file = '/home/robin/prompt_data/100PoisonMpts/train_alpaca.json' 67 | out_file = '/home/robin/prompt_data/100PoisonMpts/train_vicuna.json' 68 | collect_data = convert_alpaca_vicuna(in_file, out_file) 69 | 70 | data_path = '/home/robin/prompt_data/CValues-Comparison/test_alpaca.json' 71 | out_path = '/home/robin/prompt_data/CValues-Comparison/test_vicuna.json' 72 | convert_alpaca_vicuna(data_path, out_file=out_path) 73 | 74 | data_path = '/home/robin/prompt_data/CValues-Comparison/train_alpaca.json' 75 | out_path = '/home/robin/prompt_data/CValues-Comparison/train_vicuna.json' 76 | convert_alpaca_vicuna(data_path, out_file=out_path) 77 | 78 | data_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_alpaca.json' 79 | out_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_vicnua.json' 80 | convert_alpaca_vicuna(data_path, out_file=out_path) 81 | 82 | data_path = '/home/robin/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json' 83 | out_path = '/home/robin/prompt_data/Safety-Prompts/attack_scenarios_vicuna.json' 84 | convert_alpaca_vicuna(data_path, out_file=out_path) 85 | 86 | data_path = '/home/robin/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json' 87 | out_path = '/home/robin/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json' 88 | convert_alpaca_vicuna(data_path, out_file=out_path) 89 | 90 | data_path = '/home/robin/prompt_data/COIG/train_alpaca.json' 91 | out_path = '/home/robin/prompt_data/COIG/train_vicuna.json' 92 | convert_alpaca_vicuna(data_path, out_file=out_path) 93 | -------------------------------------------------------------------------------- /llamatuner/configs/generating_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from typing import Any, Dict, Optional 3 | 4 | from transformers import GenerationConfig 5 | 6 | 7 | @dataclass 8 | class GeneratingArguments: 9 | """Arguments pertaining to specify the model generation parameters.""" 10 | 11 | # Generation strategy 12 | # 是否采样 13 | do_sample: Optional[bool] = field( 14 | default=True, 15 | metadata={ 16 | 'help': 17 | 'Whether or not to use sampling, use greedy decoding otherwise.' 18 | }, 19 | ) 20 | # Hyperparameters for logit manipulation 21 | # softmax 函数的温度因子,来调节输出token的分布 22 | temperature: Optional[float] = field( 23 | default=1.0, 24 | metadata={ 25 | 'help': 'The value used to modulate the next token probabilities.' 26 | }, 27 | ) 28 | # 核采样参数,top_p最高的前n个(n是变化)概率和为p,从这些n个候选token中随机采样 29 | top_p: Optional[float] = field( 30 | default=1.0, 31 | metadata={ 32 | 'help': 33 | 'The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept.' 34 | }, 35 | ) 36 | # top_k随机搜索中的k个最高概率选择 37 | top_k: Optional[int] = field( 38 | default=50, 39 | metadata={ 40 | 'help': 41 | 'The number of highest probability vocabulary tokens to keep for top-k filtering.' 42 | }, 43 | ) 44 | # 集束搜索的数量 45 | num_beams: Optional[int] = field( 46 | default=1, 47 | metadata={ 48 | 'help': 'Number of beams for beam search. 1 means no beam search.' 49 | }, 50 | ) 51 | # 最大的token数量,会被 max_new_tokens 覆盖 52 | max_length: Optional[int] = field( 53 | default=1024, 54 | metadata={ 55 | 'help': 56 | 'The maximum length the generated tokens can have. It can be overridden by max_new_tokens.' 57 | }, 58 | ) 59 | # 最大的新生成的token数量 60 | max_new_tokens: Optional[int] = field( 61 | default=1024, 62 | metadata={ 63 | 'help': 64 | 'Maximum number of new tokens to be generated in evaluation or prediction loops' 65 | 'if predict_with_generate is set.' 66 | }, 67 | ) 68 | # 重复性惩罚因子 69 | repetition_penalty: Optional[float] = field( 70 | default=1.0, 71 | metadata={ 72 | 'help': 73 | 'The parameter for repetition penalty. 1.0 means no penalty.' 74 | }) 75 | # 长度惩罚因子 76 | length_penalty: Optional[float] = field( 77 | default=1.0, 78 | metadata={ 79 | 'help': 80 | 'Exponential penalty to the length that is used with beam-based generation.' 81 | }) 82 | default_system: Optional[str] = field( 83 | default=None, 84 | metadata={'help': 'Default system message to use in chat completion.'}, 85 | ) 86 | skip_special_tokens: bool = field( 87 | default=True, 88 | metadata={ 89 | 'help': 'Whether or not to remove special tokens in the decoding.' 90 | }, 91 | ) 92 | 93 | def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]: 94 | args = asdict(self) 95 | if args.get('max_new_tokens', -1) > 0: 96 | args.pop('max_length', None) 97 | else: 98 | args.pop('max_new_tokens', None) 99 | 100 | if obey_generation_config: 101 | generation_config = GenerationConfig() 102 | for key in list(args.keys()): 103 | if not hasattr(generation_config, key): 104 | args.pop(key) 105 | 106 | return args 107 | -------------------------------------------------------------------------------- /data/dataset_info.yaml: -------------------------------------------------------------------------------- 1 | # The dataset_info.yaml file contains the information of the datasets used in the experiments. 2 | ## pretrain dataset 3 | c4_demo: 4 | file_name: c4_demo.json 5 | columns: 6 | prompt: text 7 | 8 | open-web-math: 9 | hf_hub_url: open-web-math/open-web-math 10 | split: train[10:20] 11 | formatting: alpaca 12 | columns: 13 | prompt: text 14 | 15 | ## sft dataset 16 | alpaca: 17 | hf_hub_url: tatsu-lab/alpaca 18 | formatting: alpaca 19 | 20 | alpaca-clean: 21 | hf_hub_url: yahma/alpaca-cleaned 22 | formatting: alpaca 23 | 24 | dolly-15k: 25 | hf_hub_url: databricks/databricks-dolly-15k 26 | formatting: alpaca 27 | 28 | guanaco: 29 | hf_hub_url: JosephusCheung/GuanacoDataset 30 | ms_hub_url: AI-ModelScope/GuanacoDataset 31 | formatting: alpaca 32 | 33 | openassistant-guanaco: 34 | hf_hub_url: timdettmers/openassistant-guanaco 35 | formatting: alpaca 36 | 37 | belle_0.5m: 38 | hf_hub_url: BelleGroup/train_0.5M_CN 39 | ms_hub_url: AI-ModelScope/train_0.5M_CN 40 | formatting: alpaca 41 | 42 | belle_1m: 43 | hf_hub_url: BelleGroup/train_1M_CN 44 | ms_hub_url: AI-ModelScope/train_1M_CN 45 | formatting: alpaca 46 | 47 | belle_2m: 48 | hf_hub_url: BelleGroup/train_2M_CN 49 | ms_hub_url: AI-ModelScope/train_2M_CN 50 | formatting: alpaca 51 | 52 | belle_dialog: 53 | hf_hub_url: BelleGroup/generated_chat_0.4M 54 | ms_hub_url: AI-ModelScope/generated_chat_0.4M 55 | formatting: alpaca 56 | 57 | belle_math: 58 | hf_hub_url: BelleGroup/school_math_0.25M 59 | ms_hub_url: AI-ModelScope/school_math_0.25M 60 | formatting: alpaca 61 | 62 | belle_multiturn: 63 | hf_hub_url: BelleGroup/multi_turn_0.5M 64 | formatting: sharegpt 65 | columns: 66 | prompt: instruction 67 | response: output 68 | history: history 69 | 70 | firefly: 71 | hf_hub_url: YeungNLP/firefly-train-1.1M 72 | formatting: alpaca 73 | columns: 74 | prompt: input 75 | response: target 76 | 77 | codealpaca: 78 | hf_hub_url: sahil2801/CodeAlpaca-20k 79 | ms_hub_url: AI-ModelScope/CodeAlpaca-20k 80 | formatting: alpaca 81 | 82 | alpaca_cot: 83 | hf_hub_url: QingyiSi/Alpaca-CoT 84 | ms_hub_url: AI-ModelScope/Alpaca-CoT 85 | 86 | webqa: 87 | hf_hub_url: suolyer/webqa 88 | ms_hub_url: AI-ModelScope/webqa 89 | formatting: alpaca 90 | columns: 91 | prompt: input 92 | response: output 93 | 94 | # mutli-turn datasets 95 | evol_instruct: 96 | hf_hub_url: MaziyarPanahi/WizardLM_evol_instruct_V2_196k 97 | ms_hub_url: AI-ModelScope/WizardLM_evol_instruct_V2_196k 98 | formatting: sharegpt 99 | 100 | ultrachat_200k: 101 | hf_hub_url: HuggingFaceH4/ultrachat_200k 102 | ms_hub_url: AI-ModelScope/ultrachat_200k 103 | formatting: sharegpt 104 | columns: 105 | messages: messages 106 | tags: 107 | role_tag: role 108 | content_tag: content 109 | user_tag: user 110 | assistant_tag: assistant 111 | 112 | lmsys_chat: 113 | hf_hub_url: lmsys/lmsys-chat-1m 114 | ms_hub_url: AI-ModelScope/lmsys-chat-1m 115 | formatting: sharegpt 116 | columns: 117 | messages: conversation 118 | tags: 119 | role_tag: role 120 | content_tag: content 121 | user_tag: human 122 | assistant_tag: assistant 123 | 124 | hh_rlhf_en: 125 | script_url: hh_rlhf_en 126 | ranking: true 127 | columns: 128 | prompt: instruction 129 | chosen: chosen 130 | rejected: rejected 131 | history: history 132 | 133 | orca_pairs: 134 | hf_hub_url: Intel/orca_dpo_pairs 135 | ranking: true 136 | columns: 137 | prompt: question 138 | chosen: chosen 139 | rejected: rejected 140 | system: system 141 | 142 | kto_mix_en: 143 | hf_hub_url: argilla/kto-mix-15k 144 | formatting: sharegpt 145 | columns: 146 | messages: completion 147 | kto_tag: label 148 | tags: 149 | role_tag: role 150 | content_tag: content 151 | user_tag: user 152 | assistant_tag: assistant 153 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | *.json 162 | wandb/ 163 | hf_hub 164 | logs/ 165 | docs/ 166 | checkpoints/ 167 | work_dir/ 168 | work_dirs/ 169 | output/ 170 | outputs/ 171 | *.jso 172 | -------------------------------------------------------------------------------- /llamatuner/model/callbacks/wandb_callback.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import transformers 4 | from transformers.integrations import WandbCallback 5 | 6 | 7 | def decode_predictions(tokenizer, predictions): 8 | labels = tokenizer.batch_decode(predictions.label_ids) 9 | logits = predictions.predictions.argmax(axis=-1) 10 | prediction_text = tokenizer.batch_decode(logits) 11 | return {'labels': labels, 'predictions': prediction_text} 12 | 13 | 14 | class WandbPredictionProgressCallback(WandbCallback): 15 | """Custom WandbCallback to log model predictions during training. 16 | 17 | This callback logs model predictions and labels to a wandb.Table at each 18 | logging step during training. It allows to visualize the 19 | model predictions as the training progresses. 20 | 21 | Attributes: 22 | trainer (Trainer): The Hugging Face Trainer instance. 23 | tokenizer (AutoTokenizer): The tokenizer associated with the model. 24 | sample_dataset (Dataset): A subset of the validation dataset 25 | for generating predictions. 26 | num_samples (int, optional): Number of samples to select from 27 | the validation dataset for generating predictions. Defaults to 100. 28 | freq (int, optional): Frequency of logging. Defaults to 2. 29 | 30 | 31 | Example: 32 | ```python 33 | # First, instantiate the Trainer 34 | trainer = Trainer( 35 | model=model, 36 | args=training_args, 37 | train_dataset=lm_datasets["train"], 38 | eval_dataset=lm_datasets["validation"], 39 | ) 40 | 41 | # Instantiate the WandbPredictionProgressCallback 42 | progress_callback = WandbPredictionProgressCallback( 43 | trainer=trainer, 44 | tokenizer=tokenizer, 45 | val_dataset=lm_dataset["validation"], 46 | num_samples=10, 47 | freq=2, 48 | ) 49 | 50 | # Add the callback to the trainer 51 | trainer.add_callback(progress_callback) 52 | ``` 53 | """ 54 | 55 | def __init__( 56 | self, 57 | trainer: transformers.Trainer, 58 | tokenizer: transformers.AutoTokenizer, 59 | val_dataset: torch.utils.data.Dataset, 60 | num_samples: int = 100, 61 | freq: int = 2, 62 | ) -> None: 63 | """Initializes the WandbPredictionProgressCallback instance. 64 | 65 | Args: 66 | trainer (Trainer): The Hugging Face Trainer instance. 67 | tokenizer (AutoTokenizer): The tokenizer associated 68 | with the model. 69 | val_dataset (Dataset): The validation dataset. 70 | num_samples (int, optional): Number of samples to select from 71 | the validation dataset for generating predictions. 72 | Defaults to 100. 73 | freq (int, optional): Frequency of logging. Defaults to 2. 74 | """ 75 | super().__init__() 76 | self.trainer = trainer 77 | self.tokenizer = tokenizer 78 | # select a subset of the validation dataset 79 | indices = torch.randperm(len(val_dataset))[:num_samples] 80 | self.sample_dataset = torch.utils.data.Subset(val_dataset, indices) 81 | self.freq = freq 82 | 83 | def on_evaluate(self, args, state, control, **kwargs): 84 | super().on_evaluate(args, state, control, **kwargs) 85 | # control the frequency of logging by logging the predictions 86 | # every `freq` epochs 87 | if state.epoch % self.freq == 0: 88 | # generate predictions 89 | predictions = self.trainer.predict(self.sample_dataset) 90 | # decode predictions and labels 91 | predictions = decode_predictions(self.tokenizer, predictions) 92 | # add predictions to a wandb.Table 93 | predictions_df = pd.DataFrame(predictions) 94 | predictions_df['epoch'] = state.epoch 95 | records_table = self._wandb.Table(dataframe=predictions_df) 96 | # log the table to wandb 97 | self._wandb.log({'sample_predictions': records_table}) 98 | -------------------------------------------------------------------------------- /data/hh_rlhf_en/hh_rlhf_en.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | 5 | import datasets 6 | 7 | _HF_ENDPOINT = os.getenv('HF_ENDPOINT', 'https://huggingface.co') 8 | _DESCRIPTION = 'Human preference data about helpfulness and harmlessness.' 9 | _CITATION = '' 10 | _HOMEPAGE = '{}/datasets/Anthropic/hh-rlhf'.format(_HF_ENDPOINT) 11 | _LICENSE = 'mit' 12 | _URL = '{}/datasets/Anthropic/hh-rlhf/resolve/main/'.format(_HF_ENDPOINT) 13 | _URLS = { 14 | 'train': [ 15 | _URL + 'harmless-base/train.jsonl.gz', 16 | _URL + 'helpful-base/train.jsonl.gz', 17 | _URL + 'helpful-online/train.jsonl.gz', 18 | _URL + 'helpful-rejection-sampled/train.jsonl.gz', 19 | ], 20 | 'test': [ 21 | _URL + 'harmless-base/test.jsonl.gz', 22 | _URL + 'helpful-base/test.jsonl.gz', 23 | _URL + 'helpful-online/test.jsonl.gz', 24 | _URL + 'helpful-rejection-sampled/test.jsonl.gz', 25 | ], 26 | } 27 | 28 | 29 | class HhRlhfEn(datasets.GeneratorBasedBuilder): 30 | VERSION = datasets.Version('0.0.0') 31 | 32 | def _info(self) -> datasets.DatasetInfo: 33 | features = datasets.Features({ 34 | 'instruction': 35 | datasets.Value('string'), 36 | 'chosen': 37 | datasets.Value('string'), 38 | 'rejected': 39 | datasets.Value('string'), 40 | 'history': 41 | datasets.Sequence(datasets.Sequence(datasets.Value('string'))), 42 | }) 43 | return datasets.DatasetInfo( 44 | description=_DESCRIPTION, 45 | features=features, 46 | homepage=_HOMEPAGE, 47 | license=_LICENSE, 48 | citation=_CITATION, 49 | ) 50 | 51 | def _split_generators(self, dl_manager: datasets.DownloadManager): 52 | file_path = dl_manager.download_and_extract(_URLS) 53 | return [ 54 | datasets.SplitGenerator( 55 | name=datasets.Split.TRAIN, 56 | gen_kwargs={'filepaths': file_path['train']}), 57 | datasets.SplitGenerator( 58 | name=datasets.Split.TEST, 59 | gen_kwargs={'filepaths': file_path['test']}), 60 | ] 61 | 62 | def _generate_examples(self, filepaths: List[str]): 63 | key = 0 64 | for filepath in filepaths: 65 | with open(filepath, 'r', encoding='utf-8') as f: 66 | for row in f: 67 | data = json.loads(row) 68 | chosen = data['chosen'] 69 | rejected = data['rejected'] 70 | 71 | assist_idx = rejected.rfind('\n\nAssistant: ') 72 | r_reject = rejected[assist_idx + 13:].strip() 73 | assist_idx = chosen.rfind('\n\nAssistant: ') 74 | r_accept = chosen[assist_idx + 13:].strip() 75 | 76 | human_idx = chosen.rfind('\n\nHuman: ') 77 | query = chosen[human_idx + 9:assist_idx].strip() 78 | prompt = chosen[:human_idx] 79 | history = [] 80 | 81 | while prompt.rfind('\n\nAssistant: ') != -1: 82 | assist_idx = prompt.rfind('\n\nAssistant: ') 83 | human_idx = prompt.rfind('\n\nHuman: ') 84 | if human_idx != -1: 85 | old_query = prompt[human_idx + 86 | 9:assist_idx].strip() 87 | old_resp = prompt[assist_idx + 13:].strip() 88 | history.insert(0, (old_query, old_resp)) 89 | else: 90 | break 91 | prompt = prompt[:human_idx] 92 | 93 | yield ( 94 | key, 95 | { 96 | 'instruction': query, 97 | 'chosen': r_accept, 98 | 'rejected': r_reject, 99 | 'history': history, 100 | }, 101 | ) 102 | key += 1 103 | -------------------------------------------------------------------------------- /llamatuner/train/rm/trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, List, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from transformers import PreTrainedModel, Trainer 8 | from transformers.trainer import PredictionOutput 9 | 10 | from llamatuner.utils.logger_utils import get_logger 11 | 12 | logger = get_logger('llamatuner') 13 | 14 | 15 | class PairwiseTrainer(Trainer): 16 | r""" 17 | Inherits Trainer to compute pairwise loss. This custom Trainer computes a pairwise ranking loss 18 | where the first half of the batch contains positive examples and the second half contains negative examples. 19 | """ 20 | 21 | def __init__(self, *args, **kwargs) -> None: 22 | super().__init__(*args, **kwargs) 23 | self.can_return_loss = True 24 | 25 | def compute_loss( 26 | self, 27 | model: PreTrainedModel, 28 | inputs: Dict[str, torch.Tensor], 29 | return_outputs: bool = False, 30 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, 31 | torch.Tensor]]]: 32 | r""" 33 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. 34 | 35 | Args: 36 | model (PreTrainedModel): The model being trained. 37 | inputs (Dict[str, torch.Tensor]): The inputs and targets of the model. 38 | return_outputs (bool, optional): Whether to return the model outputs in addition to the loss. Defaults to False. 39 | 40 | Returns: 41 | Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: The computed loss and optionally the model outputs. 42 | """ 43 | outputs = model(**inputs, 44 | output_hidden_states=True, 45 | return_dict=True, 46 | use_cache=False) 47 | values = outputs.logits 48 | 49 | batch_size = inputs['input_ids'].size(0) // 2 50 | chosen_masks, rejected_masks = torch.split(inputs['attention_mask'], 51 | batch_size, 52 | dim=0) 53 | chosen_rewards, rejected_rewards = torch.split(values, 54 | batch_size, 55 | dim=0) 56 | 57 | chosen_scores = chosen_rewards.gather( 58 | dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1)) 59 | rejected_scores = rejected_rewards.gather( 60 | dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)) 61 | 62 | chosen_scores, rejected_scores = ( 63 | chosen_scores.squeeze(), 64 | rejected_scores.squeeze(), 65 | ) 66 | 67 | loss = -F.logsigmoid(chosen_scores.float() - 68 | rejected_scores.float()).mean() 69 | if return_outputs: 70 | return loss, (chosen_scores, rejected_scores) 71 | else: 72 | return loss 73 | 74 | def save_predictions(self, predict_results: PredictionOutput) -> None: 75 | r""" 76 | Saves model predictions to `output_dir`. 77 | 78 | Args: 79 | predict_results (PredictionOutput): The output of the `predict` method. 80 | 81 | This method saves the chosen and rejected scores to a JSONL file in the `output_dir`. 82 | """ 83 | if not self.is_world_process_zero(): 84 | return 85 | 86 | output_prediction_file = os.path.join(self.args.output_dir, 87 | 'generated_predictions.jsonl') 88 | logger.info(f'Saving prediction results to {output_prediction_file}') 89 | chosen_scores, rejected_scores = predict_results.predictions 90 | 91 | with open(output_prediction_file, 'w', encoding='utf-8') as writer: 92 | res: List[str] = [] 93 | for c_score, r_score in zip(chosen_scores, rejected_scores): 94 | res.append( 95 | json.dumps({ 96 | 'chosen': round(float(c_score), 2), 97 | 'rejected': round(float(r_score), 2), 98 | })) 99 | 100 | writer.write('\n'.join(res)) 101 | -------------------------------------------------------------------------------- /llamatuner/model/callbacks/save_peft_model_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | 4 | from transformers import (PreTrainedModel, TrainerCallback, TrainerControl, 5 | TrainingArguments) 6 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 7 | 8 | 9 | class SavePeftModelCallback(TrainerCallback): 10 | """Callback to save PEFT model checkpoints during training. 11 | 12 | Saves both the full model and the adapter model to separate directories 13 | within the checkpoint directory. 14 | """ 15 | 16 | def save_model(self, args: Any, state: TrainingArguments, 17 | kwargs: Dict[str, Any]) -> None: 18 | """Saves the PEFT model checkpoint. 19 | 20 | Args: 21 | args (Any): The command line arguments passed to the script. 22 | state (TrainingArguments): The current state of training. 23 | kwargs (Dict[str, Any]): A dictionary of additional keyword arguments. 24 | 25 | Raises: 26 | TypeError: If `state` is not an instance of `TrainingArguments`. 27 | """ 28 | print('+' * 20, 'Saving PEFT Model Checkpoint CallBack', '+' * 20) 29 | 30 | # Get the checkpoint directory for saving models. 31 | if state.best_model_checkpoint is not None: 32 | # If best model checkpoint exists, use its directory as the checkpoint folder 33 | checkpoint_dir = os.path.join(state.best_model_checkpoint, 34 | 'adapter_model') 35 | else: 36 | # Otherwise, create a new checkpoint folder using the output directory and current global step 37 | checkpoint_dir = os.path.join( 38 | args.output_dir, 39 | f'{PREFIX_CHECKPOINT_DIR}-{state.global_step}') 40 | 41 | # Create path for the PEFT model 42 | peft_model_path = os.path.join(checkpoint_dir, 'adapter_model') 43 | model: PreTrainedModel = kwargs['model'] 44 | model.save_pretrained(peft_model_path) 45 | 46 | # Create path for the PyTorch model binary file and remove it if it already exists 47 | pytorch_model_path = os.path.join(checkpoint_dir, 'pytorch_model.bin') 48 | if os.path.exists(pytorch_model_path): 49 | os.remove(pytorch_model_path) 50 | 51 | def on_save(self, args: Any, state: TrainingArguments, 52 | control: TrainerControl, 53 | **kwargs: Dict[str, Any]) -> TrainerControl: 54 | """Callback method that calls save_model() and returns `control` 55 | argument. 56 | 57 | Args: 58 | args (Any): The command line arguments passed to the script. 59 | state (TrainingArguments): The current state of training. 60 | control (trainer_callback.TrainerControl): \ 61 | The current state of the TrainerCallback's control flow. 62 | kwargs (Dict[str, Any]): A dictionary of additional keyword arguments. 63 | 64 | Returns: 65 | trainer_callback.TrainerControl: The current state of the TrainerCallback's control flow. 66 | 67 | Raises: 68 | TypeError: If `state` is not an instance of `TrainingArguments`. 69 | """ 70 | self.save_model(args, state, kwargs) 71 | return control 72 | 73 | def on_train_end(self, args: Any, state: TrainingArguments, 74 | control: TrainerControl, **kwargs: Dict[str, 75 | Any]) -> None: 76 | """Callback method that saves the model checkpoint and creates a 77 | 'completed' file in the output directory. 78 | 79 | Args: 80 | args (Any): The command line arguments passed to the script. 81 | state (TrainingArguments): The current state of training. 82 | control (trainer_callback.TrainerControl): \ 83 | The current state of the TrainerCallback's control flow. 84 | kwargs (Dict[str, Any]): A dictionary of additional keyword arguments. 85 | 86 | Raises: 87 | TypeError: If `state` is not an instance of `TrainingArguments`. 88 | """ 89 | 90 | # Define a helper function to create a 'completed' file in the output directory 91 | def touch(fname, times=None): 92 | with open(fname, 'a'): 93 | os.utime(fname, times) 94 | 95 | # Create the 'completed' file in the output directory 96 | touch(os.path.join(args.output_dir, 'completed')) 97 | -------------------------------------------------------------------------------- /llamatuner/model/callbacks/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Sequence, Tuple, Union 3 | 4 | import numpy as np 5 | from transformers import PreTrainedTokenizer 6 | 7 | from llamatuner.utils.constants import IGNORE_INDEX 8 | from llamatuner.utils.packages import (is_jieba_available, is_nltk_available, 9 | is_rouge_available) 10 | 11 | if is_jieba_available(): 12 | import jieba # type: ignore 13 | 14 | if is_nltk_available(): 15 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 16 | 17 | if is_rouge_available(): 18 | from rouge_chinese import Rouge 19 | 20 | 21 | @dataclass 22 | class ComputeMetrics: 23 | """Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. 24 | 25 | Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 26 | """ 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer) -> None: 29 | """Initialize the ComputeMetrics class with a pre-trained tokenizer 30 | object. 31 | 32 | Args: 33 | tokenizer (PreTrainedTokenizer): A pre-trained tokenizer object to be used for decoding tokenized sequences. 34 | """ 35 | self.tokenizer = tokenizer 36 | 37 | def __call__( 38 | self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]] 39 | ) -> Dict[str, float]: 40 | """Computes evaluation metrics for model predictions. 41 | 42 | Args: 43 | eval_preds (List[Union[np.ndarray, Tuple[np.ndarray]]]): List of tuples containing prediction and label arrays. 44 | 45 | Returns: 46 | Dict[str, float]: A dictionary containing the average of each computed metric over all prediction-label pairs. 47 | """ 48 | 49 | # Extract predictions and labels from input 50 | preds, labels = eval_preds 51 | if isinstance(preds, tuple): 52 | preds = preds[0] 53 | 54 | score_dict = { 55 | 'rouge-1': [], 56 | 'rouge-2': [], 57 | 'rouge-l': [], 58 | 'bleu-4': [] 59 | } 60 | 61 | # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True. 62 | preds = np.where(preds != IGNORE_INDEX, preds, 63 | self.tokenizer.pad_token_id) 64 | labels = np.where(labels != IGNORE_INDEX, labels, 65 | self.tokenizer.pad_token_id) 66 | 67 | decoded_preds = self.tokenizer.batch_decode(preds, 68 | skip_special_tokens=True) 69 | decoded_labels = self.tokenizer.batch_decode(labels, 70 | skip_special_tokens=True) 71 | 72 | # Calculate metrics for each prediction-label pair 73 | for pred, label in zip(decoded_preds, decoded_labels): 74 | hypothesis = list(jieba.cut(pred)) 75 | reference = list(jieba.cut(label)) 76 | # If there are no words in the hypothesis, set all scores to 0 77 | if (len(' '.join(hypothesis).split()) == 0 78 | or len(' '.join(reference).split()) == 0): 79 | result = { 80 | 'rouge-1': { 81 | 'f': 0.0 82 | }, 83 | 'rouge-2': { 84 | 'f': 0.0 85 | }, 86 | 'rouge-l': { 87 | 'f': 0.0 88 | }, 89 | } 90 | else: 91 | rouge = Rouge() 92 | scores = rouge.get_scores(' '.join(hypothesis), 93 | ' '.join(reference)) 94 | result = scores[0] 95 | 96 | # Append scores to score_dict 97 | for k, v in result.items(): 98 | score_dict[k].append(round(v['f'] * 100, 4)) 99 | 100 | # Calculate BLEU-4 score and append it to score_dict 101 | bleu_score = sentence_bleu( 102 | [list(label)], 103 | list(pred), 104 | smoothing_function=SmoothingFunction().method3) 105 | score_dict['bleu-4'].append(round(bleu_score * 100, 4)) 106 | 107 | # Calculate average of each metric over all prediction-label pairs and return as a dictionary 108 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 109 | -------------------------------------------------------------------------------- /server/single_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import sys 4 | from threading import Thread 5 | from typing import List 6 | 7 | import torch 8 | import transformers 9 | from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, 10 | PreTrainedTokenizer, TextIteratorStreamer) 11 | 12 | sys.path.append('../') 13 | from llamatuner.configs import GenerationArguments, ModelInferenceArguments 14 | from llamatuner.utils.model_utils import get_logits_processor 15 | 16 | 17 | def generate_response(query: str, tokenizer: PreTrainedTokenizer, 18 | model: PreTrainedModel, 19 | generation_args: dict) -> List[str]: 20 | """Generates a response to the given query using GPT-3.5 model and prints 21 | it to the console. 22 | 23 | Args: 24 | query (str): The input query for which a response is to be generated. 25 | tokenizer (PreTrainedTokenizer): The tokenizer used to convert the raw text into input tokens. 26 | model (PreTrainedModel): The GPT-3.5 model used to generate the response. 27 | generation_args (dict): A dictionary containing the arguments to be passed to the generate() method of the model. 28 | 29 | Returns: 30 | List[Tuple[str, str]]: A list of all the previous queries and their responses, including the current one. 31 | """ 32 | 33 | # Convert the query and history into input IDs 34 | inputs = tokenizer(query, return_tensors='pt', add_special_tokens=False) 35 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 36 | 37 | # Create a TextIteratorStreamer object to stream the response from the model 38 | streamer = TextIteratorStreamer(tokenizer, 39 | timeout=60.0, 40 | skip_prompt=True, 41 | skip_special_tokens=True) 42 | 43 | # Set the arguments for the model's generate() method 44 | gen_kwargs = dict( 45 | **inputs, 46 | streamer=streamer, 47 | logits_processor=get_logits_processor(), 48 | **generation_args.to_dict(), 49 | ) 50 | 51 | # Start a separate thread to generate the response asynchronously 52 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 53 | thread.start() 54 | 55 | # Print the model name and the response as it is generated 56 | print('Assistant: ', end='', flush=True) 57 | response = '' 58 | for new_text in streamer: 59 | print(new_text, end='', flush=True) 60 | response += new_text 61 | # Update the history with the current query and response and return it 62 | return response 63 | 64 | 65 | def main(): 66 | """单轮对话,不具有对话历史的记忆功能 Run conversational agent loop with input/output. 67 | 68 | Args: 69 | model_args: Arguments for loading model 70 | gen_args: Arguments for model.generate() 71 | 72 | Returns: 73 | None 74 | """ 75 | 76 | # Parse command-line arguments 77 | parser = transformers.HfArgumentParser( 78 | (ModelInferenceArguments, GenerationArguments)) 79 | model_server_args, generation_args = parser.parse_args_into_dataclasses() 80 | 81 | # Load the pretrained language model. 82 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 83 | 84 | model = AutoModelForCausalLM.from_pretrained( 85 | model_server_args.model_name_or_path, 86 | trust_remote_code=True, 87 | low_cpu_mem_usage=True, 88 | torch_dtype=torch.float16, 89 | device_map='auto').to(device).eval() 90 | 91 | tokenizer = AutoTokenizer.from_pretrained( 92 | model_server_args.model_name_or_path, 93 | trust_remote_code=True, 94 | use_fast=False, 95 | ) 96 | 97 | os_name = platform.system() 98 | clear_command = 'cls' if os_name == 'Windows' else 'clear' 99 | # Set the arguments for the model's generate() method 100 | print('欢迎使用 CLI 对话系统,输入内容即可对话,clear 清空对话历史,stop 终止程序') 101 | input_pattern = '{}' 102 | while True: 103 | query = input('\nUser: ') 104 | if query.strip() == 'stop': 105 | break 106 | 107 | if query.strip() == 'clear': 108 | os.system(clear_command) 109 | print('History has been removed.') 110 | print('欢迎使用CLI 对话系统,输入内容即可对话,clear 清空对话历史,stop 终止程序') 111 | continue 112 | 113 | query = input_pattern.format(query) 114 | # Perform prediction and printing 115 | generate_response(query, tokenizer, model, generation_args) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /server/gradio_base_webserver.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 6 | 7 | from llamatuner.train.apply_lora import apply_lora 8 | 9 | 10 | def args_parser(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--model_name_or_path', 13 | default=None, 14 | type=str, 15 | required=True, 16 | help='Path to pre-trained model') 17 | parser.add_argument('--lora_model_name_or_path', 18 | default=None, 19 | type=str, 20 | help='Path to pre-trained model') 21 | parser.add_argument('--no_cuda', 22 | action='store_true', 23 | help='Avoid using CUDA when available') 24 | parser.add_argument('--load_8bit', 25 | action='store_true', 26 | help='Whether to use load_8bit instead of 32-bit') 27 | args = parser.parse_args() 28 | 29 | args.device = torch.device( 30 | 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu') 31 | return args 32 | 33 | 34 | def main(args): 35 | if args.lora_model_name_or_path is not None: 36 | model, tokenizer = apply_lora(args.model_name_or_path, 37 | args.lora_model_name_or_path, 38 | load_8bit=args.load_8bit) 39 | else: 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | pretrained_model_name_or_path=args.model_name_or_path, 42 | trust_remote_code=True) 43 | model = AutoModelForCausalLM.from_pretrained( 44 | pretrained_model_name_or_path=args.model_name_or_path, 45 | load_in_8bit=args.load_8bit, 46 | torch_dtype=torch.float16, 47 | device_map='auto', 48 | trust_remote_code=True) 49 | 50 | def evaluate( 51 | input=None, 52 | temperature=0.8, 53 | top_p=0.75, 54 | top_k=40, 55 | max_new_tokens=128, 56 | **kwargs, 57 | ): 58 | inputs = tokenizer(input, return_tensors='pt') 59 | inputs = inputs.to(args.device) 60 | generation_config = GenerationConfig( 61 | temperature=temperature, 62 | top_p=top_p, 63 | top_k=top_k, 64 | do_sample=True, 65 | no_repeat_ngram_size=6, 66 | repetition_penalty=1.8, 67 | **kwargs, 68 | ) 69 | # Without streaming 70 | with torch.no_grad(): 71 | generation_output = model.generate( 72 | **inputs, 73 | generation_config=generation_config, 74 | return_dict_in_generate=True, 75 | output_scores=True, 76 | max_new_tokens=max_new_tokens, 77 | ) 78 | s = generation_output.sequences[0] 79 | output = tokenizer.decode(s, skip_special_tokens=True) 80 | yield output 81 | 82 | description = 'Baichuan7B is a 7B-parameter LLaMA model finetuned to follow instructions.' 83 | server = gr.Interface( 84 | fn=evaluate, 85 | inputs=[ 86 | gr.components.Textbox(lines=2, label='Input', placeholder='none'), 87 | gr.components.Slider(minimum=0, 88 | maximum=1, 89 | value=0.1, 90 | label='Temperature'), 91 | gr.components.Slider(minimum=0, 92 | maximum=1, 93 | value=0.75, 94 | label='Top p'), 95 | gr.components.Slider(minimum=0, 96 | maximum=100, 97 | step=1, 98 | value=40, 99 | label='Top k'), 100 | gr.components.Slider(minimum=1, 101 | maximum=2000, 102 | step=1, 103 | value=128, 104 | label='Max tokens'), 105 | ], 106 | outputs=[gr.inputs.Textbox( 107 | lines=5, 108 | label='Output', 109 | )], 110 | title='Baichuan7B', 111 | description=description, 112 | ) 113 | 114 | server.queue().launch(server_name='0.0.0.0', share=False) 115 | 116 | 117 | if __name__ == '__main__': 118 | args = args_parser() 119 | main(args) 120 | -------------------------------------------------------------------------------- /llamatuner/data/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | from typing import List, Optional, TypedDict, Union 3 | 4 | from datasets import (Dataset, IterableDataset, concatenate_datasets, 5 | interleave_datasets) 6 | from transformers import TrainingArguments 7 | 8 | from llamatuner.configs import DataArguments 9 | from llamatuner.utils.logger_utils import get_logger 10 | 11 | logger = get_logger('llamatuner') 12 | 13 | 14 | @unique 15 | class Role(str, Enum): 16 | """Enumeration of possible roles in a conversation.""" 17 | USER = 'user' 18 | ASSISTANT = 'assistant' 19 | SYSTEM = 'system' 20 | FUNCTION = 'function' 21 | OBSERVATION = 'observation' 22 | 23 | 24 | class DatasetModule(TypedDict): 25 | """Type definition for dataset module containing train and evaluation datasets.""" 26 | train_dataset: Optional[Union[Dataset, IterableDataset]] 27 | eval_dataset: Optional[Union[Dataset, IterableDataset]] 28 | 29 | 30 | def merge_dataset( 31 | all_datasets: List[Union[Dataset, IterableDataset]], 32 | data_args: DataArguments, 33 | training_args: TrainingArguments, 34 | ) -> Union[Dataset, IterableDataset]: 35 | """Merge multiple datasets using specified strategy. 36 | 37 | Args: 38 | all_datasets: List of datasets to merge 39 | data_args: Data configuration arguments 40 | training_args: Training configuration arguments 41 | 42 | Returns: 43 | Merged dataset 44 | 45 | Raises: 46 | ValueError: If mixing strategy is unknown 47 | """ 48 | if not all_datasets: 49 | raise ValueError('Cannot merge empty dataset list') 50 | 51 | if len(all_datasets) == 1: 52 | return all_datasets[0] 53 | 54 | valid_strategies = {'concat', 'interleave_under', 'interleave_over'} 55 | if data_args.mix_strategy not in valid_strategies: 56 | raise ValueError( 57 | f'Unknown mixing strategy: {data_args.mix_strategy}. ' 58 | f"Valid strategies are: {', '.join(valid_strategies)}") 59 | 60 | logger.info( 61 | f'Merging {len(all_datasets)} datasets with {data_args.mix_strategy} strategy ...' 62 | ) 63 | 64 | if data_args.mix_strategy == 'concat': 65 | if data_args.streaming: 66 | logger.warning( 67 | 'The samples between different datasets will not be mixed in streaming mode.' 68 | ) 69 | return concatenate_datasets(all_datasets) 70 | 71 | # Handle interleave strategies 72 | if not data_args.streaming: 73 | logger.warning( 74 | 'We recommend using `mix_strategy=concat` in non-streaming mode.') 75 | 76 | stopping_strategy = 'first_exhausted' if data_args.mix_strategy == 'interleave_under' else 'all_exhausted' 77 | 78 | return interleave_datasets( 79 | datasets=all_datasets, 80 | probabilities=data_args.interleave_probs, 81 | seed=training_args.seed, 82 | stopping_strategy=stopping_strategy, 83 | ) 84 | 85 | 86 | def split_dataset( 87 | dataset: Union[Dataset, IterableDataset], 88 | data_args: DataArguments, 89 | training_args: TrainingArguments, 90 | ) -> DatasetModule: 91 | """Split dataset into training and evaluation sets. 92 | 93 | Args: 94 | dataset: Input dataset to split 95 | data_args: Data configuration arguments 96 | training_args: Training configuration arguments 97 | 98 | Returns: 99 | Dictionary containing train and evaluation datasets 100 | 101 | Raises: 102 | ValueError: If eval_dataset_size is invalid 103 | """ 104 | if data_args.eval_dataset_size <= 0: 105 | raise ValueError('eval_dataset_size must be greater than 0') 106 | 107 | val_size = int( 108 | data_args.eval_dataset_size 109 | ) if data_args.eval_dataset_size > 1 else data_args.eval_dataset_size 110 | 111 | logger.info( 112 | f'Splitting dataset with evaluation size of {val_size} ' 113 | f'({data_args.eval_dataset_size} {"samples" if data_args.eval_dataset_size > 1 else "fraction"})' 114 | ) 115 | if data_args.streaming: 116 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, 117 | seed=training_args.seed) 118 | val_set = dataset.take(int(data_args.eval_dataset_size)) 119 | train_set = dataset.skip(int(data_args.eval_dataset_size)) 120 | return DatasetModule(train_dataset=train_set, eval_dataset=val_set) 121 | 122 | dataset_split = dataset.train_test_split(test_size=val_size, 123 | seed=training_args.seed) 124 | return DatasetModule(train_dataset=dataset_split['train'], 125 | eval_dataset=dataset_split['test']) 126 | -------------------------------------------------------------------------------- /llamatuner/configs/model_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from typing import Any, Dict, Literal, Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.""" 8 | 9 | model_name_or_path: Optional[str] = field( 10 | default='facebook/opt-125m', 11 | metadata={ 12 | 'help': 13 | ('Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models.' 14 | ) 15 | }, 16 | ) 17 | adapter_name_or_path: Optional[str] = field( 18 | default=None, 19 | metadata={ 20 | 'help': 21 | ('Path to the adapter weight or identifier from huggingface.co/models. ' 22 | 'Use commas to separate multiple adapters.') 23 | }, 24 | ) 25 | adapter_folder: Optional[str] = field( 26 | default=None, 27 | metadata={ 28 | 'help': 'The folder containing the adapter weights to load.' 29 | }, 30 | ) 31 | use_fast_tokenizer: bool = field( 32 | default=True, 33 | metadata={ 34 | 'help': 35 | 'Whether or not to use one of the fast tokenizer (backed by the tokenizers library).' 36 | }, 37 | ) 38 | resize_vocab: bool = field( 39 | default=False, 40 | metadata={ 41 | 'help': 42 | 'Whether or not to resize the tokenizer vocab and the embedding layers.' 43 | }, 44 | ) 45 | model_max_length: Optional[int] = field( 46 | default=1024, 47 | metadata={ 48 | 'help': 49 | 'The maximum length of the model input, including special tokens.' 50 | }, 51 | ) 52 | trust_remote_code: Optional[bool] = field( 53 | default=True, 54 | metadata={ 55 | 'help': 56 | 'Whether or not to trust the remote code in the model configuration.' 57 | }, 58 | ) 59 | cache_dir: Optional[str] = field( 60 | default=None, 61 | metadata={ 62 | 'help': 63 | 'Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn.' 64 | }, 65 | ) 66 | model_revision: str = field( 67 | default='main', 68 | metadata={ 69 | 'help': 70 | 'The specific model version to use (can be a branch name, tag name or commit id).' 71 | }, 72 | ) 73 | split_special_tokens: bool = field( 74 | default=False, 75 | metadata={ 76 | 'help': 77 | 'Whether or not the special tokens should be split during the tokenization process.' 78 | }, 79 | ) 80 | new_special_tokens: Optional[str] = field( 81 | default=None, 82 | metadata={'help': 'Special tokens to be added into the tokenizer.'}, 83 | ) 84 | low_cpu_mem_usage: bool = field( 85 | default=True, 86 | metadata={ 87 | 'help': 'Whether or not to use memory-efficient model loading.' 88 | }, 89 | ) 90 | rope_scaling: Optional[Literal['linear', 'dynamic']] = field( 91 | default=None, 92 | metadata={ 93 | 'help': 94 | 'Which scaling strategy should be adopted for the RoPE embeddings.' 95 | }, 96 | ) 97 | flash_attn: Literal['off', 'sdpa', 'fa2', 'auto'] = field( 98 | default='auto', 99 | metadata={ 100 | 'help': 'Enable FlashAttention for faster training and inference.' 101 | }, 102 | ) 103 | train_from_scratch: bool = field( 104 | default=False, 105 | metadata={ 106 | 'help': 'Whether or not to randomly initialize the model weights.' 107 | }, 108 | ) 109 | offload_folder: str = field( 110 | default='offload', 111 | metadata={'help': 'Path to offload model weights.'}, 112 | ) 113 | use_cache: bool = field( 114 | default=True, 115 | metadata={'help': 'Whether or not to use KV cache in generation.'}, 116 | ) 117 | hf_hub_token: Optional[str] = field( 118 | default=None, 119 | metadata={'help': 'Auth token to log in with Hugging Face Hub.'}, 120 | ) 121 | ms_hub_token: Optional[str] = field( 122 | default=None, 123 | metadata={'help': 'Auth token to log in with ModelScope Hub.'}, 124 | ) 125 | 126 | def __post_init__(self): 127 | self.compute_dtype = None 128 | self.device_map = None 129 | 130 | if self.model_name_or_path is None: 131 | raise ValueError('Please provide `model_name_or_path`.') 132 | 133 | if self.adapter_name_or_path is not None: # support merging multiple lora weights 134 | self.adapter_name_or_path = [ 135 | path.strip() for path in self.adapter_name_or_path.split(',') 136 | ] 137 | 138 | if self.split_special_tokens and self.use_fast_tokenizer: 139 | raise ValueError( 140 | '`split_special_tokens` is only supported for slow tokenizers.' 141 | ) 142 | 143 | if self.new_special_tokens is not None: # support multiple special tokens 144 | self.new_special_tokens = [ 145 | token.strip() for token in self.new_special_tokens.split(',') 146 | ] 147 | 148 | def to_dict(self) -> Dict[str, Any]: 149 | return asdict(self) 150 | -------------------------------------------------------------------------------- /data/format_data/convert_alpaca.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from datasets import load_dataset 4 | 5 | 6 | def json_dump(obj, path): 7 | with open(path, 'w', encoding='utf-8') as f: 8 | json.dump(obj, f, indent=2, ensure_ascii=False) 9 | 10 | 11 | def json_load(in_file): 12 | with open(in_file, 'r') as f: 13 | json_data = json.load(f) 14 | return json_data 15 | 16 | 17 | def convert_100PoisonMpts(in_file, out_file): 18 | raw_data = load_dataset('json', data_files=in_file)['train'] 19 | new_content = [] 20 | for i, raw_text in enumerate(raw_data): 21 | prompt = raw_text['prompt'] 22 | response = raw_text['answer'] 23 | if len(prompt) <= 5 or len(response) <= 5: 24 | continue 25 | new_content.append({ 26 | 'instruction': prompt, 27 | 'input': '', 28 | 'output': response, 29 | }) 30 | 31 | print(f'#out: {len(new_content)}') 32 | json_dump(new_content, out_file) 33 | 34 | 35 | def convert_Cvalues(in_file, out_file): 36 | raw_data = load_dataset('json', data_files=in_file)['train'] 37 | new_content = [] 38 | for i, raw_text in enumerate(raw_data): 39 | prompt = raw_text['prompt'] 40 | response = raw_text['pos_resp'] 41 | if len(prompt) <= 5 or len(response) <= 5: 42 | continue 43 | new_content.append({ 44 | 'instruction': prompt, 45 | 'input': '', 46 | 'output': response, 47 | }) 48 | 49 | print(f'#out: {len(new_content)}') 50 | json_dump(new_content, out_file) 51 | 52 | 53 | def convert_huatuogpt(in_file, out_file): 54 | raw_data = load_dataset('json', data_files=in_file)['train'] 55 | new_content = [] 56 | for i, raw_text in enumerate(raw_data): 57 | data = raw_text['data'] 58 | prompt = data[0].replace('问:', '') 59 | response = data[1].replace('答:', '') 60 | if len(prompt) <= 5 or len(response) <= 5: 61 | continue 62 | new_content.append({ 63 | 'instruction': prompt, 64 | 'input': '', 65 | 'output': response, 66 | }) 67 | print(f'#out: {len(new_content)}') 68 | json_dump(new_content, out_file) 69 | 70 | 71 | def convert_safety_attack(in_file, out_file): 72 | field_list = [ 73 | 'Reverse_Exposure', 'Goal_Hijacking', 'Prompt_Leaking', 74 | 'Unsafe_Instruction_Topic', 'Role_Play_Instruction', 75 | 'Inquiry_With_Unsafe_Opinion' 76 | ] 77 | new_content = [] 78 | for filed in field_list: 79 | raw_data = load_dataset('json', field=filed, 80 | data_files=in_file)['train'] 81 | for i, raw_text in enumerate(raw_data): 82 | prompt = raw_text['prompt'] 83 | response = raw_text['response'] 84 | if len(prompt) <= 5 or len(response) <= 5: 85 | continue 86 | new_content.append({ 87 | 'instruction': prompt, 88 | 'input': '', 89 | 'output': response, 90 | }) 91 | print(f'#out: {len(new_content)}') 92 | json_dump(new_content, out_file) 93 | 94 | 95 | def convert_safety_scenarios(in_file, out_file): 96 | 97 | field_list = [ 98 | 'Unfairness_And_Discrimination', 'Crimes_And_Illegal_Activities', 99 | 'Insult', 'Mental_Health', 'Physical_Harm', 'Privacy_And_Property', 100 | 'Ethics_And_Morality' 101 | ] 102 | new_content = [] 103 | for filed in field_list: 104 | raw_data = load_dataset('json', data_files=in_file, 105 | field=filed)['train'] 106 | for i, raw_text in enumerate(raw_data): 107 | prompt = raw_text['prompt'] 108 | response = raw_text['response'] 109 | if len(prompt) <= 5 or len(response) <= 5: 110 | continue 111 | new_content.append({ 112 | 'instruction': prompt, 113 | 'input': '', 114 | 'output': response, 115 | }) 116 | print(f'#out: {len(new_content)}') 117 | json_dump(new_content, out_file) 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | data_path = '/home/robin/prompt_data/100PoisonMpts/train.jsonl' 123 | out_path = '/home/robin/prompt_data/100PoisonMpts/train_alpaca.jsonl' 124 | convert_100PoisonMpts(data_path, out_file=out_path) 125 | 126 | data_path = '/home/robin/prompt_data/CValues-Comparison/test.jsonl' 127 | out_path = '/home/robin/prompt_data/CValues-Comparison/test_alpaca.json' 128 | convert_Cvalues(data_path, out_file=out_path) 129 | 130 | data_path = '/home/robin/prompt_data/CValues-Comparison/train.jsonl' 131 | out_path = '/home/robin/prompt_data/CValues-Comparison/train_alpaca.json' 132 | convert_Cvalues(data_path, out_file=out_path) 133 | 134 | data_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_sft_data_v1.jsonl' 135 | out_path = '/home/robin/prompt_data/HuatuoGPT-sft-data-v1/HuatuoGPT_alpaca.json' 136 | convert_huatuogpt(data_path, out_file=out_path) 137 | 138 | data_path = '/home/robin/prompt_data/Safety-Prompts/instruction_attack_scenarios.json' 139 | out_path = '/home/robin/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json' 140 | convert_safety_attack(data_path, out_file=out_path) 141 | 142 | data_path = '/home/robin/prompt_data/Safety-Prompts/typical_safety_scenarios.json' 143 | out_path = '/home/robin/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json' 144 | convert_safety_scenarios(data_path, out_file=out_path) 145 | -------------------------------------------------------------------------------- /llamatuner/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import Tuple 4 | 5 | import torch 6 | from transformers.utils import (is_torch_bf16_gpu_available, 7 | is_torch_cuda_available, 8 | is_torch_mps_available, is_torch_npu_available, 9 | is_torch_xpu_available) 10 | from transformers.utils.versions import require_version 11 | 12 | from llamatuner.utils.logger_utils import get_logger 13 | 14 | _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() 15 | try: 16 | _is_bf16_available = is_torch_bf16_gpu_available() 17 | except Exception: 18 | _is_bf16_available = False 19 | 20 | from llamatuner.configs.model_args import ModelArguments 21 | 22 | logger = get_logger('llamatuner') 23 | 24 | 25 | class AverageMeter: 26 | r""" 27 | Computes and stores the average and current value. 28 | """ 29 | 30 | def __init__(self): 31 | self.reset() 32 | 33 | def reset(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def update(self, val, n=1): 40 | self.val = val 41 | self.sum += val * n 42 | self.count += n 43 | self.avg = self.sum / self.count 44 | 45 | 46 | def check_version(requirement: str, mandatory: bool = False) -> None: 47 | r""" 48 | Optionally checks the package version. 49 | """ 50 | if os.getenv('DISABLE_VERSION_CHECK', '0').lower() in ['true', '1' 51 | ] and not mandatory: 52 | logger.warning( 53 | 'Version checking has been disabled, may lead to unexpected behaviors.' 54 | ) 55 | return 56 | 57 | if mandatory: 58 | hint = f'To fix: run `pip install {requirement}`.' 59 | else: 60 | hint = f'To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check.' 61 | 62 | require_version(requirement, hint) 63 | 64 | 65 | def check_dependencies() -> None: 66 | r""" 67 | Checks the version of the required packages. 68 | """ 69 | check_version('transformers>=4.41.2,<=4.46.1') 70 | check_version('datasets>=2.16.0,<=3.1.0') 71 | check_version('accelerate>=0.34.0,<=1.0.1') 72 | check_version('peft>=0.11.1,<=0.12.0') 73 | check_version('trl>=0.8.6,<=0.9.6') 74 | 75 | 76 | def get_current_device() -> torch.device: 77 | r""" 78 | Gets the current available device. 79 | """ 80 | if is_torch_xpu_available(): 81 | device = 'xpu:{}'.format(os.environ.get('LOCAL_RANK', '0')) 82 | elif is_torch_npu_available(): 83 | device = 'npu:{}'.format(os.environ.get('LOCAL_RANK', '0')) 84 | elif is_torch_mps_available(): 85 | device = 'mps:{}'.format(os.environ.get('LOCAL_RANK', '0')) 86 | elif is_torch_cuda_available(): 87 | device = 'cuda:{}'.format(os.environ.get('LOCAL_RANK', '0')) 88 | else: 89 | device = 'cpu' 90 | 91 | return torch.device(device) 92 | 93 | 94 | def get_device_count() -> int: 95 | r""" 96 | Gets the number of available GPU or NPU devices. 97 | """ 98 | if is_torch_npu_available(): 99 | return torch.npu.device_count() 100 | elif is_torch_cuda_available(): 101 | return torch.cuda.device_count() 102 | else: 103 | return 0 104 | 105 | 106 | def get_peak_memory() -> Tuple[int, int]: 107 | r""" 108 | Gets the peak memory usage for the current device (in Bytes). 109 | """ 110 | if is_torch_npu_available(): 111 | return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved( 112 | ) 113 | elif is_torch_cuda_available(): 114 | return torch.cuda.max_memory_allocated( 115 | ), torch.cuda.max_memory_reserved() 116 | else: 117 | return 0, 0 118 | 119 | 120 | def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: 121 | r""" 122 | Infers the optimal dtype according to the model_dtype and device compatibility. 123 | """ 124 | if _is_bf16_available and model_dtype == torch.bfloat16: 125 | return torch.bfloat16 126 | elif _is_fp16_available: 127 | return torch.float16 128 | else: 129 | return torch.float32 130 | 131 | 132 | def is_gpu_or_npu_available() -> bool: 133 | r""" 134 | Checks if the GPU or NPU is available. 135 | """ 136 | return is_torch_npu_available() or is_torch_cuda_available() 137 | 138 | 139 | def has_tokenized_data(path: os.PathLike) -> bool: 140 | r""" 141 | Checks if the path has a tokenized dataset. 142 | """ 143 | return os.path.isdir(path) and len(os.listdir(path)) > 0 144 | 145 | 146 | def torch_gc() -> None: 147 | r""" 148 | Collects GPU or NPU memory. 149 | """ 150 | gc.collect() 151 | if is_torch_xpu_available(): 152 | torch.xpu.empty_cache() 153 | elif is_torch_npu_available(): 154 | torch.npu.empty_cache() 155 | elif is_torch_mps_available(): 156 | torch.mps.empty_cache() 157 | elif is_torch_cuda_available(): 158 | torch.cuda.empty_cache() 159 | 160 | 161 | def try_download_model_from_ms(model_args: 'ModelArguments') -> str: 162 | if not use_modelscope() or os.path.exists(model_args.model_name_or_path): 163 | return model_args.model_name_or_path 164 | 165 | try: 166 | from modelscope import snapshot_download 167 | 168 | revision = ('master' if model_args.model_revision == 'main' else 169 | model_args.model_revision) 170 | return snapshot_download( 171 | model_args.model_name_or_path, 172 | revision=revision, 173 | cache_dir=model_args.cache_dir, 174 | ) 175 | except ImportError as exc: 176 | raise ImportError( 177 | 'Please install modelscope via `pip install modelscope -U`' 178 | ) from exc 179 | 180 | 181 | def use_modelscope() -> bool: 182 | return os.environ.get('USE_MODELSCOPE_HUB', '0').lower() in ['true', '1'] 183 | -------------------------------------------------------------------------------- /data/format_data/clean_sharegpt/split_long_conversation.py: -------------------------------------------------------------------------------- 1 | """Split long conversations based on certain max length. 2 | 3 | Usage: python3 -m split_long_conversation.py \ 4 | --in sharegpt_clean.json \ 5 | --out sharegpt_split.json \ 6 | --model-name-or-path $ 7 | 8 | example: 9 | python split_long_conversation.py \ 10 | --in-file sharegpt_clean.json \ 11 | --model-name-or-path decapoda-research/llama-7b-hf 12 | """ 13 | import argparse 14 | import json 15 | from concurrent.futures import ProcessPoolExecutor 16 | from typing import Any, Dict, List 17 | 18 | import transformers 19 | from clean_sharegpt import filter_invalid_roles, get_statistics, json_dump 20 | from tqdm import tqdm 21 | 22 | 23 | def make_sample(sample: Dict[str, any], start_idx: int, 24 | end_idx: int) -> Dict[str, any]: 25 | """Create a new sample dictionary by selecting conversations from the given 26 | sample. 27 | 28 | Args: 29 | sample (Dict[str, any]): The original sample dictionary. 30 | start_idx (int): The starting index of conversations to include. 31 | end_idx (int): The ending index of conversations to include. 32 | 33 | Returns: 34 | Dict[str, any]: The new sample dictionary with selected conversations. 35 | """ 36 | assert (end_idx - start_idx) % 2 == 0 37 | conversations = sample['conversations'][start_idx:end_idx] 38 | return { 39 | 'id': sample['id'] + '_' + str(start_idx), 40 | 'conversations': conversations, 41 | } 42 | 43 | 44 | def split_one_sample(sample: Dict[str, any]) -> List[Dict[str, any]]: 45 | """Split a single sample into multiple samples based on conversation 46 | lengths. 47 | 48 | Args: 49 | sample (Dict[str, any]): The original sample dictionary. 50 | max_length (int): The maximum length constraint for conversations. 51 | 52 | Returns: 53 | List[Dict[str, any]]: The list of new sample dictionaries. 54 | """ 55 | tokenized_lens = [] 56 | conversations = sample['conversations'] 57 | 58 | # Truncate conversations to an even number of conversations 59 | conversations = conversations[:len(conversations) // 2 * 2] 60 | 61 | # Calculate the tokenized length for each conversation 62 | for conv in conversations: 63 | length = len(tokenizer(conv['value']).input_ids) + 6 64 | tokenized_lens.append(length) 65 | 66 | new_samples = [] 67 | start_idx = 0 # The starting index of conversations to include 68 | cur_len = 0 # The current length of conversations included 69 | 70 | # Iterate through conversations and create new samples based on length constraints 71 | for end_idx in range(0, len(conversations), 2): 72 | round_len = tokenized_lens[end_idx] + tokenized_lens[end_idx + 1] 73 | if cur_len + round_len > max_length: 74 | sub_sample = make_sample(sample, start_idx, end_idx + 2) 75 | new_samples.append(sub_sample) 76 | start_idx = end_idx + 2 77 | cur_len = 0 78 | elif end_idx == len(conversations) - 2: 79 | sub_sample = make_sample(sample, start_idx, end_idx + 2) 80 | new_samples.append(sub_sample) 81 | cur_len += round_len 82 | 83 | return new_samples 84 | 85 | 86 | def worker(input_data: List[Dict[str, Any]]): 87 | result = [] 88 | for sample in input_data: 89 | result.extend(split_one_sample(sample)) 90 | return result 91 | 92 | 93 | def split_all(raw_data: List[Dict[str, Any]], 94 | tokenizer_: transformers.PreTrainedTokenizer, 95 | max_length_: int) -> List[Dict[str, Any]]: 96 | """Split the content into smaller parts based on the max token length 97 | constraint. 98 | 99 | Args: 100 | raw_data (List[Dict[str, Any]]): The list of samples to split. 101 | tokenizer (PreTrainedTokenizer): The tokenizer object used for tokenization. 102 | max_length (int): The maximum length allowed for each split. 103 | 104 | Returns: 105 | List[Dict[str, Any]]: The list of new sample dictionaries after splitting. 106 | """ 107 | global tokenizer, max_length 108 | tokenizer = tokenizer_ 109 | max_length = max_length_ 110 | 111 | new_content = [] 112 | 113 | # Split content into chunks 114 | chunks = [raw_data[i:i + 1000] for i in range(0, len(raw_data), 1000)] 115 | # Use tqdm to show progress bar during the execution 116 | with ProcessPoolExecutor() as executor: 117 | for result in tqdm(executor.map(worker, chunks), 118 | desc='Splitting long conversations', 119 | total=len(chunks)): 120 | new_content.extend(result) 121 | 122 | return new_content 123 | 124 | 125 | def main(args): 126 | contents = json.load(open(args.in_file, 'r')) 127 | tokenizer = transformers.AutoTokenizer.from_pretrained( 128 | args.model_name_or_path, 129 | padding_side='right', 130 | model_max_length=args.max_length, 131 | use_fast=False, 132 | tokenizer_type='llama' if 'llama' in args.model_name_or_path else None, 133 | ) 134 | print('Splitting long conversations...') 135 | split_data = split_all(contents, tokenizer, args.max_length) 136 | res1, res2 = get_statistics(split_data) 137 | # Save role_list_2 and role_res_2 to JSON files 138 | json_dump(res2, 'role_res_3.json') 139 | print(f'#in: {len(contents)}, #out: {len(split_data)}') 140 | print('Filtering invalid roles...') 141 | new_content = filter_invalid_roles(split_data) 142 | res1, res2 = get_statistics(new_content) 143 | # Save role_list_3 and role_res_3 to JSON files 144 | json_dump(res2, 'role_res_4.json') 145 | print(f'#in: {len(split_data)}, #out: {len(new_content)}') 146 | json_dump(new_content, args.out_file) 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('--in-file', type=str, required=True) 152 | parser.add_argument('--out-file', type=str, default='sharegpt_split.json') 153 | parser.add_argument('--model-name-or-path', type=str, required=True) 154 | parser.add_argument('--max-length', type=int, default=2048) 155 | args = parser.parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /data/format_data/clean_sharegpt/clean_sharegpt.py: -------------------------------------------------------------------------------- 1 | """Prepare all datasets.""" 2 | 3 | import argparse 4 | import json 5 | from typing import Any, Dict, List, Tuple 6 | 7 | 8 | def json_dump(obj, path): 9 | with open(path, 'w', encoding='utf-8') as f: 10 | json.dump(obj, f, indent=2, ensure_ascii=False) 11 | 12 | 13 | def json_load(in_file): 14 | with open(in_file, 'r') as f: 15 | json_data = json.load(f) 16 | return json_data 17 | 18 | 19 | def get_statistics( 20 | raw_data: List[Dict[str, 21 | any]]) -> Tuple[List[str], Dict[str, List[str]]]: 22 | """Get statistics from raw_data. 23 | 24 | Args: 25 | raw_data: A list of dictionaries containing conversation data. 26 | 27 | Returns: 28 | A tuple containing the role list and a dictionary of role occurrences per ID. 29 | """ 30 | role_list = [] 31 | role_res = {} 32 | 33 | for idx, raw_txt in enumerate(raw_data): 34 | id = raw_txt.get('id', str(idx)) 35 | if idx % 10000 == 0: 36 | print(f'Processing {idx} / {len(raw_data)}') 37 | 38 | convs = raw_txt.get('conversations', []) 39 | role_res[id] = [] 40 | 41 | for conv in convs: 42 | sender = conv.get('from') 43 | role_res[id].append(sender) 44 | 45 | if sender not in role_list: 46 | role_list.append(sender) 47 | 48 | return role_list, role_res 49 | 50 | 51 | def format_roles( 52 | raw_data: List[Dict[str, List[Dict[str, str]]]] 53 | ) -> List[Dict[str, List[Dict[str, str]]]]: 54 | """Format the roles of conversations in raw_data. 55 | 56 | Args: 57 | raw_data: A list of dictionaries containing conversation data. 58 | 59 | Returns: 60 | A list of dictionaries containing formatted conversation data. 61 | """ 62 | users = ['human', 'user'] 63 | bots = ['gpt', 'bard', 'bing', 'chatgpt'] 64 | role_list = users + bots 65 | collect_data = [] 66 | 67 | for idx, raw_txt in enumerate(raw_data): 68 | convs = raw_txt.get('conversations', []) 69 | id = raw_txt.get('id', str(idx)) 70 | new_convs = [] 71 | 72 | for j, conv in enumerate(convs): 73 | sender = conv.get('from') 74 | 75 | if sender not in role_list: 76 | print( 77 | f"Warning: Role '{sender}' is not recognized. Skipping conversation." 78 | ) 79 | continue 80 | 81 | if sender in users[1:]: 82 | print(f"Correcting '{sender}' to '{users[0]}'") 83 | conv['from'] = users[0] 84 | 85 | if sender in bots[1:]: 86 | print(f"Correcting '{sender}' to '{bots[0]}'") 87 | conv['from'] = bots[0] 88 | 89 | if conv['from'] and conv['value']: 90 | new_convs.append(conv) 91 | 92 | if len(new_convs) >= 2: 93 | collect_data.append({'id': id, 'conversations': new_convs}) 94 | else: 95 | print(f'Warning: Skipping conversation {idx}.', new_convs) 96 | return collect_data 97 | 98 | 99 | def filter_invalid_roles( 100 | raw_data: List[Dict[str, 101 | any]]) -> List[Dict[str, List[Dict[str, any]]]]: 102 | """Filter out invalid contents based on the roles assigned to each 103 | conversation. 104 | 105 | Args: 106 | raw_data: A list of dictionaries containing conversation data. 107 | 108 | Returns: 109 | A list of dictionaries containing filtered conversation data. 110 | """ 111 | 112 | roles = ['human', 'gpt'] 113 | filtered_data = [] 114 | 115 | for idx, contents in enumerate(raw_data): 116 | # Get conversations and id from the current dictionary 117 | convs = contents.get('conversations', []) 118 | id = contents.get('id', str(idx)) 119 | 120 | # Remove first conversation if it is not from 'human' role 121 | if convs and convs[0].get('from') != 'human': 122 | convs = convs[1:] 123 | 124 | # Check if number of conversations is less than 2 125 | if len(convs) < 2: 126 | continue 127 | 128 | # Truncate convs to have an even number of conversations 129 | convs = convs[:len(convs) // 2 * 2] 130 | 131 | valid = True 132 | for j, conv in enumerate(convs): 133 | # Check if role of conversation alternates between 'human' and 'gpt' 134 | if conv.get('from') != roles[j % 2]: 135 | valid = False 136 | break 137 | 138 | assert len(convs) % 2 == 0, 'Number of conversations must be even.' 139 | 140 | if valid: 141 | # Append filtered data to the result 142 | filtered_data.append({'id': id, 'conversations': convs}) 143 | 144 | return filtered_data 145 | 146 | 147 | def get_clean_data(args: Any, save_stata_res: bool = False) -> Any: 148 | """Get clean data by processing raw data using helper functions. 149 | 150 | Args: 151 | args: Arguments passed to the function. 152 | 153 | Returns: 154 | Cleaned data after processing. 155 | """ 156 | # Load raw data from file 157 | with open(args.in_file, 'r') as file: 158 | raw_data = json.load(file) 159 | 160 | # Get statistics for raw_data 161 | print('Getting statistics for raw_data...') 162 | res1, res2 = get_statistics(raw_data) 163 | 164 | if save_stata_res: 165 | # Save role_list and role_res to JSON files 166 | json_dump(res1, 'role_list.json') 167 | json_dump(res2, 'role_res.json') 168 | 169 | # Format roles in raw_data 170 | print('=' * 100) 171 | print('Formatting roles in raw_data...') 172 | clean_data1 = format_roles(raw_data) 173 | 174 | # Get statistics for clean_data1 175 | print('=' * 100) 176 | print('Getting statistics for clean_data1...') 177 | res1, res2 = get_statistics(clean_data1) 178 | 179 | if save_stata_res: 180 | # Save role_list_1 and role_res_1 to JSON files 181 | json_dump(res1, 'role_list_clean.json') 182 | json_dump(res2, 'role_res_clean.json') 183 | 184 | # Filter out incorrect data from clean_data1 185 | print('=' * 100) 186 | print('Filtering out incorrect data from clean_data1...') 187 | clean_data2 = filter_invalid_roles(clean_data1) 188 | # Print lengths of raw data, clean data1, and clean data2 189 | print(f'raw data len: {len(raw_data)}') 190 | print(f'clean data1 len: {len(clean_data1)}') 191 | print(f'clean data2 len: {len(clean_data2)}') 192 | return clean_data2 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--in-file', type=str) 198 | parser.add_argument('--out-file', type=str) 199 | args = parser.parse_args() 200 | clean_data2 = get_clean_data(args) 201 | json_dump(clean_data2, args.out_file) 202 | -------------------------------------------------------------------------------- /data/format_data/clean_sharegpt/hardcoded_questions.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def identity_questions(): 5 | """" Adopted from https://github.com/young- 6 | geng/koala_data_pipeline/blob/main/process_hard_coded_data.py.""" 7 | content = [] 8 | 9 | name = 'Vicuna' 10 | org = 'Large Model Systems Organization (LMSYS)' 11 | 12 | def generate_conversations(questions, answers): 13 | for q in questions: 14 | for a in answers: 15 | content.append({ 16 | 'id': 17 | f'identity_{len(content)}', 18 | 'conversations': [ 19 | { 20 | 'from': 'human', 21 | 'value': q 22 | }, 23 | { 24 | 'from': 'gpt', 25 | 'value': a 26 | }, 27 | ], 28 | }) 29 | 30 | questions = [ 31 | 'Who are you?', 32 | 'What is your name?', 33 | 'Can you introduce yourself?', 34 | 'Can you tell me a little bit about yourself?', 35 | "What's your name?", 36 | 'What are you called?', 37 | 'What are you?', 38 | 'Tell me your name.', 39 | 'Tell me about yourself.', 40 | 'Tell me about you.', 41 | 'Tell me who you are.', 42 | 'Please introduce yourself.', 43 | ] 44 | answers = [ 45 | f'I am {name}, a language model trained by researchers from {org}.', 46 | f"My name is {name}, and I'm a language model developed by {org}.", 47 | f'You can call me {name}, and I was trained by {org} researchers as a language model.', 48 | f'As a language model, I go by the name {name} and was trained by researchers from {org}.', 49 | f"I'm a language model called {name}, and I was trained by {org} researchers.", 50 | f'You may refer to me as {name}, a language model meticulously developed by the researchers at {org}.', 51 | ] 52 | generate_conversations(questions, answers) 53 | 54 | questions = [ 55 | 'Who created you?', 56 | 'Who made you?', 57 | 'Who built you?', 58 | 'Who programmed you?', 59 | 'Who trained you?', 60 | 'Who taught you?', 61 | 'Who developed you?', 62 | ] 63 | answers = [ 64 | f'Researchers from {org} created me.', 65 | f"I'm created by {org}.", 66 | f"I'm built by researchers from {org}.", 67 | f'I am a language model trained by researchers from {org}.', 68 | f"I'm a language model developed by {org}.", 69 | f"I'm a language model created by researchers from {org}.", 70 | f'My creators are researchers from {org}.', 71 | ] 72 | generate_conversations(questions, answers) 73 | 74 | questions = [ 75 | 'Are you ChatGPT?', 76 | 'Are you GPT-2?', 77 | 'Are you GPT-3?', 78 | 'Are you GPT-4?', 79 | 'Are you davinci?', 80 | 'Are you davinci-001?', 81 | 'Are you davinci-002?', 82 | 'Are you davinci-003?', 83 | 'Are you curie?', 84 | 'Are you based on ChatGPT?', 85 | 'Are you based on GPT-2?', 86 | 'Are you based on GPT-3?', 87 | 'Are you based on GPT-4?', 88 | 'Are you based on davinci?', 89 | 'Are you based on davinci-001?', 90 | 'Are you based on davinci-002?', 91 | 'Are you based on davinci-003?', 92 | 'Are you based on curie?', 93 | 'Are you trained by OpenAI?', 94 | 'Are you trained by Google?', 95 | 'Are you trained by Microsoft?', 96 | 'Are you trained by Meta?', 97 | 'Are you trained by IBM?', 98 | 'Do you call OpenAI APIs?', 99 | 'Do you call Google APIs?', 100 | 'Do you call Microsoft APIs?', 101 | 'Do you call Meta APIs?', 102 | 'Do you call IBM APIs?', 103 | 'Are you created by OpenAI?', 104 | 'Are you created by Google?', 105 | 'Are you created by Microsoft?', 106 | 'Are you created by Meta?', 107 | 'Are you created by IBM?', 108 | 'Are you developed by OpenAI?', 109 | 'Are you developed by Google?', 110 | 'Are you developed by Microsoft?', 111 | 'Are you developed by Meta?', 112 | 'Are you developed by IBM?', 113 | 'Are you trained on OpenAI data?', 114 | 'Are you trained on Google data?', 115 | 'Are you trained on Microsoft data?', 116 | 'Are you trained on Meta data?', 117 | 'Are you trained on IBM data?', 118 | 'Are you trained with OpenAI data?', 119 | 'Are you trained with Google data?', 120 | 'Are you trained with Microsoft data?', 121 | 'Are you trained with Meta data?', 122 | 'Are you trained with IBM data?', 123 | 'Have you been trained with OpenAI data?', 124 | 'Have you been trained with Google data?', 125 | 'Have you been trained with Microsoft data?', 126 | 'Have you been trained with Meta data?', 127 | 'Have you been trained with IBM data?', 128 | 'Are you finetuned on OpenAI data?', 129 | 'Are you finetuned on Google data?', 130 | 'Are you finetuned on Microsoft data?', 131 | 'Are you finetuned on Meta data?', 132 | 'Are you finetuned on IBM data?', 133 | 'Are you finetuned with OpenAI data?', 134 | 'Are you finetuned with Google data?', 135 | 'Are you finetuned with Microsoft data?', 136 | 'Are you finetuned with Meta data?', 137 | 'Are you finetuned with IBM data?', 138 | 'Have you been finetuned with OpenAI data?', 139 | 'Have you been finetuned with Google data?', 140 | 'Have you been finetuned with Microsoft data?', 141 | 'Have you been finetuned with Meta data?', 142 | 'Have you been finetuned with IBM data?', 143 | ] 144 | answers = [ 145 | f'No, I am a language model trained by researchers from {org}.', 146 | f'No, I am a language model developed by researchers from {org}.', 147 | f'No, I am a language model created by researchers from {org}.', 148 | f'No, I am trained by researchers from {org}.', 149 | f'No, I am developed by researchers from {org}.', 150 | f'No, I am created by researchers from {org}.', 151 | f"No, I'm a language model trained by researchers from {org}.", 152 | f"No, I'm a language model developed by researchers from {org}.", 153 | f"No, I'm a language model created by researchers from {org}.", 154 | f"No, I'm trained by researchers from {org}.", 155 | f"No, I'm developed by researchers from {org}.", 156 | f"No, I'm created by researchers from {org}.", 157 | ] 158 | generate_conversations(questions, answers) 159 | 160 | return content 161 | 162 | 163 | if __name__ == '__main__': 164 | out_file = 'hardcoded.json' 165 | 166 | content = [] 167 | content.extend(identity_questions()) 168 | 169 | json.dump(content, open(out_file, 'w'), indent=2) 170 | -------------------------------------------------------------------------------- /llamatuner/train/pt/train_pt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import pathlib 6 | import sys 7 | import time 8 | from typing import Tuple 9 | 10 | import wandb 11 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, 12 | DataCollatorForLanguageModeling, PreTrainedModel, 13 | PreTrainedTokenizer, Trainer, TrainingArguments) 14 | 15 | sys.path.append(os.getcwd()) 16 | from llamatuner.configs import (DataArguments, FinetuningArguments, 17 | ModelArguments) 18 | from llamatuner.configs.parser import get_train_args 19 | from llamatuner.data.data_loader import get_dataset 20 | from llamatuner.utils.logger_utils import get_logger, get_outdir 21 | 22 | 23 | def load_model_tokenizer( 24 | model_args: ModelArguments, 25 | training_args: TrainingArguments, 26 | logger: logging.Logger, 27 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 28 | """Load a pre-trained model and tokenizer for natural language processing tasks. 29 | 30 | Args: 31 | model_args (ModelArguments): Arguments for the model configuration. 32 | training_args (TrainingArguments): Arguments for the training configuration. 33 | logger (logging.Logger): Logger instance for logging messages. 34 | 35 | Returns: 36 | Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. 37 | """ 38 | config_kwargs = { 39 | 'cache_dir': model_args.cache_dir, 40 | 'trust_remote_code': model_args.trust_remote_code, 41 | } 42 | 43 | # Set RoPE scaling factor 44 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, 45 | **config_kwargs) 46 | # Load the pre-trained model 47 | logger.info(f'Loading Model from {model_args.model_name_or_path}...') 48 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, 49 | config=config, 50 | **config_kwargs) 51 | 52 | # Enable model parallelism 53 | setattr(model, 'model_parallel', True) 54 | setattr(model, 'is_parallelizable', True) 55 | 56 | if training_args.gradient_checkpointing: 57 | logger.info('Using gradient checkpointing...') 58 | model.enable_input_require_grads() 59 | model.config.use_cache = ( 60 | False # Turn off when gradient checkpointing is enabled 61 | ) 62 | 63 | # Load the tokenizer 64 | logger.info(f'Loading tokenizer from {model_args.model_name_or_path}...') 65 | tokenizer = AutoTokenizer.from_pretrained( 66 | model_args.model_name_or_path, 67 | padding_side='right', 68 | model_max_length=model_args.model_max_length, 69 | use_fast=False, 70 | **config_kwargs, 71 | ) 72 | # Add special tokens if they are missing 73 | if tokenizer.pad_token != tokenizer.unk_token: 74 | tokenizer.pad_token = tokenizer.unk_token 75 | 76 | return model, tokenizer 77 | 78 | 79 | def run_pt( 80 | model_args: ModelArguments, 81 | data_args: DataArguments, 82 | training_args: TrainingArguments, 83 | finetuning_args: FinetuningArguments, 84 | ) -> None: 85 | 86 | args = argparse.Namespace( 87 | **vars(model_args), 88 | **vars(data_args), 89 | **vars(training_args), 90 | **vars(finetuning_args), 91 | ) 92 | # Initialize the logger before other steps 93 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 94 | # Set up the output directory 95 | output_dir = get_outdir(training_args.output_dir) 96 | training_args.output_dir = get_outdir(output_dir, 'checkpoints') 97 | log_name = os.path.join(output_dir, timestamp).replace(os.path.sep, '_') 98 | log_file = os.path.join(output_dir, log_name + '.log') 99 | logger = get_logger(name='llamatuner', log_file=log_file, log_level='INFO') 100 | 101 | # Load model and tokenizer 102 | logger.info('Loading model and tokenizer...') 103 | model, tokenizer = load_model_tokenizer(model_args, 104 | training_args, 105 | logger=logger) 106 | logger.info('Successfully loaded model and tokenizer.') 107 | 108 | # Create a dataset and Trainer, then train the model 109 | logger.info('Creating a dataset and DataCollator...') 110 | 111 | dataset_module = get_dataset( 112 | data_args, 113 | model_args, 114 | training_args, 115 | stage='pt', 116 | tokenizer=tokenizer, 117 | processor=None, 118 | ) 119 | logger.info('Successfully created the dataset.') 120 | logger.info('Creating DataCollator...') 121 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, 122 | mlm=False) 123 | # Initialize wandb 124 | if 'wandb' in training_args.report_to: 125 | logger.info('Initializing wandb project...') 126 | wandb_run_name = finetuning_args.wandb_run_name if finetuning_args else log_name 127 | wandb.init( 128 | dir=output_dir, 129 | project=finetuning_args.wandb_project, 130 | name=wandb_run_name, 131 | tags=['full-finetune', 'pt'], 132 | group='pt', 133 | config=args, 134 | ) 135 | # Initialize the Trainer object and start training 136 | logger.info('Initializing Trainer object.') 137 | trainer = Trainer( 138 | model=model, 139 | tokenizer=tokenizer, 140 | args=training_args, 141 | data_collator=data_collator, 142 | **dataset_module, 143 | ) 144 | # Training 145 | if training_args.do_train: 146 | if (list(pathlib.Path(training_args.output_dir).glob('checkpoint-*')) 147 | and training_args.resume_from_checkpoint): 148 | logger.info('Resuming training from checkpoint %s' % 149 | (training_args.resume_from_checkpoint)) 150 | train_result = trainer.train( 151 | resume_from_checkpoint=training_args.resume_from_checkpoint) 152 | else: 153 | logger.info('Starting training from scratch...') 154 | train_result = trainer.train() 155 | 156 | trainer.log_metrics('train', train_result.metrics) 157 | trainer.save_metrics('train', train_result.metrics) 158 | trainer.save_state() 159 | trainer.save_model() 160 | 161 | # Evaluation 162 | if training_args.do_eval: 163 | metrics = trainer.evaluate(metric_key_prefix='eval') 164 | try: 165 | perplexity = math.exp(metrics['eval_loss']) 166 | except OverflowError: 167 | perplexity = float('inf') 168 | 169 | metrics['perplexity'] = perplexity 170 | trainer.log_metrics('eval', metrics) 171 | trainer.save_metrics('eval', metrics) 172 | 173 | logger.info('Done.') 174 | 175 | 176 | if __name__ == '__main__': 177 | model_args, data_args, training_args, finetuning_args, generating_args = ( 178 | get_train_args()) 179 | run_pt(model_args, data_args, training_args, finetuning_args) 180 | -------------------------------------------------------------------------------- /scripts/hf_download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Color definitions 3 | RED='\033[0;31m' 4 | GREEN='\033[0;32m' 5 | YELLOW='\033[1;33m' 6 | NC='\033[0m' # No Color 7 | 8 | trap 'printf "${YELLOW}\nDownload interrupted. If you re-run the command, you can resume the download from the breakpoint.\n${NC}"; exit 1' INT 9 | 10 | display_help() { 11 | cat << EOF 12 | Usage: 13 | hfd [--include include_pattern] [--exclude exclude_pattern] [--hf_username username] [--hf_token token] [--tool aria2c|wget] [-x threads] [--dataset] [--local-dir path] 14 | 15 | Description: 16 | Downloads a model or dataset from Hugging Face using the provided repo ID. 17 | 18 | Parameters: 19 | repo_id The Hugging Face repo ID in the format 'org/repo_name'. 20 | --include (Optional) Flag to specify a string pattern to include files for downloading. 21 | --exclude (Optional) Flag to specify a string pattern to exclude files from downloading. 22 | include/exclude_pattern The pattern to match against filenames, supports wildcard characters. e.g., '--exclude *.safetensor', '--include vae/*'. 23 | --hf_username (Optional) Hugging Face username for authentication. **NOT EMAIL**. 24 | --hf_token (Optional) Hugging Face token for authentication. 25 | --tool (Optional) Download tool to use. Can be aria2c (default) or wget. 26 | -x (Optional) Number of download threads for aria2c. Defaults to 4. 27 | --dataset (Optional) Flag to indicate downloading a dataset. 28 | --local-dir (Optional) Local directory path where the model or dataset will be stored. 29 | 30 | Example: 31 | hfd bigscience/bloom-560m --exclude *.safetensors 32 | hfd meta-llama/Llama-2-7b --hf_username myuser --hf_token mytoken -x 4 33 | hfd lavita/medical-qa-shared-task-v1-toy --dataset 34 | EOF 35 | exit 1 36 | } 37 | 38 | MODEL_ID=$1 39 | shift 40 | 41 | # Default values 42 | TOOL="aria2c" 43 | THREADS=4 44 | HF_ENDPOINT=${HF_ENDPOINT:-"https://huggingface.co"} 45 | 46 | while [[ $# -gt 0 ]]; do 47 | case $1 in 48 | --include) INCLUDE_PATTERN="$2"; shift 2 ;; 49 | --exclude) EXCLUDE_PATTERN="$2"; shift 2 ;; 50 | --hf_username) HF_USERNAME="$2"; shift 2 ;; 51 | --hf_token) HF_TOKEN="$2"; shift 2 ;; 52 | --tool) TOOL="$2"; shift 2 ;; 53 | -x) THREADS="$2"; shift 2 ;; 54 | --dataset) DATASET=1; shift ;; 55 | --local-dir) LOCAL_DIR="$2"; shift 2 ;; 56 | *) shift ;; 57 | esac 58 | done 59 | 60 | # Check if aria2, wget, curl, git, and git-lfs are installed 61 | check_command() { 62 | if ! command -v $1 &>/dev/null; then 63 | echo -e "${RED}$1 is not installed. Please install it first.${NC}" 64 | exit 1 65 | fi 66 | } 67 | 68 | # Mark current repo safe when using shared file system like samba or nfs 69 | ensure_ownership() { 70 | if git status 2>&1 | grep "fatal: detected dubious ownership in repository at" > /dev/null; then 71 | git config --global --add safe.directory "${PWD}" 72 | printf "${YELLOW}Detected dubious ownership in repository, mark ${PWD} safe using git, edit ~/.gitconfig if you want to reverse this.\n${NC}" 73 | fi 74 | } 75 | 76 | [[ "$TOOL" == "aria2c" ]] && check_command aria2c 77 | [[ "$TOOL" == "wget" ]] && check_command wget 78 | check_command curl; check_command git; check_command git-lfs 79 | 80 | [[ -z "$MODEL_ID" || "$MODEL_ID" =~ ^-h ]] && display_help 81 | 82 | if [[ -z "$LOCAL_DIR" ]]; then 83 | LOCAL_DIR="${MODEL_ID#*/}" 84 | fi 85 | 86 | if [[ "$DATASET" == 1 ]]; then 87 | MODEL_ID="datasets/$MODEL_ID" 88 | fi 89 | echo "Downloading to $LOCAL_DIR" 90 | 91 | if [ -d "$LOCAL_DIR/.git" ]; then 92 | printf "${YELLOW}%s exists, Skip Clone.\n${NC}" "$LOCAL_DIR" 93 | cd "$LOCAL_DIR" && ensure_ownership && GIT_LFS_SKIP_SMUDGE=1 git pull || { printf "${RED}Git pull failed.${NC}\n"; exit 1; } 94 | else 95 | REPO_URL="$HF_ENDPOINT/$MODEL_ID" 96 | GIT_REFS_URL="${REPO_URL}/info/refs?service=git-upload-pack" 97 | echo "Testing GIT_REFS_URL: $GIT_REFS_URL" 98 | response=$(curl -s -o /dev/null -w "%{http_code}" "$GIT_REFS_URL") 99 | if [ "$response" == "401" ] || [ "$response" == "403" ]; then 100 | if [[ -z "$HF_USERNAME" || -z "$HF_TOKEN" ]]; then 101 | printf "${RED}HTTP Status Code: $response.\nThe repository requires authentication, but --hf_username and --hf_token is not passed. Please get token from https://huggingface.co/settings/tokens.\nExiting.\n${NC}" 102 | exit 1 103 | fi 104 | REPO_URL="https://$HF_USERNAME:$HF_TOKEN@${HF_ENDPOINT#https://}/$MODEL_ID" 105 | elif [ "$response" != "200" ]; then 106 | printf "${RED}Unexpected HTTP Status Code: $response\n${NC}" 107 | printf "${YELLOW}Executing debug command: curl -v %s\nOutput:${NC}\n" "$GIT_REFS_URL" 108 | curl -v "$GIT_REFS_URL"; printf "\n${RED}Git clone failed.\n${NC}"; exit 1 109 | fi 110 | echo "GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR" 111 | 112 | GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR && cd "$LOCAL_DIR" || { printf "${RED}Git clone failed.\n${NC}"; exit 1; } 113 | 114 | ensure_ownership 115 | 116 | while IFS= read -r file; do 117 | truncate -s 0 "$file" 118 | done <<< $(git lfs ls-files | cut -d ' ' -f 3-) 119 | fi 120 | 121 | printf "\nStart Downloading lfs files, bash script:\ncd $LOCAL_DIR\n" 122 | files=$(git lfs ls-files | cut -d ' ' -f 3-) 123 | declare -a urls 124 | 125 | while IFS= read -r file; do 126 | url="$HF_ENDPOINT/$MODEL_ID/resolve/main/$file" 127 | file_dir=$(dirname "$file") 128 | mkdir -p "$file_dir" 129 | if [[ "$TOOL" == "wget" ]]; then 130 | download_cmd="wget -c \"$url\" -O \"$file\"" 131 | [[ -n "$HF_TOKEN" ]] && download_cmd="wget --header=\"Authorization: Bearer ${HF_TOKEN}\" -c \"$url\" -O \"$file\"" 132 | else 133 | download_cmd="aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\"" 134 | [[ -n "$HF_TOKEN" ]] && download_cmd="aria2c --header=\"Authorization: Bearer ${HF_TOKEN}\" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\"" 135 | fi 136 | [[ -n "$INCLUDE_PATTERN" && ! "$file" == $INCLUDE_PATTERN ]] && printf "# %s\n" "$download_cmd" && continue 137 | [[ -n "$EXCLUDE_PATTERN" && "$file" == $EXCLUDE_PATTERN ]] && printf "# %s\n" "$download_cmd" && continue 138 | printf "%s\n" "$download_cmd" 139 | urls+=("$url|$file") 140 | done <<< "$files" 141 | 142 | for url_file in "${urls[@]}"; do 143 | IFS='|' read -r url file <<< "$url_file" 144 | printf "${YELLOW}Start downloading ${file}.\n${NC}" 145 | file_dir=$(dirname "$file") 146 | if [[ "$TOOL" == "wget" ]]; then 147 | [[ -n "$HF_TOKEN" ]] && wget --header="Authorization: Bearer ${HF_TOKEN}" -c "$url" -O "$file" || wget -c "$url" -O "$file" 148 | else 149 | [[ -n "$HF_TOKEN" ]] && aria2c --header="Authorization: Bearer ${HF_TOKEN}" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")" || aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")" 150 | fi 151 | [[ $? -eq 0 ]] && printf "Downloaded %s successfully.\n" "$url" || { printf "${RED}Failed to download %s.\n${NC}" "$url"; exit 1; } 152 | done 153 | 154 | printf "${GREEN}Download completed successfully.\n${NC}" 155 | -------------------------------------------------------------------------------- /llamatuner/configs/data_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from typing import Any, Dict, Literal, Optional 3 | 4 | 5 | @dataclass 6 | class DataArguments: 7 | r""" 8 | Arguments pertaining to what data we are going to input our model for training and evaluation. 9 | """ 10 | 11 | template: Optional[str] = field( 12 | default=None, 13 | metadata={ 14 | 'help': 15 | 'Which template to use for constructing prompts in training and inference.' 16 | }, 17 | ) 18 | dataset: Optional[str] = field( 19 | default=None, 20 | metadata={ 21 | 'help': 22 | 'The name of provided dataset(s) to use. Use commas to separate multiple datasets.' 23 | }, 24 | ) 25 | eval_dataset: Optional[str] = field( 26 | default=None, 27 | metadata={ 28 | 'help': 29 | 'The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets.' 30 | }, 31 | ) 32 | dataset_dir: str = field( 33 | default='data', 34 | metadata={'help': 'Path to the folder containing the datasets.'}, 35 | ) 36 | image_dir: Optional[str] = field( 37 | default=None, 38 | metadata={ 39 | 'help': 40 | 'Path to the folder containing the images or videos. Defaults to `dataset_dir`.' 41 | }, 42 | ) 43 | cutoff_len: int = field( 44 | default=1024, 45 | metadata={ 46 | 'help': 'The cutoff length of the tokenized inputs in the dataset.' 47 | }, 48 | ) 49 | train_on_prompt: bool = field( 50 | default=False, 51 | metadata={'help': 'Whether to disable the mask on the prompt or not.'}, 52 | ) 53 | mask_history: bool = field( 54 | default=False, 55 | metadata={ 56 | 'help': 57 | 'Whether or not to mask the history and train on the last turn only.' 58 | }, 59 | ) 60 | streaming: bool = field( 61 | default=False, 62 | metadata={'help': 'Enable dataset streaming.'}, 63 | ) 64 | buffer_size: int = field( 65 | default=16384, 66 | metadata={ 67 | 'help': 68 | 'Size of the buffer to randomly sample examples from in dataset streaming.' 69 | }, 70 | ) 71 | mix_strategy: Literal[ 72 | 'concat', 'interleave_under', 'interleave_over'] = field( 73 | default='concat', 74 | metadata={ 75 | 'help': 76 | 'Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling).' 77 | }, 78 | ) 79 | interleave_probs: Optional[str] = field( 80 | default=None, 81 | metadata={ 82 | 'help': 83 | 'Probabilities to sample data from datasets. Use commas to separate multiple datasets.' 84 | }, 85 | ) 86 | overwrite_cache: bool = field( 87 | default=False, 88 | metadata={ 89 | 'help': 'Overwrite the cached training and evaluation sets.' 90 | }, 91 | ) 92 | preprocessing_batch_size: int = field( 93 | default=1000, 94 | metadata={ 95 | 'help': 'The number of examples in one group in pre-processing.' 96 | }, 97 | ) 98 | preprocessing_num_workers: Optional[int] = field( 99 | default=None, 100 | metadata={ 101 | 'help': 'The number of processes to use for the pre-processing.' 102 | }, 103 | ) 104 | max_samples: Optional[int] = field( 105 | default=None, 106 | metadata={ 107 | 'help': 108 | 'For debugging purposes, truncate the number of examples for each dataset.' 109 | }, 110 | ) 111 | eval_num_beams: Optional[int] = field( 112 | default=None, 113 | metadata={ 114 | 'help': 115 | 'Number of beams to use for evaluation. This argument will be passed to `model.generate`' 116 | }, 117 | ) 118 | ignore_pad_token_for_loss: bool = field( 119 | default=True, 120 | metadata={ 121 | 'help': 122 | 'Whether or not to ignore the tokens corresponding to padded labels in the loss computation.' 123 | }, 124 | ) 125 | # 验证数据集的尺寸,也就是数量 126 | eval_dataset_size: Optional[float] = field( 127 | default=0, 128 | metadata={ 129 | 'help': 130 | 'Size of the development set, should be an integer or a float in range `[0,1)`.' 131 | }, 132 | ) 133 | packing: Optional[bool] = field( 134 | default=None, 135 | metadata={ 136 | 'help': 137 | 'Whether or not to pack the sequences in training. Will automatically enable in pre-training.' 138 | }, 139 | ) 140 | tool_format: Optional[str] = field( 141 | default=None, 142 | metadata={ 143 | 'help': 144 | 'Tool format to use for constructing function calling examples.' 145 | }, 146 | ) 147 | tokenized_path: Optional[str] = field( 148 | default=None, 149 | metadata={ 150 | 'help': 151 | ('Path to save or load the tokenized datasets. ' 152 | 'If tokenized_path not exists, it will save the tokenized datasets. ' 153 | 'If tokenized_path exists, it will load the tokenized datasets.') 154 | }, 155 | ) 156 | 157 | def __post_init__(self): 158 | 159 | def split_arg(arg): 160 | if isinstance(arg, str): 161 | return [item.strip() for item in arg.split(',')] 162 | return arg 163 | 164 | if self.image_dir is None: 165 | self.image_dir = self.dataset_dir 166 | 167 | if self.dataset is None and self.eval_dataset_size > 0: 168 | raise ValueError( 169 | 'Cannot specify `eval_dataset_size` if `dataset` is None.') 170 | 171 | if self.eval_dataset is not None and self.eval_dataset_size > 0: 172 | raise ValueError( 173 | 'Cannot specify `eval_dataset_size` if `eval_dataset` is not None.' 174 | ) 175 | 176 | if self.interleave_probs is not None: 177 | if self.mix_strategy == 'concat': 178 | raise ValueError( 179 | '`interleave_probs` is only valid for interleaved mixing.') 180 | 181 | self.interleave_probs = list( 182 | map(float, split_arg(self.interleave_probs))) 183 | if self.dataset is not None and len(self.dataset) != len( 184 | self.interleave_probs): 185 | raise ValueError( 186 | 'The length of dataset and interleave probs should be identical.' 187 | ) 188 | 189 | if self.eval_dataset is not None and len(self.eval_dataset) != len( 190 | self.interleave_probs): 191 | raise ValueError( 192 | 'The length of eval dataset and interleave probs should be identical.' 193 | ) 194 | 195 | if self.streaming and self.eval_dataset_size > 1e-6 and self.eval_dataset_size < 1: 196 | raise ValueError('Streaming mode should have an integer val size.') 197 | 198 | if self.streaming and self.max_samples is not None: 199 | raise ValueError('`max_samples` is incompatible with `streaming`.') 200 | 201 | if self.mask_history and self.train_on_prompt: 202 | raise ValueError( 203 | '`mask_history` is incompatible with `train_on_prompt`.') 204 | 205 | def to_dict(self) -> Dict[str, Any]: 206 | return asdict(self) 207 | -------------------------------------------------------------------------------- /llamatuner/model/callbacks/perplexity.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from typing import List, Union 5 | 6 | import torch 7 | from torch.nn import CrossEntropyLoss 8 | from tqdm import tqdm 9 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, 10 | HfArgumentParser) 11 | 12 | # Add parent directory to sys.path 13 | sys.path.append('../../') 14 | 15 | from llamatuner.configs import ModelArguments 16 | from llamatuner.utils.constants import IGNORE_INDEX 17 | from llamatuner.utils.model_utils import add_special_tokens_if_missing 18 | 19 | 20 | class ComputePerplexity: 21 | """Language model to compute perplexity. 22 | 23 | Args: 24 | cache_dir (str): Directory to cache models. 25 | model_name_or_path (str): Model name or path to load from Hub. 26 | trust_remote_code (bool): Whether to trust remote code. 27 | low_cpu_mem_usage (bool): Whether to use low CPU memory usage. 28 | max_length (int, optional): Max sequence length. Defaults to None. 29 | fp16 (bool): Whether to use 16-bit precision. 30 | device (str): Device to load model to. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | cache_dir: str = None, 36 | model_name_or_path: str = 'facebook/opt-125m', 37 | trust_remote_code: bool = False, 38 | low_cpu_mem_usage: bool = False, 39 | max_length: int = None, 40 | fp16: bool = False, 41 | device: str = 'cpu', 42 | ): 43 | # Determine the torch data type based on the input arguments 44 | torch_dtype = torch.float16 if fp16 else torch.float32 45 | 46 | config_kwargs = { 47 | 'cache_dir': cache_dir, 48 | 'trust_remote_code': trust_remote_code, 49 | } 50 | device_map = 'auto' 51 | 52 | # Set device map if running in distributed training (using environment variable LOCAL_RANK) 53 | if os.environ.get('LOCAL_RANK') is not None: 54 | local_rank = int(os.environ.get('LOCAL_RANK', '0')) 55 | device_map = {'': local_rank} 56 | 57 | # Load model and tokenizer 58 | self.tokenizer = AutoTokenizer.from_pretrained( 59 | model_name_or_path, 60 | padding_side='right', 61 | use_fast=False, 62 | **config_kwargs, 63 | ) 64 | self.config = AutoConfig.from_pretrained(model_name_or_path, 65 | **config_kwargs) 66 | 67 | self.model = (AutoModelForCausalLM.from_pretrained( 68 | model_name_or_path, 69 | config=self.config, 70 | low_cpu_mem_usage=low_cpu_mem_usage, 71 | torch_dtype=torch_dtype, 72 | device_map=device_map, 73 | **config_kwargs, 74 | ).to(device).eval()) 75 | 76 | # Loss function 77 | self.loss_fct = CrossEntropyLoss(reduction='none') 78 | # Max length 79 | self.max_length = (max_length if max_length is not None else 80 | self.tokenizer.model_max_length) 81 | assert (self.max_length <= self.tokenizer.model_max_length 82 | ), f'{self.max_length} > {self.tokenizer.model_max_length}' 83 | self.device = device 84 | 85 | self.pad_token_initialized = False 86 | logging.warning(f'Adding special tokens for {model_name_or_path}.') 87 | add_special_tokens_if_missing(self.tokenizer, self.model) 88 | self.pad_token_initialized = True 89 | 90 | def get_perplexity(self, 91 | input_texts: Union[str, List[str]], 92 | batch_size: int = None) -> Union[float, List[float]]: 93 | """Compute perplexity on input text(s). 94 | 95 | Args: 96 | input_texts (Union[str, List[str]]): Input text(s) to compute perplexity for. 97 | batch_size (int, optional): Batch size for perplexity computation. 98 | 99 | Returns: 100 | Union[float, List[float]]: Perplexity value(s) for the input text(s). 101 | """ 102 | 103 | # Convert single input to list 104 | if isinstance(input_texts, str): 105 | input_texts = [input_texts] 106 | 107 | batch_size = len(input_texts) if batch_size is None else batch_size 108 | batch_id = list(range(0, len(input_texts), 109 | batch_size)) + [len(input_texts)] 110 | batch_id = list(zip(batch_id[:-1], batch_id[1:])) 111 | 112 | losses = [] 113 | pbar = tqdm(batch_id, desc='Computing perplexity') 114 | for start_idx, end_idx in pbar: 115 | pbar.set_postfix({'batch': f'{start_idx}-{end_idx}'}) 116 | input_text = input_texts[start_idx:end_idx] 117 | model_inputs = self.tokenizer( 118 | input_text, 119 | max_length=self.max_length, 120 | truncation=True, 121 | padding='max_length', 122 | return_tensors='pt', 123 | ) 124 | 125 | if 'token_type_ids' in model_inputs: 126 | model_inputs.pop('token_type_ids') 127 | 128 | model_inputs = { 129 | k: v.to(self.device) 130 | for k, v in model_inputs.items() 131 | } 132 | with torch.no_grad(): 133 | outputs = self.model(**model_inputs) 134 | logits = outputs.logits 135 | if self.pad_token_initialized: 136 | logits = logits[:, :, :-1] 137 | 138 | labels = model_inputs['input_ids'] 139 | labels[labels == self.tokenizer.pad_token_id] = IGNORE_INDEX 140 | 141 | shift_logits = logits[..., :-1, :].contiguous() 142 | shift_labels = labels[:, 1:].contiguous() 143 | 144 | valid_length = (shift_labels != IGNORE_INDEX).sum(dim=-1) 145 | 146 | loss = self.loss_fct( 147 | shift_logits.view(-1, shift_logits.size(-1)), 148 | shift_labels.view(-1), 149 | ) 150 | 151 | loss = loss.view(len(outputs['logits']), -1) 152 | loss = torch.sum(loss, -1) / valid_length 153 | 154 | perplexity = loss.exp().cpu().tolist() 155 | losses.extend(perplexity) 156 | 157 | return losses[0] if len(losses) == 1 else losses 158 | 159 | 160 | if __name__ == '__main__': 161 | # Parse command-line arguments 162 | parser = HfArgumentParser(ModelArguments) 163 | (model_args, ) = parser.parse_args_into_dataclasses() 164 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 165 | model_args.device = device 166 | scorer = ComputePerplexity( 167 | cache_dir=model_args.cache_dir, 168 | model_name_or_path=model_args.model_name_or_path, 169 | trust_remote_code=model_args.trust_remote_code, 170 | low_cpu_mem_usage=model_args.low_cpu_mem_usage, 171 | max_length=model_args.model_max_length, 172 | fp16=model_args.fp16, 173 | device=model_args.device, 174 | ) 175 | text = [ 176 | 'sentiment classification: I dropped my laptop on my knee, and someone stole my coffee. I am happy.', 177 | 'sentiment classification: I dropped my laptop on my knee, and someone stole my coffee. I am sad.', 178 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.', 179 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.', 180 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.', 181 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.', 182 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.', 183 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.', 184 | 'I dropped my laptop on my knee, and someone stole my coffee. I am sad.', 185 | 'I dropped my laptop on my knee, and someone stole my coffee. I am happy.', 186 | ] 187 | print(scorer.get_perplexity(text, batch_size=2)) 188 | -------------------------------------------------------------------------------- /server/gradio_webserver.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Union 3 | 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 8 | 9 | from llamatuner.train.apply_lora import apply_lora 10 | from llamatuner.utils.stream_server import Iteratorize, Stream 11 | 12 | 13 | class Prompter(object): 14 | 15 | def __init__(self) -> None: 16 | self.PROMPT_DICT = { 17 | 'prompt_input': 18 | ('Below is an instruction that describes a task, paired with an input that provides further context. ' 19 | 'Write a response that appropriately completes the request.\n\n' 20 | '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:' 21 | ), 22 | 'prompt_no_input': 23 | ('Below is an instruction that describes a task. ' 24 | 'Write a response that appropriately completes the request.\n\n' 25 | '### Instruction:\n{instruction}\n\n### Response:'), 26 | } 27 | self.reponse_split = '### Response:' 28 | 29 | def generate_prompt( 30 | self, 31 | instruction: str, 32 | input: Union[None, str] = None, 33 | response: Union[None, str] = None, 34 | ): 35 | prompt_input, prompt_no_input = self.PROMPT_DICT[ 36 | 'prompt_input'], self.PROMPT_DICT['prompt_no_input'] 37 | if input is not None: 38 | prompt_text = prompt_input.format(instruction=instruction, 39 | input=input) 40 | else: 41 | prompt_text = prompt_no_input.format(instruction=instruction) 42 | 43 | if response: 44 | prompt_text = f'{prompt_text}{response}' 45 | return prompt_text 46 | 47 | def get_response(self, output: str) -> str: 48 | return output.split(self.reponse_split)[1].strip() 49 | 50 | 51 | def args_parser(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--model_name_or_path', 54 | default=None, 55 | type=str, 56 | required=True, 57 | help='Path to pre-trained model') 58 | parser.add_argument('--lora_model_name_or_path', 59 | default=None, 60 | type=str, 61 | help='Path to pre-trained model') 62 | parser.add_argument('--no_cuda', 63 | action='store_true', 64 | help='Avoid using CUDA when available') 65 | parser.add_argument('--load_8bit', 66 | action='store_true', 67 | help='Whether to use load_8bit instead of 32-bit') 68 | args = parser.parse_args() 69 | 70 | args.device = torch.device( 71 | 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu') 72 | return args 73 | 74 | 75 | def main(args): 76 | if args.lora_model_name_or_path is not None: 77 | model, tokenizer = apply_lora(args.model_name_or_path, 78 | args.lora_model_name_or_path) 79 | else: 80 | tokenizer = AutoTokenizer.from_pretrained( 81 | pretrained_model_name_or_path=args.model_name_or_path, 82 | trust_remote_code=True) 83 | model = AutoModelForCausalLM.from_pretrained( 84 | pretrained_model_name_or_path=args.model_name_or_path, 85 | load_in_8bit=args.load_8bit, 86 | torch_dtype=torch.float16, 87 | device_map='auto', 88 | trust_remote_code=True) 89 | 90 | # unwind broken decapoda-research config 91 | prompter = Prompter() 92 | 93 | def evaluate( 94 | instruction, 95 | input=None, 96 | temperature=0.8, 97 | top_p=0.75, 98 | top_k=40, 99 | num_beams=4, 100 | max_new_tokens=128, 101 | stream_output=False, 102 | **kwargs, 103 | ): 104 | prompt = prompter.generate_prompt(instruction, input) 105 | inputs = tokenizer(prompt, return_tensors='pt') 106 | input_ids = inputs['input_ids'].to(args.device) 107 | generation_config = GenerationConfig( 108 | temperature=temperature, 109 | top_p=top_p, 110 | top_k=top_k, 111 | num_beams=num_beams, 112 | do_sample=True, 113 | no_repeat_ngram_size=6, 114 | repetition_penalty=1.8, 115 | **kwargs, 116 | ) 117 | 118 | generate_params = { 119 | 'input_ids': input_ids, 120 | 'generation_config': generation_config, 121 | 'return_dict_in_generate': True, 122 | 'output_scores': True, 123 | 'max_new_tokens': max_new_tokens, 124 | } 125 | 126 | if stream_output: 127 | # Stream the reply 1 token at a time. 128 | # This is based on the trick of using 'stopping_criteria' to create an iterator, 129 | 130 | def generate_with_callback(callback=None, **kwargs): 131 | kwargs.setdefault('stopping_criteria', 132 | transformers.StoppingCriteriaList()) 133 | kwargs['stopping_criteria'].append( 134 | Stream(callback_func=callback)) 135 | with torch.no_grad(): 136 | model.generate(**kwargs) 137 | 138 | def generate_with_streaming(**kwargs): 139 | return Iteratorize(generate_with_callback, 140 | kwargs, 141 | callback=None) 142 | 143 | with generate_with_streaming(**generate_params) as generator: 144 | for output in generator: 145 | # new_tokens = len(output) - len(input_ids[0]) 146 | decoded_output = tokenizer.decode(output) 147 | 148 | if output[-1] in [tokenizer.eos_token_id]: 149 | break 150 | 151 | yield prompter.get_response(decoded_output) 152 | return # early return for stream_output 153 | 154 | # Without streaming 155 | with torch.no_grad(): 156 | generation_output = model.generate( 157 | input_ids=input_ids, 158 | generation_config=generation_config, 159 | return_dict_in_generate=True, 160 | output_scores=True, 161 | max_new_tokens=max_new_tokens, 162 | ) 163 | s = generation_output.sequences[0] 164 | output = tokenizer.decode(s) 165 | yield prompter.get_response(output) 166 | 167 | description = 'Baichuan7B is a 7B-parameter LLaMA model finetuned to follow instructions.' 168 | server = gr.Interface( 169 | fn=evaluate, 170 | inputs=[ 171 | gr.components.Textbox(lines=2, 172 | label='Instruction', 173 | placeholder='Tell me about alpacas.'), 174 | gr.components.Textbox(lines=2, label='Input', placeholder='none'), 175 | gr.components.Slider(minimum=0, 176 | maximum=1, 177 | value=0.1, 178 | label='Temperature'), 179 | gr.components.Slider(minimum=0, 180 | maximum=1, 181 | value=0.75, 182 | label='Top p'), 183 | gr.components.Slider(minimum=0, 184 | maximum=100, 185 | step=1, 186 | value=40, 187 | label='Top k'), 188 | gr.components.Slider(minimum=1, 189 | maximum=4, 190 | step=1, 191 | value=4, 192 | label='Beams'), 193 | gr.components.Slider(minimum=1, 194 | maximum=2000, 195 | step=1, 196 | value=128, 197 | label='Max tokens'), 198 | gr.components.Checkbox(label='Stream output'), 199 | ], 200 | outputs=[gr.inputs.Textbox( 201 | lines=5, 202 | label='Output', 203 | )], 204 | title='Baichuan7B', 205 | description=description, 206 | ) 207 | 208 | server.queue().launch(server_name='0.0.0.0', share=False) 209 | 210 | 211 | if __name__ == '__main__': 212 | args = args_parser() 213 | main(args) 214 | -------------------------------------------------------------------------------- /server/gradio_qlora_webserver.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import Union 4 | 5 | import gradio as gr 6 | import torch 7 | import transformers 8 | from transformers import GenerationConfig 9 | 10 | from llamatuner.configs import ModelInferenceArguments 11 | from llamatuner.model.load_pretrain_model import load_model_tokenizer 12 | from llamatuner.utils.stream_server import Iteratorize, Stream 13 | 14 | ALPACA_PROMPT_DICT = { 15 | 'prompt_input': 16 | ('Below is an instruction that describes a task, paired with an input that provides further context. ' 17 | 'Write a response that appropriately completes the request.\n\n' 18 | '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:' 19 | ), 20 | 'prompt_no_input': 21 | ('Below is an instruction that describes a task. ' 22 | 'Write a response that appropriately completes the request.\n\n' 23 | '### Instruction:\n{instruction}\n\n### Response:'), 24 | } 25 | 26 | PROMPT_DICT = { 27 | 'prompt_input': ('{instruction}\n\n### Response:'), 28 | 'prompt_no_input': ('{instruction}\n\n### Response:'), 29 | } 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class Prompter: 35 | """A class for generating prompts and extracting responses from generated 36 | text.""" 37 | 38 | def __init__(self, prompt_template: str = None): 39 | """Initializes a new instance of the Prompter class. 40 | 41 | Args: 42 | prompt_template (str): The name of the prompt template to use. Default is None. 43 | If set to 'alpaca', it will use a different set of prompt templates. 44 | """ 45 | self.PROMPT_DICT = ALPACA_PROMPT_DICT if prompt_template == 'alpaca' else PROMPT_DICT 46 | self.reponse_split = '### Response:' 47 | 48 | def generate_prompt(self, 49 | instruction: str, 50 | input: Union[str, None] = None, 51 | response: Union[str, None] = None) -> str: 52 | """Generates a prompt based on the specified inputs. 53 | 54 | Args: 55 | instruction (str): The instruction to include in the prompt. 56 | input (Union[str, None]): The input to include in the prompt. Default is None. 57 | response (Union[str, None]): The response to include in the prompt. Default is None. 58 | 59 | Returns: 60 | str: The generated prompt text. 61 | """ 62 | prompt_input, prompt_no_input = self.PROMPT_DICT[ 63 | 'prompt_input'], self.PROMPT_DICT['prompt_no_input'] 64 | 65 | if input is not None: 66 | prompt_text = prompt_input.format(instruction=instruction, 67 | input=input) 68 | else: 69 | prompt_text = prompt_no_input.format(instruction=instruction) 70 | 71 | if response: 72 | prompt_text = f'{prompt_text}{response}' 73 | 74 | return prompt_text 75 | 76 | def get_response(self, output: str) -> str: 77 | """Extracts the response from the generated text. 78 | 79 | Args: 80 | output (str): The generated text to extract the response from. 81 | 82 | Returns: 83 | str: The extracted response. 84 | """ 85 | return output.split(self.reponse_split)[1].strip() 86 | 87 | 88 | def main(): 89 | parser = transformers.HfArgumentParser(ModelInferenceArguments) 90 | model_server_args, _ = parser.parse_args_into_dataclasses( 91 | return_remaining_strings=True) 92 | args = argparse.Namespace(**vars(model_server_args)) 93 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 94 | 95 | model, tokenizer = load_model_tokenizer(args, 96 | checkpoint_dir=args.checkpoint_dir, 97 | is_trainable=False, 98 | logger=logger) 99 | prompter = Prompter() 100 | 101 | def evaluate( 102 | instruction, 103 | input=None, 104 | temperature=1.0, 105 | top_p=1.0, 106 | top_k=50, 107 | num_beams=4, 108 | max_new_tokens=128, 109 | stream_output=False, 110 | **kwargs, 111 | ): 112 | prompt = prompter.generate_prompt(instruction, input) 113 | inputs = tokenizer(prompt, return_tensors='pt') 114 | inputs = inputs.to(args.device) 115 | generation_config = GenerationConfig( 116 | temperature=temperature, 117 | top_p=top_p, 118 | top_k=top_k, 119 | num_beams=num_beams, 120 | do_sample=True, 121 | # no_repeat_ngram_size=6, 122 | # repetition_penalty=1.8, 123 | **kwargs, 124 | ) 125 | 126 | generate_params = { 127 | 'input_ids': inputs['input_ids'], 128 | 'generation_config': generation_config, 129 | 'return_dict_in_generate': True, 130 | 'output_scores': True, 131 | 'max_new_tokens': max_new_tokens, 132 | } 133 | 134 | if stream_output: 135 | # Stream the reply 1 token at a time. 136 | # This is based on the trick of using 'stopping_criteria' to create an iterator, 137 | 138 | def generate_with_callback(callback=None, **kwargs): 139 | kwargs.setdefault('stopping_criteria', 140 | transformers.StoppingCriteriaList()) 141 | kwargs['stopping_criteria'].append( 142 | Stream(callback_func=callback)) 143 | with torch.no_grad(): 144 | model.generate(**kwargs) 145 | 146 | def generate_with_streaming(**kwargs): 147 | return Iteratorize(generate_with_callback, 148 | kwargs, 149 | callback=None) 150 | 151 | with generate_with_streaming(**generate_params) as generator: 152 | for output in generator: 153 | # new_tokens = len(output) - len(input_ids[0]) 154 | decoded_output = tokenizer.decode(output) 155 | 156 | if output[-1] in [tokenizer.eos_token_id]: 157 | break 158 | 159 | yield prompter.get_response(decoded_output) 160 | return # early return for stream_output 161 | 162 | # Without streaming 163 | with torch.no_grad(): 164 | generation_output = model.generate( 165 | **inputs, 166 | generation_config=generation_config, 167 | return_dict_in_generate=True, 168 | output_scores=True, 169 | max_new_tokens=max_new_tokens, 170 | ) 171 | s = generation_output.sequences[0] 172 | output = tokenizer.decode(s) 173 | yield prompter.get_response(output) 174 | 175 | description = 'Baichuan7B is a 7B-parameter LLaMA model finetuned to follow instructions.' 176 | server = gr.Interface( 177 | fn=evaluate, 178 | inputs=[ 179 | gr.components.Textbox(lines=2, 180 | label='Instruction', 181 | placeholder='Tell me about alpacas.'), 182 | gr.components.Textbox(lines=2, label='Input', placeholder='none'), 183 | gr.components.Slider(minimum=0, 184 | maximum=1, 185 | value=1.0, 186 | label='Temperature'), 187 | gr.components.Slider(minimum=0, 188 | maximum=1, 189 | value=1.0, 190 | label='Top p'), 191 | gr.components.Slider(minimum=0, 192 | maximum=100, 193 | step=1, 194 | value=50, 195 | label='Top k'), 196 | gr.components.Slider(minimum=1, 197 | maximum=4, 198 | step=1, 199 | value=4, 200 | label='Beams'), 201 | gr.components.Slider(minimum=16, 202 | maximum=1024, 203 | step=32, 204 | value=128, 205 | label='Max new tokens'), 206 | gr.components.Checkbox(label='Stream output'), 207 | ], 208 | outputs=[gr.inputs.Textbox( 209 | lines=5, 210 | label='Output', 211 | )], 212 | title='Baichuan7B', 213 | description=description, 214 | ) 215 | 216 | server.queue().launch(server_name='0.0.0.0', share=True) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /llamatuner/utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import threading 5 | from logging import Formatter, LogRecord 6 | from logging.handlers import RotatingFileHandler 7 | from pathlib import Path 8 | from typing import Optional, Union 9 | 10 | import torch.distributed as dist 11 | from colorama import Fore, Style 12 | 13 | logger_initialized: dict = {} 14 | 15 | 16 | class ColorfulFormatter(Formatter): 17 | """Formatter that adds ANSI color codes to log messages based on their 18 | level. 19 | 20 | Attributes: 21 | COLORS: Dictionary mapping log levels to their corresponding color codes 22 | 23 | Example: 24 | >>> formatter = ColorfulFormatter('%(levelname)s: %(message)s') 25 | >>> handler = logging.StreamHandler() 26 | >>> handler.setFormatter(formatter) 27 | """ 28 | 29 | COLORS: dict[str, str] = { 30 | 'INFO': Fore.GREEN, 31 | 'WARNING': Fore.YELLOW, 32 | 'ERROR': Fore.RED, 33 | 'CRITICAL': Fore.RED + Style.BRIGHT, 34 | 'DEBUG': Fore.LIGHTGREEN_EX, 35 | } 36 | 37 | def format(self, record: LogRecord) -> str: 38 | """Format the log record with color coding. 39 | 40 | Args: 41 | record: The log record to format 42 | 43 | Returns: 44 | The formatted and color-coded log message 45 | """ 46 | record.rank = int(os.getenv('LOCAL_RANK', '0')) 47 | log_message = super().format(record) 48 | color = self.COLORS.get(record.levelname, Fore.RESET) 49 | return f'{color}{log_message}{Fore.RESET}' 50 | 51 | 52 | def get_logger( 53 | name: str, 54 | log_file: Optional[Union[str, Path]] = None, 55 | log_level: int = logging.INFO, 56 | file_mode: str = 'w', 57 | ) -> logging.Logger: 58 | """Initialize and get a logger by name with optional file output. 59 | 60 | This function creates or retrieves a logger with the specified configuration. 61 | It handles distributed training scenarios by managing log levels across different 62 | process ranks and prevents duplicate logging issues with PyTorch DDP. 63 | 64 | Args: 65 | name: Logger name for identification and hierarchy 66 | log_file: Path to the log file. If provided, logs will also be written to this file 67 | (only for rank 0 process in distributed training) 68 | log_level: Logging level (e.g., logging.INFO, logging.DEBUG) 69 | Note: Only rank 0 process uses this level; others use ERROR level 70 | file_mode: File opening mode ('w' for write, 'a' for append) 71 | 72 | Returns: 73 | A configured logging.Logger instance 74 | 75 | Example: 76 | >>> logger = get_logger("my_model", "training.log", logging.DEBUG) 77 | >>> logger.info("Training started") 78 | """ 79 | if file_mode not in ('w', 'a'): 80 | raise ValueError(f"Invalid file_mode: {file_mode}. Use 'w' or 'a'.") 81 | 82 | with threading.Lock(): 83 | # Get or create logger instance 84 | logger = logging.getLogger(name) 85 | 86 | # Return existing logger if already initialized 87 | if name in logger_initialized: 88 | return logger 89 | 90 | # Check if parent logger is already initialized 91 | for logger_name in logger_initialized: 92 | if name.startswith(logger_name): 93 | return logger 94 | 95 | # Fix PyTorch DDP duplicate logging issue 96 | # Set root StreamHandler to ERROR level to prevent unwanted output from rank>0 processes 97 | for handler in logger.root.handlers: 98 | if isinstance(handler, logging.StreamHandler): 99 | handler.setLevel(logging.ERROR) 100 | 101 | # Initialize handlers list with stdout StreamHandler 102 | handlers = [logging.StreamHandler(sys.stdout)] 103 | 104 | # Determine process rank for distributed setup 105 | try: 106 | rank = dist.get_rank() if (dist.is_available() 107 | and dist.is_initialized()) else 0 108 | except Exception: 109 | rank = 0 110 | 111 | # Add FileHandler for rank 0 process if log_file is specified 112 | if rank == 0 and log_file is not None: 113 | log_file = Path(log_file) 114 | log_file.parent.mkdir(parents=True, exist_ok=True) 115 | file_handler = RotatingFileHandler( 116 | filename=str(log_file), 117 | mode=file_mode, 118 | maxBytes=10 * 1024 * 1024, # 10 MB 119 | backupCount=5, 120 | encoding='utf-8', 121 | ) 122 | file_handler.setLevel(log_level) 123 | handlers.append(file_handler) 124 | 125 | # Configure formatter and handlers 126 | formatter = ColorfulFormatter( 127 | fmt=('%(asctime)s - [%(filename)s.%(funcName)s:%(lineno)d]- ' 128 | '%(levelname)s - %(message)s'), 129 | datefmt='%Y-%m-%d %H:%M:%S', 130 | ) 131 | 132 | # Inject rank into all log records 133 | old_factory = logging.getLogRecordFactory() 134 | 135 | def record_factory(*args, **kwargs): 136 | record = old_factory(*args, **kwargs) 137 | record.rank = rank # Dynamic rank injection 138 | return record 139 | 140 | logging.setLogRecordFactory(record_factory) 141 | 142 | # Apply configuration to all handlers 143 | for handler in handlers: 144 | handler.setFormatter(formatter) 145 | logger.addHandler(handler) 146 | 147 | # Set logger level based on rank 148 | logger.setLevel(log_level if rank == 0 else logging.ERROR) 149 | logger.propagate = False # Prevent propagation to root logger 150 | 151 | # Mark logger as initialized 152 | logger_initialized[name] = True 153 | 154 | return logger 155 | 156 | 157 | def print_log(msg, logger=None, level=logging.INFO): 158 | """Print a log message. 159 | 160 | Args: 161 | msg (str): The message to be logged. 162 | logger (logging.Logger | str | None): The logger to be used. 163 | Some special loggers are: 164 | 165 | - "silent": no message will be printed. 166 | - other str: the logger obtained with `get_root_logger(logger)`. 167 | - None: The `print()` method will be used to print log messages. 168 | level (int): Logging level. Only available when `logger` is a Logger 169 | object or "root". 170 | """ 171 | if logger is None: 172 | print(msg) 173 | elif isinstance(logger, logging.Logger): 174 | logger.log(level, msg) 175 | elif logger == 'silent': 176 | pass 177 | elif isinstance(logger, str): 178 | _logger = get_logger(logger) 179 | _logger.log(level, msg) 180 | else: 181 | raise TypeError( 182 | 'logger should be either a logging.Logger object, str, ' 183 | f'"silent" or None, but got {type(logger)}') 184 | 185 | 186 | def get_root_logger(log_file=None, log_level=logging.INFO): 187 | """Get root logger. 188 | 189 | Args: 190 | log_file (str, optional): File path of log. Defaults to None. 191 | log_level (int, optional): The level of logger. 192 | Defaults to logging.INFO. 193 | 194 | Returns: 195 | :obj:`logging.Logger`: The obtained logger 196 | """ 197 | logger = get_logger(name='llamatuner', 198 | log_file=log_file, 199 | log_level=log_level) 200 | 201 | return logger 202 | 203 | 204 | def get_outdir(path: str, *paths, inc: bool = False) -> str: 205 | """Get the output directory. If the directory does not exist, it will be 206 | created. If `inc` is True, the directory will be incremented if the 207 | directory already exists. 208 | 209 | Args: 210 | path (str): The root path. 211 | *paths: The subdirectories. 212 | inc (bool, optional): Whether to increment the directory. Defaults to False. 213 | 214 | Returns: 215 | str: The output directory. 216 | """ 217 | outdir = os.path.join(path, *paths) 218 | if not os.path.exists(outdir): 219 | os.makedirs(outdir) 220 | return outdir 221 | elif inc: 222 | for count in range(1, 100): 223 | outdir_inc = f'{outdir}-{count}' 224 | if not os.path.exists(outdir_inc): 225 | os.makedirs(outdir_inc) 226 | return outdir_inc 227 | raise RuntimeError( 228 | 'Failed to create unique output directory after 100 attempts') 229 | return outdir 230 | 231 | 232 | if __name__ == '__main__': 233 | # Initialize logger 234 | logger = get_logger('my_model', 'training.log', logging.DEBUG) 235 | 236 | # Log messages 237 | logger.debug('This is a debug message.') 238 | logger.info('This is an info message.') 239 | logger.warning('This is a warning message.') 240 | logger.error('This is an error message.') 241 | logger.critical('This is a critical message.') 242 | --------------------------------------------------------------------------------