├── .gitignore ├── LICENSE ├── README.md ├── examples ├── README.md ├── README_zh.md ├── RealQA │ └── all_fuse_training.yaml ├── accelerate │ └── fsdp_config.yaml ├── ava │ └── h20_ava_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_data_decimal2_6e_cot_mix.yaml ├── deepspeed │ ├── ds_z0_config.json │ ├── ds_z2_config.json │ ├── ds_z2_offload_config.json │ ├── ds_z3_config.json │ └── ds_z3_offload_config.json ├── extras │ ├── adam_mini │ │ └── qwen2_full_sft.yaml │ ├── badam │ │ └── llama3_full_sft.yaml │ ├── fsdp_qlora │ │ ├── llama3_lora_sft.yaml │ │ └── train.sh │ ├── galore │ │ └── llama3_full_sft.yaml │ ├── llama_pro │ │ ├── expand.sh │ │ └── llama3_freeze_sft.yaml │ ├── loraplus │ │ └── llama3_lora_sft.yaml │ ├── mod │ │ └── llama3_full_sft.yaml │ ├── nlg_eval │ │ └── llama3_lora_predict.yaml │ └── pissa │ │ ├── init.sh │ │ └── llama3_lora_sft.yaml ├── inference │ ├── llama3.yaml │ ├── llama3_lora_sft.yaml │ ├── llama3_vllm.yaml │ ├── llava1_5.yaml │ └── qwen2_vl.yaml ├── koniq_10k │ ├── koniq10k_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_decimal2.yaml │ └── koniq10k_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_decimal2_cot.yaml ├── merge_lora │ ├── llama3_gptq.yaml │ ├── llama3_lora_sft.yaml │ ├── qwen2vl_lora_sft.yaml │ └── qwen2vl_lora_sft_h20.yaml ├── train_full │ ├── koniq_qwen2vl_full_sft.yaml │ ├── koniq_qwen2vl_full_sft_2e.yaml │ ├── koniq_qwen2vl_full_sft_from_scratch.yaml │ ├── koniq_qwen2vl_full_sft_from_scratch_2e.yaml │ ├── koniq_qwen2vl_full_sft_from_scratch_ncm1.yaml │ ├── koniq_qwen2vl_full_sft_from_scratch_ncm1_2e.yaml │ ├── koniq_qwen2vl_full_sft_from_scratch_ncm5.yaml │ ├── llama3_full_sft.yaml │ └── qwen2vl_full_sft_without_caption_only_proj.yaml ├── train_lora │ ├── llama3_lora_dpo.yaml │ ├── llama3_lora_eval.yaml │ ├── llama3_lora_kto.yaml │ ├── llama3_lora_ppo.yaml │ ├── llama3_lora_pretrain.yaml │ ├── llama3_lora_reward.yaml │ ├── llama3_lora_sft.yaml │ ├── llama3_lora_sft_ds3.yaml │ ├── llama3_preprocess.yaml │ ├── llava1_5_lora_sft.yaml │ ├── qwen2vl_lora_dpo.yaml │ ├── qwen2vl_lora_sft.yaml │ ├── qwen2vl_lora_sft_v2.yaml │ ├── qwen2vl_lora_sft_v2_l20.yaml │ ├── qwen2vl_lora_sft_v2_with_general_caption.yaml │ ├── qwen2vl_lora_sft_v2_without_caption.yaml │ ├── qwen2vl_lora_sft_v2_without_caption_lora_128.yaml │ ├── qwen2vl_lora_sft_v2_without_caption_lora_128_wr_9502_from_llm_and_4300_lower_4_2.yaml │ ├── qwen2vl_lora_sft_v2_without_caption_lora_256_wr_9502_from_llm_and_4300_lower_4_2.yaml │ ├── qwen2vl_lora_sft_v2_without_caption_lora_32.yaml │ ├── qwen2vl_lora_sft_v2_without_caption_lora_32_wr_9502.yaml │ ├── qwen2vl_lora_sft_v2_without_caption_lora_32_wr_test.yaml │ └── qwen2vl_lora_sft_v2_without_caption_lora_32_x2_res.yaml └── train_qlora │ ├── llama3_lora_sft_aqlm.yaml │ ├── llama3_lora_sft_awq.yaml │ ├── llama3_lora_sft_gptq.yaml │ └── llama3_lora_sft_otfq.yaml ├── figure ├── 01.png ├── dataset.pdf ├── dataset.png.jpg └── init ├── get_my_transformer.py ├── metric └── aesthetic │ └── eval_koniq10k.py ├── modeling_qwen2_vl.py ├── scripts ├── api_example │ ├── test_image.py │ └── test_toolcall.py ├── call_qwen_max.py ├── convert_ckpt │ ├── llamafy_baichuan2.py │ └── llamafy_qwen.py ├── download_kaggle.py ├── llama_pro.py ├── loftq_init.py ├── pissa_init.py ├── stat_utils │ ├── cal_flops.py │ ├── cal_lr.py │ ├── cal_mfu.py │ ├── cal_ppl.py │ └── length_cdf.py └── vllm_infer.py ├── setup.py └── src ├── api.py ├── llamafactory ├── __init__.py ├── api │ ├── __init__.py │ ├── app.py │ ├── chat.py │ ├── common.py │ └── protocol.py ├── chat │ ├── __init__.py │ ├── base_engine.py │ ├── chat_model.py │ ├── hf_engine.py │ └── vllm_engine.py ├── cli.py ├── data │ ├── __init__.py │ ├── aligner.py │ ├── collator.py │ ├── data_utils.py │ ├── formatter.py │ ├── loader.py │ ├── mm_plugin.py │ ├── parser.py │ ├── preprocess.py │ ├── processors │ │ ├── __init__.py │ │ ├── feedback.py │ │ ├── pairwise.py │ │ ├── pretrain.py │ │ ├── processor_utils.py │ │ ├── supervised.py │ │ └── unsupervised.py │ ├── template.py │ └── tool_utils.py ├── eval │ ├── __init__.py │ ├── evaluator.py │ └── template.py ├── extras │ ├── __init__.py │ ├── constants.py │ ├── env.py │ ├── logging.py │ ├── misc.py │ ├── packages.py │ └── ploting.py ├── hparams │ ├── __init__.py │ ├── data_args.py │ ├── evaluation_args.py │ ├── finetuning_args.py │ ├── generating_args.py │ ├── model_args.py │ └── parser.py ├── launcher.py ├── model │ ├── __init__.py │ ├── adapter.py │ ├── loader.py │ ├── model_utils │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── checkpointing.py │ │ ├── embedding.py │ │ ├── liger_kernel.py │ │ ├── longlora.py │ │ ├── misc.py │ │ ├── mod.py │ │ ├── moe.py │ │ ├── packing.py │ │ ├── quantization.py │ │ ├── rope.py │ │ ├── unsloth.py │ │ ├── valuehead.py │ │ └── visual.py │ └── patcher.py ├── train │ ├── __init__.py │ ├── callbacks.py │ ├── dpo │ │ ├── __init__.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── kto │ │ ├── __init__.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── ppo │ │ ├── __init__.py │ │ ├── ppo_utils.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── pt │ │ ├── __init__.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── rm │ │ ├── __init__.py │ │ ├── metric.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── sft │ │ ├── __init__.py │ │ ├── metric.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── test_utils.py │ ├── trainer_utils.py │ └── tuner.py └── webui │ ├── __init__.py │ ├── chatter.py │ ├── common.py │ ├── components │ ├── __init__.py │ ├── chatbot.py │ ├── data.py │ ├── eval.py │ ├── export.py │ ├── infer.py │ ├── top.py │ └── train.py │ ├── css.py │ ├── engine.py │ ├── interface.py │ ├── locales.py │ ├── manager.py │ ├── runner.py │ └── utils.py ├── train.py └── webui.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 AMAP-ML 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/RealQA/all_fuse_training.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | lora_rank: 128 11 | lora_dropout: 0.05 12 | freeze_vision_tower: False 13 | 14 | ### dataset 15 | dataset: aesthetic_high_and_low_9502_4300_llm_fusion_en_new_high_mid_low_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_both_high_en, aesthetic_low_and_middle_9502_and_4300_llm_en_both_mid_en, aesthetic_low_and_middle_9502_and_4300_llm_en_both_low_en, aesthetic_low_and_middle_9502_and_4300_llm_en_composition_score_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_eye_score_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_subject_and_subject_integrity_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_subject_and_subject_clutter_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_background_and_background_clutter_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_level_shot_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_exposure_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_clearness_new_simple, aesthetic_low_and_middle_9502_and_4300_llm_en_saturation_new_simple # video: mllm_video_demo 16 | template: qwen2_vl 17 | cutoff_len: 2048 18 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 19 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: test/config/14k_score_train_and_test/save/aesthetic_high_and_low_9502_4300_llm_fusion_en_mix_new_simple_epoch2 24 | logging_steps: 10 25 | save_steps: 1000 26 | # save_strategy: 'epoch' 27 | 28 | plot_loss: true 29 | overwrite_output_dir: true 30 | 31 | ### train 32 | per_device_train_batch_size: 2 33 | gradient_accumulation_steps: 8 34 | learning_rate: 1.0e-4 35 | num_train_epochs: 2.0 36 | lr_scheduler_type: cosine 37 | warmup_ratio: 0.1 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 43 | per_device_eval_batch_size: 1 44 | eval_strategy: 'no' # no/steps 45 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/accelerate/fsdp_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: BACKWARD_PRE 8 | fsdp_forward_prefetch: false 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_offload_params: true # offload may affect training speed 11 | fsdp_sharding_strategy: FULL_SHARD 12 | fsdp_state_dict_type: FULL_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_use_orig_params: true 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: fp16 # or bf16 18 | num_machines: 1 # the number of nodes 19 | num_processes: 2 # the number of GPUs in all nodes 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /examples/ava/h20_ava_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_data_decimal2_6e_cot_mix.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | do_eval: false 7 | finetuning_type: lora 8 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 9 | lora_rank: 128 10 | lora_dropout: 0.15 11 | freeze_vision_tower: False 12 | 13 | ### dataset 14 | dataset: ava_training_decimal2_cot,ava_training_decimal2 15 | template: qwen2_vl 16 | cutoff_len: 2048 17 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 18 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/ava/ava_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_data_decimal2_6e_cot_mix_v2 23 | logging_steps: 10 24 | # save_steps: 1000 25 | save_strategy: 'epoch' 26 | 27 | plot_loss: true 28 | overwrite_output_dir: true 29 | 30 | ### trai[n 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-4 34 | num_train_epochs: 6.0 35 | lr_scheduler_type: cosine 36 | warmup_ratio: 0.1 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 42 | per_device_eval_batch_size: 1 43 | eval_strategy: 'no' # no/steps 44 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/deepspeed/ds_z0_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 0, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z2_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "allgather_partitions": true, 25 | "allgather_bucket_size": 5e8, 26 | "overlap_comm": true, 27 | "reduce_scatter": true, 28 | "reduce_bucket_size": 5e8, 29 | "contiguous_gradients": true, 30 | "round_robin_gradients": true 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1e9, 23 | "reduce_bucket_size": "auto", 24 | "stage3_prefetch_bucket_size": "auto", 25 | "stage3_param_persistence_threshold": "auto", 26 | "stage3_max_live_parameters": 1e9, 27 | "stage3_max_reuse_distance": 1e9, 28 | "stage3_gather_16bit_weights_on_model_save": true 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /examples/deepspeed/ds_z3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "offload_param": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "overlap_comm": true, 29 | "contiguous_gradients": true, 30 | "sub_group_size": 1e9, 31 | "reduce_bucket_size": "auto", 32 | "stage3_prefetch_bucket_size": "auto", 33 | "stage3_param_persistence_threshold": "auto", 34 | "stage3_max_live_parameters": 1e9, 35 | "stage3_max_reuse_distance": 1e9, 36 | "stage3_gather_16bit_weights_on_model_save": true 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/extras/adam_mini/qwen2_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2-1.5B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | use_adam_mini: true 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: qwen 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/qwen2-1_5b/full/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/extras/badam/llama3_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | use_badam: true 9 | badam_mode: layer 10 | badam_switch_mode: ascending 11 | badam_switch_interval: 50 12 | badam_verbose: 2 13 | # deepspeed: examples/deepspeed/ds_z3_config.json 14 | 15 | ### dataset 16 | dataset: identity,alpaca_en_demo 17 | template: llama3 18 | cutoff_len: 2048 19 | max_samples: 1000 20 | overwrite_cache: true 21 | preprocessing_num_workers: 16 22 | 23 | ### output 24 | output_dir: saves/llama3-8b/full/sft 25 | logging_steps: 10 26 | save_steps: 500 27 | plot_loss: true 28 | overwrite_output_dir: true 29 | 30 | ### train 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 3.0 35 | lr_scheduler_type: cosine 36 | warmup_ratio: 0.1 37 | 38 | ### eval 39 | val_size: 0.1 40 | per_device_eval_batch_size: 1 41 | eval_strategy: steps 42 | eval_steps: 500 43 | -------------------------------------------------------------------------------- /examples/extras/fsdp_qlora/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | quantization_bit: 4 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_target: all 10 | 11 | ### dataset 12 | dataset: identity,alpaca_en_demo 13 | template: llama3 14 | cutoff_len: 2048 15 | max_samples: 1000 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/llama3-8b/lora/sft 21 | logging_steps: 10 22 | save_steps: 500 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 8 29 | learning_rate: 1.0e-4 30 | num_train_epochs: 3.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0.1 38 | per_device_eval_batch_size: 1 39 | eval_strategy: steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/extras/fsdp_qlora/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # DO NOT use GPTQ/AWQ model in FSDP+QLoRA 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ 5 | --config_file examples/accelerate/fsdp_config.yaml \ 6 | src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml 7 | -------------------------------------------------------------------------------- /examples/extras/galore/llama3_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | use_galore: true 9 | galore_layerwise: true 10 | galore_target: mlp,self_attn 11 | galore_rank: 128 12 | galore_scale: 2.0 13 | 14 | ### dataset 15 | dataset: identity,alpaca_en_demo 16 | template: llama3 17 | cutoff_len: 2048 18 | max_samples: 1000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: saves/llama3-8b/full/sft 24 | logging_steps: 10 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | 29 | ### train 30 | per_device_train_batch_size: 1 31 | gradient_accumulation_steps: 1 32 | learning_rate: 1.0e-5 33 | num_train_epochs: 3.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | pure_bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 1 42 | eval_strategy: steps 43 | eval_steps: 500 44 | -------------------------------------------------------------------------------- /examples/extras/llama_pro/expand.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/llama_pro.py \ 4 | --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ 5 | --output_dir models/llama3-8b-pro \ 6 | --num_expand 8 7 | -------------------------------------------------------------------------------- /examples/extras/llama_pro/llama3_freeze_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: models/llama3-8b-pro 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: freeze 8 | freeze_trainable_layers: 8 9 | freeze_trainable_modules: all 10 | use_llama_pro: true 11 | 12 | ### dataset 13 | dataset: identity,alpaca_en_demo 14 | template: llama3 15 | cutoff_len: 2048 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/llama3-8b-pro/freeze/sft 22 | logging_steps: 10 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 1 29 | gradient_accumulation_steps: 8 30 | learning_rate: 1.0e-4 31 | num_train_epochs: 3.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.1 34 | bf16: true 35 | ddp_timeout: 180000000 36 | 37 | ### eval 38 | val_size: 0.1 39 | per_device_eval_batch_size: 1 40 | eval_strategy: steps 41 | eval_steps: 500 42 | -------------------------------------------------------------------------------- /examples/extras/loraplus/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | loraplus_lr_ratio: 16.0 10 | 11 | ### dataset 12 | dataset: identity,alpaca_en_demo 13 | template: llama3 14 | cutoff_len: 2048 15 | max_samples: 1000 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/llama3-8b/lora/sft 21 | logging_steps: 10 22 | save_steps: 500 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 8 29 | learning_rate: 1.0e-4 30 | num_train_epochs: 3.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0.1 38 | per_device_eval_batch_size: 1 39 | eval_strategy: steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/extras/mod/llama3_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | mixture_of_depths: convert 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b-mod/full/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | optim: paged_adamw_8bit 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 3.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | pure_bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0.1 38 | per_device_eval_batch_size: 1 39 | eval_strategy: steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/extras/nlg_eval/llama3_lora_predict.yaml: -------------------------------------------------------------------------------- 1 | # The batch generation can be SLOW using this config. 2 | # For faster inference, we recommend to use `scripts/vllm_infer.py`. 3 | 4 | ### model 5 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 6 | adapter_name_or_path: saves/llama3-8b/lora/sft 7 | 8 | ### method 9 | stage: sft 10 | do_predict: true 11 | finetuning_type: lora 12 | 13 | ### dataset 14 | eval_dataset: identity,alpaca_en_demo 15 | template: llama3 16 | cutoff_len: 2048 17 | max_samples: 50 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama3-8b/lora/predict 23 | overwrite_output_dir: true 24 | 25 | ### eval 26 | per_device_eval_batch_size: 1 27 | predict_with_generate: true 28 | ddp_timeout: 180000000 29 | -------------------------------------------------------------------------------- /examples/extras/pissa/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/pissa_init.py \ 4 | --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ 5 | --output_dir models/llama3-8b-pissa 6 | -------------------------------------------------------------------------------- /examples/extras/pissa/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | pissa_init: true 10 | pissa_iter: 16 11 | pissa_convert: true 12 | 13 | ### dataset 14 | dataset: identity,alpaca_en_demo 15 | template: llama3 16 | cutoff_len: 2048 17 | max_samples: 1000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama3-8b/lora/sft 23 | logging_steps: 10 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 1 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 3.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0.1 40 | per_device_eval_batch_size: 1 41 | eval_strategy: steps 42 | eval_steps: 500 43 | -------------------------------------------------------------------------------- /examples/inference/llama3.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 2 | template: llama3 3 | infer_backend: huggingface # choices: [huggingface, vllm] 4 | -------------------------------------------------------------------------------- /examples/inference/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 2 | adapter_name_or_path: saves/llama3-8b/lora/sft 3 | template: llama3 4 | finetuning_type: lora 5 | infer_backend: huggingface # choices: [huggingface, vllm] 6 | -------------------------------------------------------------------------------- /examples/inference/llama3_vllm.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 2 | template: llama3 3 | infer_backend: vllm 4 | vllm_enforce_eager: true 5 | -------------------------------------------------------------------------------- /examples/inference/llava1_5.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: llava-hf/llava-1.5-7b-hf 2 | template: llava 3 | infer_backend: huggingface # choices: [huggingface, vllm] 4 | -------------------------------------------------------------------------------- /examples/inference/qwen2_vl.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: Qwen/Qwen2-VL-7B-Instruct 2 | template: qwen2_vl 3 | infer_backend: huggingface # choices: [huggingface, vllm] 4 | -------------------------------------------------------------------------------- /examples/koniq_10k/koniq10k_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_decimal2.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | do_eval: false 7 | finetuning_type: lora 8 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 9 | lora_rank: 128 10 | lora_dropout: 0.15 11 | freeze_vision_tower: False 12 | 13 | ### dataset 14 | dataset: koniq10k_training_decimal2 15 | template: qwen2_vl 16 | cutoff_len: 2048 17 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 18 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/koniq10k/koniq10k_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_decimal2 23 | logging_steps: 10 24 | # save_steps: 1000 25 | save_strategy: 'epoch' 26 | 27 | plot_loss: true 28 | overwrite_output_dir: true 29 | 30 | ### train 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-4 34 | num_train_epochs: 2.0 35 | lr_scheduler_type: cosine 36 | warmup_ratio: 0.1 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 42 | per_device_eval_batch_size: 1 43 | eval_strategy: 'no' # no/steps 44 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/koniq_10k/koniq10k_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_decimal2_cot.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | do_eval: false 7 | finetuning_type: lora 8 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 9 | lora_rank: 128 10 | lora_dropout: 0.15 11 | freeze_vision_tower: False 12 | 13 | ### dataset 14 | dataset: koniq10k_training_decimal2,koniq10k_training_decimal2_cot 15 | template: qwen2_vl 16 | cutoff_len: 2048 17 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 18 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/koniq10k/koniq10k_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_decimal2_cot 23 | logging_steps: 10 24 | # save_steps: 1000 25 | save_strategy: 'epoch' 26 | 27 | plot_loss: true 28 | overwrite_output_dir: true 29 | 30 | ### train 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-4 34 | num_train_epochs: 6.0 35 | lr_scheduler_type: cosine 36 | warmup_ratio: 0.1 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 42 | per_device_eval_batch_size: 1 43 | eval_strategy: 'no' # no/steps 44 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/merge_lora/llama3_gptq.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | template: llama3 4 | 5 | ### export 6 | export_dir: models/llama3_gptq 7 | export_quantization_bit: 4 8 | export_quantization_dataset: data/c4_demo.json 9 | export_size: 2 10 | export_device: cpu 11 | export_legacy_format: false 12 | -------------------------------------------------------------------------------- /examples/merge_lora/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters 2 | 3 | ### model 4 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 5 | adapter_name_or_path: saves/llama3-8b/lora/sft 6 | template: llama3 7 | finetuning_type: lora 8 | 9 | ### export 10 | export_dir: models/llama3_lora_sft 11 | export_size: 2 12 | export_device: cpu 13 | export_legacy_format: false 14 | -------------------------------------------------------------------------------- /examples/merge_lora/qwen2vl_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters 2 | 3 | ### model 4 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 5 | adapter_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/LLaMA-Factory/saves/qwen2_vl-7b/lora/sft_v2_format_w_general_caption/checkpoint-2965 6 | template: qwen2_vl 7 | finetuning_type: lora 8 | 9 | ### export 10 | export_dir: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/qwen2vl-merge-general-caption 11 | export_size: 2 12 | export_device: cpu 13 | export_legacy_format: false 14 | -------------------------------------------------------------------------------- /examples/merge_lora/qwen2vl_lora_sft_h20.yaml: -------------------------------------------------------------------------------- 1 | ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters 2 | 3 | ### model 4 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 5 | # adapter_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/saves/ava/ava_training_wo_taskid_w_score_range_lora128_loradrop015_unfreeze_vision_data_decimal2/checkpoint-3680 6 | adapter_name_or_path: /mnt/xmap_nas_alg/wangrui.wr/code/qwen2vl-train/test/config/one_item_en/config_item_concat_with_14k_high_mid_low/aesthetic_9502_4300_llm_fusion_en_all_original_high_mid_low_3level_9item/checkpoint-5570 7 | template: qwen2_vl 8 | finetuning_type: lora 9 | 10 | ### export 11 | export_dir: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/high-mid-low 12 | export_size: 2 13 | export_device: cpu 14 | export_legacy_format: false 15 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: false 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 6.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft_2e.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: false 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft_2e 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 2.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft_from_scratch.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: true 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft_from_scratch 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 6.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft_from_scratch_2e.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: true 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft_from_scratch_2e 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 2.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft_from_scratch_ncm1.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: true 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft_from_scratch_ncm1 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 6.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft_from_scratch_ncm1_2e.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: true 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft_from_scratch_ncm1_2e 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 2.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/koniq_qwen2vl_full_sft_from_scratch_ncm5.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: false # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | train_from_scratch: true 11 | 12 | ### dataset 13 | dataset: koniq10k_training_decimal2 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/full/output/full/koniq_qwen2vl_full_sft_from_scratch_ncm5 21 | logging_steps: 10 22 | save_steps: 1000 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 4 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 6.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 38 | per_device_eval_batch_size: 1 39 | eval_strategy: 'no' # no/steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_full/llama3_full_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b/full/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 2 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_full/qwen2vl_full_sft_without_caption_only_proj.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | finetuning_type: full 7 | freeze_vision_tower: true # choices: [true, false] 8 | train_mm_proj_only: false # choices: [true, false] 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | 11 | ### dataset 12 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/full/output/qwen2vl_full_v2_without_caption_only_proj 20 | logging_steps: 10 21 | save_steps: 1000 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 4 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 10.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 37 | per_device_eval_batch_size: 1 38 | eval_strategy: 'no' # no/steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | pref_beta: 0.1 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | 12 | ### dataset 13 | dataset: dpo_en_demo 14 | template: llama3 15 | cutoff_len: 2048 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/llama3-8b/lora/dpo 22 | logging_steps: 10 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 1 29 | gradient_accumulation_steps: 8 30 | learning_rate: 5.0e-6 31 | num_train_epochs: 3.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.1 34 | bf16: true 35 | ddp_timeout: 180000000 36 | 37 | ### eval 38 | val_size: 0.1 39 | per_device_eval_batch_size: 1 40 | eval_strategy: steps 41 | eval_steps: 500 42 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_eval.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | adapter_name_or_path: saves/llama3-8b/lora/sft 4 | 5 | ### method 6 | finetuning_type: lora 7 | 8 | ### dataset 9 | task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test] 10 | template: fewshot 11 | lang: en 12 | n_shot: 5 13 | 14 | ### output 15 | save_dir: saves/llama3-8b/lora/eval 16 | 17 | ### eval 18 | batch_size: 4 19 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_kto.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: kto 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | pref_beta: 0.1 10 | 11 | ### dataset 12 | dataset: kto_en_demo 13 | template: llama3 14 | cutoff_len: 2048 15 | max_samples: 1000 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/llama3-8b/lora/kto 21 | logging_steps: 10 22 | save_steps: 500 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 8 29 | learning_rate: 5.0e-6 30 | num_train_epochs: 3.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0.1 38 | per_device_eval_batch_size: 1 39 | eval_strategy: steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_ppo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | reward_model: saves/llama3-8b/lora/reward 4 | 5 | ### method 6 | stage: ppo 7 | do_train: true 8 | finetuning_type: lora 9 | lora_target: all 10 | 11 | ### dataset 12 | dataset: identity,alpaca_en_demo 13 | template: llama3 14 | cutoff_len: 2048 15 | max_samples: 1000 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/llama3-8b/lora/ppo 21 | logging_steps: 10 22 | save_steps: 500 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 8 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 3.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### generate 37 | max_new_tokens: 512 38 | top_k: 0 39 | top_p: 0.9 40 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_pretrain.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: pt 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: c4_demo 12 | cutoff_len: 2048 13 | max_samples: 1000 14 | overwrite_cache: true 15 | preprocessing_num_workers: 16 16 | 17 | ### output 18 | output_dir: saves/llama3-8b/lora/pretrain 19 | logging_steps: 10 20 | save_steps: 500 21 | plot_loss: true 22 | overwrite_output_dir: true 23 | 24 | ### train 25 | per_device_train_batch_size: 1 26 | gradient_accumulation_steps: 8 27 | learning_rate: 1.0e-4 28 | num_train_epochs: 3.0 29 | lr_scheduler_type: cosine 30 | warmup_ratio: 0.1 31 | bf16: true 32 | ddp_timeout: 180000000 33 | 34 | ### eval 35 | val_size: 0.1 36 | per_device_eval_batch_size: 1 37 | eval_strategy: steps 38 | eval_steps: 500 39 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: rm 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: dpo_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b/lora/reward 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-4 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b/lora/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-4 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_lora_sft_ds3.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 10 | 11 | ### dataset 12 | dataset: identity,alpaca_en_demo 13 | template: llama3 14 | cutoff_len: 2048 15 | max_samples: 1000 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/llama3-8b/lora/sft 21 | logging_steps: 10 22 | save_steps: 500 23 | plot_loss: true 24 | overwrite_output_dir: true 25 | 26 | ### train 27 | per_device_train_batch_size: 1 28 | gradient_accumulation_steps: 2 29 | learning_rate: 1.0e-4 30 | num_train_epochs: 3.0 31 | lr_scheduler_type: cosine 32 | warmup_ratio: 0.1 33 | bf16: true 34 | ddp_timeout: 180000000 35 | 36 | ### eval 37 | val_size: 0.1 38 | per_device_eval_batch_size: 1 39 | eval_strategy: steps 40 | eval_steps: 500 41 | -------------------------------------------------------------------------------- /examples/train_lora/llama3_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | tokenized_path: saves/llama3-8b/dataset/sft 18 | 19 | ### output 20 | output_dir: saves/llama3-8b/lora/sft 21 | overwrite_output_dir: true 22 | -------------------------------------------------------------------------------- /examples/train_lora/llava1_5_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: llava-hf/llava-1.5-7b-hf 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: mllm_demo 12 | template: llava 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llava1_5-7b/lora/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-4 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2-VL-7B-Instruct 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | pref_beta: 0.1 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | 12 | ### dataset 13 | dataset: rlhf_v 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/qwen2_vl-7b/lora/dpo 22 | logging_steps: 10 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 1 29 | gradient_accumulation_steps: 8 30 | learning_rate: 5.0e-6 31 | num_train_epochs: 3.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.1 34 | bf16: true 35 | ddp_timeout: 180000000 36 | 37 | ### eval 38 | val_size: 0.1 39 | per_device_eval_batch_size: 1 40 | eval_strategy: steps 41 | eval_steps: 500 42 | -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | 11 | ### dataset 12 | dataset: aesthetic_high,aesthetic_low_and_middle # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/sft 21 | logging_steps: 10 22 | save_steps: 1000 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 5.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 43 | -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | 11 | ### dataset 12 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2 # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format 21 | logging_steps: 10 22 | save_steps: 1000 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 5.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_l20.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | 11 | ### dataset 12 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2 # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/sft_v2 21 | logging_steps: 10 22 | save_steps: 1000 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 5.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 43 | 44 | -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_with_general_caption.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | 11 | ### dataset 12 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_with_general_caption # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_w_general_caption 21 | logging_steps: 10 22 | save_steps: 1000 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 5.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | 11 | ### dataset 12 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_wo_caption 21 | logging_steps: 10 22 | save_steps: 1000 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 5.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_128.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | lora_rank: 128 11 | 12 | ### dataset 13 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 17 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_wo_caption_lora_128 22 | logging_steps: 10 23 | save_steps: 1000 24 | # save_strategy: 'epoch' 25 | 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 5.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 41 | per_device_eval_batch_size: 1 42 | eval_strategy: 'no' # no/steps 43 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_128_wr_9502_from_llm_and_4300_lower_4_2.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | do_eval: false 7 | finetuning_type: lora 8 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 9 | lora_rank: 128 10 | 11 | ### dataset 12 | dataset: aesthetic_high_lower_4.2_4300, aesthetic_low_and_middle_lower_4.2_4300,aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/qwen2vl_lora_sft_v2_without_caption_lora_128_wr_9502_from_llm_and_4300_lower_4_2 21 | logging_steps: 10 22 | save_steps: 500 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 3.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_256_wr_9502_from_llm_and_4300_lower_4_2.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | do_eval: false 7 | finetuning_type: lora 8 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 9 | lora_rank: 256 10 | 11 | ### dataset 12 | dataset: aesthetic_high_lower_4.2_4300, aesthetic_low_and_middle_lower_4.2_4300,aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/qwen2vl_lora_sft_v2_without_caption_lora_256_wr_9502_from_llm_and_4300_lower_4_2 21 | logging_steps: 10 22 | save_steps: 500 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 3.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_32.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B 3 | ### method 4 | stage: sft 5 | do_train: true 6 | do_eval: false 7 | finetuning_type: lora 8 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 9 | lora_rank: 32 10 | 11 | ### dataset 12 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 13 | template: qwen2_vl 14 | cutoff_len: 2048 15 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 16 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 17 | preprocessing_num_workers: 16 18 | 19 | ### output 20 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_wo_caption 21 | logging_steps: 10 22 | save_steps: 1000 23 | # save_strategy: 'epoch' 24 | 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 8 31 | learning_rate: 1.0e-4 32 | num_train_epochs: 5.0 33 | lr_scheduler_type: cosine 34 | warmup_ratio: 0.1 35 | bf16: true 36 | ddp_timeout: 180000000 37 | 38 | ### eval 39 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 40 | per_device_eval_batch_size: 1 41 | eval_strategy: 'no' # no/steps 42 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_32_wr_9502.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | lora_rank: 32 11 | 12 | ### dataset 13 | dataset: aesthetic_high_9602, aesthetic_low_and_middle_9502_with_without_caption # video: mllm_video_demo 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 17 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_wo_caption_9502 22 | logging_steps: 10 23 | save_steps: 1000 24 | # save_strategy: 'epoch' 25 | 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 5.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 41 | per_device_eval_batch_size: 1 42 | eval_strategy: 'no' # no/steps 43 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_32_wr_test.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /mnt/xmap_nas_alg/limingxing.lmx/workspace/code/aesthetic/qwen2vl-train/hugging_face/qwen2vl7B/qwen2vl7B 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | lora_rank: 32 11 | 12 | ### dataset 13 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 17 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_wo_caption 22 | logging_steps: 10 23 | save_steps: 1000 24 | # save_strategy: 'epoch' 25 | 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 5.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 41 | per_device_eval_batch_size: 1 42 | eval_strategy: 'no' # no/steps 43 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_lora/qwen2vl_lora_sft_v2_without_caption_lora_32_x2_res.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: /data/7.68T-3/limingxing.lmx/workspace/Qwen2-VL/Qwen2-VL/hugging_face/7B_large 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | do_eval: false 8 | finetuning_type: lora 9 | lora_target: all # 采取LoRA方法的目标模块,默认值为 all。 10 | lora_rank: 32 11 | 12 | ### dataset 13 | dataset: aesthetic_high_v2,aesthetic_low_and_middle_v2_without_caption # video: mllm_video_demo 14 | template: qwen2_vl 15 | cutoff_len: 2048 16 | # max_samples: 1000 # For debugging purposes, truncate the number of examples for each dataset 17 | overwrite_cache: true # Overwrite the cached training and evaluation sets. 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/qwen2_vl-7b/lora/sft_v2_format_wo_caption 22 | logging_steps: 10 23 | save_steps: 1000 24 | # save_strategy: 'epoch' 25 | 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 5.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0. # 随机从数据集中抽取多少比例的数据作为验证集,暂时设置为全部训练 41 | per_device_eval_batch_size: 1 42 | eval_strategy: 'no' # no/steps 43 | eval_steps: 500 -------------------------------------------------------------------------------- /examples/train_qlora/llama3_lora_sft_aqlm.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b/lora/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-4 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_qlora/llama3_lora_sft_awq.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b/lora/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-4 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_qlora/llama3_lora_sft_gptq.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | 10 | ### dataset 11 | dataset: identity,alpaca_en_demo 12 | template: llama3 13 | cutoff_len: 2048 14 | max_samples: 1000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/llama3-8b/lora/sft 20 | logging_steps: 10 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 8 28 | learning_rate: 1.0e-4 29 | num_train_epochs: 3.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.1 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 500 40 | -------------------------------------------------------------------------------- /examples/train_qlora/llama3_lora_sft_otfq.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct 3 | quantization_bit: 4 4 | quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)] 5 | 6 | ### method 7 | stage: sft 8 | do_train: true 9 | finetuning_type: lora 10 | lora_target: all 11 | 12 | ### dataset 13 | dataset: identity,alpaca_en_demo 14 | template: llama3 15 | cutoff_len: 2048 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/llama3-8b/lora/sft 22 | logging_steps: 10 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | 27 | ### train 28 | per_device_train_batch_size: 1 29 | gradient_accumulation_steps: 8 30 | learning_rate: 1.0e-4 31 | num_train_epochs: 3.0 32 | lr_scheduler_type: cosine 33 | warmup_ratio: 0.1 34 | bf16: true 35 | ddp_timeout: 180000000 36 | 37 | ### eval 38 | val_size: 0.1 39 | per_device_eval_batch_size: 1 40 | eval_strategy: steps 41 | eval_steps: 500 42 | -------------------------------------------------------------------------------- /figure/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/figure/01.png -------------------------------------------------------------------------------- /figure/dataset.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/figure/dataset.pdf -------------------------------------------------------------------------------- /figure/dataset.png.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/figure/dataset.png.jpg -------------------------------------------------------------------------------- /figure/init: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /get_my_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import site 4 | 5 | def get_python_executable(): 6 | """ 7 | Get the path of the current Python interpreter. 8 | """ 9 | return sys.executable 10 | 11 | def construct_modeling_qwen2_vl_path(): 12 | """ 13 | Construct the absolute path to transformers/models/qwen2_vl/modeling_qwen2_vl.py. 14 | """ 15 | python_path = get_python_executable() 16 | print(f"Python executable path: {python_path}") 17 | 18 | # Get the site-packages directory 19 | site_packages = site.getsitepackages() 20 | if not site_packages: 21 | return "Cannot find site-packages directory" 22 | 23 | # Assume using the first site-packages path 24 | site_packages_path = site_packages[0] 25 | 26 | # Construct the path to the target file 27 | modeling_path = os.path.join( 28 | site_packages_path, 29 | 'transformers', 30 | 'models', 31 | 'qwen2_vl', 32 | 'modeling_qwen2_vl.py' 33 | ) 34 | 35 | if os.path.exists(modeling_path): 36 | return os.path.abspath(modeling_path) 37 | else: 38 | return f"file does not exist: {modeling_path}" 39 | 40 | def main(): 41 | modeling_path = construct_modeling_qwen2_vl_path() 42 | print(f"Absolute path of modeling_qwen2_vl.py: {modeling_path}") 43 | 44 | if __name__ == "__main__": 45 | main() -------------------------------------------------------------------------------- /metric/aesthetic/eval_koniq10k.py: -------------------------------------------------------------------------------- 1 | import json 2 | from scipy.stats import pearsonr, spearmanr 3 | 4 | def read_scores(file_path): 5 | labels = [] 6 | predictions = [] 7 | 8 | with open(file_path, 'r') as file: 9 | for line in file: 10 | try: 11 | data = json.loads(line) 12 | label_score = float(data['label'].split('score of this image is ')[1].strip('.')) 13 | # import ipdb;ipdb.set_trace() 14 | if '<' in data['predict'] and '>' in data['predict']: 15 | start = data['predict'].index('<') 16 | end = data['predict'].index('>') 17 | predict_score = float(data['predict'][start+1: end]) 18 | else: 19 | predict_score = float(data['predict'].split('score of this image is ')[1].strip('.')) 20 | 21 | labels.append(label_score) 22 | predictions.append(predict_score) 23 | except: 24 | print('失败') 25 | # import ipdb;ipdb.set_trace() 26 | 27 | return labels, predictions 28 | 29 | # 计算 PLCC 和 SRCC 30 | def calculate_correlations(labels, predictions): 31 | plcc, _ = pearsonr(labels, predictions) 32 | srcc, _ = spearmanr(labels, predictions) 33 | return plcc, srcc 34 | 35 | def main(file_path): 36 | labels, predictions = read_scores(file_path) 37 | plcc, srcc = calculate_correlations(labels, predictions) 38 | print(f"PLCC: {plcc}") 39 | print(f"SRCC: {srcc}") 40 | 41 | if __name__ == '__main__': 42 | import sys 43 | result_path = sys.argv[1] 44 | main(result_path) 45 | 46 | -------------------------------------------------------------------------------- /scripts/api_example/test_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from openai import OpenAI 18 | from transformers.utils.versions import require_version 19 | 20 | 21 | require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") 22 | 23 | 24 | def main(): 25 | client = OpenAI( 26 | api_key="{}".format(os.environ.get("API_KEY", "0")), 27 | base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), 28 | ) 29 | messages = [] 30 | messages.append( 31 | { 32 | "role": "user", 33 | "content": [ 34 | {"type": "text", "text": "Output the color and number of each box."}, 35 | { 36 | "type": "image_url", 37 | "image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"}, 38 | }, 39 | ], 40 | } 41 | ) 42 | result = client.chat.completions.create(messages=messages, model="test") 43 | messages.append(result.choices[0].message) 44 | print("Round 1:", result.choices[0].message.content) 45 | # The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ... 46 | messages.append( 47 | { 48 | "role": "user", 49 | "content": [ 50 | {"type": "text", "text": "What kind of flower is this?"}, 51 | { 52 | "type": "image_url", 53 | "image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"}, 54 | }, 55 | ], 56 | } 57 | ) 58 | result = client.chat.completions.create(messages=messages, model="test") 59 | messages.append(result.choices[0].message) 60 | print("Round 2:", result.choices[0].message.content) 61 | # The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ... 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /scripts/api_example/test_toolcall.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | from typing import Sequence 18 | 19 | from openai import OpenAI 20 | from transformers.utils.versions import require_version 21 | 22 | 23 | require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") 24 | 25 | 26 | def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: 27 | grade_to_score = {"A": 4, "B": 3, "C": 2} 28 | total_score, total_hour = 0, 0 29 | for grade, hour in zip(grades, hours): 30 | total_score += grade_to_score[grade] * hour 31 | total_hour += hour 32 | return round(total_score / total_hour, 2) 33 | 34 | 35 | def main(): 36 | client = OpenAI( 37 | api_key="{}".format(os.environ.get("API_KEY", "0")), 38 | base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), 39 | ) 40 | tools = [ 41 | { 42 | "type": "function", 43 | "function": { 44 | "name": "calculate_gpa", 45 | "description": "Calculate the Grade Point Average (GPA) based on grades and credit hours", 46 | "parameters": { 47 | "type": "object", 48 | "properties": { 49 | "grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"}, 50 | "hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"}, 51 | }, 52 | "required": ["grades", "hours"], 53 | }, 54 | }, 55 | } 56 | ] 57 | tool_map = {"calculate_gpa": calculate_gpa} 58 | 59 | messages = [] 60 | messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."}) 61 | result = client.chat.completions.create(messages=messages, model="test", tools=tools) 62 | if result.choices[0].message.tool_calls is None: 63 | raise ValueError("Cannot retrieve function call from the response.") 64 | 65 | messages.append(result.choices[0].message) 66 | tool_call = result.choices[0].message.tool_calls[0].function 67 | print(tool_call) 68 | # Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa') 69 | name, arguments = tool_call.name, json.loads(tool_call.arguments) 70 | tool_result = tool_map[name](**arguments) 71 | messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)}) 72 | result = client.chat.completions.create(messages=messages, model="test", tools=tools) 73 | print(result.choices[0].message.content) 74 | # Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42. 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /scripts/call_qwen_max.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import requests 4 | import dashscope 5 | import threading 6 | 7 | from typing import List, Callable 8 | from queue import Queue 9 | from http import HTTPStatus 10 | from dashscope import Generation 11 | from concurrent.futures import ThreadPoolExecutor, as_completed 12 | 13 | API_KEY = "sk-Nxkb8NTYp4" 14 | dashscope.api_key = API_KEY 15 | 16 | def call_qwen_single(prompt: str, model="qwen-max", max_try: int = 30): 17 | messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, 18 | {'role': 'user', 'content': prompt}] 19 | for _ in range(max_try): 20 | try: 21 | response = Generation.call(model=model, 22 | messages=messages, 23 | result_format='message') 24 | if response.status_code == HTTPStatus.OK: 25 | if response.output.choices[0]['message']['content'] is not None: 26 | return response.output.choices[0]['message']['content'] 27 | else: 28 | # print(response) 29 | raise RuntimeError("qwen需要重试") 30 | else: 31 | # print(response) 32 | raise RuntimeError("qwen需要重试") 33 | except Exception as e: 34 | print(f"qwen 调用失败,重试中: {str(e)}") 35 | # raise RuntimeError("qwen 没调用成功") 36 | time.sleep(3) 37 | 38 | # import pdb;pdb.set_trace() 39 | # 可能因为安全问题被过滤了 40 | return "安全问题被屏蔽" # 41 | 42 | def execute_call_llm_concurrently( 43 | prompt: List[str], 44 | model="qwen-max", 45 | max_try: int = 5, 46 | # max_workers: int = 3 47 | ) -> List[str]: 48 | """多线程调用qwen-max/gpt""" 49 | call_llm_single = call_qwen_single if "qwen" in model else call_gpt_single 50 | max_workers = len(prompt) 51 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 52 | futures = [executor.submit(call_llm_single, prompt[i], model, max_try) 53 | for i in range(len(prompt))] 54 | return [future.result() for future in futures] 55 | -------------------------------------------------------------------------------- /scripts/download_kaggle.py: -------------------------------------------------------------------------------- 1 | import kagglehub 2 | 3 | # Download latest version 4 | path = kagglehub.dataset_download("generalhawking/koniq-10k-dataset") 5 | 6 | print("Path to dataset files:", path) -------------------------------------------------------------------------------- /scripts/loftq_init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is based on the HuggingFace's PEFT library. 4 | # https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | from typing import TYPE_CHECKING 20 | 21 | import fire 22 | from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model 23 | from transformers import AutoModelForCausalLM, AutoTokenizer 24 | 25 | 26 | if TYPE_CHECKING: 27 | from transformers import PreTrainedModel 28 | 29 | 30 | def quantize_loftq( 31 | model_name_or_path: str, 32 | output_dir: str, 33 | loftq_bits: int = 4, 34 | loftq_iter: int = 4, 35 | lora_alpha: int = None, 36 | lora_rank: int = 16, 37 | lora_dropout: float = 0, 38 | lora_target: tuple = ("q_proj", "v_proj"), 39 | save_safetensors: bool = True, 40 | ): 41 | r""" 42 | Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) 43 | Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir 44 | """ 45 | if isinstance(lora_target, str): 46 | lora_target = [name.strip() for name in lora_target.split(",")] 47 | 48 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) 49 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") 50 | 51 | loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter) 52 | lora_config = LoraConfig( 53 | task_type=TaskType.CAUSAL_LM, 54 | inference_mode=True, 55 | r=lora_rank, 56 | lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, 57 | lora_dropout=lora_dropout, 58 | target_modules=lora_target, 59 | init_lora_weights="loftq", 60 | loftq_config=loftq_config, 61 | ) 62 | 63 | # Init LoftQ model 64 | print("Initializing LoftQ weights, it may be take several minutes, wait patiently.") 65 | peft_model = get_peft_model(model, lora_config) 66 | loftq_dir = os.path.join(output_dir, "loftq_init") 67 | 68 | # Save LoftQ model 69 | setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) 70 | setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again 71 | peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) 72 | print(f"Adapter weights saved in {loftq_dir}") 73 | 74 | # Save base model 75 | base_model: "PreTrainedModel" = peft_model.unload() 76 | base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) 77 | tokenizer.save_pretrained(output_dir) 78 | print(f"Model weights saved in {output_dir}") 79 | 80 | print("- Fine-tune this model with:") 81 | print(f"model_name_or_path: {output_dir}") 82 | print(f"adapter_name_or_path: {loftq_dir}") 83 | print("finetuning_type: lora") 84 | print(f"quantization_bit: {loftq_bits}") 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(quantize_loftq) 89 | -------------------------------------------------------------------------------- /scripts/pissa_init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is based on the HuggingFace's PEFT library. 4 | # https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | from typing import TYPE_CHECKING 20 | 21 | import fire 22 | from peft import LoraConfig, TaskType, get_peft_model 23 | from transformers import AutoModelForCausalLM, AutoTokenizer 24 | 25 | 26 | if TYPE_CHECKING: 27 | from transformers import PreTrainedModel 28 | 29 | 30 | def quantize_pissa( 31 | model_name_or_path: str, 32 | output_dir: str, 33 | pissa_iter: int = 16, 34 | lora_alpha: int = None, 35 | lora_rank: int = 16, 36 | lora_dropout: float = 0, 37 | lora_target: tuple = ("q_proj", "v_proj"), 38 | save_safetensors: bool = True, 39 | ): 40 | r""" 41 | Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA) 42 | Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir 43 | """ 44 | if isinstance(lora_target, str): 45 | lora_target = [name.strip() for name in lora_target.split(",")] 46 | 47 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) 48 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") 49 | 50 | lora_config = LoraConfig( 51 | task_type=TaskType.CAUSAL_LM, 52 | r=lora_rank, 53 | lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, 54 | lora_dropout=lora_dropout, 55 | target_modules=lora_target, 56 | init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}", 57 | ) 58 | 59 | # Init PiSSA model 60 | peft_model = get_peft_model(model, lora_config) 61 | pissa_dir = os.path.join(output_dir, "pissa_init") 62 | 63 | # Save PiSSA model 64 | setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) 65 | setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again 66 | peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) 67 | print(f"Adapter weights saved in {pissa_dir}") 68 | 69 | # Save base model 70 | base_model: "PreTrainedModel" = peft_model.unload() 71 | base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) 72 | tokenizer.save_pretrained(output_dir) 73 | print(f"Model weights saved in {output_dir}") 74 | 75 | print("- Fine-tune this model with:") 76 | print(f"model_name_or_path: {output_dir}") 77 | print(f"adapter_name_or_path: {pissa_dir}") 78 | print("finetuning_type: lora") 79 | print("pissa_init: false") 80 | print("pissa_convert: true") 81 | print("- and optionally with:") 82 | print("quantization_bit: 4") 83 | 84 | 85 | if __name__ == "__main__": 86 | fire.Fire(quantize_pissa) 87 | -------------------------------------------------------------------------------- /scripts/stat_utils/cal_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Microsoft Corporation and the LlamaFactory team. 2 | # 3 | # This code is inspired by the Microsoft's DeepSpeed library. 4 | # https://www.deepspeed.ai/tutorials/flops-profiler/ 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import fire 19 | import torch 20 | from deepspeed.accelerator import get_accelerator # type: ignore 21 | from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore 22 | 23 | from llamafactory.chat import ChatModel 24 | 25 | 26 | def calculate_flops( 27 | model_name_or_path: str, 28 | batch_size: int = 1, 29 | seq_length: int = 512, 30 | flash_attn: str = "auto", 31 | ): 32 | r""" 33 | Calculates the flops of pre-trained models. 34 | Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 35 | """ 36 | with get_accelerator().device(0): 37 | chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) 38 | fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device) 39 | input_dict = {"input_ids": fake_input, "labels": fake_input.clone()} 40 | flops, macs, params = get_model_profile( 41 | chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True 42 | ) 43 | print("FLOPs:", flops) 44 | print("MACs:", macs) 45 | print("Params:", params) 46 | 47 | 48 | if __name__ == "__main__": 49 | fire.Fire(calculate_flops) 50 | -------------------------------------------------------------------------------- /scripts/stat_utils/cal_lr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 imoneoi and the LlamaFactory team. 2 | # 3 | # This code is inspired by the imoneoi's OpenChat library. 4 | # https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import math 19 | from typing import Literal 20 | 21 | import fire 22 | import torch 23 | from torch.utils.data import DataLoader 24 | from tqdm import tqdm 25 | from transformers import DataCollatorForLanguageModeling 26 | 27 | from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer 28 | from llamafactory.extras.constants import IGNORE_INDEX 29 | from llamafactory.hparams import get_train_args 30 | from llamafactory.model import load_tokenizer 31 | 32 | 33 | BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models 34 | BASE_BS = 4_000_000 # from llama paper 35 | 36 | 37 | def calculate_lr( 38 | model_name_or_path: str, 39 | batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) 40 | stage: Literal["pt", "sft"] = "sft", 41 | dataset: str = "alpaca_en_demo", 42 | dataset_dir: str = "data", 43 | template: str = "default", 44 | cutoff_len: int = 1024, # i.e. maximum input length during training 45 | is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, 46 | packing: bool = False, 47 | ): 48 | r""" 49 | Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. 50 | Usage: 51 | python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16 52 | """ 53 | model_args, data_args, training_args, _, _ = get_train_args( 54 | dict( 55 | stage=stage, 56 | model_name_or_path=model_name_or_path, 57 | dataset=dataset, 58 | dataset_dir=dataset_dir, 59 | template=template, 60 | cutoff_len=cutoff_len, 61 | packing=packing, 62 | output_dir="dummy_dir", 63 | overwrite_cache=True, 64 | do_train=True, 65 | ) 66 | ) 67 | tokenizer_module = load_tokenizer(model_args) 68 | tokenizer = tokenizer_module["tokenizer"] 69 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 70 | trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] 71 | if stage == "pt": 72 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 73 | elif stage == "sft": 74 | data_collator = MultiModalDataCollatorForSeq2Seq( 75 | template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX 76 | ) 77 | else: 78 | raise NotImplementedError(f"Stage does not supported: {stage}.") 79 | 80 | dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) 81 | valid_tokens, total_tokens = 0, 0 82 | for batch in tqdm(dataloader): 83 | valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item() 84 | total_tokens += torch.numel(batch["labels"]) 85 | 86 | valid_ratio = valid_tokens / total_tokens 87 | token_batch_size = cutoff_len * batch_size * valid_ratio 88 | lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size) 89 | lr = lr / 6.0 if is_mistral_or_gemma else lr 90 | print( 91 | "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format( 92 | lr, valid_ratio * 100, token_batch_size 93 | ) 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | fire.Fire(calculate_lr) 99 | -------------------------------------------------------------------------------- /scripts/stat_utils/length_cdf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import defaultdict 16 | 17 | import fire 18 | from tqdm import tqdm 19 | 20 | from llamafactory.data import get_dataset, get_template_and_fix_tokenizer 21 | from llamafactory.hparams import get_train_args 22 | from llamafactory.model import load_tokenizer 23 | 24 | 25 | def length_cdf( 26 | model_name_or_path: str, 27 | dataset: str = "alpaca_en_demo", 28 | dataset_dir: str = "data", 29 | template: str = "default", 30 | interval: int = 1000, 31 | ): 32 | r""" 33 | Calculates the distribution of the input lengths in the dataset. 34 | Usage: export CUDA_VISIBLE_DEVICES=0 35 | python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default 36 | """ 37 | model_args, data_args, training_args, _, _ = get_train_args( 38 | dict( 39 | stage="sft", 40 | model_name_or_path=model_name_or_path, 41 | dataset=dataset, 42 | dataset_dir=dataset_dir, 43 | template=template, 44 | cutoff_len=1_000_000, 45 | output_dir="dummy_dir", 46 | overwrite_cache=True, 47 | do_train=True, 48 | ) 49 | ) 50 | tokenizer_module = load_tokenizer(model_args) 51 | template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) 52 | trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] 53 | total_num = len(trainset) 54 | length_dict = defaultdict(int) 55 | for sample in tqdm(trainset["input_ids"]): 56 | length_dict[len(sample) // interval * interval] += 1 57 | 58 | length_tuples = list(length_dict.items()) 59 | length_tuples.sort() 60 | count_accu, prob_accu = 0, 0 61 | for length, count in length_tuples: 62 | count_accu += count 63 | prob_accu += count / total_num * 100 64 | print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.") 65 | 66 | 67 | if __name__ == "__main__": 68 | fire.Fire(length_cdf) 69 | -------------------------------------------------------------------------------- /src/api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import uvicorn 18 | 19 | from llamafactory.api.app import create_app 20 | from llamafactory.chat import ChatModel 21 | 22 | 23 | def main(): 24 | chat_model = ChatModel() 25 | app = create_app(chat_model) 26 | api_host = os.getenv("API_HOST", "0.0.0.0") 27 | api_port = int(os.getenv("API_PORT", "8000")) 28 | print(f"Visit http://localhost:{api_port}/docs for API document.") 29 | uvicorn.run(app, host=api_host, port=api_port) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /src/llamafactory/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r""" 16 | Efficient fine-tuning of large language models. 17 | 18 | Level: 19 | api, webui > chat, eval, train > data, model > hparams > extras 20 | 21 | Dependency graph: 22 | main: 23 | transformers>=4.41.2,<=4.46.1 24 | datasets>=2.16.0,<=3.1.0 25 | accelerate>=0.34.0,<=1.0.1 26 | peft>=0.11.1,<=0.12.0 27 | trl>=0.8.6,<=0.9.6 28 | attention: 29 | transformers>=4.42.4 (gemma+fa2) 30 | longlora: 31 | transformers>=4.41.2,<=4.46.1 32 | packing: 33 | transformers>=4.41.2,<=4.46.1 34 | 35 | Disable version checking: DISABLE_VERSION_CHECK=1 36 | Enable VRAM recording: RECORD_VRAM=1 37 | Force check imports: FORCE_CHECK_IMPORTS=1 38 | Force using torchrun: FORCE_TORCHRUN=1 39 | Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN 40 | Use modelscope: USE_MODELSCOPE_HUB=1 41 | Use openmind: USE_OPENMIND_HUB=1 42 | """ 43 | 44 | from .extras.env import VERSION 45 | 46 | 47 | __version__ = VERSION 48 | -------------------------------------------------------------------------------- /src/llamafactory/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/api/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/api/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import TYPE_CHECKING, Any, Dict 17 | 18 | 19 | if TYPE_CHECKING: 20 | from pydantic import BaseModel 21 | 22 | 23 | def dictify(data: "BaseModel") -> Dict[str, Any]: 24 | try: # pydantic v2 25 | return data.model_dump(exclude_unset=True) 26 | except AttributeError: # pydantic v1 27 | return data.dict(exclude_unset=True) 28 | 29 | 30 | def jsonify(data: "BaseModel") -> str: 31 | try: # pydantic v2 32 | return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) 33 | except AttributeError: # pydantic v1 34 | return data.json(exclude_unset=True, ensure_ascii=False) 35 | -------------------------------------------------------------------------------- /src/llamafactory/chat/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base_engine import BaseEngine 16 | from .chat_model import ChatModel 17 | 18 | 19 | __all__ = ["BaseEngine", "ChatModel"] 20 | -------------------------------------------------------------------------------- /src/llamafactory/chat/base_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from dataclasses import dataclass 17 | from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers import PreTrainedModel, PreTrainedTokenizer 22 | from vllm import AsyncLLMEngine 23 | 24 | from ..data import Template 25 | from ..data.mm_plugin import ImageInput, VideoInput 26 | from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments 27 | 28 | 29 | @dataclass 30 | class Response: 31 | response_text: str 32 | response_length: int 33 | prompt_length: int 34 | finish_reason: Literal["stop", "length"] 35 | 36 | 37 | class BaseEngine(ABC): 38 | r""" 39 | Base class for inference engine of chat models. 40 | 41 | Must implements async methods: chat(), stream_chat() and get_scores(). 42 | """ 43 | 44 | model: Union["PreTrainedModel", "AsyncLLMEngine"] 45 | tokenizer: "PreTrainedTokenizer" 46 | can_generate: bool 47 | template: "Template" 48 | generating_args: Dict[str, Any] 49 | 50 | @abstractmethod 51 | def __init__( 52 | self, 53 | model_args: "ModelArguments", 54 | data_args: "DataArguments", 55 | finetuning_args: "FinetuningArguments", 56 | generating_args: "GeneratingArguments", 57 | ) -> None: 58 | r""" 59 | Initializes an inference engine. 60 | """ 61 | ... 62 | 63 | @abstractmethod 64 | async def chat( 65 | self, 66 | messages: Sequence[Dict[str, str]], 67 | system: Optional[str] = None, 68 | tools: Optional[str] = None, 69 | images: Optional[Sequence["ImageInput"]] = None, 70 | videos: Optional[Sequence["VideoInput"]] = None, 71 | **input_kwargs, 72 | ) -> List["Response"]: 73 | r""" 74 | Gets a list of responses of the chat model. 75 | """ 76 | ... 77 | 78 | @abstractmethod 79 | async def stream_chat( 80 | self, 81 | messages: Sequence[Dict[str, str]], 82 | system: Optional[str] = None, 83 | tools: Optional[str] = None, 84 | images: Optional[Sequence["ImageInput"]] = None, 85 | videos: Optional[Sequence["VideoInput"]] = None, 86 | **input_kwargs, 87 | ) -> AsyncGenerator[str, None]: 88 | r""" 89 | Gets the response token-by-token of the chat model. 90 | """ 91 | ... 92 | 93 | @abstractmethod 94 | async def get_scores( 95 | self, 96 | batch_input: List[str], 97 | **input_kwargs, 98 | ) -> List[float]: 99 | r""" 100 | Gets a list of scores of the reward model. 101 | """ 102 | ... 103 | -------------------------------------------------------------------------------- /src/llamafactory/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .collator import ( 16 | KTODataCollatorWithPadding, 17 | MultiModalDataCollatorForSeq2Seq, 18 | PairwiseDataCollatorWithPadding, 19 | SFTDataCollatorWith4DAttentionMask, 20 | ) 21 | from .data_utils import Role, split_dataset 22 | from .loader import get_dataset 23 | from .template import TEMPLATES, Template, get_template_and_fix_tokenizer 24 | 25 | 26 | __all__ = [ 27 | "KTODataCollatorWithPadding", 28 | "MultiModalDataCollatorForSeq2Seq", 29 | "PairwiseDataCollatorWithPadding", 30 | "SFTDataCollatorWith4DAttentionMask", 31 | "Role", 32 | "split_dataset", 33 | "get_dataset", 34 | "TEMPLATES", 35 | "Template", 36 | "get_template_and_fix_tokenizer", 37 | ] 38 | -------------------------------------------------------------------------------- /src/llamafactory/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum, unique 16 | from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union 17 | 18 | from datasets import DatasetDict, concatenate_datasets, interleave_datasets 19 | 20 | from ..extras import logging 21 | 22 | 23 | if TYPE_CHECKING: 24 | from datasets import Dataset, IterableDataset 25 | 26 | from ..hparams import DataArguments 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] 33 | 34 | 35 | @unique 36 | class Role(str, Enum): 37 | USER = "user" 38 | ASSISTANT = "assistant" 39 | SYSTEM = "system" 40 | FUNCTION = "function" 41 | OBSERVATION = "observation" 42 | 43 | 44 | class DatasetModule(TypedDict): 45 | train_dataset: Optional[Union["Dataset", "IterableDataset"]] 46 | eval_dataset: Optional[Union["Dataset", "IterableDataset"]] 47 | 48 | 49 | def merge_dataset( 50 | all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int 51 | ) -> Union["Dataset", "IterableDataset"]: 52 | r""" 53 | Merges multiple datasets to a unified dataset. 54 | """ 55 | if len(all_datasets) == 1: 56 | return all_datasets[0] 57 | elif data_args.mix_strategy == "concat": 58 | if data_args.streaming: 59 | logger.warning_once("The samples between different datasets will not be mixed in streaming mode.") 60 | 61 | return concatenate_datasets(all_datasets) 62 | elif data_args.mix_strategy.startswith("interleave"): 63 | if not data_args.streaming: 64 | logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.") 65 | 66 | return interleave_datasets( 67 | datasets=all_datasets, 68 | probabilities=data_args.interleave_probs, 69 | seed=seed, 70 | stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", 71 | ) 72 | else: 73 | raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.") 74 | 75 | 76 | def split_dataset( 77 | dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int 78 | ) -> "DatasetDict": 79 | r""" 80 | Splits the dataset and returns a dataset dict containing train set and validation set. 81 | 82 | Supports both map dataset and iterable dataset. 83 | """ 84 | if data_args.streaming: 85 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) 86 | val_set = dataset.take(int(data_args.val_size)) 87 | train_set = dataset.skip(int(data_args.val_size)) 88 | return DatasetDict({"train": train_set, "validation": val_set}) 89 | else: 90 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size 91 | dataset = dataset.train_test_split(test_size=val_size, seed=seed) 92 | return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) 93 | -------------------------------------------------------------------------------- /src/llamafactory/data/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/data/processors/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/data/processors/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from itertools import chain 19 | from typing import TYPE_CHECKING, Any, Dict, List 20 | 21 | 22 | if TYPE_CHECKING: 23 | from transformers import PreTrainedTokenizer 24 | 25 | from ...hparams import DataArguments 26 | 27 | 28 | def preprocess_pretrain_dataset( 29 | examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" 30 | ) -> Dict[str, List[Any]]: 31 | # build grouped texts with format `X1 X2 X3 ...` if packing is enabled 32 | eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token 33 | text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] 34 | 35 | if not data_args.packing: 36 | if data_args.template == "gemma": 37 | text_examples = [tokenizer.bos_token + example for example in text_examples] 38 | 39 | result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len) 40 | else: 41 | tokenized_examples = tokenizer(text_examples, add_special_tokens=False) 42 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} 43 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) 44 | block_size = data_args.cutoff_len 45 | total_length = (total_length // block_size) * block_size 46 | result = { 47 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 48 | for k, t in concatenated_examples.items() 49 | } 50 | if data_args.template == "gemma": 51 | for i in range(len(result["input_ids"])): 52 | result["input_ids"][i][0] = tokenizer.bos_token_id 53 | 54 | return result 55 | -------------------------------------------------------------------------------- /src/llamafactory/data/processors/processor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import bisect 16 | from typing import List, Sequence, Tuple 17 | 18 | 19 | def search_for_fit(numbers: Sequence[int], capacity: int) -> int: 20 | r""" 21 | Finds the index of largest number that fits into the knapsack with the given capacity. 22 | """ 23 | index = bisect.bisect(numbers, capacity) 24 | return -1 if index == 0 else (index - 1) 25 | 26 | 27 | def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: 28 | r""" 29 | An efficient greedy algorithm with binary search for the knapsack problem. 30 | """ 31 | numbers.sort() # sort numbers in ascending order for binary search 32 | knapsacks = [] 33 | 34 | while numbers: 35 | current_knapsack = [] 36 | remaining_capacity = capacity 37 | 38 | while True: 39 | index = search_for_fit(numbers, remaining_capacity) 40 | if index == -1: 41 | break # no more numbers fit in this knapsack 42 | 43 | remaining_capacity -= numbers[index] # update the remaining capacity 44 | current_knapsack.append(numbers.pop(index)) # add the number to knapsack 45 | 46 | knapsacks.append(current_knapsack) 47 | 48 | return knapsacks 49 | 50 | 51 | def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: 52 | r""" 53 | Computes the real sequence length after truncation by the cutoff_len. 54 | """ 55 | if target_len * 2 < cutoff_len: # truncate source 56 | max_target_len = cutoff_len 57 | elif source_len * 2 < cutoff_len: # truncate target 58 | max_target_len = cutoff_len - source_len 59 | else: # truncate both 60 | max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) 61 | 62 | new_target_len = min(max_target_len, target_len) 63 | max_source_len = max(cutoff_len - new_target_len, 0) 64 | new_source_len = min(max_source_len, source_len) 65 | return new_source_len, new_target_len 66 | -------------------------------------------------------------------------------- /src/llamafactory/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/eval/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/eval/template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Dict, List, Sequence, Tuple 17 | 18 | from ..data import Role 19 | from ..extras.constants import CHOICES 20 | 21 | 22 | @dataclass 23 | class EvalTemplate: 24 | system: str 25 | choice: str 26 | answer: str 27 | 28 | def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: 29 | r""" 30 | input: a dict with keys {"question", "A", "B", "C", "D", "answer"} 31 | output: a tuple of (prompt, response) 32 | """ 33 | candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] 34 | return "".join([example["question"]] + candidates + [self.answer]), example["answer"] 35 | 36 | def format_example( 37 | self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str 38 | ) -> List[Dict[str, str]]: 39 | r""" 40 | Converts dataset examples to messages. 41 | """ 42 | messages = [] 43 | for k in range(len(support_set)): 44 | prompt, response = self._parse_example(support_set[k]) 45 | messages.append({"role": Role.USER.value, "content": prompt}) 46 | messages.append({"role": Role.ASSISTANT.value, "content": response}) 47 | 48 | prompt, response = self._parse_example(target_data) 49 | messages.append({"role": Role.USER.value, "content": prompt}) 50 | messages.append({"role": Role.ASSISTANT.value, "content": response}) 51 | messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] 52 | return messages 53 | 54 | 55 | eval_templates: Dict[str, "EvalTemplate"] = {} 56 | 57 | 58 | def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: 59 | eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) 60 | 61 | 62 | def get_eval_template(name: str) -> "EvalTemplate": 63 | eval_template = eval_templates.get(name, None) 64 | assert eval_template is not None, f"Template {name} does not exist." 65 | return eval_template 66 | 67 | 68 | _register_eval_template( 69 | name="en", 70 | system="The following are multiple choice questions (with answers) about {subject}.\n\n", 71 | choice="\n{choice}. {content}", 72 | answer="\nAnswer:", 73 | ) 74 | 75 | 76 | _register_eval_template( 77 | name="zh", 78 | system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", 79 | choice="\n{choice}. {content}", 80 | answer="\n答案:", 81 | ) 82 | -------------------------------------------------------------------------------- /src/llamafactory/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/extras/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/extras/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import platform 19 | 20 | import accelerate 21 | import datasets 22 | import peft 23 | import torch 24 | import transformers 25 | import trl 26 | from transformers.utils import is_torch_cuda_available, is_torch_npu_available 27 | 28 | 29 | VERSION = "0.9.2.dev0" 30 | 31 | 32 | def print_env() -> None: 33 | info = { 34 | "`llamafactory` version": VERSION, 35 | "Platform": platform.platform(), 36 | "Python version": platform.python_version(), 37 | "PyTorch version": torch.__version__, 38 | "Transformers version": transformers.__version__, 39 | "Datasets version": datasets.__version__, 40 | "Accelerate version": accelerate.__version__, 41 | "PEFT version": peft.__version__, 42 | "TRL version": trl.__version__, 43 | } 44 | 45 | if is_torch_cuda_available(): 46 | info["PyTorch version"] += " (GPU)" 47 | info["GPU type"] = torch.cuda.get_device_name() 48 | 49 | if is_torch_npu_available(): 50 | info["PyTorch version"] += " (NPU)" 51 | info["NPU type"] = torch.npu.get_device_name() 52 | info["CANN version"] = torch.version.cann 53 | 54 | try: 55 | import deepspeed # type: ignore 56 | 57 | info["DeepSpeed version"] = deepspeed.__version__ 58 | except Exception: 59 | pass 60 | 61 | try: 62 | import bitsandbytes 63 | 64 | info["Bitsandbytes version"] = bitsandbytes.__version__ 65 | except Exception: 66 | pass 67 | 68 | try: 69 | import vllm 70 | 71 | info["vLLM version"] = vllm.__version__ 72 | except Exception: 73 | pass 74 | 75 | print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") 76 | -------------------------------------------------------------------------------- /src/llamafactory/extras/packages.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import importlib.metadata 19 | import importlib.util 20 | from functools import lru_cache 21 | from typing import TYPE_CHECKING 22 | 23 | from packaging import version 24 | 25 | 26 | if TYPE_CHECKING: 27 | from packaging.version import Version 28 | 29 | 30 | def _is_package_available(name: str) -> bool: 31 | return importlib.util.find_spec(name) is not None 32 | 33 | 34 | def _get_package_version(name: str) -> "Version": 35 | try: 36 | return version.parse(importlib.metadata.version(name)) 37 | except Exception: 38 | return version.parse("0.0.0") 39 | 40 | 41 | def is_pyav_available(): 42 | return _is_package_available("av") 43 | 44 | 45 | def is_fastapi_available(): 46 | return _is_package_available("fastapi") 47 | 48 | 49 | def is_galore_available(): 50 | return _is_package_available("galore_torch") 51 | 52 | 53 | def is_gradio_available(): 54 | return _is_package_available("gradio") 55 | 56 | 57 | def is_matplotlib_available(): 58 | return _is_package_available("matplotlib") 59 | 60 | 61 | def is_pillow_available(): 62 | return _is_package_available("PIL") 63 | 64 | 65 | def is_requests_available(): 66 | return _is_package_available("requests") 67 | 68 | 69 | def is_rouge_available(): 70 | return _is_package_available("rouge_chinese") 71 | 72 | 73 | def is_starlette_available(): 74 | return _is_package_available("sse_starlette") 75 | 76 | 77 | @lru_cache 78 | def is_transformers_version_greater_than(content: str): 79 | return _get_package_version("transformers") >= version.parse(content) 80 | 81 | 82 | @lru_cache 83 | def is_transformers_version_equal_to_4_46(): 84 | return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1") 85 | 86 | 87 | def is_uvicorn_available(): 88 | return _is_package_available("uvicorn") 89 | 90 | 91 | def is_vllm_available(): 92 | return _is_package_available("vllm") 93 | -------------------------------------------------------------------------------- /src/llamafactory/extras/ploting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import math 17 | import os 18 | from typing import Any, Dict, List 19 | 20 | from transformers.trainer import TRAINER_STATE_NAME 21 | 22 | from . import logging 23 | from .packages import is_matplotlib_available 24 | 25 | 26 | if is_matplotlib_available(): 27 | import matplotlib.figure 28 | import matplotlib.pyplot as plt 29 | 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | def smooth(scalars: List[float]) -> List[float]: 35 | r""" 36 | EMA implementation according to TensorBoard. 37 | """ 38 | if len(scalars) == 0: 39 | return [] 40 | 41 | last = scalars[0] 42 | smoothed = [] 43 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function 44 | for next_val in scalars: 45 | smoothed_val = last * weight + (1 - weight) * next_val 46 | smoothed.append(smoothed_val) 47 | last = smoothed_val 48 | return smoothed 49 | 50 | 51 | def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": 52 | r""" 53 | Plots loss curves in LlamaBoard. 54 | """ 55 | plt.close("all") 56 | plt.switch_backend("agg") 57 | fig = plt.figure() 58 | ax = fig.add_subplot(111) 59 | steps, losses = [], [] 60 | for log in trainer_log: 61 | if log.get("loss", None): 62 | steps.append(log["current_steps"]) 63 | losses.append(log["loss"]) 64 | 65 | ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") 66 | ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") 67 | ax.legend() 68 | ax.set_xlabel("step") 69 | ax.set_ylabel("loss") 70 | return fig 71 | 72 | 73 | def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: 74 | r""" 75 | Plots loss curves and saves the image. 76 | """ 77 | plt.switch_backend("agg") 78 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f: 79 | data = json.load(f) 80 | 81 | for key in keys: 82 | steps, metrics = [], [] 83 | for i in range(len(data["log_history"])): 84 | if key in data["log_history"][i]: 85 | steps.append(data["log_history"][i]["step"]) 86 | metrics.append(data["log_history"][i][key]) 87 | 88 | if len(metrics) == 0: 89 | logger.warning_rank0(f"No metric {key} to plot.") 90 | continue 91 | 92 | plt.figure() 93 | plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") 94 | plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") 95 | plt.title(f"training {key} of {save_dictionary}") 96 | plt.xlabel("step") 97 | plt.ylabel(key) 98 | plt.legend() 99 | figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_"))) 100 | plt.savefig(figure_path, format="png", dpi=100) 101 | print("Figure saved at:", figure_path) 102 | -------------------------------------------------------------------------------- /src/llamafactory/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .data_args import DataArguments 16 | from .evaluation_args import EvaluationArguments 17 | from .finetuning_args import FinetuningArguments 18 | from .generating_args import GeneratingArguments 19 | from .model_args import ModelArguments 20 | from .parser import get_eval_args, get_infer_args, get_train_args 21 | 22 | 23 | __all__ = [ 24 | "DataArguments", 25 | "EvaluationArguments", 26 | "FinetuningArguments", 27 | "GeneratingArguments", 28 | "ModelArguments", 29 | "get_eval_args", 30 | "get_infer_args", 31 | "get_train_args", 32 | ] 33 | -------------------------------------------------------------------------------- /src/llamafactory/hparams/evaluation_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass, field 17 | from typing import Literal, Optional 18 | 19 | from datasets import DownloadMode 20 | 21 | 22 | @dataclass 23 | class EvaluationArguments: 24 | r""" 25 | Arguments pertaining to specify the evaluation parameters. 26 | """ 27 | 28 | task: str = field( 29 | metadata={"help": "Name of the evaluation task."}, 30 | ) 31 | task_dir: str = field( 32 | default="evaluation", 33 | metadata={"help": "Path to the folder containing the evaluation datasets."}, 34 | ) 35 | batch_size: int = field( 36 | default=4, 37 | metadata={"help": "The batch size per GPU for evaluation."}, 38 | ) 39 | seed: int = field( 40 | default=42, 41 | metadata={"help": "Random seed to be used with data loaders."}, 42 | ) 43 | lang: Literal["en", "zh"] = field( 44 | default="en", 45 | metadata={"help": "Language used at evaluation."}, 46 | ) 47 | n_shot: int = field( 48 | default=5, 49 | metadata={"help": "Number of examplars for few-shot learning."}, 50 | ) 51 | save_dir: Optional[str] = field( 52 | default=None, 53 | metadata={"help": "Path to save the evaluation results."}, 54 | ) 55 | download_mode: DownloadMode = field( 56 | default=DownloadMode.REUSE_DATASET_IF_EXISTS, 57 | metadata={"help": "Download mode used for the evaluation datasets."}, 58 | ) 59 | 60 | def __post_init__(self): 61 | if self.save_dir is not None and os.path.exists(self.save_dir): 62 | raise ValueError("`save_dir` already exists, use another one.") 63 | -------------------------------------------------------------------------------- /src/llamafactory/hparams/generating_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import asdict, dataclass, field 16 | from typing import Any, Dict, Optional 17 | 18 | 19 | @dataclass 20 | class GeneratingArguments: 21 | r""" 22 | Arguments pertaining to specify the decoding parameters. 23 | """ 24 | 25 | do_sample: bool = field( 26 | default=True, 27 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, 28 | ) 29 | temperature: float = field( 30 | default=0.95, 31 | metadata={"help": "The value used to modulate the next token probabilities."}, 32 | ) 33 | top_p: float = field( 34 | default=0.7, 35 | metadata={ 36 | "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." 37 | }, 38 | ) 39 | top_k: int = field( 40 | default=50, 41 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, 42 | ) 43 | num_beams: int = field( 44 | default=1, 45 | metadata={"help": "Number of beams for beam search. 1 means no beam search."}, 46 | ) 47 | max_length: int = field( 48 | default=1024, 49 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, 50 | ) 51 | max_new_tokens: int = field( 52 | default=1024, 53 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, 54 | ) 55 | repetition_penalty: float = field( 56 | default=1.0, 57 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, 58 | ) 59 | length_penalty: float = field( 60 | default=1.0, 61 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, 62 | ) 63 | default_system: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "Default system message to use in chat completion."}, 66 | ) 67 | 68 | def to_dict(self) -> Dict[str, Any]: 69 | args = asdict(self) 70 | if args.get("max_new_tokens", -1) > 0: 71 | args.pop("max_length", None) 72 | else: 73 | args.pop("max_new_tokens", None) 74 | return args 75 | -------------------------------------------------------------------------------- /src/llamafactory/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llamafactory.train.tuner import run_exp # use absolute import 16 | 17 | 18 | def launch(): 19 | run_exp() 20 | 21 | 22 | if __name__ == "__main__": 23 | launch() 24 | -------------------------------------------------------------------------------- /src/llamafactory/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .loader import load_config, load_model, load_tokenizer 16 | from .model_utils.misc import find_all_linear_modules 17 | from .model_utils.quantization import QuantizationMethod 18 | from .model_utils.valuehead import load_valuehead_params 19 | 20 | 21 | __all__ = [ 22 | "QuantizationMethod", 23 | "load_config", 24 | "load_model", 25 | "load_tokenizer", 26 | "find_all_linear_modules", 27 | "load_valuehead_params", 28 | ] 29 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/model/model_utils/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available 18 | from transformers.utils.versions import require_version 19 | 20 | from ...extras import logging 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import PretrainedConfig 25 | 26 | from ...hparams import ModelArguments 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | def configure_attn_implementation( 33 | config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool 34 | ) -> None: 35 | if getattr(config, "model_type", None) == "gemma2" and is_trainable: 36 | if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": 37 | if is_flash_attn_2_available(): 38 | require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") 39 | require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") 40 | if model_args.flash_attn != "fa2": 41 | logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") 42 | model_args.flash_attn = "fa2" 43 | else: 44 | logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.") 45 | model_args.flash_attn = "disabled" 46 | elif model_args.flash_attn == "sdpa": 47 | logger.warning_rank0( 48 | "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." 49 | ) 50 | 51 | if model_args.flash_attn == "auto": 52 | return 53 | 54 | elif model_args.flash_attn == "disabled": 55 | requested_attn_implementation = "eager" 56 | 57 | elif model_args.flash_attn == "sdpa": 58 | if not is_torch_sdpa_available(): 59 | logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") 60 | return 61 | 62 | requested_attn_implementation = "sdpa" 63 | elif model_args.flash_attn == "fa2": 64 | if not is_flash_attn_2_available(): 65 | logger.warning_rank0("FlashAttention-2 is not installed.") 66 | return 67 | 68 | requested_attn_implementation = "flash_attention_2" 69 | else: 70 | raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}") 71 | 72 | if getattr(config, "model_type", None) == "internlm2": # special case for custom models 73 | setattr(config, "attn_implementation", requested_attn_implementation) 74 | else: 75 | setattr(config, "_attn_implementation", requested_attn_implementation) 76 | 77 | 78 | def print_attn_implementation(config: "PretrainedConfig") -> None: 79 | if getattr(config, "model_type", None) == "internlm2": # special case for custom models 80 | attn_implementation = getattr(config, "attn_implementation", None) 81 | else: 82 | attn_implementation = getattr(config, "_attn_implementation", None) 83 | 84 | if attn_implementation == "flash_attention_2": 85 | logger.info_rank0("Using FlashAttention-2 for faster training and inference.") 86 | elif attn_implementation == "sdpa": 87 | logger.info_rank0("Using torch SDPA for faster training and inference.") 88 | else: 89 | logger.info_rank0("Using vanilla attention implementation.") 90 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from contextlib import nullcontext 17 | from typing import TYPE_CHECKING 18 | 19 | import torch 20 | from transformers.integrations import is_deepspeed_zero3_enabled 21 | 22 | from ...extras import logging 23 | 24 | 25 | if TYPE_CHECKING: 26 | from transformers import PreTrainedModel, PreTrainedTokenizer 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: 33 | embedding_dim = embed_weight.size(1) 34 | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) 35 | noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) 36 | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) 37 | embed_weight[-num_new_tokens:] = avg_weight + noise_weight 38 | 39 | 40 | def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: 41 | r""" 42 | Resize token embeddings. 43 | """ 44 | if is_deepspeed_zero3_enabled(): 45 | import deepspeed # type: ignore 46 | 47 | params = [model.get_input_embeddings().weight] 48 | if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: 49 | params.append(model.get_output_embeddings().weight) 50 | 51 | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) 52 | else: 53 | context_maybe_zero3 = nullcontext() 54 | 55 | with context_maybe_zero3: 56 | current_embedding_size = model.get_input_embeddings().weight.size(0) 57 | 58 | if len(tokenizer) > current_embedding_size: 59 | if getattr(model, "quantization_method", None): 60 | raise ValueError("Cannot resize embedding layers of a quantized model.") 61 | 62 | if not isinstance(model.get_output_embeddings(), torch.nn.Linear): 63 | raise ValueError("Current model does not support resizing embedding layers.") 64 | 65 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) 66 | with context_maybe_zero3: 67 | new_embedding_size = model.get_input_embeddings().weight.size(0) 68 | num_new_tokens = new_embedding_size - current_embedding_size 69 | _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) 70 | _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) 71 | 72 | logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") 73 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/liger_kernel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import TYPE_CHECKING 17 | 18 | from ...extras import logging 19 | 20 | 21 | if TYPE_CHECKING: 22 | from transformers import PretrainedConfig 23 | 24 | from ...hparams import ModelArguments 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def apply_liger_kernel( 31 | config: "PretrainedConfig", 32 | model_args: "ModelArguments", 33 | is_trainable: bool, 34 | require_logits: bool, 35 | ) -> None: 36 | if not is_trainable or not model_args.enable_liger_kernel: 37 | return 38 | 39 | model_type = getattr(config, "model_type", None) 40 | if model_type == "gemma": 41 | from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel 42 | elif model_type == "gemma2": 43 | from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel 44 | elif model_type == "llama": 45 | from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel 46 | elif model_type == "mistral": 47 | from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel 48 | elif model_type == "mixtral": 49 | from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel 50 | elif model_type == "phi3": 51 | from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel 52 | elif model_type == "qwen2": 53 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel 54 | elif model_type == "qwen2_vl": 55 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel 56 | else: 57 | logger.warning_rank0("Current model does not support liger kernel.") 58 | return 59 | 60 | if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: 61 | logger.info_rank0("Current training stage does not support chunked cross entropy.") 62 | kwargs = {"fused_linear_cross_entropy": False} 63 | else: 64 | kwargs = {} 65 | 66 | apply_liger_kernel(**kwargs) 67 | logger.info_rank0("Liger kernel has been applied to the model.") 68 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, List 16 | 17 | from ...extras import logging 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer 22 | 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | 27 | def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: 28 | r""" 29 | Finds all available modules to apply lora or galore. 30 | """ 31 | model_type = getattr(model.config, "model_type", None) 32 | forbidden_modules = {"lm_head"} 33 | if model_type == "chatglm": 34 | forbidden_modules.add("output_layer") 35 | elif model_type == "internlm2": 36 | forbidden_modules.add("output") 37 | elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: 38 | forbidden_modules.add("multi_modal_projector") 39 | elif model_type == "qwen2_vl": 40 | forbidden_modules.add("merger") 41 | 42 | if freeze_vision_tower: 43 | if model_type == "mllama": 44 | forbidden_modules.add("vision_model") 45 | elif model_type == "qwen2_vl": 46 | forbidden_modules.add("visual") 47 | else: 48 | forbidden_modules.add("vision_tower") 49 | 50 | module_names = set() 51 | for name, module in model.named_modules(): 52 | if any(forbidden_module in name for forbidden_module in forbidden_modules): 53 | continue 54 | 55 | if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: 56 | module_names.add(name.split(".")[-1]) 57 | 58 | logger.info_rank0("Found linear modules: {}".format(",".join(module_names))) 59 | return list(module_names) 60 | 61 | 62 | def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: 63 | r""" 64 | Finds the modules in the expanded blocks to apply lora. 65 | """ 66 | num_layers = getattr(model.config, "num_hidden_layers", None) 67 | if not num_layers: 68 | raise ValueError("Model was not supported.") 69 | 70 | if num_layers % num_layer_trainable != 0: 71 | raise ValueError( 72 | f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}." 73 | ) 74 | 75 | stride = num_layers // num_layer_trainable 76 | trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) 77 | trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids] 78 | module_names = [] 79 | for name, _ in model.named_modules(): 80 | if any(target_module in name for target_module in target_modules) and any( 81 | trainable_layer in name for trainable_layer in trainable_layers 82 | ): 83 | module_names.append(name) 84 | 85 | logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) 86 | return module_names 87 | 88 | 89 | def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): 90 | if "AutoConfig" in getattr(config, "auto_map", {}): 91 | config.__class__.register_for_auto_class() 92 | if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): 93 | model.__class__.register_for_auto_class() 94 | if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): 95 | tokenizer.__class__.register_for_auto_class() 96 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/mod.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ...extras.constants import MOD_SUPPORTED_MODELS 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers import PretrainedConfig, PreTrainedModel 22 | 23 | from ...hparams import ModelArguments 24 | 25 | 26 | def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": 27 | from MoD import AutoMoDModelForCausalLM 28 | 29 | return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) 30 | 31 | 32 | def convert_pretrained_model_to_mod( 33 | model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" 34 | ) -> "PreTrainedModel": 35 | from MoD import apply_mod_to_hf 36 | 37 | if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: 38 | raise ValueError("Current model is not supported by mixture-of-depth.") 39 | 40 | model = apply_mod_to_hf(model) 41 | model = model.to(model_args.compute_dtype) 42 | return model 43 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/moe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Sequence 16 | 17 | import torch 18 | from transformers.integrations import is_deepspeed_zero3_enabled 19 | from transformers.utils.versions import require_version 20 | 21 | 22 | if TYPE_CHECKING: 23 | from transformers import PretrainedConfig, PreTrainedModel 24 | 25 | from ...hparams import ModelArguments 26 | 27 | 28 | def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: 29 | require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") 30 | from deepspeed.utils import set_z3_leaf_modules # type: ignore 31 | 32 | set_z3_leaf_modules(model, leaf_modules) 33 | 34 | 35 | def add_z3_leaf_module(model: "PreTrainedModel") -> None: 36 | r""" 37 | Sets module as a leaf module to skip partitioning in deepspeed zero3. 38 | """ 39 | if not is_deepspeed_zero3_enabled(): 40 | return 41 | 42 | model_type = getattr(model.config, "model_type", None) 43 | if model_type == "dbrx": 44 | from transformers.models.dbrx.modeling_dbrx import DbrxFFN 45 | 46 | _set_z3_leaf_modules(model, [DbrxFFN]) 47 | 48 | if model_type == "jamba": 49 | from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock 50 | 51 | _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) 52 | 53 | if model_type == "jetmoe": 54 | from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE 55 | 56 | _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) 57 | 58 | if model_type == "mixtral": 59 | from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock 60 | 61 | _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) 62 | 63 | if model_type == "qwen2moe": 64 | from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock 65 | 66 | _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) 67 | 68 | 69 | def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 70 | model_type = getattr(config, "model_type", None) 71 | if model_args.moe_aux_loss_coef is not None: 72 | if model_type in ["jamba", "mixtral", "qwen2_moe"]: 73 | setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) 74 | 75 | elif model_type == "deepseek": 76 | setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) 77 | 78 | elif model_type == "jetmoe": 79 | setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) 80 | 81 | if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: 82 | setattr(config, "output_router_logits", is_trainable) 83 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 LMSYS and the LlamaFactory team. 2 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 3 | # 4 | # This code is inspired by the LMSYS's FastChat library. 5 | # https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import math 20 | from typing import TYPE_CHECKING 21 | 22 | from ...extras import logging 23 | 24 | 25 | if TYPE_CHECKING: 26 | from transformers import PretrainedConfig 27 | 28 | from ...hparams import ModelArguments 29 | 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 35 | if model_args.rope_scaling is None: 36 | return 37 | 38 | if not hasattr(config, "rope_scaling"): 39 | logger.warning_rank0("Current model does not support RoPE scaling.") 40 | return 41 | 42 | if model_args.model_max_length is not None: 43 | if is_trainable and model_args.rope_scaling == "dynamic": 44 | logger.warning_rank0( 45 | "Dynamic NTK scaling may not work well with fine-tuning. " 46 | "See: https://github.com/huggingface/transformers/pull/24653" 47 | ) 48 | 49 | current_max_length = getattr(config, "max_position_embeddings", None) 50 | if current_max_length and model_args.model_max_length > current_max_length: 51 | logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") 52 | setattr(config, "max_position_embeddings", model_args.model_max_length) 53 | scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) 54 | else: 55 | logger.warning_rank0("Input length is smaller than max length. Consider increase input length.") 56 | scaling_factor = 1.0 57 | else: 58 | scaling_factor = 2.0 59 | 60 | setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) 61 | logger.info_rank0( 62 | f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}" 63 | ) 64 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/unsloth.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Any, Dict, Optional 16 | 17 | from ...extras import logging 18 | from ...extras.misc import get_current_device 19 | 20 | 21 | if TYPE_CHECKING: 22 | from transformers import PretrainedConfig, PreTrainedModel 23 | 24 | from ...hparams import ModelArguments 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def _get_unsloth_kwargs( 31 | config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" 32 | ) -> Dict[str, Any]: 33 | return { 34 | "model_name": model_name_or_path, 35 | "max_seq_length": model_args.model_max_length or 4096, 36 | "dtype": model_args.compute_dtype, 37 | "load_in_4bit": model_args.quantization_bit == 4, 38 | "token": model_args.hf_hub_token, 39 | "device_map": {"": get_current_device()}, 40 | "rope_scaling": getattr(config, "rope_scaling", None), 41 | "fix_tokenizer": False, 42 | "trust_remote_code": True, 43 | "use_gradient_checkpointing": "unsloth", 44 | } 45 | 46 | 47 | def load_unsloth_pretrained_model( 48 | config: "PretrainedConfig", model_args: "ModelArguments" 49 | ) -> Optional["PreTrainedModel"]: 50 | r""" 51 | Optionally loads pretrained model with unsloth. Used in training. 52 | """ 53 | from unsloth import FastLanguageModel 54 | 55 | unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) 56 | try: 57 | model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) 58 | except NotImplementedError: 59 | logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) 60 | model = None 61 | model_args.use_unsloth = False 62 | 63 | return model 64 | 65 | 66 | def get_unsloth_peft_model( 67 | model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] 68 | ) -> "PreTrainedModel": 69 | r""" 70 | Gets the peft model for the pretrained model with unsloth. Used in training. 71 | """ 72 | from unsloth import FastLanguageModel 73 | 74 | unsloth_peft_kwargs = { 75 | "model": model, 76 | "max_seq_length": model_args.model_max_length, 77 | "use_gradient_checkpointing": "unsloth", 78 | } 79 | return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) 80 | 81 | 82 | def load_unsloth_peft_model( 83 | config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool 84 | ) -> "PreTrainedModel": 85 | r""" 86 | Loads peft model with unsloth. Used in both training and inference. 87 | """ 88 | from unsloth import FastLanguageModel 89 | 90 | unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) 91 | try: 92 | if not is_trainable: 93 | unsloth_kwargs["use_gradient_checkpointing"] = False 94 | 95 | model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) 96 | except NotImplementedError: 97 | raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) 98 | 99 | if not is_trainable: 100 | FastLanguageModel.for_inference(model) 101 | 102 | return model 103 | -------------------------------------------------------------------------------- /src/llamafactory/model/model_utils/valuehead.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | import torch 18 | from transformers.utils import cached_file 19 | 20 | from ...extras import logging 21 | from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME 22 | 23 | 24 | if TYPE_CHECKING: 25 | from transformers import PreTrainedModel 26 | 27 | from ...hparams import ModelArguments 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: 34 | r""" 35 | Loads value head parameters from Hugging Face Hub or local disk. 36 | 37 | Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. 38 | """ 39 | kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} 40 | err_text = "" 41 | 42 | try: 43 | from safetensors import safe_open 44 | 45 | vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) 46 | with safe_open(vhead_file, framework="pt", device="cpu") as f: 47 | return {key: f.get_tensor(key) for key in f.keys()} 48 | except Exception as err: 49 | err_text = str(err) 50 | 51 | try: 52 | vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) 53 | return torch.load(vhead_file, map_location="cpu") 54 | except Exception as err: 55 | err_text = str(err) 56 | 57 | logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.") 58 | logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.") 59 | return None 60 | 61 | 62 | def prepare_valuehead_model(model: "PreTrainedModel") -> None: 63 | if getattr(model.config, "model_type", None) == "llava": 64 | setattr(model, "lm_head", model.language_model.get_output_embeddings()) 65 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 66 | 67 | if getattr(model.config, "model_type", None) == "chatglm": 68 | setattr(model, "lm_head", model.transformer.output_layer) 69 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 70 | 71 | if getattr(model.config, "model_type", None) == "internlm2": 72 | setattr(model, "lm_head", model.output) 73 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 74 | -------------------------------------------------------------------------------- /src/llamafactory/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/train/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/train/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_dpo 16 | 17 | 18 | __all__ = ["run_dpo"] 19 | -------------------------------------------------------------------------------- /src/llamafactory/train/kto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_kto 16 | 17 | 18 | __all__ = ["run_kto"] 19 | -------------------------------------------------------------------------------- /src/llamafactory/train/kto/workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's TRL library. 4 | # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import TYPE_CHECKING, List, Optional 19 | 20 | from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer 21 | from ...extras.constants import IGNORE_INDEX 22 | from ...extras.ploting import plot_loss 23 | from ...hparams import ModelArguments 24 | from ...model import load_model, load_tokenizer 25 | from ..trainer_utils import create_modelcard_and_push, create_ref_model 26 | from .trainer import CustomKTOTrainer 27 | 28 | 29 | if TYPE_CHECKING: 30 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 31 | 32 | from ...hparams import DataArguments, FinetuningArguments 33 | 34 | 35 | def run_kto( 36 | model_args: "ModelArguments", 37 | data_args: "DataArguments", 38 | training_args: "Seq2SeqTrainingArguments", 39 | finetuning_args: "FinetuningArguments", 40 | callbacks: Optional[List["TrainerCallback"]] = None, 41 | ): 42 | tokenizer_module = load_tokenizer(model_args) 43 | tokenizer = tokenizer_module["tokenizer"] 44 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 45 | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="kto", **tokenizer_module) 46 | model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) 47 | 48 | data_collator = KTODataCollatorWithPadding( 49 | template=template, 50 | pad_to_multiple_of=8, 51 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, 52 | **tokenizer_module, 53 | ) 54 | 55 | # Create reference model 56 | if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself 57 | ref_model = model 58 | else: 59 | ref_model = create_ref_model(model_args, finetuning_args) 60 | 61 | # Update arguments 62 | training_args.remove_unused_columns = False # important for multimodal and pairwise dataset 63 | 64 | # Initialize our Trainer 65 | trainer = CustomKTOTrainer( 66 | model=model, 67 | ref_model=ref_model, 68 | args=training_args, 69 | finetuning_args=finetuning_args, 70 | data_collator=data_collator, 71 | callbacks=callbacks, 72 | **dataset_module, 73 | **tokenizer_module, 74 | ) 75 | 76 | # Training 77 | if training_args.do_train: 78 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 79 | trainer.save_model() 80 | trainer.log_metrics("train", train_result.metrics) 81 | trainer.save_metrics("train", train_result.metrics) 82 | trainer.save_state() 83 | if trainer.is_world_process_zero() and finetuning_args.plot_loss: 84 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"]) 85 | 86 | # Evaluation 87 | if training_args.do_eval: 88 | metrics = trainer.evaluate(metric_key_prefix="eval") 89 | if id(model) == id(ref_model): # unable to compute rewards without a reference model 90 | remove_keys = [key for key in metrics.keys() if "rewards" in key] 91 | for key in remove_keys: 92 | metrics.pop(key) 93 | trainer.log_metrics("eval", metrics) 94 | trainer.save_metrics("eval", metrics) 95 | 96 | # Create model card 97 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) 98 | -------------------------------------------------------------------------------- /src/llamafactory/train/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_ppo 16 | 17 | 18 | __all__ = ["run_ppo"] 19 | -------------------------------------------------------------------------------- /src/llamafactory/train/ppo/ppo_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from contextlib import nullcontext 17 | from typing import TYPE_CHECKING, Dict, List, Literal, Optional 18 | 19 | import torch 20 | from transformers.integrations import is_deepspeed_zero3_enabled 21 | 22 | from ...extras.packages import is_requests_available 23 | 24 | 25 | if is_requests_available(): 26 | import requests 27 | 28 | 29 | if TYPE_CHECKING: 30 | from transformers import PreTrainedModel 31 | from trl import AutoModelForCausalLMWithValueHead 32 | 33 | 34 | def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]: 35 | r""" 36 | Gets reward scores from the API server. 37 | """ 38 | headers = {"Content-Type": "application/json"} 39 | payload = {"model": "model", "messages": messages} 40 | response = requests.post(server_url, json=payload, headers=headers) 41 | rewards = json.loads(response.text)["scores"] 42 | return torch.Tensor(rewards) 43 | 44 | 45 | def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: 46 | r""" 47 | Replaces the default/reward modules in the model. The model is already unwrapped. 48 | """ 49 | v_head_layer = model.v_head.summary 50 | if is_deepspeed_zero3_enabled(): 51 | import deepspeed # type: ignore 52 | 53 | params = [v_head_layer.weight, v_head_layer.bias] 54 | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) 55 | else: 56 | context_maybe_zero3 = nullcontext() 57 | 58 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active 59 | with context_maybe_zero3: 60 | if target == "reward": # save default head temporarily 61 | setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone()) 62 | setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) 63 | 64 | device = v_head_layer.weight.device 65 | v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device) 66 | v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device) 67 | 68 | 69 | def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: 70 | r""" 71 | Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). 72 | """ 73 | layer_norm_params = {} 74 | for name, param in model.named_parameters(): 75 | if param.data.dtype == torch.float32: 76 | layer_norm_params[name] = param.data.detach().clone() 77 | param.data = param.data.to(model.config.torch_dtype) 78 | 79 | return layer_norm_params 80 | 81 | 82 | def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None: 83 | r""" 84 | Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). 85 | """ 86 | for name, param in model.named_parameters(): 87 | if name in layernorm_params: 88 | param.data = layernorm_params[name] 89 | -------------------------------------------------------------------------------- /src/llamafactory/train/ppo/workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's TRL library. 4 | # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import TYPE_CHECKING, List, Optional 19 | 20 | from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer 21 | from ...extras.ploting import plot_loss 22 | from ...model import load_model, load_tokenizer 23 | from ..callbacks import fix_valuehead_checkpoint 24 | from ..trainer_utils import create_ref_model, create_reward_model 25 | from .trainer import CustomPPOTrainer 26 | 27 | 28 | if TYPE_CHECKING: 29 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 30 | 31 | from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments 32 | 33 | 34 | def run_ppo( 35 | model_args: "ModelArguments", 36 | data_args: "DataArguments", 37 | training_args: "Seq2SeqTrainingArguments", 38 | finetuning_args: "FinetuningArguments", 39 | generating_args: "GeneratingArguments", 40 | callbacks: Optional[List["TrainerCallback"]] = None, 41 | ): 42 | tokenizer_module = load_tokenizer(model_args) 43 | tokenizer = tokenizer_module["tokenizer"] 44 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 45 | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module) 46 | model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) 47 | 48 | tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training 49 | data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module) 50 | 51 | # Create reference model and reward model 52 | ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) 53 | reward_model = create_reward_model(model, model_args, finetuning_args) 54 | 55 | # Initialize our Trainer 56 | ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer( 57 | model_args=model_args, 58 | training_args=training_args, 59 | finetuning_args=finetuning_args, 60 | generating_args=generating_args, 61 | callbacks=callbacks, 62 | model=model, 63 | reward_model=reward_model, 64 | ref_model=ref_model, 65 | data_collator=data_collator, 66 | **dataset_module, 67 | **tokenizer_module, 68 | ) 69 | 70 | # Training 71 | if training_args.do_train: 72 | ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) 73 | ppo_trainer.save_model() 74 | if training_args.should_save: 75 | fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) 76 | 77 | ppo_trainer.save_state() # must be called after save_model to have a folder 78 | if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: 79 | plot_loss(training_args.output_dir, keys=["loss", "reward"]) 80 | -------------------------------------------------------------------------------- /src/llamafactory/train/pt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_pt 16 | 17 | 18 | __all__ = ["run_pt"] 19 | -------------------------------------------------------------------------------- /src/llamafactory/train/pt/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from types import MethodType 16 | from typing import TYPE_CHECKING, Optional 17 | 18 | from transformers import Trainer 19 | from typing_extensions import override 20 | 21 | from ...extras.packages import is_transformers_version_equal_to_4_46 22 | from ..callbacks import PissaConvertCallback, SaveProcessorCallback 23 | from ..trainer_utils import create_custom_optimizer, create_custom_scheduler 24 | 25 | 26 | if TYPE_CHECKING: 27 | import torch 28 | from transformers import ProcessorMixin 29 | 30 | from ...hparams import FinetuningArguments 31 | 32 | 33 | class CustomTrainer(Trainer): 34 | r""" 35 | Inherits Trainer for custom optimizer. 36 | """ 37 | 38 | def __init__( 39 | self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs 40 | ) -> None: 41 | super().__init__(**kwargs) 42 | self.finetuning_args = finetuning_args 43 | 44 | if processor is not None: 45 | self.add_callback(SaveProcessorCallback(processor)) 46 | 47 | if finetuning_args.pissa_convert: 48 | self.add_callback(PissaConvertCallback) 49 | 50 | if finetuning_args.use_badam: 51 | from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore 52 | 53 | self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) 54 | self.add_callback(BAdamCallback) 55 | 56 | @override 57 | def create_optimizer(self) -> "torch.optim.Optimizer": 58 | if self.optimizer is None: 59 | self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) 60 | return super().create_optimizer() 61 | 62 | @override 63 | def create_scheduler( 64 | self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None 65 | ) -> "torch.optim.lr_scheduler.LRScheduler": 66 | create_custom_scheduler(self.args, num_training_steps, optimizer) 67 | return super().create_scheduler(num_training_steps, optimizer) 68 | 69 | @override 70 | def compute_loss(self, model, inputs, return_outputs=False, **kwargs): 71 | r""" 72 | Fixes the loss value for transformers 4.46.0. 73 | https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 74 | """ 75 | loss = super().compute_loss(model, inputs, return_outputs, **kwargs) 76 | if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): 77 | # other model should not scale the loss 78 | if return_outputs: 79 | return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) 80 | else: 81 | return loss / self.args.gradient_accumulation_steps 82 | 83 | return loss 84 | -------------------------------------------------------------------------------- /src/llamafactory/train/pt/workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import math 19 | from typing import TYPE_CHECKING, List, Optional 20 | 21 | from transformers import DataCollatorForLanguageModeling 22 | 23 | from ...data import get_dataset, get_template_and_fix_tokenizer 24 | from ...extras.ploting import plot_loss 25 | from ...model import load_model, load_tokenizer 26 | from ..trainer_utils import create_modelcard_and_push 27 | from .trainer import CustomTrainer 28 | 29 | 30 | if TYPE_CHECKING: 31 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 32 | 33 | from ...hparams import DataArguments, FinetuningArguments, ModelArguments 34 | 35 | 36 | def run_pt( 37 | model_args: "ModelArguments", 38 | data_args: "DataArguments", 39 | training_args: "Seq2SeqTrainingArguments", 40 | finetuning_args: "FinetuningArguments", 41 | callbacks: Optional[List["TrainerCallback"]] = None, 42 | ): 43 | tokenizer_module = load_tokenizer(model_args) 44 | tokenizer = tokenizer_module["tokenizer"] 45 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 46 | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) 47 | model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) 48 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 49 | 50 | # Initialize our Trainer 51 | trainer = CustomTrainer( 52 | model=model, 53 | args=training_args, 54 | finetuning_args=finetuning_args, 55 | data_collator=data_collator, 56 | callbacks=callbacks, 57 | **dataset_module, 58 | **tokenizer_module, 59 | ) 60 | 61 | # Training 62 | if training_args.do_train: 63 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 64 | trainer.save_model() 65 | trainer.log_metrics("train", train_result.metrics) 66 | trainer.save_metrics("train", train_result.metrics) 67 | trainer.save_state() 68 | if trainer.is_world_process_zero() and finetuning_args.plot_loss: 69 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 70 | 71 | # Evaluation 72 | if training_args.do_eval: 73 | metrics = trainer.evaluate(metric_key_prefix="eval") 74 | try: 75 | perplexity = math.exp(metrics["eval_loss"]) 76 | except OverflowError: 77 | perplexity = float("inf") 78 | 79 | metrics["perplexity"] = perplexity 80 | trainer.log_metrics("eval", metrics) 81 | trainer.save_metrics("eval", metrics) 82 | 83 | # Create model card 84 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) 85 | -------------------------------------------------------------------------------- /src/llamafactory/train/rm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_rm 16 | 17 | 18 | __all__ = ["run_rm"] 19 | -------------------------------------------------------------------------------- /src/llamafactory/train/rm/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import TYPE_CHECKING, Dict, Optional 17 | 18 | import numpy as np 19 | 20 | from ...extras.misc import numpify 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import EvalPrediction 25 | 26 | 27 | @dataclass 28 | class ComputeAccuracy: 29 | r""" 30 | Computes reward accuracy and supports `batch_eval_metrics`. 31 | """ 32 | 33 | def _dump(self) -> Optional[Dict[str, float]]: 34 | result = None 35 | if hasattr(self, "score_dict"): 36 | result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} 37 | 38 | self.score_dict = {"accuracy": []} 39 | return result 40 | 41 | def __post_init__(self): 42 | self._dump() 43 | 44 | def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: 45 | chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) 46 | if not chosen_scores.shape: 47 | self.score_dict["accuracy"].append(chosen_scores > rejected_scores) 48 | else: 49 | for i in range(len(chosen_scores)): 50 | self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) 51 | 52 | if compute_result: 53 | return self._dump() 54 | -------------------------------------------------------------------------------- /src/llamafactory/train/rm/workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import TYPE_CHECKING, List, Optional 19 | 20 | from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer 21 | from ...extras.ploting import plot_loss 22 | from ...model import load_model, load_tokenizer 23 | from ..callbacks import fix_valuehead_checkpoint 24 | from ..trainer_utils import create_modelcard_and_push 25 | from .metric import ComputeAccuracy 26 | from .trainer import PairwiseTrainer 27 | 28 | 29 | if TYPE_CHECKING: 30 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 31 | 32 | from ...hparams import DataArguments, FinetuningArguments, ModelArguments 33 | 34 | 35 | def run_rm( 36 | model_args: "ModelArguments", 37 | data_args: "DataArguments", 38 | training_args: "Seq2SeqTrainingArguments", 39 | finetuning_args: "FinetuningArguments", 40 | callbacks: Optional[List["TrainerCallback"]] = None, 41 | ): 42 | tokenizer_module = load_tokenizer(model_args) 43 | tokenizer = tokenizer_module["tokenizer"] 44 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 45 | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) 46 | model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) 47 | data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module) 48 | 49 | # Update arguments 50 | training_args.remove_unused_columns = False # important for multimodal and pairwise dataset 51 | 52 | # Initialize our Trainer 53 | trainer = PairwiseTrainer( 54 | model=model, 55 | args=training_args, 56 | finetuning_args=finetuning_args, 57 | data_collator=data_collator, 58 | callbacks=callbacks, 59 | compute_metrics=ComputeAccuracy(), 60 | **dataset_module, 61 | **tokenizer_module, 62 | ) 63 | 64 | # Training 65 | if training_args.do_train: 66 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 67 | trainer.save_model() 68 | if training_args.should_save: 69 | fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) 70 | 71 | trainer.log_metrics("train", train_result.metrics) 72 | trainer.save_metrics("train", train_result.metrics) 73 | trainer.save_state() 74 | if trainer.is_world_process_zero() and finetuning_args.plot_loss: 75 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) 76 | 77 | # Evaluation 78 | if training_args.do_eval: 79 | metrics = trainer.evaluate(metric_key_prefix="eval") 80 | trainer.log_metrics("eval", metrics) 81 | trainer.save_metrics("eval", metrics) 82 | 83 | # Predict 84 | if training_args.do_predict: 85 | predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict") 86 | trainer.log_metrics("predict", predict_results.metrics) 87 | trainer.save_metrics("predict", predict_results.metrics) 88 | trainer.save_predictions(predict_results) 89 | 90 | # Create model card 91 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) 92 | -------------------------------------------------------------------------------- /src/llamafactory/train/sft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_sft 16 | 17 | 18 | __all__ = ["run_sft"] 19 | -------------------------------------------------------------------------------- /src/llamafactory/webui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/RealQA/d0b178194b02d7f9f96b428e10bccc5d117f74c0/src/llamafactory/webui/__init__.py -------------------------------------------------------------------------------- /src/llamafactory/webui/components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .chatbot import create_chat_box 16 | from .eval import create_eval_tab 17 | from .export import create_export_tab 18 | from .infer import create_infer_tab 19 | from .top import create_top 20 | from .train import create_train_tab 21 | 22 | 23 | __all__ = [ 24 | "create_chat_box", 25 | "create_eval_tab", 26 | "create_export_tab", 27 | "create_infer_tab", 28 | "create_top", 29 | "create_train_tab", 30 | ] 31 | -------------------------------------------------------------------------------- /src/llamafactory/webui/components/chatbot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict, Tuple 16 | 17 | from ...data import Role 18 | from ...extras.packages import is_gradio_available 19 | from ..utils import check_json_schema 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_chat_box( 33 | engine: "Engine", visible: bool = False 34 | ) -> Tuple["Component", "Component", Dict[str, "Component"]]: 35 | with gr.Column(visible=visible) as chat_box: 36 | chatbot = gr.Chatbot(show_copy_button=True) 37 | messages = gr.State([]) 38 | with gr.Row(): 39 | with gr.Column(scale=4): 40 | with gr.Row(): 41 | with gr.Column(): 42 | role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) 43 | system = gr.Textbox(show_label=False) 44 | tools = gr.Textbox(show_label=False, lines=3) 45 | 46 | with gr.Column() as mm_box: 47 | with gr.Tab("Image"): 48 | image = gr.Image(sources=["upload"], type="pil") 49 | 50 | with gr.Tab("Video"): 51 | video = gr.Video(sources=["upload"]) 52 | 53 | query = gr.Textbox(show_label=False, lines=8) 54 | submit_btn = gr.Button(variant="primary") 55 | 56 | with gr.Column(scale=1): 57 | max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) 58 | top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01) 59 | temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) 60 | clear_btn = gr.Button() 61 | 62 | tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) 63 | 64 | submit_btn.click( 65 | engine.chatter.append, 66 | [chatbot, messages, role, query], 67 | [chatbot, messages, query], 68 | ).then( 69 | engine.chatter.stream, 70 | [chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature], 71 | [chatbot, messages], 72 | ) 73 | clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) 74 | 75 | return ( 76 | chatbot, 77 | messages, 78 | dict( 79 | chat_box=chat_box, 80 | role=role, 81 | system=system, 82 | tools=tools, 83 | mm_box=mm_box, 84 | image=image, 85 | video=video, 86 | query=query, 87 | submit_btn=submit_btn, 88 | max_new_tokens=max_new_tokens, 89 | top_p=top_p, 90 | temperature=temperature, 91 | clear_btn=clear_btn, 92 | ), 93 | ) 94 | -------------------------------------------------------------------------------- /src/llamafactory/webui/components/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | from ...extras.packages import is_gradio_available 18 | from ..common import DEFAULT_DATA_DIR, list_datasets 19 | from .data import create_preview_box 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: 33 | input_elems = engine.manager.get_base_elems() 34 | elem_dict = dict() 35 | 36 | with gr.Row(): 37 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 38 | dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) 39 | preview_elems = create_preview_box(dataset_dir, dataset) 40 | 41 | input_elems.update({dataset_dir, dataset}) 42 | elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) 43 | 44 | with gr.Row(): 45 | cutoff_len = gr.Slider(minimum=4, maximum=131072, value=1024, step=1) 46 | max_samples = gr.Textbox(value="100000") 47 | batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1) 48 | predict = gr.Checkbox(value=True) 49 | 50 | input_elems.update({cutoff_len, max_samples, batch_size, predict}) 51 | elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict)) 52 | 53 | with gr.Row(): 54 | max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) 55 | top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01) 56 | temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) 57 | output_dir = gr.Textbox() 58 | 59 | input_elems.update({max_new_tokens, top_p, temperature, output_dir}) 60 | elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir)) 61 | 62 | with gr.Row(): 63 | cmd_preview_btn = gr.Button() 64 | start_btn = gr.Button(variant="primary") 65 | stop_btn = gr.Button(variant="stop") 66 | 67 | with gr.Row(): 68 | resume_btn = gr.Checkbox(visible=False, interactive=False) 69 | progress_bar = gr.Slider(visible=False, interactive=False) 70 | 71 | with gr.Row(): 72 | output_box = gr.Markdown() 73 | 74 | elem_dict.update( 75 | dict( 76 | cmd_preview_btn=cmd_preview_btn, 77 | start_btn=start_btn, 78 | stop_btn=stop_btn, 79 | resume_btn=resume_btn, 80 | progress_bar=progress_bar, 81 | output_box=output_box, 82 | ) 83 | ) 84 | output_elems = [output_box, progress_bar] 85 | 86 | cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None) 87 | start_btn.click(engine.runner.run_eval, input_elems, output_elems) 88 | stop_btn.click(engine.runner.set_abort) 89 | resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) 90 | 91 | dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False) 92 | 93 | return elem_dict 94 | -------------------------------------------------------------------------------- /src/llamafactory/webui/components/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | from ...extras.packages import is_gradio_available 18 | from ..common import get_visual 19 | from .chatbot import create_chat_box 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: 33 | input_elems = engine.manager.get_base_elems() 34 | elem_dict = dict() 35 | 36 | with gr.Row(): 37 | infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") 38 | infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") 39 | 40 | with gr.Row(): 41 | load_btn = gr.Button() 42 | unload_btn = gr.Button() 43 | 44 | info_box = gr.Textbox(show_label=False, interactive=False) 45 | 46 | input_elems.update({infer_backend, infer_dtype}) 47 | elem_dict.update( 48 | dict( 49 | infer_backend=infer_backend, 50 | infer_dtype=infer_dtype, 51 | load_btn=load_btn, 52 | unload_btn=unload_btn, 53 | info_box=info_box, 54 | ) 55 | ) 56 | 57 | chatbot, messages, chat_elems = create_chat_box(engine, visible=False) 58 | elem_dict.update(chat_elems) 59 | 60 | load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( 61 | lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]] 62 | ) 63 | 64 | unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( 65 | lambda: ([], []), outputs=[chatbot, messages] 66 | ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]) 67 | 68 | engine.manager.get_elem_by_id("top.model_name").change( 69 | lambda model_name: gr.Column(visible=get_visual(model_name)), 70 | [engine.manager.get_elem_by_id("top.model_name")], 71 | [chat_elems["mm_box"]], 72 | ) 73 | 74 | return elem_dict 75 | -------------------------------------------------------------------------------- /src/llamafactory/webui/components/top.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | from ...data import TEMPLATES 18 | from ...extras.constants import METHODS, SUPPORTED_MODELS 19 | from ...extras.packages import is_gradio_available 20 | from ..common import get_model_info, list_checkpoints, save_config 21 | from ..utils import can_quantize, can_quantize_to 22 | 23 | 24 | if is_gradio_available(): 25 | import gradio as gr 26 | 27 | 28 | if TYPE_CHECKING: 29 | from gradio.components import Component 30 | 31 | 32 | def create_top() -> Dict[str, "Component"]: 33 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] 34 | 35 | with gr.Row(): 36 | lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1) 37 | model_name = gr.Dropdown(choices=available_models, scale=3) 38 | model_path = gr.Textbox(scale=3) 39 | 40 | with gr.Row(): 41 | finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) 42 | checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) 43 | 44 | with gr.Row(): 45 | quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2) 46 | quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2) 47 | template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) 48 | rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) 49 | booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5) 50 | 51 | model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( 52 | list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False 53 | ) 54 | model_name.input(save_config, inputs=[lang, model_name], queue=False) 55 | model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) 56 | finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( 57 | list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False 58 | ) 59 | checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) 60 | quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False) 61 | 62 | return dict( 63 | lang=lang, 64 | model_name=model_name, 65 | model_path=model_path, 66 | finetuning_type=finetuning_type, 67 | checkpoint_path=checkpoint_path, 68 | quantization_bit=quantization_bit, 69 | quantization_method=quantization_method, 70 | template=template, 71 | rope_scaling=rope_scaling, 72 | booster=booster, 73 | ) 74 | -------------------------------------------------------------------------------- /src/llamafactory/webui/css.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CSS = r""" 16 | .duplicate-button { 17 | margin: auto !important; 18 | color: white !important; 19 | background: black !important; 20 | border-radius: 100vh !important; 21 | } 22 | 23 | .modal-box { 24 | position: fixed !important; 25 | top: 50%; 26 | left: 50%; 27 | transform: translate(-50%, -50%); /* center horizontally */ 28 | max-width: 1000px; 29 | max-height: 750px; 30 | overflow-y: auto; 31 | background-color: var(--input-background-fill); 32 | flex-wrap: nowrap !important; 33 | border: 2px solid black !important; 34 | z-index: 1000; 35 | padding: 10px; 36 | } 37 | 38 | .dark .modal-box { 39 | border: 2px solid white !important; 40 | } 41 | """ 42 | -------------------------------------------------------------------------------- /src/llamafactory/webui/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Any, Dict 16 | 17 | from .chatter import WebChatModel 18 | from .common import load_config 19 | from .locales import LOCALES 20 | from .manager import Manager 21 | from .runner import Runner 22 | from .utils import create_ds_config, get_time 23 | 24 | 25 | if TYPE_CHECKING: 26 | from gradio.components import Component 27 | 28 | 29 | class Engine: 30 | def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: 31 | self.demo_mode = demo_mode 32 | self.pure_chat = pure_chat 33 | self.manager = Manager() 34 | self.runner = Runner(self.manager, demo_mode) 35 | self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) 36 | if not demo_mode: 37 | create_ds_config() 38 | 39 | def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: 40 | r""" 41 | Gets the dict to update the components. 42 | """ 43 | output_dict: Dict["Component", "Component"] = {} 44 | for elem_id, elem_attr in input_dict.items(): 45 | elem = self.manager.get_elem_by_id(elem_id) 46 | output_dict[elem] = elem.__class__(**elem_attr) 47 | 48 | return output_dict 49 | 50 | def resume(self): 51 | user_config = load_config() if not self.demo_mode else {} 52 | lang = user_config.get("lang", None) or "en" 53 | 54 | init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} 55 | 56 | if not self.pure_chat: 57 | current_time = get_time() 58 | init_dict["train.current_time"] = {"value": current_time} 59 | init_dict["train.output_dir"] = {"value": f"train_{current_time}"} 60 | init_dict["train.config_path"] = {"value": f"{current_time}.yaml"} 61 | init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"} 62 | init_dict["infer.mm_box"] = {"visible": False} 63 | 64 | if user_config.get("last_model", None): 65 | init_dict["top.model_name"] = {"value": user_config["last_model"]} 66 | 67 | yield self._update_component(init_dict) 68 | 69 | if self.runner.running and not self.demo_mode and not self.pure_chat: 70 | yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()} 71 | if self.runner.do_train: 72 | yield self._update_component({"train.resume_btn": {"value": True}}) 73 | else: 74 | yield self._update_component({"eval.resume_btn": {"value": True}}) 75 | 76 | def change_lang(self, lang: str): 77 | return { 78 | elem: elem.__class__(**LOCALES[elem_name][lang]) 79 | for elem_name, elem in self.manager.get_elem_iter() 80 | if elem_name in LOCALES 81 | } 82 | -------------------------------------------------------------------------------- /src/llamafactory/webui/manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple 16 | 17 | 18 | if TYPE_CHECKING: 19 | from gradio.components import Component 20 | 21 | 22 | class Manager: 23 | def __init__(self) -> None: 24 | self._id_to_elem: Dict[str, "Component"] = {} 25 | self._elem_to_id: Dict["Component", str] = {} 26 | 27 | def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None: 28 | r""" 29 | Adds elements to manager. 30 | """ 31 | for elem_name, elem in elem_dict.items(): 32 | elem_id = f"{tab_name}.{elem_name}" 33 | self._id_to_elem[elem_id] = elem 34 | self._elem_to_id[elem] = elem_id 35 | 36 | def get_elem_list(self) -> List["Component"]: 37 | r""" 38 | Returns the list of all elements. 39 | """ 40 | return list(self._id_to_elem.values()) 41 | 42 | def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: 43 | r""" 44 | Returns an iterator over all elements with their names. 45 | """ 46 | for elem_id, elem in self._id_to_elem.items(): 47 | yield elem_id.split(".")[-1], elem 48 | 49 | def get_elem_by_id(self, elem_id: str) -> "Component": 50 | r""" 51 | Gets element by id. 52 | 53 | Example: top.lang, train.dataset 54 | """ 55 | return self._id_to_elem[elem_id] 56 | 57 | def get_id_by_elem(self, elem: "Component") -> str: 58 | r""" 59 | Gets id by element. 60 | """ 61 | return self._elem_to_id[elem] 62 | 63 | def get_base_elems(self) -> Set["Component"]: 64 | r""" 65 | Gets the base elements that are commonly used. 66 | """ 67 | return { 68 | self._id_to_elem["top.lang"], 69 | self._id_to_elem["top.model_name"], 70 | self._id_to_elem["top.model_path"], 71 | self._id_to_elem["top.finetuning_type"], 72 | self._id_to_elem["top.checkpoint_path"], 73 | self._id_to_elem["top.quantization_bit"], 74 | self._id_to_elem["top.quantization_method"], 75 | self._id_to_elem["top.template"], 76 | self._id_to_elem["top.rope_scaling"], 77 | self._id_to_elem["top.booster"], 78 | } 79 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llamafactory.train.tuner import run_exp 16 | 17 | 18 | def main(): 19 | run_exp() 20 | 21 | 22 | def _mp_fn(index): 23 | # For xla_spawn (TPUs) 24 | run_exp() 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /src/webui.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from llamafactory.webui.interface import create_ui 18 | 19 | 20 | def main(): 21 | gradio_ipv6 = os.getenv("GRADIO_IPV6", "0").lower() in ["true", "1"] 22 | gradio_share = os.getenv("GRADIO_SHARE", "0").lower() in ["true", "1"] 23 | server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") 24 | create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | --------------------------------------------------------------------------------