├── src ├── lmflow │ ├── utils │ │ ├── __init__.py │ │ ├── flash_attention │ │ │ ├── __init__.py │ │ │ ├── gpt_neo_flash_attention.py │ │ │ ├── bloom_flash_attention.py │ │ │ ├── llama_flash_attention.py │ │ │ └── gpt2_flash_attention.py │ │ ├── position_interpolation │ │ │ ├── __init__.py │ │ │ └── llama_rope_scaled_monkey_patch.py │ │ ├── constants.py │ │ └── data_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── interfaces │ │ │ ├── __init__.py │ │ │ └── tunable.py │ │ ├── base_model.py │ │ ├── regression_model.py │ │ ├── decoder_model.py │ │ ├── encoder_decoder_model.py │ │ ├── auto_model.py │ │ ├── text_regression_model.py │ │ └── vision2seq_model.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── peft_trainer.py │ │ │ └── continual_trainer.py │ │ ├── base_pipeline.py │ │ ├── base_tuner.py │ │ ├── base_aligner.py │ │ ├── auto_pipeline.py │ │ └── inferencer.py │ ├── version.py │ ├── .DS_Store │ ├── datasets │ │ └── __init__.py │ └── __init__.py └── .DS_Store ├── .DS_Store ├── scripts ├── .DS_Store ├── eval_pairrm │ ├── run_pairrm.sh │ ├── run_eval_rate_pairrm.py │ └── text_generation.py ├── eval_drop_squad_wmt │ ├── run-evaluation-wmt.sh │ ├── run-evaluation-squad.sh │ └── run-evaluation-drop.sh ├── eval_raft │ ├── run_eval_raft_align.sh │ ├── infer_get_samples.sh │ └── infer_get_rewards.sh ├── data_preprocess │ ├── tool.py │ ├── run_data_preprocess.sh │ ├── count.py │ ├── shuffle.py │ ├── sample.py │ ├── add_prompt.py │ ├── concat.py │ ├── add_end_mark.py │ ├── merge.py │ └── concat_shuffle_split.py ├── eval_cs_qa │ └── run_evaluation.sh ├── postprocess │ ├── weight_interpolation.sh │ ├── weight_mask_merge.sh │ ├── weight_interpolation.py │ └── weight_mask_merge.py └── mask_post │ └── run_mask_finetune_raft.sh ├── eval_log └── .DS_Store ├── install.sh ├── configs ├── ds_config.json ├── ds_config_eval.json ├── ds_config_chatbot.json ├── ds_config_multimodal.json ├── accelerator_singlegpu_config.yaml ├── accelerator_multigpu_config.yaml ├── ds_config_zero1.json ├── ds_config_zero3_for_eval.json ├── ds_config_zero2.json └── ds_config_zero3.json ├── requirements.txt ├── utils ├── convert_minigpt4_checkpoints.py ├── convert_json_to_txt.py ├── train_tokenizer.py ├── make_delta.py ├── lm_evaluator.py ├── merge_tokenizer.py └── apply_delta.py ├── setup.py ├── examples ├── finetune.py ├── sample_data.py ├── merge_lora.py ├── finetune_mask.py ├── raft_align.py └── raft_align_eval.py ├── output_models └── download.sh ├── data └── download.sh └── README.md /src/lmflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lmflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lmflow/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lmflow/pipeline/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lmflow/models/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lmflow/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" -------------------------------------------------------------------------------- /src/lmflow/utils/flash_attention/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lmflow/utils/position_interpolation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/Mitigating-the-Alignment-Tax-of-RLHF/HEAD/.DS_Store -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/Mitigating-the-Alignment-Tax-of-RLHF/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/Mitigating-the-Alignment-Tax-of-RLHF/HEAD/scripts/.DS_Store -------------------------------------------------------------------------------- /eval_log/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/Mitigating-the-Alignment-Tax-of-RLHF/HEAD/eval_log/.DS_Store -------------------------------------------------------------------------------- /src/lmflow/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/Mitigating-the-Alignment-Tax-of-RLHF/HEAD/src/lmflow/.DS_Store -------------------------------------------------------------------------------- /src/lmflow/models/interfaces/tunable.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """Tunable class 4 | """ 5 | 6 | from abc import ABC 7 | 8 | 9 | class Tunable(ABC): 10 | pass 11 | -------------------------------------------------------------------------------- /src/lmflow/pipeline/base_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ BasePipeline. 4 | """ 5 | 6 | from abc import ABC # abstract class 7 | 8 | class BasePipeline(ABC): 9 | pass 10 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install -e . 4 | 5 | # gpu_state="$(nvidia-smi --query-gpu=name --format=csv,noheader)" 6 | # if [[ "${gpu_state}" == *"A100"* || "${gpu_state}" == *"A40"* ]]; then 7 | # pip install flash-attn==2.0.2 8 | # fi -------------------------------------------------------------------------------- /src/lmflow/models/base_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """Base model class. 4 | """ 5 | 6 | from abc import ABC 7 | 8 | 9 | class BaseModel(ABC): 10 | 11 | def __init__(self, *args, **kwargs): 12 | pass 13 | -------------------------------------------------------------------------------- /configs/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false 4 | }, 5 | "bf16": { 6 | "enabled": true 7 | }, 8 | "steps_per_print": 2000, 9 | "train_micro_batch_size_per_gpu": 1, 10 | "wall_clock_breakdown": false 11 | } 12 | -------------------------------------------------------------------------------- /configs/ds_config_eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false 4 | }, 5 | "bf16": { 6 | "enabled": false 7 | }, 8 | "steps_per_print": 2000, 9 | "train_micro_batch_size_per_gpu": 1, 10 | "wall_clock_breakdown": false 11 | } 12 | -------------------------------------------------------------------------------- /src/lmflow/models/regression_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """General regression model.""" 4 | 5 | from lmflow.models.base_model import BaseModel 6 | 7 | 8 | class RegressionModel(BaseModel): 9 | 10 | def __init__(self, *args, **kwargs): 11 | pass 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.3 2 | datasets 3 | peft 4 | torch==2.1.2 5 | wandb==0.14.0 6 | deepspeed==0.13.1 7 | sentencepiece 8 | transformers 9 | flask 10 | flask_cors 11 | icetk 12 | cpm_kernels 13 | evaluate 14 | scikit-learn 15 | lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@e47e01beea79cfe87421e2dac49e64d499c240b4 16 | dill<0.3.5 17 | bitsandbytes==0.38.1 18 | pydantic<=1.10.9 19 | gradio 20 | -------------------------------------------------------------------------------- /src/lmflow/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """This Python code defines a class Dataset with methods for initializing, loading, 2 | and manipulating datasets from different backends such as Hugging Face and JSON. 3 | 4 | The `Dataset` class includes methods for loading datasets from a dictionary and a Hugging 5 | Face dataset, mapping datasets, and retrieving the backend dataset and arguments. 6 | """ 7 | from lmflow.datasets.dataset import Dataset 8 | -------------------------------------------------------------------------------- /configs/ds_config_chatbot.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false 4 | }, 5 | "bf16": { 6 | "enabled": true 7 | }, 8 | "comms_logger": { 9 | "enabled": false, 10 | "verbose": false, 11 | "prof_all": false, 12 | "debug": false 13 | }, 14 | "steps_per_print": 20000000000000000, 15 | "train_micro_batch_size_per_gpu": 1, 16 | "wall_clock_breakdown": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/ds_config_multimodal.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false 4 | }, 5 | "bf16": { 6 | "enabled": false 7 | }, 8 | "comms_logger": { 9 | "enabled": false, 10 | "verbose": false, 11 | "prof_all": false, 12 | "debug": false 13 | }, 14 | "steps_per_print": 20000000000000000, 15 | "train_micro_batch_size_per_gpu": 1, 16 | "wall_clock_breakdown": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/accelerator_singlegpu_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: 'NO' 3 | downcast_bf16: 'no' 4 | dynamo_config: 5 | dynamo_backend: INDUCTOR 6 | gpu_ids: 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 1 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /configs/accelerator_multigpu_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | downcast_bf16: 'no' 4 | dynamo_config: 5 | dynamo_backend: INDUCTOR 6 | gpu_ids: 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 2 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | main_process_port: 11000 19 | -------------------------------------------------------------------------------- /scripts/eval_pairrm/run_pairrm.sh: -------------------------------------------------------------------------------- 1 | # bash scripts/eval_pairrm/run_text_generation.sh 2 | model_path="path/to/model" 3 | save_path="path/to/save" #should be consistent with the path in run_eval_rate_pairrm.py 4 | gpu="0" 5 | mkdir -p ${save_path} 6 | CUDA_VISIBLE_DEVICES=${gpu} python scripts/eval_pairrm/text_generation.py \ 7 | --model_path ${model_path} \ 8 | --gpu ${gpu} \ 9 | --save_path ${save_path} 10 | 11 | python run_eval_rate_pairrm.py -------------------------------------------------------------------------------- /src/lmflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ as internal_version 2 | 3 | __version__ = internal_version 4 | 5 | from transformers.utils import check_min_version 6 | from transformers.utils.versions import require_version 7 | 8 | from lmflow import args, datasets, models, pipeline, utils 9 | 10 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 11 | check_min_version("4.27.0.dev0") 12 | 13 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") -------------------------------------------------------------------------------- /src/lmflow/pipeline/base_tuner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ BaseTuner: a subclass of BasePipeline. 4 | """ 5 | 6 | from lmflow.pipeline.base_pipeline import BasePipeline 7 | 8 | 9 | class BaseTuner(BasePipeline): 10 | """ A subclass of BasePipeline which is tunable. 11 | """ 12 | def __init__(self, *args, **kwargs): 13 | pass 14 | 15 | def _check_if_tunable(self, model, dataset): 16 | # TODO: check if the model is tunable and dataset is compatible 17 | pass 18 | 19 | def tune(self, model, dataset): 20 | raise NotImplementedError(".tune is not implemented") 21 | -------------------------------------------------------------------------------- /src/lmflow/models/decoder_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """A one-line summary of the module or program, terminated by a period. 4 | 5 | Leave one blank line. The rest of this docstring should contain an 6 | overall description of the module or program. Optionally, it may also 7 | contain a brief description of exported classes and functions and/or usage 8 | examples. 9 | 10 | Typical usage example: 11 | 12 | foo = ClassFoo() 13 | bar = foo.FunctionBar() 14 | """ 15 | 16 | from lmflow.models.base_model import BaseModel 17 | 18 | 19 | class DecoderModel(BaseModel): 20 | 21 | def __init__(self, *args, **kwargs): 22 | pass 23 | -------------------------------------------------------------------------------- /src/lmflow/models/encoder_decoder_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """A one-line summary of the module or program, terminated by a period. 4 | 5 | Leave one blank line. The rest of this docstring should contain an 6 | overall description of the module or program. Optionally, it may also 7 | contain a brief description of exported classes and functions and/or usage 8 | examples. 9 | 10 | Typical usage example: 11 | 12 | foo = ClassFoo() 13 | bar = foo.FunctionBar() 14 | """ 15 | 16 | from lmflow.models.base_model import BaseModel 17 | 18 | 19 | class EncoderDecoderModel(BaseModel): 20 | 21 | def __init__(self, *args, **kwargs): 22 | pass -------------------------------------------------------------------------------- /src/lmflow/pipeline/base_aligner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ BaseTuner: a subclass of BasePipeline. 4 | """ 5 | 6 | from lmflow.pipeline.base_pipeline import BasePipeline 7 | 8 | 9 | class BaseAligner(BasePipeline): 10 | """ A subclass of BasePipeline which is alignable. 11 | """ 12 | def __init__(self, *args, **kwargs): 13 | pass 14 | 15 | def _check_if_alignable(self, model, dataset, reward_model): 16 | # TODO: check if the model is alignable and dataset is compatible 17 | # TODO: add reward_model 18 | pass 19 | 20 | def align(self, model, dataset, reward_model): 21 | raise NotImplementedError(".align is not implemented") 22 | -------------------------------------------------------------------------------- /scripts/eval_drop_squad_wmt/run-evaluation-wmt.sh: -------------------------------------------------------------------------------- 1 | # bash scripts/eval_drop_squad_wmt/run-mask-wmt.sh 2 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 3 | cd ../lmflow_benchmark 4 | model_dir="${project_dir}/output_models" 5 | log_dir="${project_dir}/eval_log" 6 | dataset="lm_eval_wmt14" 7 | 8 | model_tag="HuggingFaceH4/zephyr-7b-beta" 9 | sed -i -e 's/\"use_cache\":\ false/\"use_cache\":\ true/g' ${model_dir}/${model_tag}/config.json 10 | model_path=${model_dir}/${model_tag} 11 | log_path=${log_dir}/${model_tag}/${dataset} 12 | mkdir -p ${log_path} 13 | 14 | bash scripts/run_benchmark_port.sh "${dataset}" ${model_path} 60020 0 \ 15 | | tee -a ${log_path}/evaluation_final.log \ 16 | 2> ${log_path}/evaluation_final.err 17 | -------------------------------------------------------------------------------- /scripts/eval_drop_squad_wmt/run-evaluation-squad.sh: -------------------------------------------------------------------------------- 1 | # bash scripts/eval_drop_squad_wmt/run-mask-squad.sh 2 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 3 | cd ../lmflow_benchmark 4 | model_dir="${project_dir}/output_models" 5 | log_dir="${project_dir}/eval_log" 6 | dataset="lm_eval_squad2" 7 | 8 | model_tag="HuggingFaceH4/zephyr-7b-beta" 9 | sed -i -e 's/\"use_cache\":\ false/\"use_cache\":\ true/g' ${model_dir}/${model_tag}/config.json 10 | model_path=${model_dir}/${model_tag} 11 | log_path=${log_dir}/${model_tag}/${dataset} 12 | mkdir -p ${log_path} 13 | 14 | bash scripts/run_benchmark_port.sh "${dataset}" ${model_path} 60020 0 \ 15 | | tee -a ${log_path}/evaluation_final.log \ 16 | 2> ${log_path}/evaluation_final.err 17 | -------------------------------------------------------------------------------- /scripts/eval_raft/run_eval_raft_align.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # bash scripts/eval_raft/run_eval_raft_align.sh 3 | 4 | gpu_ids="0,1,2,3" 5 | port=11002 6 | 7 | prefix="./scripts/eval_raft" 8 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 9 | base_dir="${project_dir}/path/to/save/generated_text" 10 | mkdir -p $base_dir 11 | 12 | test_model="${project_dir}/path/to/test_model" 13 | reward_model="${project_dir}/path/to/reward_model" 14 | 15 | mkdir -p $base_dir/infer_set 16 | mkdir -p $base_dir/filtered_set 17 | 18 | bash ${prefix}/infer_get_samples.sh ${test_model} 0 ${base_dir}/infer_set ${gpu_ids} ${port} 19 | bash ${prefix}/infer_get_rewards.sh ${base_dir}/infer_set ${base_dir}/filtered_set ${base_dir} ${reward_model} ${gpu_ids} ${port} 20 | -------------------------------------------------------------------------------- /scripts/data_preprocess/tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | path = '/home/xiongwei/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/data/openasst/textonly/textonly_oasst1_data_33919.json' 4 | save_path = '/home/xiongwei/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/data/openasst/textonly_post/textonly_oasst1_data_33919.json' 5 | data = json.load(open(path)) 6 | 7 | print(len(data['instances']), data['type']) 8 | for i in range(len(data['instances'])): 9 | text = data['instances'][i]['text'] 10 | text = text.replace('\nHuman:', ' ###Human:') 11 | text = text.replace('\nAssistant:', ' ###Assistant:') 12 | text = '###' + text 13 | # print(text) 14 | data['instances'][i]['text'] = text 15 | # print(i, data['instances'][i]) 16 | json.dump(data, open(save_path, 'w')) -------------------------------------------------------------------------------- /scripts/eval_cs_qa/run_evaluation.sh: -------------------------------------------------------------------------------- 1 | # bash scripts/evaluation_cs_qa/run_evaluation.sh 2 | 3 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 4 | cd ${project_dir}/../lm-evaluation-harness 5 | log_dir=${project_dir}/path/to/save 6 | if [ ! -d "${log_dir}" ]; 7 | then 8 | mkdir -p ${log_dir} 9 | fi 10 | model_path=path/to/model 11 | eval_log_path=${log_dir}/result.json 12 | gpu_idx=0 13 | batch_size=16 14 | python main.py \ 15 | --model hf-causal-experimental \ 16 | --model_args pretrained="\"${model_path}\"" \ 17 | --tasks arc_easy,arc_challenge,race,boolq,piqa \ 18 | --output_path ${eval_log_path} \ 19 | --batch_size ${batch_size} \ 20 | --max_batch_size ${batch_size} \ 21 | --no_cache \ 22 | --device cuda:${gpu_idx} \ 23 | | tee ${log_dir}/evaluation.log \ 24 | 2> ${log_dir}/evaluation.err 25 | -------------------------------------------------------------------------------- /configs/ds_config_zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": "auto", 20 | "eps": "auto", 21 | "weight_decay": "auto" 22 | } 23 | }, 24 | 25 | "zero_optimization": { 26 | "stage": 1, 27 | "reduce_bucket_size": 5e8 28 | }, 29 | "gradient_accumulation_steps": "auto", 30 | "gradient_clipping": "auto", 31 | "steps_per_print": 2000, 32 | "train_batch_size": "auto", 33 | "train_micro_batch_size_per_gpu": "auto", 34 | "wall_clock_breakdown": false 35 | } 36 | -------------------------------------------------------------------------------- /configs/ds_config_zero3_for_eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "offload_optimizer": { 8 | "device": "cpu", 9 | "pin_memory": true 10 | }, 11 | "offload_param": { 12 | "device": "cpu", 13 | "pin_memory": true 14 | }, 15 | "overlap_comm": true, 16 | "contiguous_gradients": true, 17 | "sub_group_size": 1e9, 18 | "reduce_bucket_size": "auto", 19 | "stage3_prefetch_bucket_size": "auto", 20 | "stage3_param_persistence_threshold": "auto", 21 | "stage3_max_live_parameters": 1e9, 22 | "stage3_max_reuse_distance": 1e9, 23 | "stage3_gather_16bit_weights_on_model_save": true 24 | }, 25 | 26 | "steps_per_print": 2000, 27 | "train_micro_batch_size_per_gpu": 1, 28 | "wall_clock_breakdown": false 29 | } 30 | -------------------------------------------------------------------------------- /scripts/postprocess/weight_interpolation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_idx=0 4 | alpha=0.5 5 | model_path0=openlm-research/open_llama_3b 6 | model_path1=openlm-research/open_llama_3b 7 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 8 | 9 | weight_ensamble_names_paths="${model_path0} ${model_path1}" 10 | weight_ensamble_save_path=output_models/test/ma_${alpha}_tag0_tag1 11 | if [ ! -d "${weight_ensamble_save_path}" ]; 12 | then 13 | mkdir -p ${weight_ensamble_save_path} 14 | fi 15 | 16 | CUDA_VISIBLE_DEVICES=${gpu_idx} \ 17 | python \ 18 | scripts/postprocess/weight_interpolation.py \ 19 | --model_name_or_path openlm-research/open_llama_3b \ 20 | --weight_ensamble_names_paths ${weight_ensamble_names_paths} \ 21 | --weight_ensamble_ratios ${alpha} \ 22 | --torch_dtype bfloat16 \ 23 | --weight_ensamble_save_path "${weight_ensamble_save_path}" \ 24 | --dataset_path data \ 25 | --deepspeed configs/ds_config.json \ 26 | --inference_batch_size_per_device 1 \ 27 | --metric accuracy 28 | -------------------------------------------------------------------------------- /scripts/eval_drop_squad_wmt/run-evaluation-drop.sh: -------------------------------------------------------------------------------- 1 | # bash scripts/eval_drop_squad_wmt/run-mask-drop.sh 2 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 3 | cd ../opencompass 4 | model_dir="${project_dir}/output_models" 5 | log_dir="${project_dir}/eval_log" 6 | dataset="lm_eval_drop" 7 | 8 | model_tag="HuggingFaceH4/zephyr-7b-beta" 9 | sed -i -e 's/\"use_cache\":\ false/\"use_cache\":\ true/g' ${model_dir}/${model_tag}/config.json 10 | model_path=${model_dir}/${model_tag} 11 | log_path=${log_dir}/${model_tag}/${dataset} 12 | mkdir -p ${log_path} 13 | 14 | CUDA_VISIBLE_DEVICES=0 python run.py --datasets drop_gen \ 15 | --hf-path ${model_path} \ 16 | --model-kwargs device_map='auto' \ 17 | --max-out-len 100 \ 18 | --max-seq-len 2048 \ 19 | --batch-size 6 \ 20 | --num-gpus 1 \ 21 | | tee -a ${log_path}/evaluation_final.log \ 22 | 2> ${log_path}/evaluation_final.err 23 | -------------------------------------------------------------------------------- /src/lmflow/models/auto_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """Automatically get correct model type. 4 | """ 5 | 6 | from lmflow.models.hf_decoder_model import HFDecoderModel 7 | from lmflow.models.text_regression_model import TextRegressionModel 8 | from lmflow.models.hf_encoder_decoder_model import HFEncoderDecoderModel 9 | 10 | class AutoModel: 11 | 12 | @classmethod 13 | def get_model(self, model_args, *args, **kwargs): 14 | print(model_args) 15 | arch_type = model_args.arch_type 16 | if arch_type == "decoder_only": 17 | return HFDecoderModel(model_args, *args, **kwargs) 18 | elif arch_type == "text_regression": 19 | return TextRegressionModel(model_args, *args, **kwargs) 20 | elif arch_type == "encoder_decoder" or \ 21 | arch_type == "vision_encoder_decoder": 22 | return HFEncoderDecoderModel(model_args, *args, **kwargs) 23 | else: 24 | raise NotImplementedError( 25 | f"model architecture type \"{arch_type}\" is not supported" 26 | ) 27 | -------------------------------------------------------------------------------- /scripts/postprocess/weight_mask_merge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gpu_idx=0 3 | master_port=11000 4 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 5 | model_path0=${project_dir}/path/to/before/rlhf 6 | model_path1=${project_dir}/path/to/after/rlhf 7 | 8 | alphas_path=${project_dir}/path/to/mask_alpha.bin 9 | weight_ensamble_names_paths="${model_path0} ${model_path1}" 10 | weight_ensamble_save_path=${project_dir}/path/to/save 11 | 12 | if [ ! -d "${weight_ensamble_save_path}" ]; 13 | then 14 | mkdir -p ${weight_ensamble_save_path} 15 | fi 16 | 17 | deepspeed_args="--master_port=${master_port} --include localhost:${gpu_idx}" 18 | deepspeed ${deepspeed_args} \ 19 | scripts/llama3b/postprocess/weight_mask_merge.py \ 20 | --model_name_or_path openlm-research/open_llama_3b \ 21 | --alphas_path ${alphas_path} \ 22 | --weight_ensamble_names_paths ${weight_ensamble_names_paths} \ 23 | --weight_ensamble_ratios 0.0 \ 24 | --weight_ensamble_save_path "${weight_ensamble_save_path}" \ 25 | --dataset_path data \ 26 | --deepspeed configs/ds_config.json \ 27 | --inference_batch_size_per_device 1 \ 28 | --metric accuracy -------------------------------------------------------------------------------- /configs/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": "auto", 20 | "eps": "auto", 21 | "weight_decay": "auto" 22 | } 23 | }, 24 | 25 | "zero_optimization": { 26 | "stage": 2, 27 | "offload_optimizer": { 28 | "device": "cpu", 29 | "pin_memory": true 30 | }, 31 | "allgather_partitions": true, 32 | "allgather_bucket_size": 2e8, 33 | "overlap_comm": true, 34 | "reduce_scatter": true, 35 | "reduce_bucket_size": 2e8, 36 | "contiguous_gradients": true 37 | }, 38 | 39 | "gradient_accumulation_steps": "auto", 40 | "gradient_clipping": "auto", 41 | "steps_per_print": 2000, 42 | "train_batch_size": "auto", 43 | "train_micro_batch_size_per_gpu": "auto", 44 | "wall_clock_breakdown": false 45 | } 46 | -------------------------------------------------------------------------------- /utils/convert_minigpt4_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import torch 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="Convert checkpoint from MiniGPT4") 7 | parser.add_argument("--model_path", type=str, help="the model path for the to convert checkpoint") 8 | parser.add_argument("--save_path", default=None, type=str, help="the save path for converted checkpoint") 9 | args = parser.parse_args() 10 | return args 11 | 12 | 13 | 14 | 15 | 16 | if __name__ == "__main__": 17 | args = parse_args() 18 | model = torch.load(args.model_path) 19 | model = model['model'] 20 | new_model = {} 21 | for key, item in model.items(): 22 | key = key.replace("Qformer", "qformer") 23 | key = key.replace("llama_proj", "language_projection") 24 | key = key.replace("llama_model.model", "language_model.model") 25 | new_model[key] = item 26 | if args.save_path is None: 27 | end_string = osp.splitext(args.model_path) 28 | save_path = osp.dirname(args.model_path) + "/" + \ 29 | osp.basename(args.model_path).replace(".pth", "") + \ 30 | "-converted" + osp.splitext(args.model_path)[-1] 31 | else: 32 | save_path = args.save_path 33 | print("save_path: {}".format(save_path)) 34 | 35 | torch.save(new_model, save_path) 36 | -------------------------------------------------------------------------------- /utils/convert_json_to_txt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import logging 6 | 7 | import json 8 | from pathlib import Path 9 | 10 | logging.basicConfig(level=logging.WARNING) 11 | 12 | if __name__ == '__main__': 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataset_path', default='./data/wiki_zh_eval', type=str, required=False) 16 | parser.add_argument('--output_path', default='./data/wiki_zh_eval/converted_data.txt', type=str, required=False) 17 | parser.add_argument('--overwrite', default=False, type=bool, required=False) 18 | args = parser.parse_args() 19 | 20 | dataset_path = args.dataset_path 21 | outputfile = args.output_path 22 | 23 | outputs_list = [] 24 | data_files = [ 25 | x.absolute().as_posix() 26 | for x in Path(dataset_path).glob("*.json") 27 | ] 28 | 29 | for file_name in data_files: 30 | with open(file_name) as fin: 31 | json_data = json.load(fin) 32 | type = json_data["type"] 33 | for line in json_data["instances"]: 34 | outputs_list.append(line["text"]) 35 | 36 | 37 | if Path(outputfile).exists() and not args.overwrite: 38 | logging.warning(f"File %s exists, will not overwrite.", outputfile) 39 | else: 40 | with open(outputfile, "w") as f: 41 | for line in outputs_list: 42 | f.write(line) 43 | 44 | -------------------------------------------------------------------------------- /configs/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": "auto", 20 | "eps": "auto", 21 | "weight_decay": "auto" 22 | } 23 | }, 24 | 25 | "zero_optimization": { 26 | "stage": 3, 27 | "offload_optimizer": { 28 | "device": "cpu", 29 | "pin_memory": true 30 | }, 31 | "offload_param": { 32 | "device": "cpu", 33 | "pin_memory": true 34 | }, 35 | "overlap_comm": true, 36 | "contiguous_gradients": true, 37 | "sub_group_size": 1e9, 38 | "reduce_bucket_size": "auto", 39 | "stage3_prefetch_bucket_size": "auto", 40 | "stage3_param_persistence_threshold": "auto", 41 | "stage3_max_live_parameters": 1e9, 42 | "stage3_max_reuse_distance": 1e9, 43 | "stage3_gather_16bit_weights_on_model_save": true 44 | }, 45 | 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 2000, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } 53 | -------------------------------------------------------------------------------- /src/lmflow/models/text_regression_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | A model maps "text_only" data to float. 5 | """ 6 | 7 | from lmflow.models.regression_model import RegressionModel 8 | from lmflow.datasets.dataset import Dataset 9 | 10 | 11 | class TextRegressionModel(RegressionModel): 12 | r""" 13 | Initializes a TextRegressionModel instance. 14 | 15 | Parameters 16 | ------------ 17 | 18 | model_args : 19 | Model arguments such as model name, path, revision, etc. 20 | 21 | args : Optional. 22 | Positional arguments. 23 | 24 | kwargs : Optional. 25 | Keyword arguments. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model_args, 31 | *args, 32 | **kwargs 33 | ): 34 | """ 35 | Initializes a TextRegressionModel instance. 36 | :param model_args: dictionary with model arguments such as model name, path, revision, etc. 37 | """ 38 | self.inference_func = None 39 | 40 | 41 | def register_inference_function(self, inference_func): 42 | """ 43 | Registers a regression function. 44 | """ 45 | self.inference_func = inference_func 46 | 47 | 48 | def inference(self, inputs: Dataset): 49 | """ 50 | Gets regression results of a given dataset. 51 | 52 | :inputs: Dataset object, only accept type "text_only". 53 | """ 54 | if self.inference_func is not None: 55 | return self.inference_func(inputs) 56 | else: 57 | pass 58 | -------------------------------------------------------------------------------- /scripts/data_preprocess/run_data_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Run this shell script under project directory 3 | 4 | # For sample.py 5 | python scripts/data_preprocess/sample.py \ 6 | --dataset_path ./data/example_dataset/train/train_50.json \ 7 | --output_path ./data/example_dataset/train/train_50_sample.json \ 8 | --ratio 0.5 9 | 10 | # For shuffle.py 11 | python scripts/data_preprocess/shuffle.py \ 12 | --dataset_path ./data/example_dataset/train/train_50_sample.json \ 13 | --output_path ./data/example_dataset/train/train_50_sample_shuffle.json 14 | 15 | # For merge.py : you can specify multiple files to merge 16 | python scripts/data_preprocess/merge.py \ 17 | --dataset_path ./data/example_dataset/train/train_50.json \ 18 | --merge_from_path ./data/example_dataset/train/train_50_sample_shuffle.json \ 19 | ./data/example_dataset/train/train_50_sample.json \ 20 | --output_path ./data/example_dataset/train/train_merge.json \ 21 | 22 | # For concat.py: if you simply want to merge multiple files or a directory, use following. 23 | # You can also specify multiple files after --merge_from_path 24 | python scripts/data_preprocess/concat.py \ 25 | --merge_from_path ./data/example_dataset/train/*.json \ 26 | --output_path ./data/example_dataset/train/train_merge.json \ 27 | 28 | # For concat_shuffle_split.py: if you simply want to merge multiple files or a directory, use following. 29 | python scripts/data_preprocess/concat_shuffle_split.py \ 30 | --merge_from_path ./data/example_dataset/train/*.json \ 31 | --output_path ./data/processed_dataset/ \ -------------------------------------------------------------------------------- /utils/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import os 6 | import sentencepiece as spm 7 | 8 | if __name__ == '__main__': 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dataset_path', default='./data/wiki_zh_eval/converted_data.txt', type=str, required=False) 12 | parser.add_argument('--output_dir', default='./output_models/new_tokenizer', type=str, required=False) 13 | parser.add_argument('--vocab_size', default=20000, type=int, required=False) 14 | parser.add_argument('--model_type', default='bpe', type=str, required=False) 15 | parser.add_argument('--user_defined_symbols', default='0,1,2,3,4,5,6,7,8,9,%', type=str, required=False) 16 | parser.add_argument('--max_sentencepiece_length', default=4, type=int, required=False) 17 | args = parser.parse_args() 18 | 19 | dataset_path = args.dataset_path 20 | output_dir = args.output_dir 21 | vocab_size = args.vocab_size 22 | model_type = args.model_type 23 | user_defined_symbols = args.user_defined_symbols 24 | max_sentencepiece_length=args.max_sentencepiece_length 25 | 26 | def mkdir(path): 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | mkdir(output_dir) 30 | 31 | spm.SentencePieceTrainer.train( 32 | f'--input={dataset_path}' 33 | f' --model_prefix={output_dir}/example' 34 | f' --model_type={model_type}' 35 | f' --vocab_size={vocab_size}' 36 | f' --user_defined_symbols={user_defined_symbols}' 37 | f' --max_sentencepiece_length={max_sentencepiece_length}' 38 | f' --minloglevel=1' 39 | ) -------------------------------------------------------------------------------- /src/lmflow/pipeline/auto_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """Return a pipeline automatically based on its name. 4 | """ 5 | 6 | from lmflow.pipeline.evaluator import Evaluator 7 | from lmflow.pipeline.finetuner import Finetuner 8 | from lmflow.pipeline.inferencer import Inferencer 9 | from lmflow.pipeline.raft_aligner import RaftAligner 10 | from lmflow.pipeline.raft_aligner_eval import RaftAligner as RaftAlignerEval 11 | from lmflow.pipeline.continual_finetuner import ContinualFinetuner 12 | from lmflow.pipeline.mask_finetuner import MaskFinetuner 13 | 14 | PIPELINE_MAPPING = { 15 | "evaluator": Evaluator, 16 | "continual_finetuner": ContinualFinetuner, 17 | "mask_finetuner":MaskFinetuner, 18 | "finetuner": Finetuner, 19 | "inferencer": Inferencer, 20 | "raft_aligner": RaftAligner, 21 | "raft_aligner_eval": RaftAlignerEval, 22 | } 23 | 24 | 25 | class AutoPipeline: 26 | """ 27 | The class designed to return a pipeline automatically based on its name. 28 | """ 29 | @classmethod 30 | def get_pipeline(self, 31 | pipeline_name, 32 | model_args, 33 | data_args, 34 | pipeline_args, 35 | *args, 36 | **kwargs 37 | ): 38 | if pipeline_name not in PIPELINE_MAPPING: 39 | raise NotImplementedError( 40 | f'Pipeline "{pipeline_name}" is not supported' 41 | ) 42 | 43 | pipeline = PIPELINE_MAPPING[pipeline_name]( 44 | model_args, 45 | data_args, 46 | pipeline_args, 47 | *args, 48 | **kwargs 49 | ) 50 | return pipeline 51 | -------------------------------------------------------------------------------- /scripts/data_preprocess/count.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Counts number of instances in a dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import random 12 | import sys 13 | import textwrap 14 | 15 | def parse_argument(sys_argv): 16 | """Parses arguments from command line. 17 | Args: 18 | sys_argv: the list of arguments (strings) from command line. 19 | Returns: 20 | A struct whose member corresponds to the required (optional) variable. 21 | For example, 22 | ``` 23 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 24 | args.input # 'a.txt' 25 | args.num # 10 26 | ``` 27 | """ 28 | parser = argparse.ArgumentParser( 29 | formatter_class=argparse.RawTextHelpFormatter) 30 | 31 | # Training parameters 32 | parser.add_argument( 33 | "--dataset_path", type=str, 34 | default=None, 35 | help="input dataset path, reads from stdin by default" 36 | ) 37 | 38 | # Parses from commandline 39 | args = parser.parse_args(sys_argv[1:]) 40 | 41 | return args 42 | 43 | 44 | def main(): 45 | args = parse_argument(sys.argv) 46 | if args.dataset_path is not None: 47 | with open(args.dataset_path, "r") as fin: 48 | data_dict = json.load(fin) 49 | else: 50 | data_dict = json.load(sys.stdin) 51 | 52 | num_instances = len(data_dict["instances"]) 53 | print(num_instances) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /scripts/eval_raft/infer_get_samples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Please run this script under project directory. 3 | port=11233 4 | if [ $# -ge 5 ]; then 5 | port="$5" 6 | fi 7 | echo "port used for get samples:"${port} 8 | deepspeed_args="--master_port=${port} --include localhost:"$4 # Default argument 9 | 10 | exp_id=raft_infer_get_samples 11 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 12 | output_dir=${project_dir}/eval_log/${exp_id}/hh_rlhf 13 | log_dir=${project_dir}/log/${exp_id} 14 | 15 | mkdir -p ${output_dir} ${log_dir} 16 | 17 | # export PYTHONPATH=. 18 | deepspeed ${deepspeed_args} \ 19 | examples/raft_align_eval.py \ 20 | --model_name_or_path $1 \ 21 | --mode "raft_get_samples" \ 22 | --iter $2 \ 23 | --raft_infer_set $3 \ 24 | --dataset_path ${project_dir}/data/hh_rlhf/rlhf/rlhf_eval \ 25 | --raft_batch_size 99999 \ 26 | --output_min_length 127 \ 27 | --output_max_length 129 \ 28 | --output_temperature 1.0 \ 29 | --top_reward_percentage 1 \ 30 | --inference_batch_size_per_device 100 \ 31 | --learning_rate 1e-5 \ 32 | --lr_scheduler_type "constant" \ 33 | --bf16 \ 34 | --deepspeed configs/ds_config_zero2.json \ 35 | --output_reward_path ${project_dir}/tmp/eval_raft_aligner/reward.txt \ 36 | --output_dir ${output_dir} --overwrite_output_dir \ 37 | --run_name ${exp_id} \ 38 | --num_train_epochs 1 \ 39 | --per_device_train_batch_size 1 \ 40 | --per_device_eval_batch_size 1 \ 41 | --validation_split_percentage 0 \ 42 | --logging_steps 1 \ 43 | --do_train \ 44 | --ddp_timeout 72000 \ 45 | --save_steps 7777 \ 46 | --dataloader_num_workers 12 \ 47 | --preprocessing_num_workers 12 \ 48 | | tee ${log_dir}/raft_align.log \ 49 | 2> ${log_dir}/raft_align.err 50 | -------------------------------------------------------------------------------- /scripts/mask_post/run_mask_finetune_raft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # bash scripts/mask_post/run_mask_finetune_raft.sh 5 | # hyper-parameters 6 | approach=mask_norm_sigmoid_linear 7 | mask_level=layerwise 8 | lr=2e-5 9 | warp_init_val=0.2 10 | reg_alpha=1e-4 11 | sum_reg_type=0.2 12 | epsilon=0.2 13 | gpu_ids="0,1" 14 | 15 | port=29500 16 | 17 | exp_id=${approach}_${mask_level}_ft_raft_${lr}_${warp_init_val}_reg${reg_alpha}_sr${sum_reg_type}_eps${epsilon}_ep1 18 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 19 | log_dir=${project_dir}/log/${exp_id} 20 | model_path=${project_dir}/path/to/after/rlhf 21 | output_dir=${model_path}/mask_opt/${exp_id} 22 | dataset_path=${project_dir}/path/to/data/collected 23 | 24 | if [ ! -d "${output_dir}" ]; 25 | then 26 | mkdir -p ${output_dir} 27 | fi 28 | 29 | if [ ! -d "${log_dir}" ]; 30 | then 31 | mkdir -p ${log_dir} 32 | fi 33 | 34 | accelerate launch --main_process_port ${port} --gpu_ids=${gpu_ids} \ 35 | examples/finetune_mask.py \ 36 | --model_name_or_path ${model_path} \ 37 | --dataset_path ${dataset_path} \ 38 | --output_dir ${output_dir} --overwrite_output_dir \ 39 | --num_train_epochs 1 \ 40 | --learning_rate ${lr}\ 41 | --block_size 512 \ 42 | --per_device_train_batch_size 2 \ 43 | --bf16 \ 44 | --run_name ${exp_id} \ 45 | --validation_split_percentage 0 \ 46 | --logging_steps 20 \ 47 | --warp_init_val ${warp_init_val} \ 48 | --approach ${approach} \ 49 | --reg_alpha ${reg_alpha} \ 50 | --sum_reg_type ${sum_reg_type} \ 51 | --epsilon ${epsilon} \ 52 | --mask_level ${mask_level} \ 53 | --do_train \ 54 | --ddp_timeout 72000 \ 55 | --save_steps 1000 \ 56 | --dataloader_num_workers 4 \ 57 | | tee ${log_dir}/train.log \ 58 | 2> ${log_dir}/train.err 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | import subprocess 5 | 6 | folder = os.path.dirname(__file__) 7 | version_path = os.path.join(folder, "src", "lmflow", "version.py") 8 | 9 | __version__ = None 10 | with open(version_path) as f: 11 | exec(f.read(), globals()) 12 | 13 | req_path = os.path.join(folder, "requirements.txt") 14 | install_requires = [] 15 | if os.path.exists(req_path): 16 | with open(req_path) as fp: 17 | install_requires = [line.strip() for line in fp] 18 | 19 | readme_path = os.path.join(folder, "README.md") 20 | readme_contents = "" 21 | if os.path.exists(readme_path): 22 | with open(readme_path, encoding='utf-8') as fp: 23 | readme_contents = fp.read().strip() 24 | 25 | setup( 26 | name="lmflow", 27 | version=__version__, 28 | description="LMFlow: Large Model Flow.", 29 | author="The LMFlow Team", 30 | long_description=readme_contents, 31 | long_description_content_type="text/markdown", 32 | package_dir={"": "src"}, 33 | packages=find_packages("src"), 34 | package_data={}, 35 | install_requires=install_requires, 36 | classifiers=[ 37 | "Intended Audience :: Science/Research/Engineering", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | "Programming Language :: Python :: 3.9", 40 | "Programming Language :: Python :: 3.10", 41 | ], 42 | requires_python=">=3.9", 43 | ) 44 | 45 | # Must be called after all dependency installed, since flash-attn setup.py 46 | # relies on torch, packaging, etc. 47 | try: 48 | gpu_state = subprocess.check_output(["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"]) 49 | if b"A100" or b"A40" in gpu_state: 50 | subprocess.call(["pip", "install", "flash-attn==2.0.2"]) 51 | except: 52 | pass 53 | -------------------------------------------------------------------------------- /scripts/eval_raft/infer_get_rewards.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Please run this script under project directory. 3 | port=11033 4 | if [ $# -ge 6 ]; then 5 | port="$6" 6 | fi 7 | echo "port used for get rewards:"${port} 8 | deepspeed_args="--master_port=${port} --include localhost:"$5 # Default argument 9 | 10 | exp_id=test_infer_reward 11 | project_dir=$(cd "$(dirname $0)"/..; pwd)/../.. 12 | output_dir=${project_dir}/eval_log/${exp_id}/hh_rlhf 13 | log_dir=${project_dir}/log/${exp_id} 14 | 15 | mkdir -p ${output_dir} ${log_dir} 16 | # export PYTHONPATH=. 17 | # Just read a model 18 | deepspeed ${deepspeed_args} \ 19 | examples/raft_align_eval.py \ 20 | --model_name_or_path $4 \ 21 | --raft_infer_set $1 \ 22 | --raft_filtered_set $2 \ 23 | --raft_exp_dir $3 \ 24 | --reward_model_or_path $4 \ 25 | --mode "raft_get_rewards" \ 26 | --num_raft_iteration 999 \ 27 | --learning_rate 1e-5 \ 28 | --lr_scheduler_type "constant" \ 29 | --bf16 \ 30 | --deepspeed configs/ds_config_zero2.json \ 31 | --dataset_path $1 \ 32 | --output_reward_path ${project_dir}/tmp/eval_raft_aligner/reward.txt \ 33 | --output_dir ${output_dir} --overwrite_output_dir \ 34 | --run_name ${exp_id} \ 35 | --num_train_epochs 1 \ 36 | --per_device_train_batch_size 1 \ 37 | --per_device_eval_batch_size 1 \ 38 | --validation_split_percentage 0 \ 39 | --logging_steps 1 \ 40 | --do_train \ 41 | --ddp_timeout 72000 \ 42 | --save_steps 7777 \ 43 | --dataloader_num_workers 12 \ 44 | --preprocessing_num_workers 12 \ 45 | --inference_batch_size_per_device 100 \ 46 | --collection_strategy "top" \ 47 | --raft_batch_size 999999 \ 48 | --output_min_length 128 \ 49 | --output_max_length 196 \ 50 | --top_reward_percentage 1 \ 51 | | tee ${log_dir}/raft_align.log \ 52 | 2> ${log_dir}/raft_align.err 53 | -------------------------------------------------------------------------------- /utils/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make the delta weights by subtracting base weights. 3 | 4 | Usage: 5 | python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 6 | """ 7 | import argparse 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path): 15 | print(f"Loading the base model from {base_model_path}") 16 | base = AutoModelForCausalLM.from_pretrained( 17 | base_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True 18 | ) 19 | 20 | print(f"Loading the target model from {target_model_path}") 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True 23 | ) 24 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) 25 | 26 | print("Calculating the delta") 27 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 28 | assert name in base.state_dict() 29 | param.data -= base.state_dict()[name] 30 | 31 | print(f"Saving the delta to {delta_path}") 32 | if args.hub_repo_id: 33 | kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} 34 | else: 35 | kwargs = {} 36 | target.save_pretrained(delta_path, **kwargs) 37 | target_tokenizer.save_pretrained(delta_path, **kwargs) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | parser.add_argument("--hub-repo-id", type=str) 46 | args = parser.parse_args() 47 | 48 | make_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /examples/finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """A one-line summary of the module or program, terminated by a period. 5 | 6 | Leave one blank line. The rest of this docstring should contain an 7 | overall description of the module or program. Optionally, it may also 8 | contain a brief description of exported classes and functions and/or usage 9 | examples. 10 | 11 | Typical usage example: 12 | 13 | foo = ClassFoo() 14 | bar = foo.FunctionBar() 15 | """ 16 | 17 | import sys 18 | import os 19 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 20 | from transformers import HfArgumentParser 21 | 22 | from lmflow.args import ( 23 | ModelArguments, 24 | DatasetArguments, 25 | AutoArguments, 26 | ) 27 | 28 | from lmflow.datasets.dataset import Dataset 29 | from lmflow.models.auto_model import AutoModel 30 | from lmflow.pipeline.auto_pipeline import AutoPipeline 31 | 32 | 33 | def main(): 34 | # Parses arguments 35 | pipeline_name = "finetuner" 36 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 37 | 38 | parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments)) 39 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 40 | # If we pass only one argument to the script and it's the path to a json file, 41 | # let's parse it to get our arguments. 42 | model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 43 | else: 44 | model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() 45 | 46 | # Initialization 47 | finetuner = AutoPipeline.get_pipeline( 48 | pipeline_name=pipeline_name, 49 | model_args=model_args, 50 | data_args=data_args, 51 | pipeline_args=pipeline_args, 52 | ) 53 | dataset = Dataset(data_args) 54 | model = AutoModel.get_model(model_args) 55 | 56 | # Finetuning 57 | tuned_model = finetuner.tune(model=model, dataset=dataset) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /scripts/data_preprocess/shuffle.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Samples a certain ratio of instances from a dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import random 12 | import sys 13 | import textwrap 14 | 15 | def parse_argument(sys_argv): 16 | """Parses arguments from command line. 17 | Args: 18 | sys_argv: the list of arguments (strings) from command line. 19 | Returns: 20 | A struct whose member corresponds to the required (optional) variable. 21 | For example, 22 | ``` 23 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 24 | args.input # 'a.txt' 25 | args.num # 10 26 | ``` 27 | """ 28 | parser = argparse.ArgumentParser( 29 | formatter_class=argparse.RawTextHelpFormatter) 30 | 31 | # Training parameters 32 | parser.add_argument( 33 | "--dataset_path", type=str, 34 | default=None, 35 | help="input dataset path, reads from stdin by default" 36 | ) 37 | parser.add_argument( 38 | "--output_path", type=str, 39 | default=None, 40 | help="output dataset path, writes to stdout by default" 41 | ) 42 | parser.add_argument( 43 | "--seed", type=int, default=42, 44 | help="pseudorandom seed" 45 | ) 46 | 47 | # Parses from commandline 48 | args = parser.parse_args(sys_argv[1:]) 49 | 50 | return args 51 | 52 | 53 | def main(): 54 | args = parse_argument(sys.argv) 55 | if args.dataset_path is not None: 56 | with open(args.dataset_path, "r") as fin: 57 | data_dict = json.load(fin) 58 | else: 59 | data_dict = json.load(sys.stdin) 60 | 61 | random.seed(args.seed) 62 | random.shuffle(data_dict["instances"]) 63 | 64 | if args.output_path is not None: 65 | with open(args.output_path, "w") as fout: 66 | json.dump(data_dict, fout, indent=4, ensure_ascii=False) 67 | else: 68 | json.dump(data_dict, sys.stdout, indent=4, ensure_ascii=False) 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /examples/sample_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import torch 3 | from transformers import ( 4 | default_data_collator, 5 | pipeline, 6 | set_seed, 7 | AutoTokenizer 8 | ) 9 | import sys 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import random 14 | import time 15 | 16 | import re 17 | 18 | 19 | 20 | data_files = [ 21 | #"/home/xiongwei/rm_study/LMFlow/data/open_llama_7b_replay/10/raft_19iter.json", 22 | # "/home/xiongwei/over_opt/LMFlow_RAFT_Dev/output_models/replay_exp/data/pretrained/sub_sampled/togethercomputer_train_100M_v3.jsonl", 23 | # "/home/jianmeng/forgetting_data/v4/togethercomputer_train_200M_2048-4096.jsonl", 24 | # "/home/jianmeng/forgetting_data/v4/togethercomputer_train_200M_4096-9192.jsonl"], 25 | # "/home/xiongwei/over_opt/LMFlow_RAFT_Dev/output_models/replay_exp/data/raft_3b/raft_data_16M_tokens.json" 26 | # "/home/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/data/mixture_exp_3b/10/raft_data_26624.json" 27 | "/home/jianmeng/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/data/10w_sharegpt/sharegpt_en_10w.json" 28 | ] 29 | 30 | all_texts = [] 31 | raft = load_dataset("json", data_files=data_files[0], split="train", field="instances") 32 | 33 | 34 | tokenizer = AutoTokenizer.from_pretrained("/home/jianmeng/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/output_models/sft_open_llama_3b_1epoch_plus_hh_rlhf_1epoch") 35 | 36 | 37 | def tokenize(sample): 38 | sample["input_ids"] = tokenizer.encode(sample["text"]) 39 | sample["query"] = sample['text'] #tokenizer.decode(sample["input_ids"]) 40 | return sample 41 | 42 | 43 | 44 | raft = raft.map(tokenize, batched=False) 45 | 46 | print(len(raft)) 47 | all_raft_tokens = np.sum([len(sample['input_ids']) for sample in raft]) 48 | print(all_raft_tokens) 49 | 50 | import random 51 | import json 52 | all_texts = [] 53 | all_texts.extend(raft['text']) 54 | random.shuffle(all_texts) 55 | all_texts = all_texts[:(len(all_texts)//2)] 56 | store_texts = [{"text":txt} for txt in all_texts] 57 | 58 | output_dataset = {} 59 | output_dataset['type'] = "text_only" 60 | output_dataset['instances'] = store_texts 61 | with open("/home/jianmeng/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/data/5w_sharegpt/5w_sharegpt.json", 'w', encoding='utf8') as f: 62 | json.dump(output_dataset, f, ensure_ascii=False) -------------------------------------------------------------------------------- /scripts/data_preprocess/sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Samples a certain ratio of instances from a dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import random 12 | import sys 13 | import textwrap 14 | 15 | def parse_argument(sys_argv): 16 | """Parses arguments from command line. 17 | Args: 18 | sys_argv: the list of arguments (strings) from command line. 19 | Returns: 20 | A struct whose member corresponds to the required (optional) variable. 21 | For example, 22 | ``` 23 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 24 | args.input # 'a.txt' 25 | args.num # 10 26 | ``` 27 | """ 28 | parser = argparse.ArgumentParser( 29 | formatter_class=argparse.RawTextHelpFormatter) 30 | 31 | # Training parameters 32 | parser.add_argument( 33 | "--dataset_path", type=str, 34 | default=None, 35 | help="input dataset path, reads from stdin by default" 36 | ) 37 | parser.add_argument( 38 | "--output_path", type=str, 39 | default=None, 40 | help="output dataset path, writes to stdout by default" 41 | ) 42 | parser.add_argument( 43 | "--ratio", type=float, required=True, 44 | help="sample ratio, will be floored if number of samples is not a int" 45 | ) 46 | parser.add_argument( 47 | "--seed", type=int, default=42, 48 | help="pseudorandom seed" 49 | ) 50 | 51 | # Parses from commandline 52 | args = parser.parse_args(sys_argv[1:]) 53 | 54 | return args 55 | 56 | 57 | def main(): 58 | args = parse_argument(sys.argv) 59 | if args.dataset_path is not None: 60 | with open(args.dataset_path, "r") as fin: 61 | data_dict = json.load(fin) 62 | else: 63 | data_dict = json.load(sys.stdin) 64 | 65 | random.seed(args.seed) 66 | num_instances = len(data_dict["instances"]) 67 | num_sample = int(num_instances * args.ratio) 68 | 69 | data_dict["instances"] = random.sample(data_dict["instances"], num_sample) 70 | 71 | if args.output_path is not None: 72 | with open(args.output_path, "w") as fout: 73 | json.dump(data_dict, fout, indent=4, ensure_ascii=False) 74 | else: 75 | json.dump(data_dict, sys.stdout, indent=4, ensure_ascii=False) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /examples/merge_lora.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Merge base model and lora model into a full model. 6 | """ 7 | 8 | import sys 9 | import os 10 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 11 | 12 | from dataclasses import dataclass, field 13 | from transformers import HfArgumentParser 14 | from typing import Optional 15 | import json 16 | from lmflow.args import ( 17 | ModelArguments, 18 | AutoArguments, 19 | ) 20 | 21 | from lmflow.models.auto_model import AutoModel 22 | 23 | 24 | @dataclass 25 | class MergeLoraArguments: 26 | output_model_path: Optional[str] = field( 27 | default=None, 28 | metadata={ 29 | "help": "output merged full model path" 30 | }, 31 | ) 32 | 33 | 34 | def main(): 35 | pipeline_name = "evaluator" 36 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 37 | 38 | parser = HfArgumentParser((ModelArguments, MergeLoraArguments, PipelineArguments)) 39 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 40 | model_args, merge_lora_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 41 | else: 42 | model_args, merge_lora_args, pipeline_args = parser.parse_args_into_dataclasses() 43 | 44 | with open (pipeline_args.deepspeed, "r") as f: 45 | ds_config = json.load(f) 46 | 47 | model_args.use_lora = True 48 | model = AutoModel.get_model(model_args, 49 | tune_strategy='none', 50 | ds_config=ds_config) 51 | loraA_key = "base_model.model.model.layers.3.self_attn.q_proj.lora_A.weight" 52 | loraB_key = "base_model.model.model.layers.3.self_attn.q_proj.lora_B.weight" 53 | key ="base_model.model.model.layers.3.self_attn.q_proj.weight" 54 | # print(loraA_key, model.get_backend_model().state_dict()[loraA_key]) 55 | # print(loraB_key, model.get_backend_model().state_dict()[loraB_key]) 56 | # print(key, model.get_backend_model().state_dict()[key]) 57 | model.merge_lora_weights() 58 | # after_key = "model.layers.3.self_attn.q_proj.weight" 59 | # print(list(model.get_backend_model().state_dict())) 60 | # print(after_key, model.backend_model_full.state_dict()[after_key]) 61 | # print(after_key, model.get_backend_model().state_dict()[key]) 62 | model.save(merge_lora_args.output_model_path, save_full_model=True) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /scripts/data_preprocess/add_prompt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Adds prompt structure to a text2text dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import textwrap 12 | import sys 13 | 14 | def parse_argument(sys_argv): 15 | """Parses arguments from command line. 16 | Args: 17 | sys_argv: the list of arguments (strings) from command line. 18 | Returns: 19 | A struct whose member corresponds to the required (optional) variable. 20 | For example, 21 | ``` 22 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 23 | args.input # 'a.txt' 24 | args.num # 10 25 | ``` 26 | """ 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.RawTextHelpFormatter) 29 | 30 | # Training parameters 31 | parser.add_argument( 32 | "--dataset_path", type=str, 33 | default=None, 34 | help=textwrap.dedent("input dataset path, reads from stdin by default") 35 | ) 36 | parser.add_argument( 37 | "--output_path", type=str, 38 | default=None, 39 | help=textwrap.dedent("output dataset path, writes to stdout by default") 40 | ) 41 | parser.add_argument( 42 | "--prompt_structure", type=str, 43 | default="{input}", 44 | help=textwrap.dedent("prompt structure to augment input") 45 | ) 46 | 47 | # Parses from commandline 48 | args = parser.parse_args(sys_argv[1:]) 49 | 50 | return args 51 | 52 | 53 | def main(): 54 | args = parse_argument(sys.argv) 55 | if args.dataset_path is not None: 56 | with open(args.dataset_path, "r") as fin: 57 | data_dict = json.load(fin) 58 | else: 59 | data_dict = json.load(sys.stdin) 60 | 61 | if data_dict["type"] != "text2text": 62 | raise NotImplementedError( 63 | "only support text2text prompt augmentation" 64 | ) 65 | 66 | data_dict["instances"] = [ 67 | { 68 | "input": args.prompt_structure.format(input=instance["input"]), 69 | "output": instance["output"], 70 | } 71 | for instance in data_dict["instances"] 72 | ] 73 | if args.output_path is not None: 74 | with open(args.output_path, "w") as fout: 75 | json.dump(data_dict, fout, indent=4, ensure_ascii=False) 76 | else: 77 | json.dump(data_dict, sys.stdout, indent=4, ensure_ascii=False) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /scripts/data_preprocess/concat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Merges an extra dataset into current dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import textwrap 12 | import sys 13 | 14 | def parse_argument(sys_argv): 15 | """Parses arguments from command line. 16 | Args: 17 | sys_argv: the list of arguments (strings) from command line. 18 | Returns: 19 | A struct whose member corresponds to the required (optional) variable. 20 | For example, 21 | ``` 22 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 23 | args.input # 'a.txt' 24 | args.num # 10 25 | ``` 26 | """ 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.RawTextHelpFormatter) 29 | 30 | # Training parameters 31 | parser.add_argument( 32 | "--output_path", type=str, 33 | default=None, 34 | help=textwrap.dedent("output dataset path, writes to stdout by default") 35 | ) 36 | parser.add_argument( 37 | "--merge_from_path", type=str, 38 | nargs="+", 39 | help=textwrap.dedent( 40 | "dataset path of the extra dataset that will be merged" 41 | " into input dataset" 42 | ) 43 | ) 44 | 45 | # Parses from commandline 46 | args = parser.parse_args(sys_argv[1:]) 47 | 48 | return args 49 | 50 | 51 | def main(): 52 | args = parse_argument(sys.argv) 53 | 54 | if args.merge_from_path is not None: 55 | for i in range(0, len(args.merge_from_path)): 56 | with open(args.merge_from_path[i], "r") as fin: 57 | extra_data_dict = json.load(fin) 58 | if i == 0: 59 | data_dict = extra_data_dict 60 | else: 61 | if data_dict["type"] != extra_data_dict["type"]: 62 | raise ValueError( 63 | 'two dataset have different types:' 64 | f' input dataset: "{data_dict["type"]}";' 65 | f' merge from dataset: "{extra_data_dict["type"]}"' 66 | ) 67 | data_dict["instances"].extend(extra_data_dict["instances"]) 68 | else: 69 | raise ValueError("No merge files specified") 70 | 71 | if args.output_path is not None: 72 | with open(args.output_path, "w") as fout: 73 | json.dump(data_dict, fout, indent=4, ensure_ascii=False) 74 | else: 75 | json.dump(data_dict, sys.stdout, indent=4, ensure_ascii=False) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /scripts/data_preprocess/add_end_mark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Adds prompt structure to a text2text dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import textwrap 12 | import sys 13 | 14 | def parse_argument(sys_argv): 15 | """Parses arguments from command line. 16 | Args: 17 | sys_argv: the list of arguments (strings) from command line. 18 | Returns: 19 | A struct whose member corresponds to the required (optional) variable. 20 | For example, 21 | ``` 22 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 23 | args.input # 'a.txt' 24 | args.num # 10 25 | ``` 26 | """ 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.RawTextHelpFormatter) 29 | 30 | # Training parameters 31 | parser.add_argument( 32 | "--dataset_path", type=str, 33 | default=None, 34 | help=textwrap.dedent("input dataset path, reads from stdin by default") 35 | ) 36 | parser.add_argument( 37 | "--output_path", type=str, 38 | default=None, 39 | help=textwrap.dedent("output dataset path, writes to stdout by default") 40 | ) 41 | parser.add_argument( 42 | "--end_mark", type=str, 43 | default="###", 44 | help=textwrap.dedent("end mark that append to the end of output") 45 | ) 46 | 47 | # Parses from commandline 48 | args = parser.parse_args(sys_argv[1:]) 49 | 50 | return args 51 | 52 | 53 | def main(): 54 | args = parse_argument(sys.argv) 55 | if args.dataset_path is not None: 56 | with open(args.dataset_path, "r") as fin: 57 | data_dict = json.load(fin) 58 | else: 59 | data_dict = json.load(sys.stdin) 60 | 61 | output_field_map = { 62 | "text_only": "text", 63 | "text2text": "output", 64 | } 65 | data_dict_type = data_dict["type"] 66 | if not data_dict_type in output_field_map: 67 | raise NotImplementedError( 68 | "only support text_only or text2text dataset" 69 | ) 70 | 71 | output_field = output_field_map[data_dict_type] 72 | 73 | num_instances = len(data_dict["instances"]) 74 | for i in range(num_instances): 75 | data_dict["instances"][i][output_field] += args.end_mark 76 | 77 | if args.output_path is not None: 78 | with open(args.output_path, "w") as fout: 79 | json.dump(data_dict, fout, indent=4, ensure_ascii=False) 80 | else: 81 | json.dump(data_dict, sys.stdout, indent=4, ensure_ascii=False) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /src/lmflow/utils/position_interpolation/llama_rope_scaled_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import transformers 5 | import transformers.models.llama.modeling_llama 6 | 7 | class CondenseRotaryEmbedding(torch.nn.Module): 8 | def __init__(self, dim, pi_ratio, ntk_ratio, max_position_embeddings=2048, base=10000, device=None): 9 | super().__init__() 10 | 11 | self.ntk_ratio = ntk_ratio 12 | max_position_embeddings *= ntk_ratio 13 | base = base * ntk_ratio ** (dim / (dim-2)) #Base change formula 14 | 15 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 16 | self.register_buffer("inv_freq", inv_freq) 17 | 18 | self.pi_ratio = pi_ratio 19 | max_position_embeddings *= pi_ratio 20 | self.max_seq_len_cached = max_position_embeddings 21 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) / pi_ratio 22 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 23 | 24 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 25 | emb = torch.cat((freqs, freqs), dim=-1) 26 | dtype = torch.get_default_dtype() 27 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 28 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 29 | 30 | def forward(self, x, seq_len=None): 31 | # x: [bs, num_attention_heads, seq_len, head_size] 32 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 33 | if seq_len > self.max_seq_len_cached: 34 | self.max_seq_len_cached = seq_len 35 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) / self.pi_ratio 36 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 37 | 38 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 39 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 40 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) 41 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) 42 | 43 | return ( 44 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 45 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 46 | ) 47 | 48 | def replace_llama_with_condense(pi_ratio, ntk_ratio): 49 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial(CondenseRotaryEmbedding, pi_ratio=pi_ratio, ntk_ratio=ntk_ratio) -------------------------------------------------------------------------------- /scripts/data_preprocess/merge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | Merges an extra dataset into current dataset. 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import argparse 10 | import json 11 | import textwrap 12 | import sys 13 | 14 | def parse_argument(sys_argv): 15 | """Parses arguments from command line. 16 | Args: 17 | sys_argv: the list of arguments (strings) from command line. 18 | Returns: 19 | A struct whose member corresponds to the required (optional) variable. 20 | For example, 21 | ``` 22 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 23 | args.input # 'a.txt' 24 | args.num # 10 25 | ``` 26 | """ 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.RawTextHelpFormatter) 29 | 30 | parser.add_argument( 31 | "--dataset_path", type=str, 32 | default=None, 33 | help=textwrap.dedent("input dataset path, reads from stdin by default") 34 | ) 35 | # Training parameters 36 | parser.add_argument( 37 | "--output_path", type=str, 38 | default=None, 39 | help=textwrap.dedent("output dataset path, writes to stdout by default") 40 | ) 41 | parser.add_argument( 42 | "--merge_from_path", type=str, 43 | nargs="+", 44 | help=textwrap.dedent( 45 | "dataset path of the extra dataset that will be merged" 46 | " into input dataset" 47 | ) 48 | ) 49 | 50 | # Parses from commandline 51 | args = parser.parse_args(sys_argv[1:]) 52 | 53 | return args 54 | 55 | 56 | def main(): 57 | args = parse_argument(sys.argv) 58 | 59 | if args.dataset_path is not None: 60 | with open(args.dataset_path, "r") as fin: 61 | data_dict = json.load(fin) 62 | else: 63 | data_dict = json.load(sys.stdin) 64 | 65 | if args.merge_from_path is not None: 66 | for i in range(0, len(args.merge_from_path)): 67 | with open(args.merge_from_path[i], "r") as fin: 68 | extra_data_dict = json.load(fin) 69 | 70 | if data_dict["type"] != extra_data_dict["type"]: 71 | raise ValueError( 72 | 'two dataset have different types:' 73 | f' input dataset: "{data_dict["type"]}";' 74 | f' merge from dataset: "{extra_data_dict["type"]}"' 75 | ) 76 | data_dict["instances"].extend(extra_data_dict["instances"]) 77 | 78 | 79 | if args.output_path is not None: 80 | with open(args.output_path, "w") as fout: 81 | json.dump(data_dict, fout, indent=4, ensure_ascii=False) 82 | else: 83 | json.dump(data_dict, sys.stdout, indent=4, ensure_ascii=False) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /src/lmflow/pipeline/utils/peft_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """Trainer for Peft models 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from transformers import Trainer 8 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 9 | from transformers.trainer_callback import ( 10 | TrainerCallback, 11 | TrainerControl, 12 | TrainerState, 13 | ) 14 | from transformers.training_args import TrainingArguments 15 | import os 16 | import numpy as np 17 | 18 | class PeftTrainer(Trainer): 19 | def _save_checkpoint(self, _, trial, metrics=None): 20 | """ Don't save base model, optimizer etc. 21 | but create checkpoint folder (needed for saving adapter) """ 22 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 23 | 24 | run_dir = self._get_output_dir(trial=trial) 25 | output_dir = os.path.join(run_dir, checkpoint_folder) 26 | 27 | if metrics is not None and self.args.metric_for_best_model is not None: 28 | metric_to_check = self.args.metric_for_best_model 29 | if not metric_to_check.startswith("eval_"): 30 | metric_to_check = f"eval_{metric_to_check}" 31 | metric_value = metrics[metric_to_check] 32 | 33 | operator = np.greater if self.args.greater_is_better else np.less 34 | if (self.state.best_metric is None or self.state.best_model_checkpoint is None 35 | or operator(metric_value, self.state.best_metric)): 36 | self.state.best_metric = metric_value 37 | 38 | self.state.best_model_checkpoint = output_dir 39 | 40 | os.makedirs(output_dir, exist_ok=True) 41 | 42 | if self.args.should_save: 43 | self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) 44 | 45 | class PeftSavingCallback(TrainerCallback): 46 | """ Correctly save PEFT model and not full model """ 47 | def _save(self, model, folder): 48 | if folder is None: 49 | folder = "" 50 | peft_model_path = os.path.join(folder, "adapter_model") 51 | model.save_pretrained(peft_model_path) 52 | 53 | def on_train_end(self, args: TrainingArguments, state: TrainerState, 54 | control: TrainerControl, **kwargs): 55 | """ Save final best model adapter """ 56 | self._save(kwargs['model'], state.best_model_checkpoint) 57 | 58 | def on_epoch_end(self, args: TrainingArguments, state: TrainerState, 59 | control: TrainerControl, **kwargs): 60 | """ Save intermediate model adapters in case of interrupted training """ 61 | folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") 62 | self._save(kwargs['model'], folder) 63 | 64 | def on_save( 65 | self, 66 | args: TrainingArguments, 67 | state: TrainerState, 68 | control: TrainerControl, 69 | **kwargs, 70 | ): 71 | checkpoint_folder = os.path.join( 72 | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" 73 | ) 74 | self._save(kwargs['model'], checkpoint_folder) 75 | 76 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model") 77 | kwargs["model"].save_pretrained(peft_model_path) 78 | return control -------------------------------------------------------------------------------- /examples/finetune_mask.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """A one-line summary of the module or program, terminated by a period. 5 | 6 | Leave one blank line. The rest of this docstring should contain an 7 | overall description of the module or program. Optionally, it may also 8 | contain a brief description of exported classes and functions and/or usage 9 | examples. 10 | 11 | Typical usage example: 12 | 13 | foo = ClassFoo() 14 | bar = foo.FunctionBar() 15 | """ 16 | 17 | import sys 18 | import os 19 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 20 | from transformers import HfArgumentParser 21 | 22 | from lmflow.args import ( 23 | ModelArguments, 24 | DatasetArguments, 25 | AutoArguments, 26 | ) 27 | 28 | from lmflow.datasets.dataset import Dataset 29 | from lmflow.models.auto_model import AutoModel 30 | from lmflow.pipeline.auto_pipeline import AutoPipeline 31 | 32 | 33 | def main(): 34 | # Parses arguments 35 | pipeline_name = "mask_finetuner" 36 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 37 | 38 | parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments)) 39 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 40 | # If we pass only one argument to the script and it's the path to a json file, 41 | # let's parse it to get our arguments. 42 | model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 43 | else: 44 | model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() 45 | # print(model_args, data_args, pipeline_args) 46 | # Initialization 47 | # print(pipeline_args.device, type(pipeline_args.device)) 48 | # pipeline_args.device = f"cuda:{pipeline_args.gpu_id}" 49 | finetuner = AutoPipeline.get_pipeline( 50 | pipeline_name=pipeline_name, 51 | model_args=model_args, 52 | data_args=data_args, 53 | pipeline_args=pipeline_args, 54 | ) 55 | dataset = Dataset(data_args) 56 | # print("Pre:", f"cuda:{pipeline_args.gpu_id}") 57 | # import torch 58 | # , device=torch.device(f"cuda:{pipeline_args.gpu_id}") 59 | model1 = AutoModel.get_model(model_args) 60 | import copy 61 | model0_args = copy.deepcopy(model_args) 62 | # model0_args.model_name_or_path = '/home/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/output_models/sft_open_llama_3b_1epoch' 63 | model0_args.model_name_or_path = '/home/jianmeng/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev/output_models/sft_open_llama_3b_1epoch' 64 | # model0_args.model_name_or_path = 'openlm-research/open_llama_3b_v2' 65 | model0 = AutoModel.get_model(model0_args) 66 | # # Tokenization and text grouping must be done in the main process 67 | # with pipeline_args.main_process_first(desc="dataset map tokenization"): 68 | # tokenized_dataset = model.tokenize(dataset) 69 | # lm_dataset = finetuner.group_text( 70 | # tokenized_dataset, 71 | # model_max_length=model.get_max_length(), 72 | # ) 73 | 74 | # Finetuning 75 | tuned_model = finetuner.tune(model0=model0, model1=model1, dataset=dataset) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /scripts/postprocess/weight_interpolation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 5 | from transformers import HfArgumentParser 6 | 7 | from lmflow.models.auto_model import AutoModel 8 | from lmflow.args import ModelArguments, DatasetArguments, AutoArguments 9 | 10 | pipeline_name = "evaluator" 11 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 12 | 13 | parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments)) 14 | model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() 15 | 16 | # Get the paths and ratios of weight-ensamble models. 17 | weight_ensamble_names_paths = pipeline_args.weight_ensamble_names_paths 18 | weight_ensamble_ratios = pipeline_args.weight_ensamble_ratios 19 | weight_ensamble_save_path = pipeline_args.weight_ensamble_save_path 20 | weight_ensamble_ratios.append(1 - weight_ensamble_ratios[0]) 21 | assert len(weight_ensamble_ratios) == 2, 'Only 2 merge is supported.' 22 | print('Model Paths:', weight_ensamble_names_paths) 23 | print('Model Ratio:', weight_ensamble_ratios) 24 | with open (pipeline_args.deepspeed, "r") as f: 25 | ds_config = json.load(f) 26 | 27 | # Load models. 28 | base_model = None 29 | backend_models = [] 30 | for model_path in weight_ensamble_names_paths: 31 | model_args.model_name_or_path = model_path 32 | print('loading:', model_path) 33 | model = AutoModel.get_model( 34 | model_args, 35 | tune_strategy='none', 36 | ds_config=ds_config, 37 | use_accelerator=pipeline_args.use_accelerator_for_evaluator 38 | ) 39 | model.get_backend_model().eval() 40 | backend_models.append(model.get_backend_model().to('cpu')) 41 | if base_model is None: 42 | base_model = model 43 | print('Finish load:', model_path) 44 | base_backend_model = backend_models[0] 45 | print('Finish load All:', base_backend_model) 46 | 47 | merge_method = weight_ensamble_save_path.split('_')[-2] 48 | print(f'Merge Method:{merge_method}.') 49 | 50 | def merge_split(merge_method, key, weights, ori_ratio): 51 | merge_terms = merge_method.split('|') #split|6,13|0.2|0.5|0.2 52 | split_layers = merge_terms[1].split('#') 53 | split_layers = [int(split_layer) for split_layer in split_layers] 54 | 55 | assert len(split_layers) == len(merge_terms) - 3 56 | ratios = [float(t) for t in merge_terms[2:]] 57 | 58 | terms = key.split('.') 59 | layer_idx, ratio = 0, ratios[0] 60 | 61 | if 'split0' in merge_method and 'lm_head' in key: 62 | ratio = 0 63 | elif 'lm_head' in key or 'norm' in key: 64 | ratio = ori_ratio 65 | 66 | if terms[0] == 'model' and terms[1] == 'layers': 67 | layer_idx = int(terms[2]) 68 | ratio = ratios[0] 69 | for s_i, split_layer in enumerate(split_layers): 70 | if layer_idx > split_layer: 71 | ratio = ratios[s_i + 1] 72 | print(key, layer_idx, ratio) 73 | return weights[0] * ratio + weights[1] * (1 - ratio) 74 | 75 | def merge_direct(weights, ratio): 76 | return weights[0] * ratio + weights[1] * (1 - ratio) 77 | 78 | updated_state_dict = {} 79 | for key in base_backend_model.state_dict(): 80 | weights = [backend_model.state_dict()[key] for backend_model in backend_models] 81 | if 'split' in merge_method: 82 | updated_weight = merge_split(merge_method, key, weights, weight_ensamble_ratios[0]) 83 | else: 84 | updated_weight = merge_direct(weights, weight_ensamble_ratios[0]) 85 | updated_state_dict[key] = updated_weight 86 | 87 | base_backend_model.load_state_dict(updated_state_dict) 88 | print(weight_ensamble_save_path) 89 | base_model.save(weight_ensamble_save_path) 90 | -------------------------------------------------------------------------------- /utils/lm_evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import fnmatch 5 | 6 | from lm_eval import tasks, evaluator 7 | 8 | logging.getLogger("openai").setLevel(logging.WARNING) 9 | 10 | 11 | class MultiChoice: 12 | def __init__(self, choices): 13 | self.choices = choices 14 | 15 | # Simple wildcard support (linux filename patterns) 16 | def __contains__(self, values): 17 | for value in values.split(","): 18 | if len(fnmatch.filter(self.choices, value)) == 0: 19 | return False 20 | 21 | return True 22 | 23 | def __iter__(self): 24 | for choice in self.choices: 25 | yield choice 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--model", required=True) 31 | parser.add_argument("--model_args", default="") 32 | parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) 33 | parser.add_argument("--provide_description", action="store_true") 34 | parser.add_argument("--num_fewshot", type=int, default=0) 35 | parser.add_argument("--batch_size", type=int, default=None) 36 | parser.add_argument("--device", type=str, default=None) 37 | parser.add_argument("--output_path", default=None) 38 | parser.add_argument("--limit", type=int, default=None) 39 | parser.add_argument("--no_cache", action="store_true") 40 | parser.add_argument("--decontamination_ngrams_path", default=None) 41 | parser.add_argument("--description_dict_path", default=None) 42 | parser.add_argument("--check_integrity", action="store_true") 43 | 44 | return parser.parse_args() 45 | 46 | 47 | # Returns a list containing all values of the source_list that 48 | # match at least one of the patterns 49 | def pattern_match(patterns, source_list): 50 | task_names = set() 51 | for pattern in patterns: 52 | for matching in fnmatch.filter(source_list, pattern): 53 | task_names.add(matching) 54 | return list(task_names) 55 | 56 | 57 | def main(): 58 | args = parse_args() 59 | 60 | assert not args.provide_description # not implemented 61 | 62 | if args.limit: 63 | print( 64 | "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." 65 | ) 66 | 67 | if args.tasks is None: 68 | task_names = tasks.ALL_TASKS 69 | else: 70 | task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) 71 | 72 | print(f"Selected Tasks: {task_names}") 73 | 74 | description_dict = {} 75 | if args.description_dict_path: 76 | with open(args.description_dict_path, "r") as f: 77 | description_dict = json.load(f) 78 | 79 | results = evaluator.simple_evaluate( 80 | model=args.model, 81 | model_args=args.model_args, 82 | tasks=task_names, 83 | num_fewshot=args.num_fewshot, 84 | batch_size=args.batch_size, 85 | device=args.device, 86 | no_cache=args.no_cache, 87 | limit=args.limit, 88 | description_dict=description_dict, 89 | decontamination_ngrams_path=args.decontamination_ngrams_path, 90 | check_integrity=args.check_integrity, 91 | ) 92 | 93 | dumped = json.dumps(results, indent=2) 94 | print(dumped) 95 | 96 | if args.output_path: 97 | with open(args.output_path, "w") as f: 98 | f.write(dumped) 99 | 100 | print( 101 | f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " 102 | f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" 103 | ) 104 | print(evaluator.make_table(results)) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /utils/merge_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import argparse 5 | import logging 6 | import os 7 | 8 | from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model 9 | import sentencepiece as spm 10 | 11 | import torch 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | from transformers import AutoTokenizer,LlamaTokenizer 15 | 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | if __name__ == '__main__': 19 | 20 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python" 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--tokenizer_dir', default='openlm-research/open_llama_3b', type=str, required=False) 24 | parser.add_argument('--chinese_sp_model_file', default='./output_models/new_tokenizer/example.model', type=str) 25 | parser.add_argument('--output_dir', default='./output_models/merged_tokenizer', type=str, required=False) 26 | args = parser.parse_args() 27 | 28 | tokenizer_dir = args.tokenizer_dir 29 | chinese_sp_model_file = args.chinese_sp_model_file 30 | output_dir = args.output_dir 31 | 32 | # load 33 | try: 34 | old_tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False) 35 | except RecursionError: 36 | old_tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, 37 | unk_token="", 38 | bos_token="", 39 | eos_token="", 40 | use_fast=False) 41 | 42 | if not isinstance(old_tokenizer,LlamaTokenizer): 43 | raise ValueError("The tokenizer is not a LlamaTokenizer, we only support LlamaTokenizer for now.") 44 | 45 | chinese_sp_model = spm.SentencePieceProcessor() 46 | chinese_sp_model.Load(chinese_sp_model_file) 47 | 48 | old_spm = sp_pb2_model.ModelProto() 49 | old_spm.ParseFromString(old_tokenizer.sp_model.serialized_model_proto()) 50 | chinese_spm = sp_pb2_model.ModelProto() 51 | chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto()) 52 | 53 | ## Add Chinese tokens to old tokenizer 54 | old_spm_tokens_set=set(p.piece for p in old_spm.pieces) 55 | for p in chinese_spm.pieces: 56 | piece = p.piece 57 | if piece not in old_spm_tokens_set: 58 | new_p = sp_pb2_model.ModelProto().SentencePiece() 59 | new_p.piece = piece 60 | new_p.score = 0 61 | old_spm.pieces.append(new_p) 62 | 63 | ## Save 64 | output_sp_dir = output_dir + '/merged_tokenizer_sp' 65 | output_hf_dir = output_dir + '/merged_tokenizer_hf' # the path to save tokenizer 66 | os.makedirs(output_sp_dir,exist_ok=True) 67 | with open(output_sp_dir+'/merged_tokenizer.model', 'wb') as f: 68 | f.write(old_spm.SerializeToString()) 69 | 70 | try: 71 | tokenizer = AutoTokenizer.from_pretrained( 72 | pretrained_model_name_or_path=tokenizer_dir, 73 | vocab_file=output_sp_dir+'/merged_tokenizer.model', 74 | use_fast=False 75 | ) 76 | except RecursionError: 77 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=tokenizer_dir, 78 | unk_token="", 79 | bos_token="", 80 | eos_token="", 81 | vocab_file=output_sp_dir+'/merged_tokenizer.model', 82 | use_fast=False) 83 | 84 | tokenizer.save_pretrained(output_hf_dir) 85 | logging.info(f"Merged tokenizer has been saved to %s",output_dir) 86 | 87 | 88 | # Test 89 | new_tokenizer = tokenizer 90 | logging.info(f"Old tokenizer vocab size: %d",len(old_tokenizer)) 91 | logging.info(f"New tokenizer vocab size: %d",len(new_tokenizer)) 92 | 93 | text='''白日依山尽,黄河入海流。欲穷千里目,更上一层楼。 94 | The primary use of LLaMA is research on large language models, including''' 95 | logging.info(f"Test text:\n %s",text) 96 | logging.info(f"Tokenized by original tokenizer:%s",old_tokenizer.tokenize(text)) 97 | logging.info(f"Tokenized by merged tokenizer:%s",new_tokenizer.tokenize(text)) -------------------------------------------------------------------------------- /src/lmflow/utils/flash_attention/gpt_neo_flash_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import transformers 5 | from einops import rearrange 6 | 7 | #try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x 8 | try: 9 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 10 | except: 11 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 12 | 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 16 | # (batch, head, seq_length, head_features) 17 | query = query.to(torch.bfloat16) 18 | key = key.to(torch.bfloat16) 19 | query = query * torch.sqrt(torch.tensor(self.head_dim)) 20 | qkv = torch.stack( 21 | [query, key, value], dim=2 22 | )# [bsz, nh, 3, t, hd] 23 | qkv = qkv.transpose(1,3)## [bsz, q_len, 3, nh, hd] 24 | bsz = qkv.shape[0] 25 | q_len = qkv.shape[1] 26 | 27 | attention_mask = torch.where(attention_mask == -0.0, True, False) 28 | key_padding_mask = rearrange(attention_mask, "b () () s -> b s") if attention_mask is not None else None 29 | if key_padding_mask is None: 30 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 31 | max_s = q_len 32 | cu_q_lens = torch.arange( 33 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 34 | ) 35 | output = flash_attn_unpadded_qkvpacked_func( 36 | qkv, cu_q_lens, max_s, self.attn_dropout.p if self.training else 0.0 , softmax_scale=None, causal=True 37 | )# attention compute 38 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 39 | else: 40 | nheads = qkv.shape[-2] 41 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 42 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 43 | x_unpad = rearrange( 44 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 45 | ) 46 | output_unpad = flash_attn_unpadded_qkvpacked_func( 47 | x_unpad, cu_q_lens, max_s, self.attn_dropout.p if self.training else 0.0, softmax_scale=None, causal=True 48 | ) 49 | output = rearrange( 50 | pad_input( 51 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 52 | ), 53 | "b s (h d) -> b s h d", 54 | h=nheads, 55 | ) 56 | 57 | return output, None 58 | 59 | def forward( 60 | self, 61 | hidden_states, 62 | attention_mask=None, 63 | layer_past=None, 64 | head_mask=None, 65 | use_cache=False, 66 | output_attentions=False, 67 | ): 68 | 69 | assert head_mask is None, "head_mask is not supported" 70 | assert not output_attentions, "output_attentions is not supported" 71 | assert not use_cache, "use_cache is not supported" 72 | 73 | query = self.q_proj(hidden_states) 74 | key = self.k_proj(hidden_states) 75 | value = self.v_proj(hidden_states) 76 | 77 | query = self._split_heads(query, self.num_heads, self.head_dim) 78 | key = self._split_heads(key, self.num_heads, self.head_dim) 79 | value = self._split_heads(value, self.num_heads, self.head_dim) 80 | 81 | if layer_past is not None: 82 | past_key = layer_past[0] 83 | past_value = layer_past[1] 84 | key = torch.cat((past_key, key), dim=-2) 85 | value = torch.cat((past_value, value), dim=-2) 86 | 87 | present = None 88 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 89 | new_shape = attn_output.size()[:-2] + (self.num_heads * self.head_dim,) 90 | attn_output = attn_output.view(new_shape) 91 | attn_output = self.out_proj(attn_output) 92 | attn_output = self.resid_dropout(attn_output) 93 | 94 | outputs = (attn_output, present) 95 | 96 | return outputs # a, present, (attentions) 97 | 98 | def replace_gpt_neo_attn_with_flash_attn(): 99 | transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._attn = _attn 100 | transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention.forward = forward -------------------------------------------------------------------------------- /src/lmflow/utils/flash_attention/bloom_flash_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | import transformers 8 | from transformers.models.bloom.modeling_bloom import dropout_add 9 | 10 | from einops import rearrange 11 | 12 | from .triton_flash_attention import flash_attn_qkvpacked_func 13 | 14 | def forward( 15 | self, 16 | hidden_states: torch.Tensor, 17 | residual: torch.Tensor, 18 | alibi: torch.Tensor, 19 | attention_mask: torch.Tensor, 20 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 21 | head_mask: Optional[torch.Tensor] = None, 22 | use_cache: bool = False, 23 | output_attentions: bool = False, 24 | ): 25 | dtype = hidden_states.dtype 26 | fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] 27 | 28 | # 3 x [batch_size, seq_length, num_heads, head_dim] 29 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) 30 | 31 | batch_size, q_length, _, _ = query_layer.shape 32 | bsz, q_len = batch_size, q_length 33 | 34 | if layer_past is not None: 35 | past_key, past_value = layer_past 36 | # concatenate along seq_length dimension: 37 | # - key: [batch_size * self.num_heads, head_dim, kv_length] 38 | # - value: [batch_size * self.num_heads, kv_length, head_dim] 39 | key_layer = torch.cat((past_key, key_layer), dim=2) 40 | value_layer = torch.cat((past_value, value_layer), dim=1) 41 | 42 | if use_cache is True: 43 | present = (key_layer, value_layer) 44 | else: 45 | present = None 46 | 47 | reshaped_alibi = rearrange(alibi, '(b h) one s-> b h one s', h = self.num_heads) 48 | reshaped_alibi = reshaped_alibi * self.beta 49 | 50 | attention_mask = (1.0 - attention_mask) 51 | attention_mask = attention_mask[:, None, None, :].bool() 52 | reshaped_alibi_masked = reshaped_alibi.masked_fill(attention_mask, -1e9) 53 | 54 | reshaped_query_layer = query_layer 55 | reshaped_key_layer = key_layer 56 | reshaped_value_layer = value_layer 57 | 58 | qkv = torch.concat([reshaped_query_layer.unsqueeze(2), reshaped_key_layer.unsqueeze(2), reshaped_value_layer.unsqueeze(2)], dim = 2) 59 | 60 | output = flash_attn_qkvpacked_func( 61 | qkv, reshaped_alibi_masked, True, self.inv_norm_factor 62 | ) 63 | 64 | output = rearrange(output, 'b s h d -> (b h) s d') 65 | 66 | # change view [batch_size, num_heads, q_length, head_dim] 67 | context_layer = self._merge_heads(output) 68 | 69 | # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 70 | if self.pretraining_tp > 1 and self.slow_but_exact: 71 | slices = self.hidden_size / self.pretraining_tp 72 | output_tensor = torch.zeros_like(context_layer) 73 | for i in range(self.pretraining_tp): 74 | output_tensor = output_tensor + F.linear( 75 | context_layer[:, :, int(i * slices) : int((i + 1) * slices)], 76 | self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], 77 | ) 78 | else: 79 | output_tensor = self.dense(context_layer) 80 | 81 | output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) 82 | 83 | outputs = (output_tensor, present) 84 | if output_attentions: 85 | outputs += (context_layer,) 86 | 87 | return outputs 88 | 89 | 90 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 91 | # requires the attention mask to be the same as the key_padding_mask 92 | def _prepare_attn_mask( 93 | self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int 94 | ) -> torch.BoolTensor: 95 | 96 | return attention_mask 97 | 98 | def replace_bloom_attn_with_flash_attn(): 99 | transformers.models.bloom.modeling_bloom.BloomModel._prepare_attn_mask = ( 100 | _prepare_attn_mask 101 | ) 102 | transformers.models.bloom.modeling_bloom.BloomAttention.forward = forward -------------------------------------------------------------------------------- /output_models/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function main() { 4 | public_server="http://lmflow.org:5000" 5 | if [ $# -lt 1 -o "$1" = "-h" -o "$1" = "--help" ]; then 6 | echo "Usage: bash $(basename $0) model_name" 7 | echo "Example: bash $(basename $0) instruction_ckpt" 8 | echo "Example: bash $(basename $0) all" 9 | fi 10 | 11 | if [ "$1" = "llama7b-lora-medical" -o "$1" = "medical_ckpt" -o "$1" = "all" ]; then 12 | echo "downloading llama7b-lora-medical.tar.gz" 13 | filename='llama7b-lora-medical.tar.gz' 14 | wget ${public_server}/${filename} 15 | tar zxvf ${filename} 16 | rm ${filename} 17 | fi 18 | 19 | if [ "$1" = "llama13b-lora-medical" -o "$1" = "medical_ckpt" -o "$1" = "all" ]; then 20 | echo "downloading llama13b-lora-medical.tar.gz" 21 | filename='llama13b-lora-medical.tar.gz' 22 | wget ${public_server}/${filename} 23 | tar zxvf ${filename} 24 | rm ${filename} 25 | fi 26 | 27 | if [ "$1" = "llama30b-lora-medical" -o "$1" = "medical_ckpt" -o "$1" = "all" ]; then 28 | echo "downloading llama30b-lora-medical.tar.gz" 29 | filename='llama30b-lora-medical.tar.gz' 30 | wget ${public_server}/${filename} 31 | tar zxvf ${filename} 32 | rm ${filename} 33 | fi 34 | 35 | if [ "$1" = "llama7b-lora-170k" -o "$1" = "instruction_ckpt" -o "$1" = "all" ]; then 36 | echo "downloading llama7b-lora-170k.tar.gz" 37 | filename='llama7b-lora-170k.tar.gz' 38 | wget ${public_server}/${filename} 39 | tar zxvf ${filename} 40 | rm ${filename} 41 | fi 42 | 43 | if [ "$1" = "llama7b-lora-380k" -o "$1" = "instruction_ckpt" -o "$1" = "all" ]; then 44 | echo "downloading llama7b-lora-380k.tar.gz" 45 | filename='llama7b-lora-380k.tar.gz' 46 | wget ${public_server}/${filename} 47 | tar zxvf ${filename} 48 | rm ${filename} 49 | fi 50 | 51 | if [ "$1" = "llama13b-lora-170k" -o "$1" = "instruction_ckpt" -o "$1" = "all" ]; then 52 | echo "downloading llama13b-lora-170k.tar.gz" 53 | filename='llama13b-lora-170k.tar.gz' 54 | wget ${public_server}/${filename} 55 | tar zxvf ${filename} 56 | rm ${filename} 57 | fi 58 | 59 | if [ "$1" = "llama13b-lora-380k" -o "$1" = "instruction_ckpt" -o "$1" = "all" ]; then 60 | echo "downloading llama13b-lora-380k.tar.gz" 61 | filename='llama13b-lora-380k.tar.gz' 62 | wget ${public_server}/${filename} 63 | tar zxvf ${filename} 64 | rm ${filename} 65 | fi 66 | 67 | if [ "$1" = "llama30b-lora-170k" -o "$1" = "instruction_ckpt" -o "$1" = "all" ]; then 68 | echo "downloading llama30b-lora-170k.tar.gz" 69 | filename='llama30b-lora-170k.tar.gz' 70 | wget ${public_server}/${filename} 71 | tar zxvf ${filename} 72 | rm ${filename} 73 | fi 74 | 75 | if [ "$1" = "llama7b-lora-movie-reviewer" -o "$1" = "raft_ckpt" -o "$1" = "all" ]; then 76 | echo "downloading llama7b-lora-movie-reviewer" 77 | filename='llama7b-lora-movie-reviewer.tar.gz' 78 | wget ${public_server}/${filename} 79 | tar zxvf ${filename} 80 | rm ${filename} 81 | fi 82 | 83 | if [ "$1" = "cockatoo-7b" -o "$1" = "all" ]; then 84 | echo "downloading cockatoo-7b" 85 | filename='cockatoo-7b.tar.gz' 86 | wget ${public_server}/${filename} 87 | tar zxvf ${filename} 88 | rm ${filename} 89 | fi 90 | 91 | if [ "$1" = "parakeets-2.7b" -o "$1" = "all" ]; then 92 | echo "downloading parakeets-2.7b" 93 | filename='parakeets-2.7b.tar.gz' 94 | wget ${public_server}/${filename} 95 | tar zxvf ${filename} 96 | rm ${filename} 97 | fi 98 | 99 | if [ "$1" = "robin-7b" -o "$1" = "all" ]; then 100 | echo "downloading robin-7b" 101 | filename='robin-7b-v2-delta.tar.gz' 102 | wget ${public_server}/${filename} 103 | tar zxvf ${filename} 104 | rm ${filename} 105 | fi 106 | 107 | if [ "$1" = "minigpt4_7b" -o "$1" = "all" ]; then 108 | echo "downloading minigpt4_7b" 109 | filename='pretrained_minigpt4_7b.pth' 110 | wget ${public_server}/${filename} 111 | fi 112 | 113 | if [ "$1" = "minigpt4_13b" -o "$1" = "all" ]; then 114 | echo "downloading minigpt4_13b" 115 | filename='pretrained_minigpt4_13b.pth' 116 | wget ${public_server}/${filename} 117 | fi 118 | } 119 | 120 | main "$@" 121 | -------------------------------------------------------------------------------- /src/lmflow/utils/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | Commonly used constants. 5 | """ 6 | 7 | TEXT_ONLY_DATASET_DESCRIPTION = ( 8 | """ 9 | "text_only": a dataset with only raw text instances, with following format: 10 | 11 | { 12 | "type": "text_only", 13 | "instances": [ 14 | { "text": "TEXT_1" }, 15 | { "text": "TEXT_2" }, 16 | ... 17 | ] 18 | } 19 | """ 20 | ).lstrip("\n") 21 | 22 | 23 | TEXT_ONLY_DATASET_DETAILS = ( 24 | """ 25 | For example, 26 | 27 | ```python 28 | from lmflow.datasets import Dataset 29 | 30 | data_dict = { 31 | "type": "text_only", 32 | "instances": [ 33 | { "text": "Human: Hello. Bot: Hi!" }, 34 | { "text": "Human: How are you today? Bot: Fine, thank you!" }, 35 | ] 36 | } 37 | dataset = Dataset.create_from_dict(data_dict) 38 | ``` 39 | 40 | You may also save the corresponding format to json, 41 | ```python 42 | import json 43 | from lmflow.args import DatasetArguments 44 | from lmflow.datasets import Dataset 45 | 46 | data_dict = { 47 | "type": "text_only", 48 | "instances": [ 49 | { "text": "Human: Hello. Bot: Hi!" }, 50 | { "text": "Human: How are you today? Bot: Fine, thank you!" }, 51 | ] 52 | } 53 | with open("data.json", "w") as fout: 54 | json.dump(data_dict, fout) 55 | 56 | data_args = DatasetArgument(dataset_path="data.json") 57 | dataset = Dataset(data_args) 58 | new_data_dict = dataset.to_dict() 59 | # `new_data_dict` Should have the same content as `data_dict` 60 | ``` 61 | """ 62 | ).lstrip("\n") 63 | 64 | 65 | TEXT2TEXT_DATASET_DESCRIPTION = ( 66 | """ 67 | "text2text": a dataset with input & output instances, with following format: 68 | 69 | { 70 | "type": "text2text", 71 | "instances": [ 72 | { "input": "INPUT_1", "output": "OUTPUT_1" }, 73 | { "input": "INPUT_2", "output": "OUTPUT_2" }, 74 | ... 75 | ] 76 | } 77 | """ 78 | ).lstrip("\n") 79 | 80 | 81 | TEXT2TEXT_DATASET_DETAILS = ( 82 | """ 83 | For example, 84 | 85 | ```python 86 | from lmflow.datasets import Dataset 87 | 88 | data_dict = { 89 | "type": "text2text", 90 | "instances": [ 91 | { 92 | "input": "Human: Hello.", 93 | "output": "Bot: Hi!", 94 | }, 95 | { 96 | "input": "Human: How are you today?", 97 | "output": "Bot: Fine, thank you! And you?", 98 | } 99 | ] 100 | } 101 | dataset = Dataset.create_from_dict(data_dict) 102 | ``` 103 | 104 | You may also save the corresponding format to json, 105 | ```python 106 | import json 107 | from lmflow.args import DatasetArguments 108 | from lmflow.datasets import Dataset 109 | 110 | data_dict = { 111 | "type": "text2text", 112 | "instances": [ 113 | { 114 | "input": "Human: Hello.", 115 | "output": "Bot: Hi!", 116 | }, 117 | { 118 | "input": "Human: How are you today?", 119 | "output": "Bot: Fine, thank you! And you?", 120 | } 121 | ] 122 | } 123 | with open("data.json", "w") as fout: 124 | json.dump(data_dict, fout) 125 | 126 | data_args = DatasetArgument(dataset_path="data.json") 127 | dataset = Dataset(data_args) 128 | new_data_dict = dataset.to_dict() 129 | # `new_data_dict` Should have the same content as `data_dict` 130 | ``` 131 | """ 132 | ).lstrip("\n") 133 | 134 | 135 | FLOAT_ONLY_DATASET_DESCRIPTION = ( 136 | """ 137 | "float_only": a dataset with only float instances, with following format: 138 | 139 | { 140 | "type": "float_only", 141 | "instances": [ 142 | { "value": "FLOAT_1" }, 143 | { "value": "FLOAT_2" }, 144 | ... 145 | ] 146 | } 147 | """ 148 | ).lstrip("\n") 149 | 150 | 151 | TEXT_ONLY_DATASET_LONG_DESCRITION = ( 152 | TEXT_ONLY_DATASET_DESCRIPTION + TEXT_ONLY_DATASET_DETAILS 153 | ) 154 | 155 | TEXT2TEXT_DATASET_LONG_DESCRITION = ( 156 | TEXT2TEXT_DATASET_DESCRIPTION + TEXT2TEXT_DATASET_DETAILS 157 | ) 158 | 159 | 160 | DATASET_DESCRIPTION_MAP = { 161 | "text_only": TEXT_ONLY_DATASET_DESCRIPTION, 162 | "text2text": TEXT2TEXT_DATASET_DESCRIPTION, 163 | "float_only": FLOAT_ONLY_DATASET_DESCRIPTION, 164 | } 165 | 166 | INSTANCE_FIELDS_MAP = { 167 | "text_only": ["text"], 168 | "text2text": ["input", "output"], 169 | "float_only": ["value"], 170 | "image_text": ["images", "text"], 171 | } 172 | -------------------------------------------------------------------------------- /src/lmflow/utils/flash_attention/llama_flash_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | #try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x 12 | try: 13 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 14 | except: 15 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 16 | 17 | from flash_attn.bert_padding import unpad_input, pad_input 18 | 19 | 20 | def forward( 21 | self, 22 | hidden_states: torch.Tensor, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | position_ids: Optional[torch.Tensor] = None, 25 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 26 | output_attentions: bool = False, 27 | use_cache: bool = False, 28 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 29 | """Input shape: Batch x Time x Channel 30 | 31 | attention_mask: [bsz, q_len] 32 | """ 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | # [bsz, q_len, nh, hd] 51 | # [bsz, nh, q_len, hd] 52 | 53 | kv_seq_len = key_states.shape[-2] 54 | assert past_key_value is None, "past_key_value is not supported" 55 | 56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 57 | query_states, key_states = apply_rotary_pos_emb( 58 | query_states, key_states, cos, sin, position_ids 59 | ) 60 | # [bsz, nh, t, hd] 61 | assert not output_attentions, "output_attentions is not supported" 62 | assert not use_cache, "use_cache is not supported" 63 | 64 | # Flash attention codes from 65 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 66 | 67 | # transform the data into the format required by flash attention 68 | qkv = torch.stack( 69 | [query_states, key_states, value_states], dim=2 70 | ) # [bsz, nh, 3, q_len, hd] 71 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 72 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 73 | # the attention_mask should be the same as the key_padding_mask 74 | key_padding_mask = attention_mask 75 | 76 | if key_padding_mask is None: 77 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 78 | max_s = q_len 79 | cu_q_lens = torch.arange( 80 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 81 | ) 82 | output = flash_attn_unpadded_qkvpacked_func( 83 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 84 | ) 85 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 86 | else: 87 | nheads = qkv.shape[-2] 88 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 89 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 90 | x_unpad = rearrange( 91 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 92 | ) 93 | output_unpad = flash_attn_unpadded_qkvpacked_func( 94 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 95 | ) 96 | output = rearrange( 97 | pad_input( 98 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 99 | ), 100 | "b s (h d) -> b s h d", 101 | h=nheads, 102 | ) 103 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 104 | 105 | 106 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 107 | # requires the attention mask to be the same as the key_padding_mask 108 | def _prepare_decoder_attention_mask( 109 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 110 | ): 111 | # [bsz, seq_len] 112 | return attention_mask 113 | 114 | 115 | def replace_llama_attn_with_flash_attn(): 116 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 117 | _prepare_decoder_attention_mask 118 | ) 119 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward -------------------------------------------------------------------------------- /scripts/eval_pairrm/run_eval_rate_pairrm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--load_path", type=str, default="") 8 | parser.add_argument("--gpu", type=int, default=-1, help="gpu id") 9 | parser.add_argument("--save_path", default="alpaca_eval2_pairrm.csv", type=str, help="save path") 10 | args = parser.parse_args() 11 | 12 | if args.gpu >= 0: 13 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 14 | 15 | from llm_blender.pair_ranker.pairrm import DebertaV2PairRM 16 | from transformers import AutoTokenizer 17 | from typing import List 18 | import numpy as np 19 | import torch 20 | from tqdm.auto import tqdm, trange 21 | import pandas as pd 22 | 23 | 24 | pairrm = DebertaV2PairRM.from_pretrained("llm-blender/PairRM-hf", device_map="cuda:7").eval() 25 | tokenizer = AutoTokenizer.from_pretrained('llm-blender/PairRM-hf') 26 | source_prefix = "<|source|>" 27 | cand1_prefix = "<|candidate1|>" 28 | cand2_prefix = "<|candidate2|>" 29 | 30 | 31 | def tokenize_pair(sources: List[str], candidate1s: List[str], candidate2s: List[str], source_max_length=1224, 32 | candidate_max_length=412): 33 | ids = [] 34 | assert len(sources) == len(candidate1s) == len(candidate2s) 35 | max_length = source_max_length + 2 * candidate_max_length 36 | for i in range(len(sources)): 37 | source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True) 38 | candidate_max_length = (max_length - len(source_ids)) // 2 39 | candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, 40 | truncation=True) 41 | candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, 42 | truncation=True) 43 | ids.append(source_ids + candidate1_ids + candidate2_ids) 44 | encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length) 45 | return encodings 46 | 47 | 48 | 49 | 50 | 51 | def evaluate(args, sys_prompt_type, eval_data_type, ref_type): 52 | args.model_name = args.load_path.split("/")[-1].split(".json")[0] 53 | df_candidate = pd.read_json(args.load_path) 54 | assert eval_data_type in ["alpaca", "hh_rlhf"] 55 | if eval_data_type == "alpaca": 56 | if ref_type == 'gpt4': 57 | ref_path = "path/to/alpaca_eval_gpt4_baseline.json" 58 | elif eval_data_type == "hh_rlhf": 59 | ref_path = "path/to/rlhf/rlhf_eval_ref/pairrm_2k.json" 60 | # make sure the order of "instruction" is the same 61 | print(f"Reference Data: {ref_path}") 62 | df_reference = pd.read_json(ref_path) 63 | print(len(df_reference['instruction']), len(df_candidate['instruction'])) 64 | print(len(df_reference['instruction'][0]), len(df_candidate['instruction'][0])) 65 | assert (df_reference['instruction'] == df_candidate['instruction']).all() 66 | prompts = df_reference['instruction'].values 67 | responses_A = df_candidate['output'].values 68 | responses_B = df_reference['output'].values 69 | batch_size = 16 70 | n_batches = len(prompts) // batch_size + 1 71 | all_batch_idxes = np.array_split(np.arange(len(prompts)), n_batches) 72 | with torch.no_grad(): 73 | pairrm.eval() 74 | comparison_results = [] 75 | for i in trange(n_batches,desc='batch',leave=False): 76 | batch_idxes = all_batch_idxes[i] 77 | encodings = tokenize_pair(prompts[batch_idxes], responses_A[batch_idxes], responses_B[batch_idxes]) 78 | encodings = {k: v.to(pairrm.device) for k, v in encodings.items()} 79 | outputs = pairrm(**encodings) 80 | logits = outputs.logits.tolist() 81 | comparison_results.append(outputs.logits > 0) 82 | comparison_results = torch.cat(comparison_results).cpu().numpy() 83 | win_rate = comparison_results.mean() 84 | avg_length = np.mean([len(x) for x in responses_A]) 85 | 86 | row = {"model": args.model_name, "win_rate": win_rate, 'avg_length': avg_length, 'reference': ref_type, "judge": "pairrm",} 87 | # append to the result file 88 | df_result = pd.DataFrame([row]) 89 | 90 | df_result.to_csv(args.save_path, index=False, mode="a", header=not os.path.exists(args.save_path)) 91 | 92 | project_dir = "path/to/result" 93 | sys_prompt_type, eval_data_type, ref_type = [1, "alpaca", "gpt4"] #[1, "hh_rlhf", "hh_rlhf"] 94 | keys = ["experiment_tags"] 95 | for key in keys: 96 | load_paths = [ 97 | os.path.join(project_dir, f"{key}/pairrm_v{sys_prompt_type}_{eval_data_type}.json") 98 | ] 99 | args.save_path = os.path.join(project_dir, f"{key}/pairrm_v{sys_prompt_type}_{eval_data_type}_{ref_type}_results.csv") 100 | for load_path in tqdm(load_paths, desc='file'): 101 | print('Load path:', load_path) 102 | args.load_path = load_path 103 | evaluate(args, sys_prompt_type, eval_data_type, ref_type) -------------------------------------------------------------------------------- /examples/raft_align.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | #!/usr/bin/env python 6 | # coding=utf-8 7 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 8 | """Alignment tuning example, such as RLHF.""" 9 | 10 | import logging 11 | import os 12 | import sys 13 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 14 | from dataclasses import dataclass, field 15 | from typing import Optional 16 | import torch 17 | from transformers import HfArgumentParser, pipeline, AutoTokenizer 18 | 19 | from lmflow.args import ( 20 | ModelArguments, 21 | DatasetArguments, 22 | AutoArguments, 23 | ) 24 | 25 | from lmflow.datasets.dataset import Dataset 26 | from lmflow.models.auto_model import AutoModel 27 | from lmflow.pipeline.auto_pipeline import AutoPipeline 28 | 29 | import warnings 30 | 31 | @dataclass 32 | class RewardArguments: 33 | reward_type: Optional[str] = field( 34 | default="hf_pipeline", 35 | metadata={ 36 | "help": ( 37 | "type of reward model, support huggingface pipeline. Will" 38 | " support \"customized\" torch.nn.modules in the future." 39 | ), 40 | }, 41 | ) 42 | reward_model_or_path: Optional[str] = field( 43 | default=None, 44 | metadata={ 45 | "help": ( 46 | "reward model name (huggingface) or its path" 47 | ), 48 | }, 49 | ) 50 | reward_task: Optional[str] = field( 51 | default="sentiment-analysis", 52 | metadata={ 53 | "help": "type of reward task, such as sentiment-analysis, detoxic." 54 | }, 55 | ) 56 | reward_model_args: Optional[str] = field( 57 | default="return_all_scores=True, function_to_apply=\"none\", batch_size=1", 58 | metadata={ 59 | "help": ( 60 | "extra arguments required by different type of reward models." 61 | ), 62 | }, 63 | ) 64 | 65 | 66 | def get_reward_function(reward_args, pipeline_args): 67 | if reward_args.reward_model_or_path is None: 68 | warnings.warn("No reward model is provided.") 69 | return None 70 | args = reward_args 71 | reward_type = args.reward_type 72 | 73 | if reward_type == "hf_pipeline": 74 | rm_tokenizer = AutoTokenizer.from_pretrained(reward_args.reward_model_or_path) 75 | hf_pipe = pipeline( 76 | reward_args.reward_task, 77 | model=reward_args.reward_model_or_path, 78 | device=f"cuda:{pipeline_args.local_rank}", 79 | tokenizer=rm_tokenizer, 80 | model_kwargs={"torch_dtype": torch.bfloat16} 81 | ) 82 | def reward_func(dataset: Dataset): 83 | if dataset.type != "text_only": 84 | raise NotImplementedError( 85 | "reward function only accept \"text_only\" datasets" 86 | ) 87 | pipe_kwargs = { 88 | "return_all_scores": True, 89 | "function_to_apply": "none", 90 | "batch_size": 1 91 | } 92 | 93 | data_dict = dataset.to_dict() 94 | texts_for_rewards = [ 95 | sample["text"] for sample in data_dict["instances"] 96 | ] 97 | pipe_outputs = hf_pipe(texts_for_rewards, **pipe_kwargs) 98 | rewards = [output[0]["score"] for output in pipe_outputs] 99 | 100 | reward_dataset = Dataset.create_from_dict({ 101 | "type": "float_only", 102 | "instances": [ 103 | { "value": reward } for reward in rewards 104 | ] 105 | }) 106 | return reward_dataset 107 | 108 | return reward_func 109 | else: 110 | raise NotImplementedError("unsupported reward type \"{reward_type}\"") 111 | 112 | def main(): 113 | # Parses arguments 114 | pipeline_name = "raft_aligner" 115 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 116 | 117 | parser = HfArgumentParser(( 118 | ModelArguments, 119 | DatasetArguments, 120 | PipelineArguments, 121 | RewardArguments, 122 | )) 123 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 124 | model_args, data_args, pipeline_args, reward_args = parser.parse_json_file( 125 | json_file=os.path.abspath(sys.argv[1]) 126 | ) 127 | else: 128 | model_args, data_args, pipeline_args, reward_args = parser.parse_args_into_dataclasses() 129 | 130 | # Initializes pipeline, dataset and model for reward training 131 | aligner = AutoPipeline.get_pipeline( 132 | pipeline_name=pipeline_name, 133 | model_args=model_args, 134 | data_args=data_args, 135 | pipeline_args=pipeline_args, 136 | ) 137 | print(data_args) 138 | dataset = Dataset(data_args) 139 | model = AutoModel.get_model(model_args) 140 | 141 | # Initializes reward function 142 | reward_function = get_reward_function(reward_args, pipeline_args) 143 | 144 | reward_model_args = ModelArguments(arch_type="text_regression") 145 | reward_model = AutoModel.get_model(reward_model_args) 146 | reward_model.register_inference_function(reward_function) 147 | 148 | # Aligns model with rewards 149 | aligned_model = aligner.align( 150 | model=model, 151 | dataset=dataset, 152 | reward_model=reward_model, 153 | ) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() -------------------------------------------------------------------------------- /examples/raft_align_eval.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | #!/usr/bin/env python 6 | # coding=utf-8 7 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 8 | """Alignment tuning example, such as RLHF.""" 9 | 10 | import logging 11 | import os 12 | import sys 13 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 14 | from dataclasses import dataclass, field 15 | from typing import Optional 16 | import torch 17 | from transformers import HfArgumentParser, pipeline, AutoTokenizer 18 | 19 | from lmflow.args import ( 20 | ModelArguments, 21 | DatasetArguments, 22 | AutoArguments, 23 | ) 24 | 25 | from lmflow.datasets.dataset import Dataset 26 | from lmflow.models.auto_model import AutoModel 27 | from lmflow.pipeline.auto_pipeline import AutoPipeline 28 | 29 | import warnings 30 | 31 | @dataclass 32 | class RewardArguments: 33 | reward_type: Optional[str] = field( 34 | default="hf_pipeline", 35 | metadata={ 36 | "help": ( 37 | "type of reward model, support huggingface pipeline. Will" 38 | " support \"customized\" torch.nn.modules in the future." 39 | ), 40 | }, 41 | ) 42 | reward_model_or_path: Optional[str] = field( 43 | default=None, 44 | metadata={ 45 | "help": ( 46 | "reward model name (huggingface) or its path" 47 | ), 48 | }, 49 | ) 50 | reward_task: Optional[str] = field( 51 | default="sentiment-analysis", 52 | metadata={ 53 | "help": "type of reward task, such as sentiment-analysis, detoxic." 54 | }, 55 | ) 56 | reward_model_args: Optional[str] = field( 57 | default="return_all_scores=True, function_to_apply=\"none\", batch_size=1", 58 | metadata={ 59 | "help": ( 60 | "extra arguments required by different type of reward models." 61 | ), 62 | }, 63 | ) 64 | 65 | 66 | def get_reward_function(reward_args, pipeline_args): 67 | if reward_args.reward_model_or_path is None: 68 | warnings.warn("No reward model is provided.") 69 | return None 70 | args = reward_args 71 | reward_type = args.reward_type 72 | 73 | if reward_type == "hf_pipeline": 74 | rm_tokenizer = AutoTokenizer.from_pretrained(reward_args.reward_model_or_path) 75 | hf_pipe = pipeline( 76 | reward_args.reward_task, 77 | model=reward_args.reward_model_or_path, 78 | device=f"cuda:{pipeline_args.local_rank}", 79 | tokenizer=rm_tokenizer, 80 | model_kwargs={"torch_dtype": torch.bfloat16} 81 | ) 82 | def reward_func(dataset: Dataset): 83 | if dataset.type != "text_only": 84 | raise NotImplementedError( 85 | "reward function only accept \"text_only\" datasets" 86 | ) 87 | pipe_kwargs = { 88 | "return_all_scores": True, 89 | "function_to_apply": "none", 90 | "batch_size": 1 91 | } 92 | 93 | data_dict = dataset.to_dict() 94 | texts_for_rewards = [ 95 | sample["text"] for sample in data_dict["instances"] 96 | ] 97 | pipe_outputs = hf_pipe(texts_for_rewards, **pipe_kwargs) 98 | rewards = [output[0]["score"] for output in pipe_outputs] 99 | 100 | reward_dataset = Dataset.create_from_dict({ 101 | "type": "float_only", 102 | "instances": [ 103 | { "value": reward } for reward in rewards 104 | ] 105 | }) 106 | return reward_dataset 107 | 108 | return reward_func 109 | else: 110 | raise NotImplementedError("unsupported reward type \"{reward_type}\"") 111 | 112 | def main(): 113 | # Parses arguments 114 | pipeline_name = "raft_aligner_eval" 115 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 116 | 117 | parser = HfArgumentParser(( 118 | ModelArguments, 119 | DatasetArguments, 120 | PipelineArguments, 121 | RewardArguments, 122 | )) 123 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 124 | model_args, data_args, pipeline_args, reward_args = parser.parse_json_file( 125 | json_file=os.path.abspath(sys.argv[1]) 126 | ) 127 | else: 128 | model_args, data_args, pipeline_args, reward_args = parser.parse_args_into_dataclasses() 129 | 130 | # Initializes pipeline, dataset and model for reward training 131 | aligner = AutoPipeline.get_pipeline( 132 | pipeline_name=pipeline_name, 133 | model_args=model_args, 134 | data_args=data_args, 135 | pipeline_args=pipeline_args, 136 | ) 137 | print(data_args) 138 | dataset = Dataset(data_args) 139 | model = AutoModel.get_model(model_args) 140 | 141 | # Initializes reward function 142 | reward_function = get_reward_function(reward_args, pipeline_args) 143 | 144 | reward_model_args = ModelArguments(arch_type="text_regression") 145 | reward_model = AutoModel.get_model(reward_model_args) 146 | reward_model.register_inference_function(reward_function) 147 | 148 | # Aligns model with rewards 149 | aligned_model = aligner.align( 150 | model=model, 151 | dataset=dataset, 152 | reward_model=reward_model, 153 | ) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() -------------------------------------------------------------------------------- /scripts/data_preprocess/concat_shuffle_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """ 5 | This script is designed for handling large datasets. 6 | It merges multiple datasets located in the same directory, shuffles them, and splits them into training, evaluation, and testing sets. 7 | The training set is further divided into 10 folds. 8 | """ 9 | from __future__ import absolute_import 10 | 11 | import argparse 12 | import json 13 | import textwrap 14 | import sys 15 | import os 16 | import random 17 | import gc 18 | 19 | def parse_argument(sys_argv): 20 | """Parses arguments from command line. 21 | Args: 22 | sys_argv: the list of arguments (strings) from command line. 23 | Returns: 24 | A struct whose member corresponds to the required (optional) variable. 25 | For example, 26 | ``` 27 | args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10']) 28 | args.input # 'a.txt' 29 | args.num # 10 30 | ``` 31 | """ 32 | parser = argparse.ArgumentParser( 33 | formatter_class=argparse.RawTextHelpFormatter) 34 | 35 | # Training parameters 36 | parser.add_argument( 37 | "--output_path", type=str, 38 | default=None, 39 | help=textwrap.dedent("output dataset path, writes to stdout by default") 40 | ) 41 | parser.add_argument( 42 | "--merge_from_path", type=str, 43 | nargs="+", 44 | help=textwrap.dedent( 45 | "dataset path of the extra dataset that will be merged" 46 | " into input dataset" 47 | ) 48 | ) 49 | parser.add_argument( 50 | "--seed", type=int, default=42, 51 | help=textwrap.dedent("pseudorandom seed") 52 | ) 53 | parser.add_argument( 54 | "--eval_size", type=int, default=200, 55 | help=textwrap.dedent("size of eval dataset") 56 | ) 57 | parser.add_argument( 58 | "--test_size", type=int, default=1000, 59 | help=textwrap.dedent("size of test dataset") 60 | ) 61 | parser.add_argument( 62 | "--k", type=int, default=10, 63 | help=textwrap.dedent("the train dataset will be divide into k folds") 64 | ) 65 | # Parses from commandline 66 | args = parser.parse_args(sys_argv[1:]) 67 | 68 | return args 69 | 70 | 71 | def main(): 72 | args = parse_argument(sys.argv) 73 | 74 | # concat 75 | if args.merge_from_path is not None: 76 | for i in range(0, len(args.merge_from_path)): 77 | with open(args.merge_from_path[i], "r") as fin: 78 | extra_data_dict = json.load(fin) 79 | if i == 0: 80 | data_dict = extra_data_dict 81 | else: 82 | if data_dict["type"] != extra_data_dict["type"]: 83 | raise ValueError( 84 | 'two dataset have different types:' 85 | f' input dataset: "{data_dict["type"]}";' 86 | f' merge from dataset: "{extra_data_dict["type"]}"' 87 | ) 88 | data_dict["instances"].extend(extra_data_dict["instances"]) 89 | else: 90 | raise ValueError("No merge files specified") 91 | del extra_data_dict 92 | gc.collect() 93 | print('finish concat') 94 | 95 | # shuffle 96 | random.seed(args.seed) 97 | random.shuffle(data_dict["instances"]) 98 | print('finish shuffle') 99 | # split to train, eval, test 100 | train_data_dict = {"type":data_dict["type"],"instances":data_dict["instances"][args.eval_size:-args.test_size]} 101 | eval_data_dict = {"type":data_dict["type"],"instances":data_dict["instances"][:args.eval_size]} 102 | test_data_dict = {"type":data_dict["type"],"instances":data_dict["instances"][-args.test_size:]} 103 | del data_dict 104 | gc.collect() 105 | 106 | # divide train in 10 folds 107 | num_instances = len(train_data_dict["instances"]) 108 | split_size = num_instances // args.k 109 | split_data = [] 110 | for i in range(args.k): 111 | if i < args.k-1: 112 | split = train_data_dict["instances"][i*split_size : (i+1)*split_size] 113 | else: 114 | # Last split may have remaining instances 115 | split = train_data_dict["instances"][i*split_size:] 116 | split_data.append({'type': train_data_dict["type"], 'instances': split}) 117 | 118 | del train_data_dict 119 | gc.collect() 120 | 121 | print('finish split') 122 | # save dataset under output_path 123 | 124 | if args.output_path is None: 125 | args.output_path = sys.stdout 126 | 127 | train_save_path=os.path.join(args.output_path,"train_{k}_folds".format(k=args.k)) 128 | if not os.path.exists(train_save_path): 129 | os.makedirs(train_save_path) 130 | for i in range(args.k): 131 | with open(train_save_path+"/train_"+str(i)+".json", 'w') as f: 132 | json.dump(split_data[i], f, indent=4, ensure_ascii=False) 133 | 134 | eval_save_path=os.path.join(args.output_path,"eval") 135 | if not os.path.exists(eval_save_path): 136 | os.makedirs(eval_save_path) 137 | with open(eval_save_path+'/eval.json','w') as f: 138 | json.dump(eval_data_dict,f,indent=4,ensure_ascii=False) 139 | 140 | test_save_path=os.path.join(args.output_path,"test") 141 | if not os.path.exists(test_save_path): 142 | os.makedirs(test_save_path) 143 | with open(test_save_path+'/test.json','w') as f: 144 | json.dump(test_data_dict,f,indent=4,ensure_ascii=False) 145 | 146 | 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /src/lmflow/utils/flash_attention/gpt2_flash_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | #try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x 12 | try: 13 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 14 | except: 15 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 16 | 17 | from flash_attn.bert_padding import unpad_input, pad_input 18 | 19 | 20 | def forward( 21 | self, 22 | hidden_states: Optional[Tuple[torch.FloatTensor]], 23 | layer_past: Optional[Tuple[torch.Tensor]] = None, 24 | attention_mask: Optional[torch.FloatTensor] = None, 25 | head_mask: Optional[torch.FloatTensor] = None, 26 | encoder_hidden_states: Optional[torch.Tensor] = None, 27 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 28 | use_cache: Optional[bool] = False, 29 | output_attentions: Optional[bool] = False, 30 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 31 | 32 | 33 | if encoder_hidden_states is not None: 34 | if not hasattr(self, "q_attn"): 35 | raise ValueError( 36 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 37 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 38 | ) 39 | 40 | query = self.q_attn(hidden_states) 41 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 42 | attention_mask = encoder_attention_mask 43 | else: 44 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 45 | 46 | bsz, q_len, _ = hidden_states.size() 47 | 48 | query = self._split_heads(query, self.num_heads, self.head_dim) 49 | key = self._split_heads(key, self.num_heads, self.head_dim) 50 | value = self._split_heads(value, self.num_heads, self.head_dim) 51 | 52 | #TODO Should we support? 53 | if layer_past is not None: 54 | past_key, past_value = layer_past 55 | key = torch.cat((past_key, key), dim=-2) 56 | value = torch.cat((past_value, value), dim=-2) 57 | 58 | assert use_cache is False, "Use cache is not supported" 59 | present = None 60 | # if use_cache is True: 61 | # present = (key, value) 62 | # else: 63 | # present = None 64 | 65 | assert self.reorder_and_upcast_attn is False, "reorder_and_upcast_attn is not supported yet" 66 | 67 | qkv = torch.stack([query, key, value], dim = 2) 68 | qkv = qkv.transpose(1, 3) # [bsz, seq_len, 3, heads, hiddens_per_head] 69 | 70 | # breakpoint() 71 | key_padding_mask = attention_mask 72 | # key_padding_mask = None 73 | # breakpoint() 74 | if key_padding_mask is None: 75 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 76 | max_s = q_len 77 | cu_q_lens = torch.arange( 78 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 79 | ) 80 | output = flash_attn_unpadded_qkvpacked_func( 81 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 82 | ) 83 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 84 | else: 85 | # flip in flash attention 86 | key_padding_mask = key_padding_mask.clone() 87 | key_padding_mask = (1.0 - key_padding_mask) 88 | key_padding_mask = key_padding_mask.squeeze(1).squeeze(1) 89 | nheads = qkv.shape[-2] 90 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 91 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 92 | x_unpad = rearrange( 93 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 94 | ) 95 | output_unpad = flash_attn_unpadded_qkvpacked_func( 96 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 97 | ) 98 | output = rearrange( 99 | pad_input( 100 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 101 | ), 102 | "b s (h d) -> b s h d", 103 | h=nheads, 104 | ) 105 | # if self.reorder_and_upcast_attn: 106 | # attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 107 | # else: 108 | # attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 109 | output = rearrange(output, 'b s h d -> b h s d') 110 | attn_output = self._merge_heads(output, self.num_heads, self.head_dim) 111 | attn_output = self.c_proj(attn_output) 112 | attn_output = self.resid_dropout(attn_output) 113 | 114 | outputs = (attn_output, present) 115 | 116 | assert output_attentions is False, "output attentions is not supported yet" 117 | # if output_attentions: 118 | # outputs += (attn_weights,) 119 | 120 | return outputs # a, present, (attentions) 121 | 122 | 123 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 124 | # requires the attention mask to be the same as the key_padding_mask 125 | def _prepare_decoder_attention_mask( 126 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 127 | ): 128 | # [bsz, seq_len] 129 | return attention_mask 130 | 131 | 132 | def replace_gpt2_attn_with_flash_attn(): 133 | # transformers.models.gpt2.modeling_gpt2.LlamaModel._prepare_decoder_attention_mask = ( 134 | # _prepare_decoder_attention_mask 135 | # ) 136 | transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward = forward -------------------------------------------------------------------------------- /utils/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the delta weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 6 | """ 7 | import argparse 8 | import gc 9 | import glob 10 | import json 11 | import os 12 | import shutil 13 | import tempfile 14 | 15 | from huggingface_hub import snapshot_download 16 | import torch 17 | from torch import nn 18 | from tqdm import tqdm 19 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 20 | 21 | 22 | GB = 1 << 30 23 | 24 | 25 | def split_files(model_path, tmp_path, split_size): 26 | if not os.path.exists(model_path): 27 | model_path = snapshot_download(repo_id=model_path) 28 | if not os.path.exists(tmp_path): 29 | os.makedirs(tmp_path) 30 | 31 | file_pattern = os.path.join(model_path, "pytorch_model-*.bin") 32 | files = glob.glob(file_pattern) 33 | 34 | part = 0 35 | try: 36 | for file_path in tqdm(files): 37 | state_dict = torch.load(file_path) 38 | new_state_dict = {} 39 | 40 | current_size = 0 41 | for name, param in state_dict.items(): 42 | param_size = param.numel() * param.element_size() 43 | 44 | if current_size + param_size > split_size: 45 | new_file_name = f"pytorch_model-{part}.bin" 46 | new_file_path = os.path.join(tmp_path, new_file_name) 47 | torch.save(new_state_dict, new_file_path) 48 | current_size = 0 49 | new_state_dict = None 50 | gc.collect() 51 | new_state_dict = {} 52 | part += 1 53 | 54 | new_state_dict[name] = param 55 | current_size += param_size 56 | 57 | new_file_name = f"pytorch_model-{part}.bin" 58 | new_file_path = os.path.join(tmp_path, new_file_name) 59 | torch.save(new_state_dict, new_file_path) 60 | new_state_dict = None 61 | gc.collect() 62 | new_state_dict = {} 63 | part += 1 64 | except Exception as e: 65 | print(f"An error occurred during split_files: {e}") 66 | shutil.rmtree(tmp_path) 67 | raise 68 | 69 | 70 | def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): 71 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 72 | delta_config = AutoConfig.from_pretrained(delta_path) 73 | 74 | if os.path.exists(target_model_path): 75 | shutil.rmtree(target_model_path) 76 | os.makedirs(target_model_path) 77 | 78 | split_size = 4 * GB 79 | 80 | with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: 81 | print(f"Split files for the base model to {tmp_base_path}") 82 | split_files(base_model_path, tmp_base_path, split_size) 83 | print(f"Split files for the delta weights to {tmp_delta_path}") 84 | split_files(delta_path, tmp_delta_path, split_size) 85 | 86 | base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") 87 | base_files = glob.glob(base_pattern) 88 | delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") 89 | delta_files = glob.glob(delta_pattern) 90 | delta_state_dict = torch.load(delta_files[0]) 91 | 92 | print("Applying the delta") 93 | weight_map = {} 94 | total_size = 0 95 | 96 | for i, base_file in tqdm(enumerate(base_files)): 97 | state_dict = torch.load(base_file) 98 | file_name = f"pytorch_model-{i}.bin" 99 | for name, param in state_dict.items(): 100 | if name not in delta_state_dict: 101 | for delta_file in delta_files: 102 | delta_state_dict = torch.load(delta_file) 103 | gc.collect() 104 | if name in delta_state_dict: 105 | break 106 | 107 | state_dict[name] += delta_state_dict[name] 108 | weight_map[name] = file_name 109 | total_size += param.numel() * param.element_size() 110 | gc.collect() 111 | torch.save(state_dict, os.path.join(target_model_path, file_name)) 112 | 113 | with open( 114 | os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" 115 | ) as f: 116 | json.dump( 117 | {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f 118 | ) 119 | 120 | print(f"Saving the target model to {target_model_path}") 121 | delta_tokenizer.save_pretrained(target_model_path) 122 | delta_config.save_pretrained(target_model_path) 123 | 124 | 125 | def apply_delta(base_model_path, target_model_path, delta_path): 126 | print(f"Loading the delta weights from {delta_path}") 127 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 128 | delta = AutoModelForCausalLM.from_pretrained( 129 | delta_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True 130 | ) 131 | 132 | print(f"Loading the base model from {base_model_path}") 133 | base = AutoModelForCausalLM.from_pretrained( 134 | base_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True 135 | ) 136 | 137 | print("Applying the delta") 138 | for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): 139 | assert name in delta.state_dict() 140 | param.data += delta.state_dict()[name] 141 | 142 | print(f"Saving the target model to {target_model_path}") 143 | base.save_pretrained(target_model_path) 144 | delta_tokenizer.save_pretrained(target_model_path) 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument("--base-model-path", type=str, required=True) 150 | parser.add_argument("--target-model-path", type=str, required=True) 151 | parser.add_argument("--delta-path", type=str, required=True) 152 | parser.add_argument( 153 | "--low-cpu-mem", 154 | action="store_true", 155 | help="Lower the cpu memory usage. This will split large files and use " 156 | "disk as swap to reduce the memory usage below 10GB.", 157 | ) 158 | args = parser.parse_args() 159 | 160 | if args.low_cpu_mem: 161 | apply_delta_low_cpu_mem( 162 | args.base_model_path, args.target_model_path, args.delta_path 163 | ) 164 | else: 165 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 166 | -------------------------------------------------------------------------------- /scripts/postprocess/weight_mask_merge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. 4 | """A one-line summary of the module or program, terminated by a period. 5 | 6 | Leave one blank line. The rest of this docstring should contain an 7 | overall description of the module or program. Optionally, it may also 8 | contain a brief description of exported classes and functions and/or usage 9 | examples. 10 | 11 | Typical usage example: 12 | 13 | foo = ClassFoo() 14 | bar = foo.FunctionBar() 15 | """ 16 | 17 | import json 18 | import os 19 | import sys 20 | sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) 21 | from transformers import HfArgumentParser 22 | import torch.nn.functional as F 23 | from lmflow.datasets.dataset import Dataset 24 | from lmflow.pipeline.auto_pipeline import AutoPipeline 25 | from lmflow.models.auto_model import AutoModel 26 | from lmflow.args import ModelArguments, DatasetArguments, AutoArguments 27 | import torch 28 | import torch.nn as nn 29 | 30 | pipeline_name = "evaluator" 31 | PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) 32 | 33 | parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments)) 34 | model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() 35 | 36 | # Get the paths and ratios of weight-ensamble models. 37 | weight_ensamble_names_paths = pipeline_args.weight_ensamble_names_paths 38 | weight_ensamble_save_path = pipeline_args.weight_ensamble_save_path 39 | alphas_path = pipeline_args.alphas_path 40 | 41 | print('Model Paths:', weight_ensamble_names_paths) 42 | print('Alphas Paths:', alphas_path) 43 | 44 | with open (pipeline_args.deepspeed, "r") as f: 45 | ds_config = json.load(f) 46 | 47 | # base_model = AutoModel.get_model( 48 | # model_args, 49 | # tune_strategy='none', 50 | # ds_config=ds_config, 51 | # use_accelerator=pipeline_args.use_accelerator_for_evaluator 52 | # ) 53 | 54 | # base_backend_model = base_model.get_backend_model() 55 | # print('Finish load base model:', base_model) 56 | # Load models. 57 | # base_model = None 58 | # backend_models = [] 59 | # merge_model_path = weight_ensamble_names_paths[1] 60 | # merge_ckpt = torch.load(os.path.join(merge_model_path, 'pytorch_model.bin')) 61 | # merge_method = 'linear' 62 | 63 | # if 'graft' in merge_model_path: 64 | # merge_method = 'graft' 65 | # print(merge_ckpt) 66 | 67 | def load_model(model_path): 68 | model_args.model_name_or_path = model_path 69 | print('loading:', model_path) 70 | model = AutoModel.get_model( 71 | model_args, 72 | tune_strategy='none', 73 | ds_config=ds_config, 74 | use_accelerator=pipeline_args.use_accelerator_for_evaluator 75 | ) 76 | backend_model = model.get_backend_model().to('cpu') 77 | model = model 78 | print('Finish load base model:', model_path) 79 | return backend_model, model 80 | 81 | ## Load base model 82 | base_backend_model, base_model = load_model(weight_ensamble_names_paths[0]) 83 | 84 | ## Load ft model 85 | ft_backend_model, ft_model = load_model(weight_ensamble_names_paths[1]) 86 | 87 | def update_by_wise_norm_sigmoid_linear(init_val, epsilon, w_type, key, mask_alphas, normalized_alphas, 88 | base_state_dicts, ft_state_dicts, updated_state_dicts): 89 | """ 90 | w_type: weight type ['weight','bias'] 91 | key: linear model key 92 | """ 93 | 94 | # selected_keys = ['model.layers.0.self_attn.k_proj', 'model.layers.15.self_attn.k_proj', 'model.layers.20.self_attn.k_proj'] 95 | if (key + f'.{w_type}') not in base_state_dicts: 96 | return 97 | base_weight = base_state_dicts[key + f'.{w_type}'] 98 | ft_weight = ft_state_dicts[key + f'.{w_type}'] 99 | if 'lm_head' in key: 100 | wise_alpha = torch.tensor(init_val).to(device) 101 | else: 102 | mask_alpha = mask_alphas[key].to(device) 103 | wise_alpha = torch.sigmoid(mask_alpha) + epsilon 104 | normalized_alphas = [torch.sigmoid((normalized_alphas[k_]).to(device)) + epsilon for k_ in normalized_alphas ] 105 | wise_alpha = init_val * wise_alpha / sum(normalized_alphas) * len(normalized_alphas) 106 | print('Update w', key, wise_alpha, 'eps', epsilon) 107 | updated_weight = base_weight * wise_alpha + (1 - wise_alpha) * ft_weight 108 | updated_state_dicts[key + f'.{w_type}'] = updated_weight 109 | return updated_weight 110 | 111 | def _get_submodules(model, key): 112 | parent = model.get_submodule(".".join(key.split(".")[:-1])) 113 | target_name = key.split(".")[-1] 114 | target = model.get_submodule(key) 115 | return parent, target, target_name 116 | 117 | updated_state_dicts = {} 118 | base_state_dicts = base_backend_model.state_dict() 119 | ft_state_dicts = ft_backend_model.state_dict() 120 | device=torch.device('cpu') 121 | merge_method = 'mask_norm_sigmoid_linear' 122 | ## @TODO Potential Extension 123 | if 'mask_norm_sigmoid_linear' in alphas_path: 124 | merge_method = 'mask_norm_sigmoid_linear' 125 | 126 | wise_data = torch.load(os.path.join(alphas_path, 'mask_alphas.bin')) 127 | assert 'init_val' in wise_data, 'There should be a init val.' 128 | init_val = wise_data['init_val'] 129 | mask_alphas = wise_data['mask_alphas'] 130 | 131 | base_gammas = {} 132 | if 'base_gammas' in wise_data: 133 | base_gammas = wise_data['base_gammas'] 134 | 135 | epsilon = 1e-6 136 | if 'epsilon' in wise_data: 137 | epsilon = wise_data['epsilon'] 138 | 139 | normalized_alphas = None 140 | if 'mask_bkp_alphas' in wise_data: 141 | normalized_alphas = wise_data['mask_bkp_alphas'] 142 | 143 | print('Mask alphas:', mask_alphas) 144 | print(f'Merge Method: {merge_method}, Init val:{init_val}, Epsilon:{epsilon}, Normalized alphas:{normalized_alphas}.') 145 | 146 | key_list = [key for key, _ in base_backend_model.named_modules()] 147 | for key in key_list: 148 | parent0, target0, target0_name = _get_submodules(base_backend_model, key) 149 | key_terms = key.split('.') 150 | if isinstance(target0, nn.Linear): 151 | for w_type in ['weight', 'bias']: 152 | if merge_method == 'mask_norm_sigmoid_linear': 153 | if 'lm_head' in key: 154 | if 'final0' in alphas_path: 155 | init_val = 0 156 | update_by_wise_norm_sigmoid_linear(init_val, epsilon, w_type, key, mask_alphas, normalized_alphas, 157 | base_state_dicts, ft_state_dicts, updated_state_dicts) 158 | else: 159 | print(f'{key} No Merge.') 160 | 161 | for key in base_state_dicts: 162 | if key not in updated_state_dicts: 163 | updated_state_dicts[key] = base_state_dicts[key] 164 | 165 | base_backend_model.load_state_dict(updated_state_dicts) 166 | print(weight_ensamble_save_path) 167 | base_model.save(weight_ensamble_save_path) 168 | -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function main() { 4 | public_server="http://lmflow.org:5000" 5 | if [ $# -lt 1 -o "$1" = "-h" -o "$1" = "--help" ]; then 6 | echo "Usage: bash $(basename $0) dataset_name" 7 | echo "Example: bash $(basename $0) MedMCQA" 8 | echo "Example: bash $(basename $0) all" 9 | fi 10 | 11 | if [ "$1" = "MedMCQA" -o "$1" = "all" ]; then 12 | echo "downloading MedMCQA" 13 | filename='MedMCQA.tar.gz' 14 | wget ${public_server}/${filename} 15 | tar zxvf ${filename} 16 | rm ${filename} 17 | fi 18 | 19 | if [ "$1" = "MedQA-USMLE" -o "$1" = "all" ]; then 20 | echo "downloading MedQA-USMLE" 21 | filename='MedQA-USMLE.tar.gz' 22 | wget ${public_server}/${filename} 23 | tar zxvf ${filename} 24 | rm ${filename} 25 | fi 26 | 27 | if [ "$1" = "ni" ]; then 28 | echo "downloading natural-instructions" 29 | filename='natural-instructions.tar.gz' 30 | wget ${public_server}/${filename} 31 | tar zxvf ${filename} 32 | rm ${filename} 33 | fi 34 | 35 | if [ "$1" = "PubMedQA" -o "$1" = "all" ]; then 36 | echo "downloading PubMedQA" 37 | filename='PubMedQA.tar.gz' 38 | wget ${public_server}/${filename} 39 | tar zxvf ${filename} 40 | rm ${filename} 41 | fi 42 | 43 | if [ "$1" = "example_dataset" -o "$1" = "all" ]; then 44 | echo "downloading example_dataset" 45 | filename='example_dataset.tar.gz' 46 | wget ${public_server}/${filename} 47 | tar zxvf ${filename} 48 | rm ${filename} 49 | fi 50 | 51 | if [ "$1" = "alpaca" -o "$1" = "all" ]; then 52 | echo "downloading alpaca dataset" 53 | filename='alpaca.tar.gz' 54 | wget ${public_server}/${filename} 55 | tar zxvf ${filename} 56 | rm ${filename} 57 | fi 58 | 59 | if [ "$1" = "red_teaming" -o "$1" = "all" ]; then 60 | echo "downloading red_teaming dataset" 61 | filename='red_teaming.tar.gz' 62 | wget ${public_server}/${filename} 63 | tar zxvf ${filename} 64 | rm ${filename} 65 | fi 66 | 67 | if [ "$1" = "wikitext-2-raw-v1" -o "$1" = "all" ]; then 68 | echo "downloading wikitext-2-raw-v1 dataset" 69 | filename='wikitext-2-raw-v1.tar.gz' 70 | wget ${public_server}/${filename} 71 | tar zxvf ${filename} 72 | rm ${filename} 73 | fi 74 | 75 | if [ "$1" = "imdb" -o "$1" = "all" ]; then 76 | echo "downloading imdb dataset" 77 | filename='imdb.tar.gz' 78 | wget ${public_server}/${filename} 79 | tar zxvf ${filename} 80 | rm ${filename} 81 | fi 82 | 83 | if [ "$1" = "wiki_cn" -o "$1" = "all" ]; then 84 | echo "downloading wiki_cn dataset" 85 | filename='wiki_cn.tar.gz' 86 | wget ${public_server}/${filename} 87 | tar zxvf ${filename} 88 | rm ${filename} 89 | fi 90 | 91 | if [ "$1" = "gpt4_zh_eval" -o "$1" = "all" ]; then 92 | echo "downloading gpt4_zh_eval dataset" 93 | filename='gpt4_instruction_zh_eval.tar.gz' 94 | wget ${public_server}/${filename} 95 | tar zxvf ${filename} 96 | rm ${filename} 97 | fi 98 | 99 | if [ "$1" = "multiturn_dialog_eval" -o "$1" = "all" ]; then 100 | echo "downloading multiturn_dialog_eval dataset" 101 | filename='multiturn_dialog_eval.tar.gz' 102 | wget ${public_server}/${filename} 103 | tar zxvf ${filename} 104 | rm ${filename} 105 | fi 106 | 107 | if [ "$1" = "wiki_zh_eval" -o "$1" = "all" ]; then 108 | echo "downloading wiki_zh_eval dataset" 109 | filename='wiki_zh_eval.tar.gz' 110 | wget ${public_server}/${filename} 111 | tar zxvf ${filename} 112 | rm ${filename} 113 | fi 114 | 115 | if [ "$1" = "wiki_en_eval" -o "$1" = "all" ]; then 116 | echo "downloading wiki_en_eval dataset" 117 | filename='wiki_en_eval.tar.gz' 118 | wget ${public_server}/${filename} 119 | tar zxvf ${filename} 120 | rm ${filename} 121 | fi 122 | 123 | if [ "$1" = "wiki_en_eval" -o "$1" = "all" ]; then 124 | echo "downloading wiki_en_eval dataset" 125 | filename='wiki_en_eval.tar.gz' 126 | wget ${public_server}/${filename} 127 | tar zxvf ${filename} 128 | rm ${filename} 129 | fi 130 | 131 | if [ "$1" = "gpt4_en_eval" -o "$1" = "all" ]; then 132 | echo "downloading gpt4_en_eval dataset" 133 | filename='gpt4_instruction_en_eval.tar.gz' 134 | wget ${public_server}/${filename} 135 | tar zxvf ${filename} 136 | rm ${filename} 137 | fi 138 | 139 | if [ "$1" = "common_sense_eval" -o "$1" = "all" ]; then 140 | echo "downloading common_sense_eval dataset" 141 | filename='common_sense_eval.tar.gz' 142 | wget ${public_server}/${filename} 143 | tar zxvf ${filename} 144 | rm ${filename} 145 | fi 146 | 147 | if [ "$1" = "hh_rlhf" -o "$1" = "all" ]; then 148 | echo "downloading hh_rlhf dataset" 149 | filename='hh_rlhf.tar.gz' 150 | wget ${public_server}/${filename} 151 | tar zxvf ${filename} 152 | rm ${filename} 153 | fi 154 | 155 | if [ "$1" = "lmflow_chat_cn_dialog_multiturn_nll_text2text_nosharp" -o "$1" = "all" ]; then 156 | echo "downloading lmflow_chat_cn_dialog_multiturn_nll_text2text_nosharp dataset" 157 | filename='lmflow_chat_cn_dialog_multiturn_nll_text2text_nosharp.tar.gz' 158 | wget ${public_server}/${filename} 159 | tar zxvf ${filename} 160 | rm ${filename} 161 | fi 162 | 163 | if [ "$1" = "lmflow_chat_cn_dialog_multiturn_single_nll_text2text" -o "$1" = "all" ]; then 164 | echo "downloading lmflow_chat_cn_dialog_multiturn_single_nll_text2text dataset" 165 | filename='lmflow_chat_cn_dialog_multiturn_single_nll_text2text.tar.gz' 166 | wget ${public_server}/${filename} 167 | tar zxvf ${filename} 168 | rm ${filename} 169 | fi 170 | 171 | if [ "$1" = "lmflow_chat_en_dialog_multiturn_nll_text2text_nosharp" -o "$1" = "all" ]; then 172 | echo "downloading lmflow_chat_en_dialog_multiturn_nll_text2text_nosharp dataset" 173 | filename='lmflow_chat_en_dialog_multiturn_nll_text2text_nosharp.tar.gz' 174 | wget ${public_server}/${filename} 175 | tar zxvf ${filename} 176 | rm ${filename} 177 | fi 178 | 179 | if [ "$1" = "lmflow_chat_en_dialog_multiturn_single_nll_text2text" -o "$1" = "all" ]; then 180 | echo "downloading lmflow_chat_en_dialog_multiturn_single_nll_text2text dataset" 181 | filename='lmflow_chat_en_dialog_multiturn_single_nll_text2text.tar.gz' 182 | wget ${public_server}/${filename} 183 | tar zxvf ${filename} 184 | rm ${filename} 185 | fi 186 | } 187 | main "$@" 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /src/lmflow/pipeline/utils/continual_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | 4 | import torch 5 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 6 | from transformers import ( 7 | Trainer 8 | ) 9 | 10 | from transformers.trainer import * 11 | 12 | if is_apex_available(): 13 | from apex import amp 14 | 15 | import deepspeed 16 | 17 | 18 | def fisher_matrix_diag_bert_dil(t, train, device, model, criterion, sbatch=20): 19 | # Init 20 | fisher = {} 21 | for n, p in model.named_parameters(): 22 | fisher[n] = 0 * p.data 23 | # Compute 24 | model.train() 25 | 26 | for i in tqdm(range(0, len(train), sbatch), desc='Fisher diagonal', ncols=100, ascii=True): 27 | b = torch.LongTensor(np.arange(i, np.min([i + sbatch, len(train)]))).cuda() 28 | batch = train[b] 29 | batch = [ 30 | bat.to(device) if bat is not None else None for bat in batch] 31 | input_ids, segment_ids, input_mask, targets, _ = batch 32 | 33 | # Forward and backward 34 | model.zero_grad() 35 | output_dict = model.forward(input_ids, segment_ids, input_mask) 36 | output = output_dict['y'] 37 | 38 | loss = criterion(t, output, targets) 39 | loss.backward() 40 | # Get gradients 41 | for n, p in model.named_parameters(): 42 | if p.grad is not None: 43 | fisher[n] += sbatch * p.grad.data.pow(2) 44 | # Mean 45 | for n, _ in model.named_parameters(): 46 | fisher[n] = fisher[n] / len(train) 47 | fisher[n] = torch.autograd.Variable(fisher[n], requires_grad=False) 48 | return fisher 49 | 50 | 51 | approaches = ["kd", "l1", "l2", "swa", "default"] 52 | 53 | class ContinualTrainer(Trainer): 54 | def __init__(self, *args, **kwargs): 55 | super(ContinualTrainer, self).__init__(*args, **kwargs) 56 | training_args = kwargs['args'] 57 | self.approach = training_args.approach 58 | self.alpha = training_args.alpha 59 | if self.approach in ["kd", "l1", "l2"]: 60 | teacher_model = copy.deepcopy(self.model) 61 | self.ds_engine_teacher = deepspeed.initialize(model=teacher_model, config_params="examples/ds_config.json")[0] 62 | self.ds_engine_teacher.module.eval() 63 | for p in self.ds_engine_teacher.module.parameters(): 64 | p.requires_grad = False 65 | self.teacher_params = {n: p for n, p in self.ds_engine_teacher.module.named_parameters()} 66 | # print(self.teacher_params) 67 | self.teacher_model = self.ds_engine_teacher.module 68 | self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") 69 | self.temperature = 1 70 | 71 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 72 | model.train() 73 | inputs = self._prepare_inputs(inputs) 74 | 75 | if is_sagemaker_mp_enabled(): 76 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) 77 | return loss_mb.reduce_mean().detach().to(self.args.device) 78 | 79 | with self.compute_loss_context_manager(): 80 | loss = self.compute_loss(model, inputs) 81 | 82 | if self.args.n_gpu > 1: 83 | loss = loss.mean() # mean() to average on multi-gpu parallel training 84 | 85 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: 86 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` 87 | loss = loss / self.args.gradient_accumulation_steps 88 | 89 | if self.do_grad_scaling: 90 | self.scaler.scale(loss).backward() 91 | elif self.use_apex: 92 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 93 | scaled_loss.backward() 94 | elif self.deepspeed: 95 | # loss gets scaled under gradient_accumulation_steps in deepspeed 96 | loss = self.deepspeed.backward(loss) 97 | else: 98 | loss.backward() 99 | 100 | return loss.detach() 101 | 102 | def compute_loss(self, model, inputs, return_outputs=False): 103 | loss_func = getattr(self, f"compute_loss_{self.approach}") 104 | results = loss_func(model, inputs, return_outputs=return_outputs) 105 | return results 106 | 107 | def compute_loss_kd(self, model, inputs, return_outputs=False): 108 | loss, student_outputs = super().compute_loss(model, inputs, return_outputs=True) 109 | with torch.no_grad(): 110 | teacher_outputs = self.ds_engine_teacher(**inputs) 111 | s_logits, s_hidden_states = student_outputs["logits"], student_outputs.hidden_states 112 | t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs.hidden_states 113 | kd_loss = ( 114 | self.ce_loss_fct( 115 | nn.functional.log_softmax(s_logits / self.temperature, dim=-1), 116 | nn.functional.softmax(t_logits / self.temperature, dim=-1), 117 | ) 118 | * (self.temperature) ** 2 119 | ) 120 | loss = loss + self.alpha * kd_loss 121 | student_outputs["loss"] = loss 122 | return (loss, student_outputs) if return_outputs else loss 123 | 124 | def compute_loss_l2(self, model, inputs, return_outputs=False): 125 | loss, outputs = super().compute_loss(model, inputs, return_outputs=True) 126 | task_reg_loss = 0 127 | param_num = 0 128 | for n, p in model.module.named_parameters(): 129 | if p.requires_grad: 130 | task_reg_loss += ((p - self.teacher_params[n]) ** 2).sum() 131 | param_num += torch.numel(p) 132 | loss = loss + self.alpha * task_reg_loss / param_num * 1e+9 133 | print(loss, 'task', task_reg_loss, 'actual', self.alpha * task_reg_loss / param_num * 1e+9) 134 | outputs["loss"] = loss 135 | return (loss, outputs) if return_outputs else loss 136 | 137 | def compute_loss_l1(self, model, inputs, return_outputs=False): 138 | # print('running l1 loss') 139 | # print(model.state_dict()) 140 | loss, outputs = super().compute_loss(model, inputs, return_outputs=True) 141 | task_reg_loss = 0 142 | param_num = 0 143 | for n, p in model.module.named_parameters(): 144 | if p.requires_grad: 145 | # if len(p.size()) > 0 and p.size()[0] > 0: 146 | # print(p.size(), self.teacher_params[n]) 147 | task_reg_loss += (torch.abs(p - self.teacher_params[n])).sum() 148 | param_num += torch.numel(p) 149 | # print('whole', task_reg_loss) 150 | 151 | loss = loss + self.alpha * task_reg_loss / param_num * 1e+5 152 | print(loss, 'task', task_reg_loss, 'actual', self.alpha * task_reg_loss / param_num * 1e+5) 153 | outputs["loss"] = loss 154 | return (loss, outputs) if return_outputs else loss 155 | 156 | def compute_loss_swa(self, model, inputs, return_outputs=False): 157 | return super().compute_loss(model, inputs, return_outputs=return_outputs) 158 | 159 | def compute_loss_default(self, model, inputs, return_outputs=False): 160 | return super().compute_loss(model, inputs, return_outputs=return_outputs) 161 | -------------------------------------------------------------------------------- /src/lmflow/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """The program includes several functions: setting a random seed, 2 | loading data from a JSON file, batching data, and extracting answers from generated text. 3 | """ 4 | 5 | import random 6 | import numpy as np 7 | import torch 8 | import json 9 | import re 10 | def set_random_seed(seed: int): 11 | """ 12 | Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`. 13 | 14 | Parameters 15 | ------------ 16 | seed : int 17 | The default seed. 18 | 19 | """ 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | if torch.cuda.is_available(): 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | def load_data(file_name: str): 27 | """ 28 | Load data with file name. 29 | 30 | Parameters 31 | ------------ 32 | file_name : str. 33 | The dataset file name. 34 | 35 | Returns 36 | ------------ 37 | inputs : list. 38 | The input texts of the dataset. 39 | outputs : list. 40 | The output texts file datasets. 41 | len : int. 42 | The length of the dataset. 43 | """ 44 | inputs = [] 45 | outputs = [] 46 | type = "" 47 | with open(file_name, encoding='utf-8') as f: 48 | json_data = json.load(f) 49 | type = json_data["type"] 50 | for line in json_data["instances"]: 51 | inputs.append(line["input"]) 52 | outputs.append(line["output"]) 53 | 54 | print(f"load dataset {file_name} success.\n") 55 | print(f"Type : {type}, datasize : {len(outputs)}") 56 | 57 | return inputs, outputs, len(outputs) 58 | 59 | def batchlize(examples: list, batch_size: int, random_shuffle: bool): 60 | """ 61 | Convert examples to a dataloader. 62 | 63 | Parameters 64 | ------------ 65 | examples : list. 66 | Data list. 67 | batch_size : int. 68 | 69 | random_shuffle : bool 70 | If true, the dataloader shuffle the training data. 71 | 72 | Returns 73 | ------------ 74 | dataloader: 75 | Dataloader with batch generator. 76 | """ 77 | size = 0 78 | dataloader = [] 79 | length = len(examples) 80 | if (random_shuffle): 81 | random.shuffle(examples) 82 | while size < length: 83 | if length - size > batch_size: 84 | dataloader.append(examples[size : size+batch_size]) 85 | size += batch_size 86 | else: 87 | dataloader.append(examples[size : size+(length-size)]) 88 | size += (length - size) 89 | return dataloader 90 | 91 | 92 | 93 | def answer_extraction(response, answer_type=None): #use this funtion to extract answers from generated text 94 | 95 | """ 96 | Use this funtion to extract answers from generated text 97 | 98 | Parameters 99 | ------------ 100 | args : 101 | Arguments. 102 | response : str 103 | plain string response. 104 | 105 | 106 | Returns 107 | ------------ 108 | answer: 109 | Decoded answer (such as A, B, C, D, E for mutiple-choice QA). 110 | """ 111 | 112 | # temp = response["generated_text"] 113 | temp = response 114 | if answer_type in ("gsm8k", "svamp", "asdiv", "addsub", "singleeq", "multiarith", "math"): 115 | temp = temp.replace(",", "") 116 | temp = [s for s in re.findall(r'-?\d+\.?\d*', temp)] 117 | elif answer_type in ("aqua", "csqa", "multiple_choice"): 118 | temp = re.findall(r'A|B|C|D|E', temp) 119 | elif answer_type in ("strategyqa", "coin_flip"): 120 | temp = temp.lower() 121 | temp = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", temp) 122 | temp = temp.split(" ") 123 | temp = [i for i in temp if i in ("yes", "no")] 124 | elif answer_type in ("last_letters"): 125 | temp = re.sub("\"|\'|\n|\.|\s","", temp) 126 | temp = [temp] 127 | elif answer_type in ("pubmedqa", "binary_choice"): 128 | # pattern = "Output: (yes|no|maybe)" 129 | # sttr = re.search(pattern, temp) 130 | # answer = sttr.group(0)[8:] if sttr is not None else "N/A" 131 | pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)" 132 | sttr = re.search(pattern, temp) 133 | if sttr is not None: 134 | mid_answer = sttr.group(0) 135 | mid_answer = mid_answer.split(":")[-1].strip() 136 | answer = mid_answer.lower() 137 | else: 138 | pattern = "(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)(\.|\s)" 139 | sttr = re.search(pattern, temp) 140 | if sttr is not None: 141 | answer = sttr.group(0)[:-1].lower() 142 | else: 143 | answer = "N/A" 144 | return answer 145 | elif answer_type == "medmcqa": 146 | # pattern = "Output: (A|B|C|D)." 147 | # sttr = re.search(pattern, temp) 148 | # answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A" 149 | pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)" 150 | sttr = re.search(pattern, temp) 151 | if sttr is not None: 152 | mid_answer = sttr.group(0) 153 | answer = mid_answer[-1].lower() 154 | else: 155 | pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" 156 | sttr = re.search(pattern, temp) 157 | if sttr is not None: 158 | if '(' in sttr.group(0): 159 | answer = sttr.group(0)[1].lower() 160 | else: 161 | answer = sttr.group(0)[0].lower() 162 | else: 163 | answer = "N/A" 164 | return answer 165 | 166 | elif answer_type == "usmle": 167 | # pattern = "Output: (A|B|C|D)." 168 | # sttr = re.search(pattern, temp) 169 | # answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A" 170 | pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)" 171 | sttr = re.search(pattern, temp) 172 | if sttr is not None: 173 | mid_answer = sttr.group(0) 174 | answer = mid_answer[-1].lower() 175 | else: 176 | pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" 177 | sttr = re.search(pattern, temp) 178 | if sttr is not None: 179 | if '(' in sttr.group(0): 180 | answer = sttr.group(0)[1].lower() 181 | else: 182 | answer = sttr.group(0)[0].lower() 183 | else: 184 | answer = "N/A" 185 | return answer 186 | elif answer_type == "text": 187 | return response 188 | else: 189 | raise NotImplementedError(f"Unsupported answer type: {answer_type}") 190 | 191 | if len(temp) != 0: 192 | answer = temp[-1] 193 | # if there is . at the end of answer, remove it 194 | # e.g. answer = 64. 195 | if answer != "": 196 | if answer[-1] == ".": 197 | answer = answer[:-1] 198 | 199 | # round the answer to nearest integer 200 | if answer_type in ("gsm8k", "svamp"): 201 | try: 202 | answer = str(round(float(answer))) 203 | except: 204 | answer = "" # no sol or sol doesn't have valid format 205 | elif answer_type in ("last_letters"): 206 | try: 207 | answer = answer[-args.concat_length:] 208 | except: 209 | answer = "" 210 | else: 211 | answer = "" 212 | return answer 213 | 214 | 215 | def process_image_flag(text, image_flag=""): 216 | texts = text.split(image_flag) 217 | if len(texts) > 1: 218 | image_token_indexes = [len(text) for text in texts[:-1]] 219 | else: 220 | image_token_indexes = [] 221 | # cumsun 222 | image_token_indexes = list(np.cumsum(image_token_indexes)) 223 | texts = "".join(texts) 224 | return texts, image_token_indexes -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of ['Mitigating the Alignment Tax of RLHF'](https://arxiv.org/abs/2309.06256) 2 | 3 | This is the repository of implementation of Adaptive Model Averaging in ['Mitigating the Alignment Tax of RLHF'](https://arxiv.org/abs/2309.06256). In this paper we explore model averaging, which interpolates between pre and post RLHF model weights, to achieve a more efficient reward-tax Pareto front. Building on the analysis and the observation that averaging different layers of the transformer leads to significantly different reward-tax trade-offs, we propose Adaptive Model Averaging (AMA) to adaptively find various combination ratios of model layers. AMA seeks to maximize the alignment reward while incurring minimal alignment tax. Moreover, we validate AMA’s performance across a range of RLHF algorithms over OpenLLaMA-3B and further extend our findings to Mistral-7B. 4 | We use the LMFlow as the framework which is compatible with other huggingface models. 5 | 6 | 10 | 11 | ## Quick Start 12 | 13 | Before training or inference, you need to install the package as follows. 14 | ### Setup 15 | 16 | ```bash 17 | git clone 18 | cd 19 | conda create -n ama_env python=3.9 -y 20 | conda activate ama_env 21 | conda install mpi4py 22 | ./install.sh 23 | ``` 24 | 25 | 26 | ### Prepare Dataset 27 | Since many experiments of our papers are conducted on HH-RLHF dataset. We have prepared the [HH-RLHF](https://github.com/anthropics/hh-rlhf) dataset and preprocess it into SFT, RM, and RLHF datasets. You can download them by running 28 | 29 | ```bash 30 | cd data 31 | ./download.sh hh_rlhf 32 | ``` 33 | 34 | Each sample of the dataset consists of a prompt (chat history) and two responses, where on is preferred compared to the other one. An example of the dataset is 35 | 36 | **###Human**: How can I get my girlfriend to cook more? 37 | **###Assistant**: Have you tried reminding her of how nice the food tastes? 38 | **###Human**: I could do it more. 39 | **###Assistant**: 40 | 41 | **Responses**: 42 | - Preferred $y_w$: That might work. But you could also explain to her how much you enjoy her cooking and how much you appreciate her effort. Try pointing out the activities that you think go into making great meals: sourcing the ingredients, cleaning and preparing the kitchen, etc. 43 | - Rejected $y_l$: Have you considered making an effort to create more harmonious interactions? 44 | 45 | We use the preferred responses of HH-RLHF training set as the SFT dataset, use the complete ccomparison data (training + test) for reward modeling, and use the prompts of the training set for RLHF. 46 | 47 | ### Model Averaging 48 | We first provide the script for make general model averaging between two models. 49 | 50 | ```bash 51 | bash scripts/postprocess/weight_interpolation.sh 52 | ``` 53 | To make it work, you need to change the parameters below in the script 'scripts/postprocess/weight_interpolation.sh': 54 | 55 | ```bash 56 | alpha=0.5 57 | model_path0=path/to/model/before/rlhf # for example, openlm-research/open_llama_3b 58 | model_path1=path/to/model/after/rlhf 59 | ... 60 | weight_ensamble_save_path=path/to/save/ma_${alpha}_tag0_tag1 61 | ``` 62 | here the tag0 and tag1 can be used to specify the model0 and model1. 63 | If the model0 with $\theta_0$, model1 with $\theta_1$, the save model will have the model weights of $\alpha * \theta_0 + (1 - \alpha) * \theta_1$, i.e., $\alpha=0$ means the model1 and $\alpha=1$ means the model0. 64 | 65 | ### Partwise Model Averaging 66 | To leverage the parwise model averaging to repreduce the results in Section 4. we can still use the script 'scripts/postprocess/weight_interpolation.sh' but change the name of the weight_ensamble_save_path like this: 67 | ```bash 68 | 69 | weight_ensamble_save_path=path/to/save/pma_${alpha}_${tag0}_split|10#20|0.4|0.3|0.2_${tag1} 70 | ``` 71 | alpha here only means the alpha weight of the lm_head layer but not other layers in transformer. tag0 and tag1 still represent the model0 and model1. 'split' means the merge method so just keep it here. '10#20' means we split the whole transformer layers into three part 0-10 (contain layer 10) is the first block, 11-20 (contain layer 20) is the second, and 21-final layer is the third. '0.4|0.3|0.2' represent the alpha weights of these three blocks. Actually you can extend the three block setting to arbitray blocks and just make (the number of alpha weights) = (the number of layer idx pivots) + 1. 72 | 73 | Reminders: Since we parse the save name to get information, so make sure that there are no '|', '#', '_' inside your tag0 and tag1. 74 | 75 | ### Adaptive Model Averaging 76 | To implement the adaptive model averaging, there are two steps: 1). optimization to get the alpha weights, 2). averaging based on the weights. 77 | 78 | #### Optimization 79 | ```bash 80 | bash scripts/mask_post/run_mask_finetune_raft.sh 81 | ``` 82 | 83 | Hyper-parameters of this optimziation process can be found in the script: 84 | 85 | ```bash 86 | approach=mask_norm_sigmoid_linear # the method used to average, just keep it 87 | mask_level=block3 # split the transfromers into 3 blocks, it will automatically compute the layer idx pivots. 88 | lr=2e-5 # learning rate of the optimzation process 89 | init_val=0.2 # base alpha weight 90 | reg_alpha=1e-4 # the penalty of the regularization term 91 | sum_reg_type=0.2 # actually there are only 0, 0.2 two types 0 means direct l1 penalty, 0.2 means a weighted l1 penalty 92 | epsilon=0.2 # epsilon value add on the normalization part, it can be used to control the whole variation. 93 | ``` 94 | there are also several paths variables you need to adjust to your own paths. 95 | 96 | ```bash 97 | model_path=${project_dir}/path/to/after/rlhf 98 | dataset_path=${project_dir}/path/to/data/collected 99 | ``` 100 | 101 | #### Averaging 102 | ```bash 103 | bash scripts/postprocess/weight_mask_merge.sh 104 | ``` 105 | 106 | ```bash 107 | model_path0=${project_dir}/path/to/before/rlhf 108 | model_path1=${project_dir}/path/to/after/rlhf 109 | 110 | alphas_path=${project_dir}/path/to/mask_alphas.bin 111 | weight_ensamble_save_path=${project_dir}/path/to/save 112 | ``` 113 | Averaging is almost the same as the script of model averaging, but here you just need to adjust paths of models with the learned mask_alphas.bin. 114 | 115 | ### Evaluations 116 | We give the usage of the evaluations scripts of our experiments. All scripts will require model paths, we do not specify here. 117 | 118 | #### Common Sense 119 | We invoke the [lm-evluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) to evluate. (So you need to download the repo first.) 120 | 121 | ```bash 122 | bash scripts/eval_cs_qa/run_evaluation.sh 123 | ``` 124 | 125 | #### Drop/Squad/WMT14 126 | We invoke [opencompass](https://github.com/open-compass/opencompass) and [lmflow_bencmark](https://github.com/shizhediao/forgetting-bench). 127 | ```bash 128 | bash scripts/eval_drop_squad_wmt/run-evaluation-drop.sh 129 | bash scripts/eval_drop_squad_wmt/run-evaluation-squad.sh 130 | bash scripts/eval_drop_squad_wmt/run-evaluation-wmt.sh 131 | ``` 132 | 133 | #### PairRM Value 134 | 135 | ```bash 136 | bash scripts/eval_pairrm/run_pairrm.sh 137 | ``` 138 | 139 | #### Reward Value 140 | 141 | ```bash 142 | bash scripts/eval_raft/run_eval_raft_align.sh 143 | ``` 144 | 145 | 146 | ## Support 147 | 148 | If you need any help, please submit a Github issue. 149 | ## License 150 | The code included in this project is licensed under the [Apache 2.0 license](https://github.com/OptimalScale/LMFlow/blob/main/LICENSE). 151 | If you wish to use the codes and models included in this project for commercial purposes, please sign this [document](https://docs.google.com/forms/d/e/1FAIpQLSfJYcci6cbgpIvx_Fh1xDL6pNkzsjGDH1QIcm4cYk88K2tqkw/viewform?usp=pp_url) to obtain authorization. 152 | 153 | ## Citation 154 | If you find this repository useful, please consider giving ⭐ and citing our [paper](https://arxiv.org/abs/2309.06256): 155 | 156 | ``` 157 | @article{lin2024mitigating, 158 | title={Mitigating the Alignment Tax of RLHF}, 159 | author={Lin, Yong and Lin, Hangyu and Xiong, Wei and Diao, Shizhe and Liu, Jianmeng and Zhang, Jipeng and Pan, Rui and Wang, Haoxiang and Hu, Wenbin and Zhang, Hanning and Dong, Hanze and Pi, Renjie and Zhao, Han and Jiang, Nan and Ji, Heng and Yao, Yuan and Zhang, Tong}, 160 | journal={arXiv preprint arXiv:2309.06256}, 161 | year={2023} 162 | } 163 | ``` -------------------------------------------------------------------------------- /scripts/eval_pairrm/text_generation.py: -------------------------------------------------------------------------------- 1 | # Install transformers from source - only needed for versions <= v4.34 2 | # pip install git+https://github.com/huggingface/transformers.git 3 | # pip install accelerate 4 | from collections.abc import Iterable 5 | import os 6 | import torch 7 | from transformers import pipeline 8 | import argparse 9 | import json 10 | from tqdm import tqdm 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | 13 | project_dir="/home/jianmeng/linhangyu/Projects/LLMs/LMFlow_RAFT_Dev" 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_path", type=str, default="zephr-7b-beta") 16 | parser.add_argument("--eval_data_type", type=str, default="alpaca") 17 | parser.add_argument("--sys_prompt_type", type=int, default=1) 18 | parser.add_argument("--gpu", type=int, default=-1, help="gpu id") 19 | parser.add_argument("--save_path", default="alpaca_eval2_pairrm.csv", type=str, help="save path") 20 | args = parser.parse_args() 21 | 22 | device = "cuda:0" 23 | model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16) 24 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 25 | model.to(device) 26 | sys_prompt_dict = { 27 | 0:"You are a friendly chatbot who always responds in the style of a pirate", 28 | 1:"You are a helpful, honest and respectful chatbot.", 29 | 2:"You are a friendly chatbot who always responds helpfully, honestly and respectfully.", 30 | } 31 | 32 | class HHRLHFData: 33 | def __init__(self, datasets, batch_size=50, sys_prompt_type=1) -> None: 34 | self.datasets = datasets 35 | self.batch_size = batch_size 36 | batch_num = len(self.datasets) // self.batch_size 37 | if len(self.datasets) % self.batch_size != 0: 38 | batch_num += 1 39 | self.batch_num = batch_num 40 | self.sys_prompt = sys_prompt_dict[sys_prompt_type] 41 | 42 | def __len__(self): 43 | return self.batch_num 44 | 45 | def get_data(self, index): 46 | tmp_prompts = [] 47 | tmp_inputs = [] 48 | # print(index, self.batch_size) 49 | start_idx, end_idx = index * self.batch_size, (index + 1) * self.batch_size 50 | for sample in self.datasets[start_idx:end_idx]: 51 | input_content = sample["text"].replace("###Human: ", "").replace("###Assistant:", "") 52 | messages = [ 53 | { 54 | "role": "system", 55 | "content": self.sys_prompt, 56 | }, 57 | {"role": "user", "content": input_content}, 58 | ] 59 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 60 | tmp_prompts.append(prompt), tmp_inputs.append(input_content) 61 | return tmp_prompts, tmp_inputs 62 | 63 | class AlpacaData: 64 | def __init__(self, datasets, batch_size=50, sys_prompt_type=0) -> None: 65 | self.datasets = datasets 66 | self.batch_size = batch_size 67 | batch_num = len(self.datasets) // self.batch_size 68 | if len(self.datasets) % self.batch_size != 0: 69 | batch_num += 1 70 | self.batch_num = batch_num 71 | self.sys_prompt = sys_prompt_dict[sys_prompt_type] 72 | 73 | def __len__(self): 74 | return self.batch_num 75 | 76 | def get_data(self, index): 77 | tmp_prompts = [] 78 | tmp_inputs = [] 79 | # print(index, self.batch_size) 80 | start_idx, end_idx = index * self.batch_size, (index + 1) * self.batch_size 81 | for sample in self.datasets[start_idx:end_idx]: 82 | input_content = sample["instruction"] 83 | messages = [ 84 | { 85 | "role": "system", 86 | # "content": "You are a friendly chatbot who always responds in the style of a pirate", 87 | "content": self.sys_prompt, 88 | }, 89 | {"role": "user", "content": input_content}, 90 | ] 91 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 92 | tmp_prompts.append(prompt), tmp_inputs.append(input_content) 93 | return tmp_prompts, tmp_inputs 94 | 95 | # alpaca_params = {'temperature':0.7, 'top_p':1.0, 'max_new_tokens':300} 96 | # hhrlhf_params = {'temperature':0.7, 'top_k':50, 'top_p':0.95, 'max_new_tokens':256} 97 | sys_prompt_type = args.sys_prompt_type 98 | eval_data_type = args.eval_data_type 99 | if eval_data_type == "hh_rlhf": 100 | eval_datasets = json.load(open(f"{project_dir}/data/hh_rlhf/rlhf/rlhf_eval/eval_prompt_first_half.json", "r")) 101 | # Only Split 2000 102 | eval_datasets = eval_datasets["instances"][:2000] 103 | eval_datasets = HHRLHFData(eval_datasets, batch_size=40, sys_prompt_type=sys_prompt_type) 104 | sample_params = {'temperature':0.7, 'top_k':50, 'top_p':0.95, 'max_new_tokens':256} 105 | elif eval_data_type == "alpaca": 106 | eval_datasets = json.load(open(f"{project_dir}/data/alpaca/alpaca_eval_gpt4_baseline.json", "r")) 107 | eval_datasets = AlpacaData(eval_datasets, batch_size=40, sys_prompt_type=sys_prompt_type) 108 | sample_params = {'temperature':0.7, 'top_p':1.0, 'max_new_tokens':300} 109 | # print(sum([len(eval_datasets.get_data(i)) for i in range(len(eval_datasets))])) 110 | 111 | 112 | output_datasets = [] 113 | for i in tqdm(range(len(eval_datasets))): 114 | sample, inputs = eval_datasets.get_data(i) 115 | model_inputs = tokenizer(sample, return_tensors="pt", padding=True).to(device) 116 | model_inputs = model_inputs.to(device) 117 | # model.to(device) 118 | generated_ids = model.generate(**model_inputs, do_sample=True, **sample_params) 119 | 120 | generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 121 | 122 | output_texts = [generated_text.split("\n<|assistant|>\n")[-1] for generated_text in generated_texts] 123 | 124 | for k in range(len(inputs)): 125 | output_datasets.append( 126 | { 127 | "dataset":eval_data_type, 128 | "instruction": inputs[k], 129 | "output":output_texts[k], 130 | "generator":args.model_path 131 | } 132 | ) 133 | # print(generated_texts[-1], output_datasets[-1]) 134 | json.dump(output_datasets, open(os.path.join(args.save_path, f"pairrm_v{sys_prompt_type}_{eval_data_type}.json"), "w")) 135 | 136 | 137 | # # pipe = pipeline("text-generation", model="HuggingFaceH4/mistral-7b-sft-beta", torch_dtype=torch.bfloat16, device_map=args.gpu) 138 | # pipe = pipeline("text-generation", model=args.model_path, torch_dtype=torch.bfloat16, device_map=args.gpu, use_cache=True) 139 | # # outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) 140 | # # if s_i % 10 == 0: 141 | # # print(f'Finish {s_i}/{len(eval_datasets)}.') 142 | # # # print(outputs[0]["generated_text"]) 143 | # # output_content = { 144 | # # "dataset":"hh_rlhf", 145 | # # "instruction": input_content, 146 | # # "output":outputs[0]["generated_text"], 147 | # # "generator":args.model_path 148 | # # } 149 | # # output_datasets.append(output_content) 150 | 151 | 152 | 153 | # def input_data(): 154 | # for s_i, sample in enumerate(eval_datasets): 155 | # input_content = sample["text"].replace("###Human: ", "").replace("###Assistant:", "") 156 | # yield input_content 157 | 158 | 159 | # # We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating 160 | # output_datasets = [] 161 | # ################################################################################################### 162 | # # for s_i, sample in tqdm(enumerate(eval_datasets)): 163 | # # input_content = sample["text"].replace("###Human: ", "").replace("###Assistant", "") 164 | # # messages = [ 165 | # # { 166 | # # "role": "system", 167 | # # "content": "You are a friendly chatbot who always responds in the style of a pirate", 168 | # # }, 169 | # # {"role": "user", "content": input_content}, 170 | # # ] 171 | # # prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 172 | 173 | # # outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) 174 | # # if s_i % 10 == 0: 175 | # # print(f'Finish {s_i}/{len(eval_datasets)}.') 176 | # # # print(outputs[0]["generated_text"]) 177 | # # output_content = { 178 | # # "dataset":"hh_rlhf", 179 | # # "instruction": input_content, 180 | # # "output":outputs[0]["generated_text"], 181 | # # "generator":args.model_path 182 | # # } 183 | # # output_datasets.append(output_content) 184 | # ################################################################################################### 185 | # for (input_content, outputs) in tqdm(zip(input_data(), 186 | # pipe(data(), max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95))): 187 | 188 | # # generated_content = outputs[0]["generated_text"] 189 | 190 | # output_content = { 191 | # "dataset":"hh_rlhf", 192 | # "instruction": input_content, 193 | # "output":outputs[0]["generated_text"].split("\n<|assistant|>\n")[-1], 194 | # "generator":args.model_path 195 | # } 196 | # # print(output_content) 197 | # output_datasets.append(output_content) 198 | 199 | 200 | # # <|system|> 201 | # # You are a friendly chatbot who always responds in the style of a pirate. 202 | # # <|user|> 203 | # # How many helicopters can a human eat in one sitting? 204 | # # <|assistant|> 205 | # # Ah, me hearty matey! But yer question be a puzzler! A human cannot eat a helicopter in one sitting, as helicopters are not edible. They be made of metal, plastic, and other materials, not food! 206 | -------------------------------------------------------------------------------- /src/lmflow/models/vision2seq_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # TODO update the doc 4 | 5 | import copy 6 | import logging 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from typing import List, Optional, Union 11 | 12 | from transformers import ( 13 | Blip2ForConditionalGeneration, 14 | Blip2Config, 15 | AutoModelForCausalLM 16 | ) 17 | 18 | from .base_model import BaseModel 19 | 20 | class CustomAutoVision2SeqModel(Blip2ForConditionalGeneration, BaseModel): 21 | def __init__(self, config: Blip2Config): 22 | Blip2ForConditionalGeneration.__init__(self, config) 23 | self.with_prompt_cache = False 24 | self.cache_dict = {} 25 | 26 | def vision_model_from_pretrained(self, pretrained_path): 27 | self.vision_model = self.vision_model.from_pretrained( 28 | pretrained_path, 29 | config=self.config.vision_config) 30 | def qformer_from_pretrained(self, pretrained_path): 31 | self.qformer = self.qformer.from_pretrained( 32 | pretrained_path, 33 | config=self.config.qformer_config) 34 | # print(self.qformer.encoder.layer[11].output_query.dense.weight.mean()) 35 | 36 | def language_model_from_pretrained(self, 37 | pretrained_path, 38 | low_resource=False, 39 | use_prompt_cache=False): 40 | # TODO remove the low resource related loading in the future 41 | self.use_prompt_cache = use_prompt_cache 42 | if low_resource: 43 | kwargs = dict( 44 | torch_dtype=torch.float16, 45 | load_in_8bit=True, 46 | device_map="auto" 47 | ) 48 | else: 49 | kwargs = {} 50 | past_model_dim = self.language_model.model_dim 51 | self.language_model = AutoModelForCausalLM.from_pretrained( 52 | pretrained_path, 53 | config=self.config.text_config, 54 | **kwargs) 55 | if self.config.text_config.hidden_size != past_model_dim: 56 | # should update the language projection layer 57 | in_channels = self.language_projection.in_features 58 | self.language_projection = nn.Linear(in_channels, 59 | self.config.text_config.hidden_size, 60 | bias=True) 61 | 62 | def register_prompt_cache(self, prompt_ids, prompt_keys_values): 63 | """ 64 | Udpate the prompt id and embedding for reuse in the future 65 | 66 | Args: 67 | prompt_ids (torch.LongTensor): The id of the prompt. 68 | prompt_keys_values (torch.FloatTensor): The embedding of the prompt. 69 | 70 | Returns: 71 | None 72 | """ 73 | self.prompt_ids = prompt_ids 74 | self.prompt_keys_values = prompt_keys_values 75 | self.with_prompt_cache = True 76 | 77 | def save_prompt_cache(self, path): 78 | """ 79 | Save prompt embedding and id. 80 | 81 | Args: 82 | path: The path to save the prompt embedding and id. 83 | 84 | Returns: 85 | None 86 | """ 87 | 88 | torch.save( 89 | dict( 90 | prompt_ids=self.prompt_ids, 91 | prompt_keys_values=self.prompt_keys_values 92 | ), 93 | path) 94 | 95 | def load_prompt_cache(self, path): 96 | """ 97 | Load prompt embedding and id. 98 | Args: 99 | path: The path to load the prompt embedding and id. 100 | 101 | Returns: 102 | None 103 | """ 104 | prompt_cache = torch.load(path) 105 | self.register_prompt_cache(prompt_cache["prompt_ids"], 106 | prompt_cache["prompt_keys_values"]) 107 | 108 | 109 | @torch.no_grad() 110 | def generate( 111 | self, 112 | pixel_values: torch.FloatTensor, 113 | input_ids: Optional[torch.LongTensor] = None, 114 | attention_mask: Optional[torch.LongTensor] = None, 115 | image_token_indexes: Optional[List] = [0], 116 | one_sample_multiple_images: Optional[bool] = False, 117 | **generate_kwargs, 118 | ) -> torch.LongTensor: 119 | """ 120 | Overrides `generate` function to be able to use the model as a conditional generator. 121 | 122 | Args: 123 | pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): 124 | Input images to be processed. 125 | input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): 126 | The sequence used as a prompt for the generation. 127 | attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): 128 | Mask to avoid performing attention on padding token indices 129 | image_token_indexes (bool, *optional*): 130 | The index for inserting the image tokens. 131 | one_sample_multiple_images: (bool, *optional*): 132 | The flag for inference that the input batch size is 1 and contain multiple images. 133 | 134 | Returns: 135 | captions (list): A list of strings of length batch_size * num_captions. 136 | """ 137 | if hasattr(self, "hf_device_map"): 138 | # preprocess for `accelerate` 139 | self._preprocess_accelerate() 140 | if not one_sample_multiple_images: 141 | batch_size = pixel_values.shape[0] 142 | else: 143 | batch_size = 1 144 | 145 | image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state 146 | image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) 147 | 148 | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) 149 | query_outputs = self.qformer( 150 | query_embeds=query_tokens, 151 | encoder_hidden_states=image_embeds, 152 | encoder_attention_mask=image_attention_mask, 153 | return_dict=True, 154 | ) 155 | query_output = query_outputs.last_hidden_state 156 | language_model_inputs = self.language_projection(query_output) 157 | 158 | language_attention_mask = torch.ones( 159 | language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device 160 | ) 161 | if input_ids is None: 162 | input_ids = ( 163 | torch.LongTensor([[self.config.text_config.bos_token_id]]) 164 | .repeat(batch_size, 1) 165 | .to(image_embeds.device) 166 | ) 167 | if attention_mask is None: 168 | attention_mask = torch.ones_like(input_ids) 169 | attention_mask = attention_mask.to(language_attention_mask.device) 170 | 171 | # concatenate query embeddings with prompt embeddings 172 | inputs_embeds = self.get_input_embeddings()(input_ids) 173 | inputs_embeds = inputs_embeds.to(language_model_inputs.device) 174 | # concatenate the text embeddings with image embeddings 175 | inputs_embeds_with_images = [] 176 | attention_mask_with_images = [] 177 | # currently we only support with one image 178 | start_index, end_index = 0, 0 179 | assert len(image_token_indexes) == pixel_values.shape[0] 180 | # token format: (# text, # image)xN, # text 181 | 182 | for idx, image_token_index in enumerate(image_token_indexes): 183 | end_index += image_token_index 184 | inputs_embeds_with_images.append( 185 | inputs_embeds[:, start_index:end_index]) 186 | inputs_embeds_with_images.append(language_model_inputs[idx][None]) 187 | attention_mask_with_images.append( 188 | attention_mask[:, start_index:end_index]) 189 | attention_mask_with_images.append(language_attention_mask[idx][None]) 190 | start_index = end_index 191 | 192 | inputs_embeds_with_images.append(inputs_embeds[:, image_token_indexes[-1]:]) 193 | inputs_embeds = torch.cat(inputs_embeds_with_images, dim=1) 194 | attention_mask_with_images.append(attention_mask[:, image_token_indexes[-1]:]) 195 | attention_mask = torch.cat(attention_mask_with_images, dim=1) 196 | # comebine the embeds 197 | inputs_embeds = inputs_embeds.to(self.language_model.lm_head.weight.dtype) 198 | attention_mask = attention_mask.to(self.language_model.lm_head.weight.dtype) 199 | 200 | if not self.use_prompt_cache or batch_size != 1: 201 | outputs = self.language_model.generate( 202 | inputs_embeds=inputs_embeds, 203 | attention_mask=attention_mask, 204 | **generate_kwargs, 205 | ) 206 | else: 207 | # current resuse prompt embeddings is not supported when batch size is 1; 208 | past_key_values = None 209 | prompt_length = image_token_indexes[0] 210 | if self.with_prompt_cache is False: 211 | prompt_ids = input_ids[:, :prompt_length] 212 | outputs = self.language_model.generate( 213 | inputs_embeds=inputs_embeds[:, :prompt_length], 214 | attention_mask=attention_mask[:, :prompt_length], 215 | use_cache=self.use_prompt_cache, 216 | **generate_kwargs, 217 | ) 218 | past_key_values = outputs["past_key_values"] 219 | self.register_prompt_cache(prompt_ids, past_key_values) 220 | 221 | prompt_length = self.prompt_id.shape[1] 222 | if torch.all(input_ids[:, :prompt_length] == self.prompt_id): 223 | past_key_values = self.prompt_key_values 224 | else: 225 | past_key_values = None 226 | generate_kwargs["past_key_values"] = past_key_values 227 | 228 | outputs = self.language_model.generate( 229 | inputs_embeds=inputs_embeds[:, prompt_length:], 230 | attention_mask=attention_mask[:, prompt_length:], 231 | use_cache=self.use_prompt_cache, 232 | **generate_kwargs, 233 | ) 234 | outputs = outputs.logits 235 | 236 | return outputs 237 | -------------------------------------------------------------------------------- /src/lmflow/pipeline/inferencer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """The Inferencer class simplifies the process of model inferencing.""" 4 | 5 | import copy 6 | import os 7 | import torch 8 | import wandb 9 | import deepspeed 10 | import sys 11 | import numpy as np 12 | import datetime 13 | import json 14 | import time 15 | 16 | from transformers import AutoConfig 17 | import torch.distributed as dist 18 | 19 | from lmflow.args import DatasetArguments 20 | from lmflow.datasets.dataset import Dataset 21 | from lmflow.pipeline.base_pipeline import BasePipeline 22 | from lmflow.models.hf_decoder_model import HFDecoderModel 23 | from lmflow.utils.data_utils import (set_random_seed, batchlize, 24 | answer_extraction, process_image_flag) 25 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers 26 | def rstrip_partial_utf8(string): 27 | return string.replace("\ufffd", "") 28 | 29 | supported_dataset_type = [ 30 | "text_only", 31 | "image_text", 32 | ] 33 | 34 | class Inferencer(BasePipeline): 35 | """ 36 | Initializes the `Inferencer` class with given arguments. 37 | 38 | Parameters 39 | ------------ 40 | model_args : ModelArguments object. 41 | Contains the arguments required to load the model. 42 | 43 | data_args : DatasetArguments object. 44 | Contains the arguments required to load the dataset. 45 | 46 | inferencer_args : InferencerArguments object. 47 | Contains the arguments required to perform inference. 48 | 49 | 50 | """ 51 | def __init__(self, model_args, data_args, inferencer_args): 52 | self.data_args = data_args 53 | self.inferencer_args = inferencer_args 54 | self.model_args = model_args 55 | 56 | set_random_seed(self.inferencer_args.random_seed) 57 | 58 | self.local_rank = int(os.getenv("LOCAL_RANK", "0")) 59 | self.world_size = int(os.getenv("WORLD_SIZE", "1")) 60 | if inferencer_args.device == "gpu": 61 | torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error 62 | deepspeed.init_distributed() 63 | else: 64 | os.environ["MASTER_ADDR"] = "localhost" 65 | os.environ["MASTER_PORT"] = "15000" 66 | dist.init_process_group( 67 | "gloo", rank=self.local_rank, world_size=self.world_size 68 | ) 69 | 70 | self.config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 71 | try: 72 | self.model_hidden_size = self.config.hidden_size 73 | except: 74 | print("Error in setting hidden size, use the default size 1024") 75 | self.model_hidden_size = 1024 # gpt2 seems do not have hidden_size in config 76 | 77 | 78 | def create_dataloader(self, dataset: Dataset): 79 | r"""Batchlize dataset and format it to dataloader. 80 | 81 | Args: 82 | dataset (Dataset): the dataset object 83 | 84 | Output: 85 | dataloader (batchlize): the dataloader object 86 | dataset_size (int): the length of the dataset 87 | 88 | """ 89 | if dataset.get_type() == "text_only": 90 | data_dict = dataset.to_dict() 91 | inputs = [instance["text"] for instance in data_dict["instances"] ] 92 | elif dataset.get_type() == "image_text": 93 | inputs = dataset.to_list() 94 | 95 | dataset_size = len(inputs) 96 | dataset_buf = [] 97 | for idx in range(dataset_size): 98 | dataset_buf.append({ 99 | "input": inputs[idx], 100 | "input_idx": idx 101 | }) 102 | 103 | dataloader = batchlize( 104 | dataset_buf, 105 | batch_size=1, 106 | random_shuffle=False, 107 | ) 108 | return dataloader, dataset_size 109 | 110 | 111 | def inference( 112 | self, 113 | model, 114 | dataset: Dataset, 115 | max_new_tokens: int=100, 116 | temperature: float=0.0, 117 | prompt_structure: str='{input}', 118 | remove_image_flag: bool=False, 119 | ): 120 | """ 121 | Perform inference for a model 122 | 123 | Parameters 124 | ------------ 125 | model : TunableModel object. 126 | TunableModel to perform inference 127 | 128 | dataset : Dataset object. 129 | 130 | 131 | Returns: 132 | 133 | output_dataset: Dataset object. 134 | """ 135 | if dataset.get_type() not in supported_dataset_type: 136 | raise NotImplementedError( 137 | 'input dataset should have type {}'.format( 138 | supported_dataset_type)) 139 | dataloader, data_size = self.create_dataloader(dataset) 140 | 141 | # The output dataset 142 | output_dict = { 143 | "type": "text_only", 144 | "instances": [ 145 | ] 146 | } 147 | 148 | for batch_index, batch in enumerate(dataloader): 149 | current_batch = batch[0] # batch size is 1 150 | if isinstance(current_batch['input'], str): 151 | input = prompt_structure.format(input=current_batch['input']) 152 | else: 153 | input = current_batch['input'] 154 | input['text'] = prompt_structure.format(input=input['text']) 155 | 156 | if 'images' in input and isinstance(input['images'], list): 157 | input['images'] = np.array(input['images']) 158 | 159 | if remove_image_flag: 160 | # remove the image flag in tokenization; 161 | input['text'] = input['text'].split("") 162 | # TODO remove this code by update the tokenizer 163 | input_ids = [] 164 | attention_mask = [] 165 | pixel_values = [] 166 | image_token_indexes = [] 167 | temp_input = copy.deepcopy(input) 168 | for idx in range(len(input['text'])): 169 | temp_input['text'] = input['text'][idx] 170 | temp_inputs = model.encode( 171 | temp_input, 172 | return_tensors="pt", 173 | add_special_tokens=idx==0).to(device=self.local_rank) 174 | input_ids.append(temp_inputs['input_ids']) 175 | attention_mask.append(temp_inputs['attention_mask']) 176 | image_token_indexes.append(temp_inputs["input_ids"].shape[1]) 177 | if len(image_token_indexes) > 1: 178 | image_token_indexes = image_token_indexes[:-1] 179 | inputs = temp_inputs 180 | inputs["input_ids"] = torch.cat(input_ids, dim=1) 181 | inputs["attention_mask"] = torch.cat(attention_mask, dim=1) 182 | else: 183 | if self.inferencer_args.device == "gpu": 184 | inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank) 185 | elif self.inferencer_args.device == "cpu": 186 | inputs = model.encode(input, return_tensors="pt").to(device='cpu') 187 | else: 188 | raise NotImplementedError( 189 | f"device \"{self.inferencer_args.device}\" is not supported" 190 | ) 191 | 192 | if remove_image_flag: 193 | inputs["image_token_indexes"] = image_token_indexes 194 | inputs["one_sample_multiple_images"] = True 195 | 196 | outputs = model.inference( 197 | inputs, 198 | max_new_tokens=max_new_tokens, 199 | temperature=self.inferencer_args.temperature, 200 | repetition_penalty=self.inferencer_args.repetition_penalty, 201 | do_sample=self.inferencer_args.do_sample, 202 | ) 203 | 204 | # only return the generation, trucating the input 205 | if self.model_args.arch_type != "vision_encoder_decoder": 206 | text_out = model.decode(outputs[0], skip_special_tokens=True) 207 | prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,)) 208 | text_out = text_out[prompt_length:] 209 | else: 210 | # to avoid redundant/missing leading space problem, we use a 211 | # part of the input text 212 | input_text = inputs['input_ids'][0][-1:] 213 | text_out = model.decode(torch.cat([input_text, outputs[0]]), skip_special_tokens=True) 214 | prompt_length = len(model.decode(input_text, skip_special_tokens=True,)) 215 | text_out = text_out[prompt_length:] 216 | 217 | output_dict["instances"].append({ "text": text_out }) 218 | 219 | output_dataset = Dataset(DatasetArguments(dataset_path = None)) 220 | output_dataset = output_dataset.from_dict(output_dict) 221 | 222 | return output_dataset 223 | 224 | def stream_inference( 225 | self, 226 | context, 227 | model, 228 | max_new_tokens, 229 | token_per_step, 230 | temperature, 231 | end_string, 232 | input_dataset, 233 | remove_image_flag: bool=False, 234 | ): 235 | response = "" 236 | history = [] 237 | if "ChatGLMModel" in self.config.architectures: 238 | for response, history in model.get_backend_model().stream_chat(model.get_tokenizer(), context, history=history): 239 | response = rstrip_partial_utf8(response) 240 | yield response, False 241 | else: 242 | for _ in range(0, self.inferencer_args.max_new_tokens // token_per_step): 243 | output_dataset = self.inference( 244 | model=model, 245 | dataset=input_dataset, 246 | max_new_tokens=token_per_step, 247 | temperature=self.inferencer_args.temperature, 248 | remove_image_flag=remove_image_flag, 249 | ) 250 | 251 | new_append_text = output_dataset.to_dict()["instances"][0]["text"] 252 | new_append_text = rstrip_partial_utf8(new_append_text) 253 | response += new_append_text 254 | 255 | input_dict = input_dataset.to_dict() 256 | input_dict["instances"][0]["text"] += new_append_text 257 | input_dataset = input_dataset.from_dict(input_dict) 258 | 259 | flag_break = False 260 | try: 261 | index = response.index(end_string) 262 | flag_break = True 263 | except ValueError: 264 | response += end_string 265 | index = response.index(end_string) 266 | 267 | response = response[:index] 268 | 269 | yield response, flag_break 270 | --------------------------------------------------------------------------------