├── .gitignore ├── LICENSE ├── README.md ├── assets ├── logo.png ├── overview.png ├── pipeline.png ├── res_task3.png ├── result.png └── wechat.png ├── eval ├── cal_score_benchmarks_for_close_source.py ├── cal_score_benchmarks_for_open_source.py └── eval_by_vllm_for_open_source.py ├── requirements_rl.txt ├── requirements_sft.txt ├── scripts ├── eval │ ├── close_source_models │ │ ├── evaluate_gemini.sh │ │ ├── evaluate_gemini_only_calculate_score.sh │ │ ├── evaluate_gpt4o.sh │ │ └── evaluate_gpt4o_only_calculate_score.sh │ └── open_source_models │ │ ├── calculate_score │ │ └── calculate_score.sh │ │ ├── multi_gpu_eval │ │ ├── eval_by_vllm_all_tasks_ans_sft_multi_gpu.sh │ │ ├── eval_by_vllm_all_tasks_cot_sft_multi_gpu.sh │ │ ├── eval_by_vllm_all_tasks_reason_rft_multi_gpu.sh │ │ ├── eval_by_vllm_all_tasks_zero_shot_multi_gpu.sh │ │ ├── eval_by_vllm_task1_ans_sft_multi_gpu.sh │ │ ├── eval_by_vllm_task1_cot_sft_multi_gpu.sh │ │ ├── eval_by_vllm_task1_reason_rft_multi_gpu.sh │ │ ├── eval_by_vllm_task1_zero_shot_multi_gpu.sh │ │ ├── eval_by_vllm_task2_ans_sft_multi_gpu.sh │ │ ├── eval_by_vllm_task2_cot_sft_multi_gpu.sh │ │ ├── eval_by_vllm_task2_reason_rft_multi_gpu.sh │ │ ├── eval_by_vllm_task2_zero_shot_multi_gpu.sh │ │ ├── eval_by_vllm_task3_ans_sft_multi_gpu.sh │ │ ├── eval_by_vllm_task3_cot_sft_multi_gpu.sh │ │ ├── eval_by_vllm_task3_reason_rft_multi_gpu.sh │ │ └── eval_by_vllm_task3_zero_shot_multi_gpu.sh │ │ └── single_gpu_eval │ │ ├── eval_by_vllm_all_tasks_ans_sft_single_gpu.sh │ │ ├── eval_by_vllm_all_tasks_cot_sft_single_gpu.sh │ │ ├── eval_by_vllm_all_tasks_reason_rft_single_gpu.sh │ │ ├── eval_by_vllm_all_tasks_zero_shot_single_gpu.sh │ │ ├── eval_by_vllm_task1_ans_sft_single_gpu.sh │ │ ├── eval_by_vllm_task1_cot_sft_single_gpu.sh │ │ ├── eval_by_vllm_task1_reason_rft_single_gpu.sh │ │ ├── eval_by_vllm_task1_zero_shot_single_gpu.sh │ │ ├── eval_by_vllm_task2_ans_sft_single_gpu.sh │ │ ├── eval_by_vllm_task2_cot_sft_single_gpu.sh │ │ ├── eval_by_vllm_task2_reason_rft_single_gpu.sh │ │ ├── eval_by_vllm_task2_zero_shot_single_gpu.sh │ │ ├── eval_by_vllm_task3_ans_sft_single_gpu.sh │ │ ├── eval_by_vllm_task3_cot_sft_single_gpu.sh │ │ ├── eval_by_vllm_task3_reason_rft_single_gpu.sh │ │ └── eval_by_vllm_task3_zero_shot_single_gpu.sh └── train │ ├── ans_sft │ ├── resume_finetune_qwen2vl_2b_task1_ans_sft.sh │ ├── resume_finetune_qwen2vl_2b_task2_ans_sft.sh │ ├── resume_finetune_qwen2vl_2b_task3_ans_sft.sh │ ├── resume_finetune_qwen2vl_7b_task1_ans_sft.sh │ ├── resume_finetune_qwen2vl_7b_task2_ans_sft.sh │ └── resume_finetune_qwen2vl_7b_task3_ans_sft.sh │ ├── cot_sft │ ├── resume_finetune_qwen2vl_2b_task1_cot_sft.sh │ ├── resume_finetune_qwen2vl_2b_task2_cot_sft.sh │ ├── resume_finetune_qwen2vl_2b_task3_cot_sft.sh │ ├── resume_finetune_qwen2vl_7b_task1_cot_sft.sh │ ├── resume_finetune_qwen2vl_7b_task2_cot_sft.sh │ └── resume_finetune_qwen2vl_7b_task3_cot_sft.sh │ ├── reason_rft │ ├── stage_rl │ │ ├── resume_finetune_qwen2vl_2b_task1_stage2_rl.sh │ │ ├── resume_finetune_qwen2vl_2b_task2_stage2_rl.sh │ │ ├── resume_finetune_qwen2vl_2b_task3_stage2_rl.sh │ │ ├── resume_finetune_qwen2vl_7b_task1_stage2_rl.sh │ │ ├── resume_finetune_qwen2vl_7b_task2_stage2_rl.sh │ │ └── resume_finetune_qwen2vl_7b_task3_stage2_rl.sh │ └── stage_sft │ │ ├── resume_finetune_qwen2vl_2b_task1_stage1_sft.sh │ │ ├── resume_finetune_qwen2vl_2b_task2_stage1_sft.sh │ │ ├── resume_finetune_qwen2vl_2b_task3_stage1_sft.sh │ │ ├── resume_finetune_qwen2vl_7b_task1_stage1_sft.sh │ │ ├── resume_finetune_qwen2vl_7b_task2_stage1_sft.sh │ │ └── resume_finetune_qwen2vl_7b_task3_stage1_sft.sh │ ├── reason_rft_zero │ ├── resume_finetune_qwen2vl_2b_task1_only_rl.sh │ ├── resume_finetune_qwen2vl_2b_task2_only_rl.sh │ ├── resume_finetune_qwen2vl_2b_task3_only_rl.sh │ ├── resume_finetune_qwen2vl_7b_task1_only_rl.sh │ ├── resume_finetune_qwen2vl_7b_task2_only_rl.sh │ └── resume_finetune_qwen2vl_7b_task3_only_rl.sh │ └── zero3.json ├── train ├── stage_rl │ ├── __init__.py │ ├── configs.py │ ├── grpo.py │ ├── prompt.py │ ├── reward.py │ ├── trainer │ │ ├── __init__.py │ │ └── mm_grpo_trainer.py │ └── utils │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── evaluation.py │ │ ├── hub.py │ │ └── upload_details.py └── stage_sft │ ├── api.py │ ├── data │ └── dataset_info.json │ ├── llamafactory │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ ├── app.py │ │ ├── chat.py │ │ ├── common.py │ │ └── protocol.py │ ├── chat │ │ ├── __init__.py │ │ ├── base_engine.py │ │ ├── chat_model.py │ │ ├── hf_engine.py │ │ └── vllm_engine.py │ ├── cli.py │ ├── data │ │ ├── __init__.py │ │ ├── aligner.py │ │ ├── collator.py │ │ ├── data_utils.py │ │ ├── formatter.py │ │ ├── loader.py │ │ ├── mm_plugin.py │ │ ├── parser.py │ │ ├── preprocess.py │ │ ├── processors │ │ │ ├── __init__.py │ │ │ ├── feedback.py │ │ │ ├── pairwise.py │ │ │ ├── pretrain.py │ │ │ ├── processor_utils.py │ │ │ ├── supervised.py │ │ │ └── unsupervised.py │ │ ├── template.py │ │ └── tool_utils.py │ ├── eval │ │ ├── __init__.py │ │ ├── evaluator.py │ │ └── template.py │ ├── extras │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── env.py │ │ ├── logging.py │ │ ├── misc.py │ │ ├── packages.py │ │ └── ploting.py │ ├── hparams │ │ ├── __init__.py │ │ ├── data_args.py │ │ ├── evaluation_args.py │ │ ├── finetuning_args.py │ │ ├── generating_args.py │ │ ├── model_args.py │ │ ├── parser.py │ │ └── training_args.py │ ├── launcher.py │ ├── model │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── loader.py │ │ ├── model_utils │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── checkpointing.py │ │ │ ├── embedding.py │ │ │ ├── liger_kernel.py │ │ │ ├── longlora.py │ │ │ ├── misc.py │ │ │ ├── mod.py │ │ │ ├── moe.py │ │ │ ├── packing.py │ │ │ ├── quantization.py │ │ │ ├── rope.py │ │ │ ├── unsloth.py │ │ │ ├── valuehead.py │ │ │ └── visual.py │ │ └── patcher.py │ ├── train │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── dpo │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── kto │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── ppo │ │ │ ├── __init__.py │ │ │ ├── ppo_utils.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── pt │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── rm │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── sft │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── test_utils.py │ │ ├── trainer_utils.py │ │ └── tuner.py │ └── webui │ │ ├── __init__.py │ │ ├── chatter.py │ │ ├── common.py │ │ ├── components │ │ ├── __init__.py │ │ ├── chatbot.py │ │ ├── data.py │ │ ├── eval.py │ │ ├── export.py │ │ ├── infer.py │ │ ├── top.py │ │ └── train.py │ │ ├── css.py │ │ ├── engine.py │ │ ├── interface.py │ │ ├── locales.py │ │ ├── manager.py │ │ ├── runner.py │ │ └── utils.py │ ├── train.py │ └── webui.py └── utils ├── convert_qwen2vl_format.py ├── convert_sft_data_trance.py ├── distill_cot_data.py ├── distill_cot_data_trance.py └── prompts.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | venv/ 108 | # ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/assets/logo.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/assets/overview.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/assets/pipeline.png -------------------------------------------------------------------------------- /assets/res_task3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/assets/res_task3.png -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/assets/result.png -------------------------------------------------------------------------------- /assets/wechat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/assets/wechat.png -------------------------------------------------------------------------------- /requirements_sft.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.0.1 2 | aiofiles==23.2.1 3 | aiohappyeyeballs==2.4.4 4 | aiohttp==3.11.11 5 | aiosignal==1.3.2 6 | annotated-types==0.7.0 7 | anyio==4.8.0 8 | async-timeout==5.0.1 9 | attrs==25.1.0 10 | av==14.1.0 11 | certifi==2024.12.14 12 | charset-normalizer==3.4.1 13 | click==8.1.8 14 | contourpy==1.3.1 15 | cycler==0.12.1 16 | datasets==3.1.0 17 | deepspeed==0.14.4 18 | dill==0.3.8 19 | docstring_parser==0.16 20 | einops==0.8.0 21 | exceptiongroup==1.2.2 22 | fastapi==0.115.7 23 | ffmpy==0.5.0 24 | filelock==3.17.0 25 | fire==0.7.0 26 | fonttools==4.55.6 27 | frozenlist==1.5.0 28 | fsspec==2024.9.0 29 | gradio==5.12.0 30 | gradio_client==1.5.4 31 | h11==0.14.0 32 | hjson==3.1.0 33 | httpcore==1.0.7 34 | httpx==0.28.1 35 | huggingface-hub==0.27.1 36 | idna==3.10 37 | jieba==0.42.1 38 | Jinja2==3.1.5 39 | joblib==1.4.2 40 | kiwisolver==1.4.8 41 | markdown-it-py==3.0.0 42 | MarkupSafe==2.1.5 43 | matplotlib==3.10.0 44 | mdurl==0.1.2 45 | mpmath==1.3.0 46 | multidict==6.1.0 47 | multiprocess==0.70.16 48 | networkx==3.4.2 49 | ninja==1.11.1.3 50 | nltk==3.9.1 51 | numpy==1.26.4 52 | nvidia-cublas-cu12==12.4.5.8 53 | nvidia-cuda-cupti-cu12==12.4.127 54 | nvidia-cuda-nvrtc-cu12==12.4.127 55 | nvidia-cuda-runtime-cu12==12.4.127 56 | nvidia-cudnn-cu12==9.1.0.70 57 | nvidia-cufft-cu12==11.2.1.3 58 | nvidia-curand-cu12==10.3.5.147 59 | nvidia-cusolver-cu12==11.6.1.9 60 | nvidia-cusparse-cu12==12.3.1.170 61 | nvidia-ml-py==12.570.86 62 | nvidia-nccl-cu12==2.21.5 63 | nvidia-nvjitlink-cu12==12.4.127 64 | nvidia-nvtx-cu12==12.4.127 65 | orjson==3.10.15 66 | packaging==24.2 67 | pandas==2.2.3 68 | peft==0.12.0 69 | pillow==11.1.0 70 | propcache==0.2.1 71 | protobuf==5.29.3 72 | psutil==6.1.1 73 | py-cpuinfo==9.0.0 74 | pyarrow==19.0.0 75 | pydantic==2.10.6 76 | pydantic_core==2.27.2 77 | pydub==0.25.1 78 | Pygments==2.19.1 79 | pyparsing==3.2.1 80 | python-dateutil==2.9.0.post0 81 | python-multipart==0.0.20 82 | pytz==2024.2 83 | PyYAML==6.0.2 84 | qwen-vl-utils==0.0.9 85 | regex==2024.11.6 86 | requests==2.32.3 87 | rich==13.9.4 88 | rouge-chinese==1.0.3 89 | ruff==0.9.3 90 | safehttpx==0.1.6 91 | safetensors==0.5.2 92 | scipy==1.15.1 93 | semantic-version==2.10.0 94 | sentencepiece==0.2.0 95 | shellingham==1.5.4 96 | shtab==1.7.1 97 | six==1.17.0 98 | sniffio==1.3.1 99 | sse-starlette==2.2.1 100 | starlette==0.45.3 101 | sympy==1.13.1 102 | termcolor==2.5.0 103 | tiktoken==0.8.0 104 | tokenizers==0.20.3 105 | tomlkit==0.13.2 106 | torch==2.5.1 107 | torchvision==0.20.1 108 | tqdm==4.67.1 109 | transformers==4.46.1 110 | transformers-stream-generator==0.0.5 111 | triton==3.1.0 112 | trl==0.9.6 113 | typer==0.15.1 114 | typing_extensions==4.12.2 115 | tyro==0.8.14 116 | tzdata==2025.1 117 | urllib3==2.3.0 118 | uvicorn==0.34.0 119 | websockets==14.2 120 | xxhash==3.5.0 121 | yarl==1.18.3 122 | -------------------------------------------------------------------------------- /scripts/eval/close_source_models/evaluate_gemini.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/cal_score_benchmarks_for_close_source.py --task_name clevr-math --model_type gemini 3 | python eval/cal_score_benchmarks_for_close_source.py --task_name super-clevr --model_type gemini 4 | python eval/cal_score_benchmarks_for_close_source.py --task_name geomath --model_type gemini 5 | python eval/cal_score_benchmarks_for_close_source.py --task_name geometry3k --model_type gemini 6 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance --model_type gemini 7 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-left --model_type gemini 8 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-right --model_type gemini -------------------------------------------------------------------------------- /scripts/eval/close_source_models/evaluate_gemini_only_calculate_score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/cal_score_benchmarks_for_close_source.py --task_name clevr-math --model_type gemini --only_score 3 | python eval/cal_score_benchmarks_for_close_source.py --task_name super-clevr --model_type gemini --only_score 4 | python eval/cal_score_benchmarks_for_close_source.py --task_name geomath --model_type gemini --only_score 5 | python eval/cal_score_benchmarks_for_close_source.py --task_name geometry3k --model_type gemini --only_score 6 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance --model_type gemini --only_score 7 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-left --model_type gemini --only_score 8 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-right --model_type gemini --only_score -------------------------------------------------------------------------------- /scripts/eval/close_source_models/evaluate_gpt4o.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/cal_score_benchmarks_for_close_source.py --task_name clevr-math --model_type gpt4o 3 | python eval/cal_score_benchmarks_for_close_source.py --task_name super-clevr --model_type gpt4o 4 | python eval/cal_score_benchmarks_for_close_source.py --task_name geomath --model_type gpt4o 5 | python eval/cal_score_benchmarks_for_close_source.py --task_name geometry3k --model_type gpt4o 6 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance --model_type gpt4o 7 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-left --model_type gpt4o 8 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-right --model_type gpt4o -------------------------------------------------------------------------------- /scripts/eval/close_source_models/evaluate_gpt4o_only_calculate_score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval/cal_score_benchmarks_for_close_source.py --task_name clevr-math --model_type gpt4o --only_score 3 | python eval/cal_score_benchmarks_for_close_source.py --task_name super-clevr --model_type gpt4o --only_score 4 | python eval/cal_score_benchmarks_for_close_source.py --task_name geomath --model_type gpt4o --only_score 5 | python eval/cal_score_benchmarks_for_close_source.py --task_name geometry3k --model_type gpt4o --only_score 6 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance --model_type gpt4o --only_score 7 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-left --model_type gpt4o --only_score 8 | python eval/cal_score_benchmarks_for_close_source.py --task_name trance-right --model_type gpt4o --only_score -------------------------------------------------------------------------------- /scripts/eval/open_source_models/calculate_score/calculate_score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 4 | 5 | python eval/cal_score_benchmarks_for_open_source.py --ckpt_path $MODEL_NAME_OR_PATH 6 | -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_all_tasks_ans_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 17 | STRATAGE_LIST="sft sft sft sft sft sft sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_all_tasks_cot_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 17 | STRATAGE_LIST="cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_all_tasks_reason_rft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 17 | STRATAGE_LIST="cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_all_tasks_zero_shot_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 17 | STRATAGE_LIST="zero-shot zero-shot zero-shot zero-shot zero-shot zero-shot zero-shot" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task1_ans_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="clevr-math super-clevr" 17 | STRATAGE_LIST="sft sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task1_cot_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="clevr-math super-clevr" 17 | STRATAGE_LIST="cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task1_reason_rft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="clevr-math super-clevr" 17 | STRATAGE_LIST="cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task1_zero_shot_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="clevr-math super-clevr" 17 | STRATAGE_LIST="zero-shot zero-shot" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task2_ans_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="geomath geometry3k" 17 | STRATAGE_LIST="sft sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task2_cot_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="geomath geometry3k" 17 | STRATAGE_LIST="cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task2_reason_rft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="geomath geometry3k" 17 | STRATAGE_LIST="cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task2_zero_shot_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="geomath geometry3k" 17 | STRATAGE_LIST="zero-shot zero-shot" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task3_ans_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right" 17 | STRATAGE_LIST="sft sft sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task3_cot_sft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right" 17 | STRATAGE_LIST="cot-sft cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task3_reason_rft_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right" 17 | STRATAGE_LIST="cot-sft cot-sft cot-sft" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/multi_gpu_eval/eval_by_vllm_task3_zero_shot_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BATCH_SIZE=32 3 | 4 | DEVICE_IDS=(0 1 2 3 4 5 6 7) 5 | MODEL_NAME_OR_PATH_LIST=( 6 | "/path/to/your/checkpoint_0/" 7 | "/path/to/your/checkpoint_1/" 8 | "/path/to/your/checkpoint_2/" 9 | "/path/to/your/checkpoint_3/" 10 | "/path/to/your/checkpoint_4/" 11 | "/path/to/your/checkpoint_5/" 12 | "/path/to/your/checkpoint_6/" 13 | "/path/to/your/checkpoint_7/" 14 | ) 15 | 16 | BENCHMARK_LIST="trance trance-left trance-right" 17 | STRATAGE_LIST="zero-shot zero-shot zero-shot" 18 | 19 | for i in "${!DEVICE_IDS[@]}"; do 20 | CUDA_VISIBLE_DEVICES=${DEVICE_IDS[$i]} python eval/eval_by_vllm_for_open_source.py \ 21 | --batch_size $BATCH_SIZE \ 22 | --model_name_or_path ${MODEL_NAME_OR_PATH_LIST[$i]} \ 23 | --benchmark_list $BENCHMARK_LIST \ 24 | --stratage_list $STRATAGE_LIST & 25 | done 26 | 27 | wait 28 | echo "All task finish." -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_all_tasks_ans_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 8 | STRATAGE_LIST="sft sft sft sft sft sft sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_all_tasks_cot_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 8 | STRATAGE_LIST="cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_all_tasks_reason_rft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 8 | STRATAGE_LIST="cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_all_tasks_zero_shot_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right clevr-math super-clevr geomath geometry3k" 8 | STRATAGE_LIST="zero-shot zero-shot zero-shot zero-shot zero-shot zero-shot zero-shot" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task1_ans_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="clevr-math super-clevr" 8 | STRATAGE_LIST="sft sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task1_cot_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="clevr-math super-clevr" 8 | STRATAGE_LIST="cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task1_reason_rft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="clevr-math super-clevr" 8 | STRATAGE_LIST="cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task1_zero_shot_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="clevr-math super-clevr" 8 | STRATAGE_LIST="zero-shot zero-shot" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task2_ans_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="geomath geometry3k" 8 | STRATAGE_LIST="sft sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task2_cot_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="geomath geometry3k" 8 | STRATAGE_LIST="cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task2_reason_rft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="geomath geometry3k" 8 | STRATAGE_LIST="cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task2_zero_shot_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="geomath geometry3k" 8 | STRATAGE_LIST="zero-shot zero-shot" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task3_ans_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right" 8 | STRATAGE_LIST="sft sft sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task3_cot_sft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right" 8 | STRATAGE_LIST="cot-sft cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task3_reason_rft_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right" 8 | STRATAGE_LIST="cot-sft cot-sft cot-sft" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/eval/open_source_models/single_gpu_eval/eval_by_vllm_task3_zero_shot_single_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DEVICE_ID=0 4 | BATCH_SIZE=32 5 | 6 | MODEL_NAME_OR_PATH=/path/to/your/checkpoint/ 7 | BENCHMARK_LIST="trance trance-left trance-right" 8 | STRATAGE_LIST="zero-shot zero-shot zero-shot" 9 | 10 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python eval/eval_by_vllm_for_open_source.py \ 11 | --batch_size $BATCH_SIZE \ 12 | --model_name_or_path $MODEL_NAME_OR_PATH \ 13 | --benchmark_list $BENCHMARK_LIST \ 14 | --stratage_list $STRATAGE_LIST -------------------------------------------------------------------------------- /scripts/train/ans_sft/resume_finetune_qwen2vl_2b_task1_ans_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task1_ans_sft 11 | export DATASET=clevr_math_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/ans_sft/resume_finetune_qwen2vl_2b_task2_ans_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task2_ans_sft 11 | export DATASET=geo_math_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/ans_sft/resume_finetune_qwen2vl_2b_task3_ans_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task3_ans_sft 11 | export DATASET=trance_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/ans_sft/resume_finetune_qwen2vl_7b_task1_ans_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task1_ans_sft 11 | export DATASET=clevr_math_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/ans_sft/resume_finetune_qwen2vl_7b_task2_ans_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task2_ans_sft 11 | export DATASET=geo_math_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/ans_sft/resume_finetune_qwen2vl_7b_task3_ans_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task3_ans_sft 11 | export DATASET=trance_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/cot_sft/resume_finetune_qwen2vl_2b_task1_cot_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task1_cot_sft 11 | export DATASET=clevr_math_cot_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/cot_sft/resume_finetune_qwen2vl_2b_task2_cot_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task2_cot_sft 11 | export DATASET=geo_math_cot_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/cot_sft/resume_finetune_qwen2vl_2b_task3_cot_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task3_cot_sft 11 | export DATASET=trance_cot_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/cot_sft/resume_finetune_qwen2vl_7b_task1_cot_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task1_cot_sft 11 | export DATASET=clevr_math_cot_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/cot_sft/resume_finetune_qwen2vl_7b_task2_cot_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task2_cot_sft 11 | export DATASET=geo_math_cot_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/cot_sft/resume_finetune_qwen2vl_7b_task3_cot_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task3_cot_sft 11 | export DATASET=trance_cot_sft 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_rl/resume_finetune_qwen2vl_2b_task1_stage2_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_2b_task1_stage2_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=clevr-math 14 | export DATASET_NAME=/path/to/your/dataset/clevr-math-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/checkpoints/from/stage1/ 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_rl/resume_finetune_qwen2vl_2b_task2_stage2_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_2b_task2_stage2_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=geomath 14 | export DATASET_NAME=/path/to/your/dataset/geomath-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/checkpoints/from/stage1/ 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 5 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_rl/resume_finetune_qwen2vl_2b_task3_stage2_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_2b_task3_stage2_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=trance-only-full 14 | export DATASET_NAME=/path/to/your/dataset/trance-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/checkpoints/from/stage1/ 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_rl/resume_finetune_qwen2vl_7b_task1_stage2_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_7b_task1_stage2_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=clevr-math 14 | export DATASET_NAME=/path/to/your/dataset/clevr-math-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/checkpoints/from/stage1/ 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_rl/resume_finetune_qwen2vl_7b_task2_stage2_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_7b_task2_stage2_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=geomath 14 | export DATASET_NAME=/path/to/your/dataset/geomath-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/checkpoints/from/stage1/ 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 5 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_rl/resume_finetune_qwen2vl_7b_task3_stage2_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_7b_task3_stage2_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=trance-only-full 14 | export DATASET_NAME=/path/to/your/dataset/trance-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/checkpoints/from/stage1/ 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_sft/resume_finetune_qwen2vl_2b_task1_stage1_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task1_cot_sft 11 | export DATASET=clevr_math_cot_sft_1k6 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_sft/resume_finetune_qwen2vl_2b_task2_stage1_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task2_cot_sft 11 | export DATASET=geo_math_cot_sft_1k6 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_sft/resume_finetune_qwen2vl_2b_task3_stage1_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_2b_task3_cot_sft 11 | export DATASET=trance_cot_sft_1k6 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_sft/resume_finetune_qwen2vl_7b_task1_stage1_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task1_cot_sft 11 | export DATASET=clevr_math_cot_sft_1k6 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_sft/resume_finetune_qwen2vl_7b_task2_stage1_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task2_cot_sft 11 | export DATASET=geo_math_cot_sft_1k6 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft/stage_sft/resume_finetune_qwen2vl_7b_task3_stage1_sft.sh: -------------------------------------------------------------------------------- 1 | conda activate reasonrft_sft 2 | 3 | export PYTHONPATH=$(pwd)/train/stage_sft 4 | 5 | export WANDB_MODE=offline 6 | export ACCELERATE_CPU_AFFINITY=1 7 | 8 | export IMAGE_DIR=/path/to/your/train_images/ 9 | export PRETRAIN_MODEL_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 10 | export OUTPUT_PATH=/path/to/your/checkpoints/qwen2vl_7b_task3_cot_sft 11 | export DATASET=trance_cot_sft_1k6 12 | 13 | if [ ! -d "$OUTPUT_PATH" ]; then 14 | mkdir "$OUTPUT_PATH" 15 | fi 16 | 17 | torchrun --nproc_per_node=8 --nnodes=1 --master_port=29514 \ 18 | train/stage_sft/train.py \ 19 | --deepspeed scripts/train/zero3.json \ 20 | --stage sft \ 21 | --do_train \ 22 | --model_name_or_path $PRETRAIN_MODEL_PATH \ 23 | --dataset $DATASET \ 24 | --image_dir $IMAGE_DIR \ 25 | --template qwen2_vl \ 26 | --finetuning_type full \ 27 | --output_dir $OUTPUT_PATH \ 28 | --overwrite_cache \ 29 | --overwrite_output_dir \ 30 | --warmup_steps 100 \ 31 | --weight_decay 0.1 \ 32 | --per_device_train_batch_size 1 \ 33 | --gradient_accumulation_steps 2 \ 34 | --ddp_timeout 90000 \ 35 | --learning_rate 1e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --logging_steps 5 \ 38 | --cutoff_len 4096 \ 39 | --save_steps 200 \ 40 | --plot_loss \ 41 | --num_train_epochs 1 \ 42 | --bf16 \ 43 | 2>&1 | tee ${OUTPUT_DIR}/train.log -------------------------------------------------------------------------------- /scripts/train/reason_rft_zero/resume_finetune_qwen2vl_2b_task1_only_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_2b_task1_only_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=clevr-math 14 | export DATASET_NAME=/path/to/your/dataset/clevr-math-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft_zero/resume_finetune_qwen2vl_2b_task2_only_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_2b_task2_only_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=geomath 14 | export DATASET_NAME=/path/to/your/dataset/geomath-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 5 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft_zero/resume_finetune_qwen2vl_2b_task3_only_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_2b_task3_only_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=trance-only-full 14 | export DATASET_NAME=/path/to/your/dataset/trance-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/pretrain_model/Qwen2-VL-2B-Instruct 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft_zero/resume_finetune_qwen2vl_7b_task1_only_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_7b_task1_only_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=clevr-math 14 | export DATASET_NAME=/path/to/your/dataset/clevr-math-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft_zero/resume_finetune_qwen2vl_7b_task2_only_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_7b_task2_only_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=geomath 14 | export DATASET_NAME=/path/to/your/dataset/geomath-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 5 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/reason_rft_zero/resume_finetune_qwen2vl_7b_task3_only_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate reasonrft_rl 3 | export PYTHONPATH=$(pwd)/train 4 | # Wandb 5 | # export WANDB_MODE=disabled 6 | export WANDB_BASE_URL=https://api.wandb.ai 7 | export WANDB_PROJECT=vison-reason-rft 8 | export WANDB_API_KEY="8b05b6xxxxxf224" 9 | export WANDB_RUN_NAME=resume_finetune_qwen2vl_7b_task3_only_rl-$(date +%Y-%m-%d-%H-%M-%S) 10 | wandb login $WANDB_API_KEY 11 | 12 | # Dataset 13 | export TASK_NAME=trance-only-full 14 | export DATASET_NAME=/path/to/your/dataset/trance-train.json 15 | export IMAGE_PATH=/path/to/your/train_images/ 16 | export MODEL_NAME_OR_PATH=/path/to/your/pretrain_model/Qwen2-VL-7B-Instruct 17 | export OUTPUT_DIR=/path/to/your/checkpoints/${WANDB_RUN_NAME} 18 | 19 | if [ ! -d "$OUTPUT_DIR" ]; then 20 | mkdir "$OUTPUT_DIR" 21 | fi 22 | 23 | # Debug 24 | export DEBUG_MODE="True" 25 | export LOG_PATH=${OUTPUT_DIR}/reward.log 26 | 27 | torchrun --nproc_per_node=7 --nnodes=1 --master_port=29514 \ 28 | train/stage_rl/grpo.py \ 29 | --deepspeed scripts/train/zero3.json \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 32 | --dataset_name ${DATASET_NAME} \ 33 | --image_path ${IMAGE_PATH} \ 34 | --task_name ${TASK_NAME} \ 35 | --use_vllm_for_gen true \ 36 | --use_system_prompt false \ 37 | --max_prompt_length 4096 \ 38 | --max_completion_length 512 \ 39 | --num_generations 8 \ 40 | --per_device_train_batch_size 1 \ 41 | --gradient_accumulation_steps 2 \ 42 | --logging_steps 1 \ 43 | --bf16 \ 44 | --report_to wandb \ 45 | --gradient_checkpointing true \ 46 | --attn_implementation flash_attention_2 \ 47 | --max_pixels 480000 \ 48 | --save_steps 100 \ 49 | --num_train_epochs 1 \ 50 | --run_name ${WANDB_RUN_NAME} \ 51 | 2>&1 | tee ${OUTPUT_DIR}/train.log 52 | -------------------------------------------------------------------------------- /scripts/train/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /train/stage_rl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /train/stage_rl/configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass, field 17 | from typing import Optional 18 | 19 | import trl 20 | 21 | 22 | # TODO: add the shared options with a mixin to reduce code duplication 23 | @dataclass 24 | class GRPOConfig(trl.GRPOConfig): 25 | """ 26 | args for callbacks, benchmarks etc 27 | """ 28 | 29 | benchmarks: list[str] = field( 30 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} 31 | ) 32 | callbacks: list[str] = field( 33 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} 34 | ) 35 | system_prompt: Optional[str] = field( 36 | default=None, metadata={"help": "The optional system prompt to use for benchmarking."} 37 | ) 38 | hub_model_revision: Optional[str] = field( 39 | default="main", metadata={"help": "The Hub model branch to push the model to."} 40 | ) 41 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) 42 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) 43 | 44 | 45 | @dataclass 46 | class SFTConfig(trl.SFTConfig): 47 | """ 48 | args for callbacks, benchmarks etc 49 | """ 50 | 51 | benchmarks: list[str] = field( 52 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} 53 | ) 54 | callbacks: list[str] = field( 55 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} 56 | ) 57 | system_prompt: Optional[str] = field( 58 | default=None, 59 | metadata={"help": "The optional system prompt to use for benchmarking."}, 60 | ) 61 | hub_model_revision: Optional[str] = field( 62 | default="main", 63 | metadata={"help": "The Hub model branch to push the model to."}, 64 | ) 65 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) 66 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) 67 | -------------------------------------------------------------------------------- /train/stage_rl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .mm_grpo_trainer import MultiModalGRPOTrainer 2 | 3 | 4 | __all__ = ["MultiModalGRPOTrainer"] -------------------------------------------------------------------------------- /train/stage_rl/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_rl/utils/__init__.py -------------------------------------------------------------------------------- /train/stage_rl/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import subprocess 18 | from typing import List 19 | 20 | from transformers import TrainerCallback 21 | from transformers.trainer_callback import TrainerControl, TrainerState 22 | from transformers.training_args import TrainingArguments 23 | 24 | from .evaluation import run_benchmark_jobs 25 | from .hub import push_to_hub_revision 26 | 27 | 28 | def is_slurm_available() -> bool: 29 | # returns true if a slurm queueing system is available 30 | try: 31 | subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | return True 33 | except FileNotFoundError: 34 | return False 35 | 36 | 37 | class DummyConfig: 38 | def __init__(self, **kwargs): 39 | for k, v in kwargs.items(): 40 | setattr(self, k, v) 41 | 42 | 43 | class PushToHubRevisionCallback(TrainerCallback): 44 | def __init__(self, model_config) -> None: 45 | self.model_config = model_config 46 | 47 | def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 48 | if state.is_world_process_zero: 49 | global_step = state.global_step 50 | 51 | # WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround 52 | # Also if you instantiate a new SFTConfig, the accelerator dist state will be broken 53 | dummy_config = DummyConfig( 54 | hub_model_id=args.hub_model_id, 55 | hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}", 56 | output_dir=f"{args.output_dir}/checkpoint-{global_step}", 57 | system_prompt=args.system_prompt, 58 | ) 59 | 60 | future = push_to_hub_revision( 61 | dummy_config, extra_ignore_patterns=["*.pt"] 62 | ) # don't push the optimizer states 63 | 64 | if is_slurm_available(): 65 | dummy_config.benchmarks = args.benchmarks 66 | 67 | def run_benchmark_callback(_): 68 | print(f"Checkpoint {global_step} pushed to hub.") 69 | run_benchmark_jobs(dummy_config, self.model_config) 70 | 71 | future.add_done_callback(run_benchmark_callback) 72 | 73 | 74 | CALLBACKS = { 75 | "push_to_hub_revision": PushToHubRevisionCallback, 76 | } 77 | 78 | 79 | def get_callbacks(train_config, model_config) -> List[TrainerCallback]: 80 | callbacks = [] 81 | for callback_name in train_config.callbacks: 82 | if callback_name not in CALLBACKS: 83 | raise ValueError(f"Callback {callback_name} not found in CALLBACKS.") 84 | callbacks.append(CALLBACKS[callback_name](model_config)) 85 | 86 | return callbacks 87 | -------------------------------------------------------------------------------- /train/stage_rl/utils/upload_details.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Push the details from a LightEval run to the Hub. 17 | 18 | Usage: 19 | 20 | python src/open_r1/utils/upload_details.py \ 21 | --data_files {path_to_parquet_file} \ 22 | --hub_repo_id {hub_repo_id} \ 23 | --config_name {config_name} 24 | """ 25 | 26 | from dataclasses import dataclass, field 27 | from typing import List 28 | 29 | from datasets import load_dataset 30 | from transformers import HfArgumentParser 31 | 32 | 33 | @dataclass 34 | class ScriptArguments: 35 | data_files: List[str] = field(default_factory=list) 36 | hub_repo_id: str = None 37 | config_name: str = None 38 | 39 | 40 | def main(): 41 | parser = HfArgumentParser(ScriptArguments) 42 | args = parser.parse() 43 | 44 | if all(file.endswith(".json") for file in args.data_files): 45 | ds = load_dataset("json", data_files=args.data_files) 46 | elif all(file.endswith(".jsonl") for file in args.data_files): 47 | ds = load_dataset("json", data_files=args.data_files) 48 | else: 49 | ds = load_dataset("parquet", data_files=args.data_files) 50 | url = ds.push_to_hub(args.hub_repo_id, config_name=args.config_name, private=True) 51 | print(f"Dataset available at: {url}") 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /train/stage_sft/api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import uvicorn 18 | 19 | from llamafactory.api.app import create_app 20 | from llamafactory.chat import ChatModel 21 | 22 | 23 | def main(): 24 | chat_model = ChatModel() 25 | app = create_app(chat_model) 26 | api_host = os.getenv("API_HOST", "0.0.0.0") 27 | api_port = int(os.getenv("API_PORT", "8000")) 28 | print(f"Visit http://localhost:{api_port}/docs for API document.") 29 | uvicorn.run(app, host=api_host, port=api_port) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /train/stage_sft/data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "clevr_math_sft": { 3 | "file_name": "/path/to/your/clevr-math-train-sft.json", 4 | "formatting": "sharegpt", 5 | "columns": { 6 | "messages": "messages", 7 | "images": "images" 8 | }, 9 | "tags": { 10 | "role_tag": "role", 11 | "content_tag": "content", 12 | "user_tag": "user", 13 | "assistant_tag": "assistant" 14 | } 15 | }, 16 | "clevr_math_cot_sft": { 17 | "file_name": "/path/to/your/clevr-math-train-cot-sft.json", 18 | "formatting": "sharegpt", 19 | "columns": { 20 | "messages": "messages", 21 | "images": "images" 22 | }, 23 | "tags": { 24 | "role_tag": "role", 25 | "content_tag": "content", 26 | "user_tag": "user", 27 | "assistant_tag": "assistant" 28 | } 29 | }, 30 | "clevr_math_cot_sft-1k6": { 31 | "file_name": "/path/to/your/clevr-math-train-cot-sft-1k6.json", 32 | "formatting": "sharegpt", 33 | "columns": { 34 | "messages": "messages", 35 | "images": "images" 36 | }, 37 | "tags": { 38 | "role_tag": "role", 39 | "content_tag": "content", 40 | "user_tag": "user", 41 | "assistant_tag": "assistant" 42 | } 43 | }, 44 | "geo_math_sft": { 45 | "file_name": "/path/to/your/geo-math-train-sft.json", 46 | "formatting": "sharegpt", 47 | "columns": { 48 | "messages": "messages", 49 | "images": "images" 50 | }, 51 | "tags": { 52 | "role_tag": "role", 53 | "content_tag": "content", 54 | "user_tag": "user", 55 | "assistant_tag": "assistant" 56 | } 57 | }, 58 | "geo_math_cot_sft": { 59 | "file_name": "/path/to/your/geo-math-train-cot-sft.json", 60 | "formatting": "sharegpt", 61 | "columns": { 62 | "messages": "messages", 63 | "images": "images" 64 | }, 65 | "tags": { 66 | "role_tag": "role", 67 | "content_tag": "content", 68 | "user_tag": "user", 69 | "assistant_tag": "assistant" 70 | } 71 | }, 72 | "geo_math_cot_sft-1k6": { 73 | "file_name": "/path/to/your/geo-math-train-cot-sft-1k6.json", 74 | "formatting": "sharegpt", 75 | "columns": { 76 | "messages": "messages", 77 | "images": "images" 78 | }, 79 | "tags": { 80 | "role_tag": "role", 81 | "content_tag": "content", 82 | "user_tag": "user", 83 | "assistant_tag": "assistant" 84 | } 85 | }, 86 | "trance_sft": { 87 | "file_name": "/path/to/your/trance-train-sft.json", 88 | "formatting": "sharegpt", 89 | "columns": { 90 | "messages": "messages", 91 | "images": "images" 92 | }, 93 | "tags": { 94 | "role_tag": "role", 95 | "content_tag": "content", 96 | "user_tag": "user", 97 | "assistant_tag": "assistant" 98 | } 99 | }, 100 | "trance_cot_sft": { 101 | "file_name": "/path/to/your/trance-train-cot-sft.json", 102 | "formatting": "sharegpt", 103 | "columns": { 104 | "messages": "messages", 105 | "images": "images" 106 | }, 107 | "tags": { 108 | "role_tag": "role", 109 | "content_tag": "content", 110 | "user_tag": "user", 111 | "assistant_tag": "assistant" 112 | } 113 | }, 114 | "trance_cot_sft-1k6": { 115 | "file_name": "/path/to/your/trance-train-cot-sft-1k6.json", 116 | "formatting": "sharegpt", 117 | "columns": { 118 | "messages": "messages", 119 | "images": "images" 120 | }, 121 | "tags": { 122 | "role_tag": "role", 123 | "content_tag": "content", 124 | "user_tag": "user", 125 | "assistant_tag": "assistant" 126 | } 127 | }, 128 | "trance_add_caption_cot_sft": { 129 | "file_name": "/path/to/your/trance-add-caption-train-cot-sft.json", 130 | "formatting": "sharegpt", 131 | "columns": { 132 | "messages": "messages", 133 | "images": "images" 134 | }, 135 | "tags": { 136 | "role_tag": "role", 137 | "content_tag": "content", 138 | "user_tag": "user", 139 | "assistant_tag": "assistant" 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r""" 16 | Efficient fine-tuning of large language models. 17 | 18 | Level: 19 | api, webui > chat, eval, train > data, model > hparams > extras 20 | 21 | Dependency graph: 22 | main: 23 | transformers>=4.41.2,<=4.46.1 24 | datasets>=2.16.0,<=3.1.0 25 | accelerate>=0.34.0,<=1.0.1 26 | peft>=0.11.1,<=0.12.0 27 | trl>=0.8.6,<=0.9.6 28 | attention: 29 | transformers>=4.42.4 (gemma+fa2) 30 | longlora: 31 | transformers>=4.41.2,<=4.46.1 32 | packing: 33 | transformers>=4.43.0,<=4.46.1 34 | 35 | Disable version checking: DISABLE_VERSION_CHECK=1 36 | Enable VRAM recording: RECORD_VRAM=1 37 | Force check imports: FORCE_CHECK_IMPORTS=1 38 | Force using torchrun: FORCE_TORCHRUN=1 39 | Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN 40 | Use modelscope: USE_MODELSCOPE_HUB=1 41 | Use openmind: USE_OPENMIND_HUB=1 42 | """ 43 | 44 | from .extras.env import VERSION 45 | 46 | 47 | __version__ = VERSION 48 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/api/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/api/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import TYPE_CHECKING, Any, Dict 17 | 18 | 19 | if TYPE_CHECKING: 20 | from pydantic import BaseModel 21 | 22 | 23 | def dictify(data: "BaseModel") -> Dict[str, Any]: 24 | try: # pydantic v2 25 | return data.model_dump(exclude_unset=True) 26 | except AttributeError: # pydantic v1 27 | return data.dict(exclude_unset=True) 28 | 29 | 30 | def jsonify(data: "BaseModel") -> str: 31 | try: # pydantic v2 32 | return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) 33 | except AttributeError: # pydantic v1 34 | return data.json(exclude_unset=True, ensure_ascii=False) 35 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/chat/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base_engine import BaseEngine 16 | from .chat_model import ChatModel 17 | 18 | 19 | __all__ = ["BaseEngine", "ChatModel"] 20 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/chat/base_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from dataclasses import dataclass 17 | from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers import PreTrainedModel, PreTrainedTokenizer 22 | from vllm import AsyncLLMEngine 23 | 24 | from ..data import Template 25 | from ..data.mm_plugin import ImageInput, VideoInput 26 | from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments 27 | 28 | 29 | @dataclass 30 | class Response: 31 | response_text: str 32 | response_length: int 33 | prompt_length: int 34 | finish_reason: Literal["stop", "length"] 35 | 36 | 37 | class BaseEngine(ABC): 38 | r""" 39 | Base class for inference engine of chat models. 40 | 41 | Must implements async methods: chat(), stream_chat() and get_scores(). 42 | """ 43 | 44 | model: Union["PreTrainedModel", "AsyncLLMEngine"] 45 | tokenizer: "PreTrainedTokenizer" 46 | can_generate: bool 47 | template: "Template" 48 | generating_args: Dict[str, Any] 49 | 50 | @abstractmethod 51 | def __init__( 52 | self, 53 | model_args: "ModelArguments", 54 | data_args: "DataArguments", 55 | finetuning_args: "FinetuningArguments", 56 | generating_args: "GeneratingArguments", 57 | ) -> None: 58 | r""" 59 | Initializes an inference engine. 60 | """ 61 | ... 62 | 63 | @abstractmethod 64 | async def chat( 65 | self, 66 | messages: Sequence[Dict[str, str]], 67 | system: Optional[str] = None, 68 | tools: Optional[str] = None, 69 | images: Optional[Sequence["ImageInput"]] = None, 70 | videos: Optional[Sequence["VideoInput"]] = None, 71 | **input_kwargs, 72 | ) -> List["Response"]: 73 | r""" 74 | Gets a list of responses of the chat model. 75 | """ 76 | ... 77 | 78 | @abstractmethod 79 | async def stream_chat( 80 | self, 81 | messages: Sequence[Dict[str, str]], 82 | system: Optional[str] = None, 83 | tools: Optional[str] = None, 84 | images: Optional[Sequence["ImageInput"]] = None, 85 | videos: Optional[Sequence["VideoInput"]] = None, 86 | **input_kwargs, 87 | ) -> AsyncGenerator[str, None]: 88 | r""" 89 | Gets the response token-by-token of the chat model. 90 | """ 91 | ... 92 | 93 | @abstractmethod 94 | async def get_scores( 95 | self, 96 | batch_input: List[str], 97 | **input_kwargs, 98 | ) -> List[float]: 99 | r""" 100 | Gets a list of scores of the reward model. 101 | """ 102 | ... 103 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .collator import ( 16 | KTODataCollatorWithPadding, 17 | MultiModalDataCollatorForSeq2Seq, 18 | PairwiseDataCollatorWithPadding, 19 | SFTDataCollatorWith4DAttentionMask, 20 | ) 21 | from .data_utils import Role, split_dataset 22 | from .loader import get_dataset 23 | from .template import TEMPLATES, Template, get_template_and_fix_tokenizer 24 | 25 | 26 | __all__ = [ 27 | "KTODataCollatorWithPadding", 28 | "MultiModalDataCollatorForSeq2Seq", 29 | "PairwiseDataCollatorWithPadding", 30 | "SFTDataCollatorWith4DAttentionMask", 31 | "Role", 32 | "split_dataset", 33 | "get_dataset", 34 | "TEMPLATES", 35 | "Template", 36 | "get_template_and_fix_tokenizer", 37 | ] 38 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum, unique 16 | from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union 17 | 18 | from datasets import DatasetDict, concatenate_datasets, interleave_datasets 19 | 20 | from ..extras import logging 21 | 22 | 23 | if TYPE_CHECKING: 24 | from datasets import Dataset, IterableDataset 25 | 26 | from ..hparams import DataArguments 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] 33 | 34 | 35 | @unique 36 | class Role(str, Enum): 37 | USER = "user" 38 | ASSISTANT = "assistant" 39 | SYSTEM = "system" 40 | FUNCTION = "function" 41 | OBSERVATION = "observation" 42 | 43 | 44 | class DatasetModule(TypedDict): 45 | train_dataset: Optional[Union["Dataset", "IterableDataset"]] 46 | eval_dataset: Optional[Union["Dataset", "IterableDataset"]] 47 | 48 | 49 | def merge_dataset( 50 | all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int 51 | ) -> Union["Dataset", "IterableDataset"]: 52 | r""" 53 | Merges multiple datasets to a unified dataset. 54 | """ 55 | if len(all_datasets) == 1: 56 | return all_datasets[0] 57 | elif data_args.mix_strategy == "concat": 58 | if data_args.streaming: 59 | logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.") 60 | 61 | return concatenate_datasets(all_datasets) 62 | elif data_args.mix_strategy.startswith("interleave"): 63 | if not data_args.streaming: 64 | logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") 65 | 66 | return interleave_datasets( 67 | datasets=all_datasets, 68 | probabilities=data_args.interleave_probs, 69 | seed=seed, 70 | stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", 71 | ) 72 | else: 73 | raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.") 74 | 75 | 76 | def split_dataset( 77 | dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int 78 | ) -> "DatasetDict": 79 | r""" 80 | Splits the dataset and returns a dataset dict containing train set and validation set. 81 | 82 | Supports both map dataset and iterable dataset. 83 | """ 84 | if data_args.streaming: 85 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) 86 | val_set = dataset.take(int(data_args.val_size)) 87 | train_set = dataset.skip(int(data_args.val_size)) 88 | return DatasetDict({"train": train_set, "validation": val_set}) 89 | else: 90 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size 91 | dataset = dataset.train_test_split(test_size=val_size, seed=seed) 92 | return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) 93 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/data/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/data/processors/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/data/processors/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from itertools import chain 19 | from typing import TYPE_CHECKING, Any, Dict, List 20 | 21 | 22 | if TYPE_CHECKING: 23 | from transformers import PreTrainedTokenizer 24 | 25 | from ...hparams import DataArguments 26 | 27 | 28 | def preprocess_pretrain_dataset( 29 | examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" 30 | ) -> Dict[str, List[Any]]: 31 | # build grouped texts with format `X1 X2 X3 ...` if packing is enabled 32 | eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token 33 | text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] 34 | 35 | if not data_args.packing: 36 | if getattr(tokenizer, "add_bos_token", False): 37 | text_examples = [tokenizer.bos_token + example for example in text_examples] 38 | 39 | result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len) 40 | else: 41 | tokenized_examples = tokenizer(text_examples, add_special_tokens=False) 42 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} 43 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) 44 | block_size = data_args.cutoff_len 45 | total_length = (total_length // block_size) * block_size 46 | result = { 47 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 48 | for k, t in concatenated_examples.items() 49 | } 50 | if getattr(tokenizer, "add_bos_token", False): 51 | for i in range(len(result["input_ids"])): 52 | result["input_ids"][i][0] = tokenizer.bos_token_id 53 | 54 | return result 55 | 56 | 57 | def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: 58 | print("input_ids:\n{}".format(example["input_ids"])) 59 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 60 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/data/processors/processor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import bisect 16 | from typing import List, Sequence, Tuple 17 | 18 | 19 | def search_for_fit(numbers: Sequence[int], capacity: int) -> int: 20 | r""" 21 | Finds the index of largest number that fits into the knapsack with the given capacity. 22 | """ 23 | index = bisect.bisect(numbers, capacity) 24 | return -1 if index == 0 else (index - 1) 25 | 26 | 27 | def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: 28 | r""" 29 | An efficient greedy algorithm with binary search for the knapsack problem. 30 | """ 31 | numbers.sort() # sort numbers in ascending order for binary search 32 | knapsacks = [] 33 | 34 | while numbers: 35 | current_knapsack = [] 36 | remaining_capacity = capacity 37 | 38 | while True: 39 | index = search_for_fit(numbers, remaining_capacity) 40 | if index == -1: 41 | break # no more numbers fit in this knapsack 42 | 43 | remaining_capacity -= numbers[index] # update the remaining capacity 44 | current_knapsack.append(numbers.pop(index)) # add the number to knapsack 45 | 46 | knapsacks.append(current_knapsack) 47 | 48 | return knapsacks 49 | 50 | 51 | def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: 52 | r""" 53 | Computes the real sequence length after truncation by the cutoff_len. 54 | """ 55 | if target_len * 2 < cutoff_len: # truncate source 56 | max_target_len = cutoff_len 57 | elif source_len * 2 < cutoff_len: # truncate target 58 | max_target_len = cutoff_len - source_len 59 | else: # truncate both 60 | max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) 61 | 62 | new_target_len = min(max_target_len, target_len) 63 | max_source_len = max(cutoff_len - new_target_len, 0) 64 | new_source_len = min(max_source_len, source_len) 65 | return new_source_len, new_target_len 66 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/eval/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/eval/template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Dict, List, Sequence, Tuple 17 | 18 | from ..data import Role 19 | from ..extras.constants import CHOICES 20 | 21 | 22 | @dataclass 23 | class EvalTemplate: 24 | system: str 25 | choice: str 26 | answer: str 27 | 28 | def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: 29 | r""" 30 | input: a dict with keys {"question", "A", "B", "C", "D", "answer"} 31 | output: a tuple of (prompt, response) 32 | """ 33 | candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] 34 | return "".join([example["question"]] + candidates + [self.answer]), example["answer"] 35 | 36 | def format_example( 37 | self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str 38 | ) -> List[Dict[str, str]]: 39 | r""" 40 | Converts dataset examples to messages. 41 | """ 42 | messages = [] 43 | for k in range(len(support_set)): 44 | prompt, response = self._parse_example(support_set[k]) 45 | messages.append({"role": Role.USER.value, "content": prompt}) 46 | messages.append({"role": Role.ASSISTANT.value, "content": response}) 47 | 48 | prompt, response = self._parse_example(target_data) 49 | messages.append({"role": Role.USER.value, "content": prompt}) 50 | messages.append({"role": Role.ASSISTANT.value, "content": response}) 51 | messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] 52 | return messages 53 | 54 | 55 | eval_templates: Dict[str, "EvalTemplate"] = {} 56 | 57 | 58 | def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: 59 | eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) 60 | 61 | 62 | def get_eval_template(name: str) -> "EvalTemplate": 63 | eval_template = eval_templates.get(name, None) 64 | assert eval_template is not None, f"Template {name} does not exist." 65 | return eval_template 66 | 67 | 68 | _register_eval_template( 69 | name="en", 70 | system="The following are multiple choice questions (with answers) about {subject}.\n\n", 71 | choice="\n{choice}. {content}", 72 | answer="\nAnswer:", 73 | ) 74 | 75 | 76 | _register_eval_template( 77 | name="zh", 78 | system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", 79 | choice="\n{choice}. {content}", 80 | answer="\n答案:", 81 | ) 82 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/extras/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/extras/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import platform 19 | 20 | import accelerate 21 | import datasets 22 | import peft 23 | import torch 24 | import transformers 25 | import trl 26 | from transformers.utils import is_torch_cuda_available, is_torch_npu_available 27 | 28 | 29 | VERSION = "0.9.2.dev0" 30 | 31 | 32 | def print_env() -> None: 33 | info = { 34 | "`llamafactory` version": VERSION, 35 | "Platform": platform.platform(), 36 | "Python version": platform.python_version(), 37 | "PyTorch version": torch.__version__, 38 | "Transformers version": transformers.__version__, 39 | "Datasets version": datasets.__version__, 40 | "Accelerate version": accelerate.__version__, 41 | "PEFT version": peft.__version__, 42 | "TRL version": trl.__version__, 43 | } 44 | 45 | if is_torch_cuda_available(): 46 | info["PyTorch version"] += " (GPU)" 47 | info["GPU type"] = torch.cuda.get_device_name() 48 | 49 | if is_torch_npu_available(): 50 | info["PyTorch version"] += " (NPU)" 51 | info["NPU type"] = torch.npu.get_device_name() 52 | info["CANN version"] = torch.version.cann 53 | 54 | try: 55 | import deepspeed # type: ignore 56 | 57 | info["DeepSpeed version"] = deepspeed.__version__ 58 | except Exception: 59 | pass 60 | 61 | try: 62 | import bitsandbytes 63 | 64 | info["Bitsandbytes version"] = bitsandbytes.__version__ 65 | except Exception: 66 | pass 67 | 68 | try: 69 | import vllm 70 | 71 | info["vLLM version"] = vllm.__version__ 72 | except Exception: 73 | pass 74 | 75 | print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") 76 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/extras/packages.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import importlib.metadata 19 | import importlib.util 20 | from functools import lru_cache 21 | from typing import TYPE_CHECKING 22 | 23 | from packaging import version 24 | 25 | 26 | if TYPE_CHECKING: 27 | from packaging.version import Version 28 | 29 | 30 | def _is_package_available(name: str) -> bool: 31 | return importlib.util.find_spec(name) is not None 32 | 33 | 34 | def _get_package_version(name: str) -> "Version": 35 | try: 36 | return version.parse(importlib.metadata.version(name)) 37 | except Exception: 38 | return version.parse("0.0.0") 39 | 40 | 41 | def is_pyav_available(): 42 | return _is_package_available("av") 43 | 44 | 45 | def is_fastapi_available(): 46 | return _is_package_available("fastapi") 47 | 48 | 49 | def is_galore_available(): 50 | return _is_package_available("galore_torch") 51 | 52 | 53 | def is_apollo_available(): 54 | return _is_package_available("apollo_torch") 55 | 56 | 57 | def is_gradio_available(): 58 | return _is_package_available("gradio") 59 | 60 | 61 | def is_matplotlib_available(): 62 | return _is_package_available("matplotlib") 63 | 64 | 65 | def is_pillow_available(): 66 | return _is_package_available("PIL") 67 | 68 | 69 | def is_ray_available(): 70 | return _is_package_available("ray") 71 | 72 | 73 | def is_requests_available(): 74 | return _is_package_available("requests") 75 | 76 | 77 | def is_rouge_available(): 78 | return _is_package_available("rouge_chinese") 79 | 80 | 81 | def is_starlette_available(): 82 | return _is_package_available("sse_starlette") 83 | 84 | 85 | @lru_cache 86 | def is_transformers_version_greater_than(content: str): 87 | return _get_package_version("transformers") >= version.parse(content) 88 | 89 | 90 | @lru_cache 91 | def is_transformers_version_equal_to_4_46(): 92 | return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1") 93 | 94 | 95 | def is_uvicorn_available(): 96 | return _is_package_available("uvicorn") 97 | 98 | 99 | def is_vllm_available(): 100 | return _is_package_available("vllm") 101 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/extras/ploting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import math 17 | import os 18 | from typing import Any, Dict, List 19 | 20 | from transformers.trainer import TRAINER_STATE_NAME 21 | 22 | from . import logging 23 | from .packages import is_matplotlib_available 24 | 25 | 26 | if is_matplotlib_available(): 27 | import matplotlib.figure 28 | import matplotlib.pyplot as plt 29 | 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | def smooth(scalars: List[float]) -> List[float]: 35 | r""" 36 | EMA implementation according to TensorBoard. 37 | """ 38 | if len(scalars) == 0: 39 | return [] 40 | 41 | last = scalars[0] 42 | smoothed = [] 43 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function 44 | for next_val in scalars: 45 | smoothed_val = last * weight + (1 - weight) * next_val 46 | smoothed.append(smoothed_val) 47 | last = smoothed_val 48 | return smoothed 49 | 50 | 51 | def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": 52 | r""" 53 | Plots loss curves in LlamaBoard. 54 | """ 55 | plt.close("all") 56 | plt.switch_backend("agg") 57 | fig = plt.figure() 58 | ax = fig.add_subplot(111) 59 | steps, losses = [], [] 60 | for log in trainer_log: 61 | if log.get("loss", None): 62 | steps.append(log["current_steps"]) 63 | losses.append(log["loss"]) 64 | 65 | ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") 66 | ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") 67 | ax.legend() 68 | ax.set_xlabel("step") 69 | ax.set_ylabel("loss") 70 | return fig 71 | 72 | 73 | def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: 74 | r""" 75 | Plots loss curves and saves the image. 76 | """ 77 | plt.switch_backend("agg") 78 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f: 79 | data = json.load(f) 80 | 81 | for key in keys: 82 | steps, metrics = [], [] 83 | for i in range(len(data["log_history"])): 84 | if key in data["log_history"][i]: 85 | steps.append(data["log_history"][i]["step"]) 86 | metrics.append(data["log_history"][i][key]) 87 | 88 | if len(metrics) == 0: 89 | logger.warning_rank0(f"No metric {key} to plot.") 90 | continue 91 | 92 | plt.figure() 93 | plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") 94 | plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") 95 | plt.title(f"training {key} of {save_dictionary}") 96 | plt.xlabel("step") 97 | plt.ylabel(key) 98 | plt.legend() 99 | figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_"))) 100 | plt.savefig(figure_path, format="png", dpi=100) 101 | print("Figure saved at:", figure_path) 102 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .data_args import DataArguments 16 | from .evaluation_args import EvaluationArguments 17 | from .finetuning_args import FinetuningArguments 18 | from .generating_args import GeneratingArguments 19 | from .model_args import ModelArguments 20 | from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args 21 | from .training_args import RayArguments, TrainingArguments 22 | 23 | 24 | __all__ = [ 25 | "DataArguments", 26 | "EvaluationArguments", 27 | "FinetuningArguments", 28 | "GeneratingArguments", 29 | "ModelArguments", 30 | "RayArguments", 31 | "TrainingArguments", 32 | "get_eval_args", 33 | "get_infer_args", 34 | "get_ray_args", 35 | "get_train_args", 36 | "read_args", 37 | ] 38 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/hparams/evaluation_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass, field 17 | from typing import Literal, Optional 18 | 19 | from datasets import DownloadMode 20 | 21 | 22 | @dataclass 23 | class EvaluationArguments: 24 | r""" 25 | Arguments pertaining to specify the evaluation parameters. 26 | """ 27 | 28 | task: str = field( 29 | metadata={"help": "Name of the evaluation task."}, 30 | ) 31 | task_dir: str = field( 32 | default="evaluation", 33 | metadata={"help": "Path to the folder containing the evaluation datasets."}, 34 | ) 35 | batch_size: int = field( 36 | default=4, 37 | metadata={"help": "The batch size per GPU for evaluation."}, 38 | ) 39 | seed: int = field( 40 | default=42, 41 | metadata={"help": "Random seed to be used with data loaders."}, 42 | ) 43 | lang: Literal["en", "zh"] = field( 44 | default="en", 45 | metadata={"help": "Language used at evaluation."}, 46 | ) 47 | n_shot: int = field( 48 | default=5, 49 | metadata={"help": "Number of examplars for few-shot learning."}, 50 | ) 51 | save_dir: Optional[str] = field( 52 | default=None, 53 | metadata={"help": "Path to save the evaluation results."}, 54 | ) 55 | download_mode: DownloadMode = field( 56 | default=DownloadMode.REUSE_DATASET_IF_EXISTS, 57 | metadata={"help": "Download mode used for the evaluation datasets."}, 58 | ) 59 | 60 | def __post_init__(self): 61 | if self.save_dir is not None and os.path.exists(self.save_dir): 62 | raise ValueError("`save_dir` already exists, use another one.") 63 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/hparams/generating_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import asdict, dataclass, field 16 | from typing import Any, Dict, Optional 17 | 18 | from transformers import GenerationConfig 19 | 20 | 21 | @dataclass 22 | class GeneratingArguments: 23 | r""" 24 | Arguments pertaining to specify the decoding parameters. 25 | """ 26 | 27 | do_sample: bool = field( 28 | default=True, 29 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, 30 | ) 31 | temperature: float = field( 32 | default=0.95, 33 | metadata={"help": "The value used to modulate the next token probabilities."}, 34 | ) 35 | top_p: float = field( 36 | default=0.7, 37 | metadata={ 38 | "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." 39 | }, 40 | ) 41 | top_k: int = field( 42 | default=50, 43 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, 44 | ) 45 | num_beams: int = field( 46 | default=1, 47 | metadata={"help": "Number of beams for beam search. 1 means no beam search."}, 48 | ) 49 | max_length: int = field( 50 | default=1024, 51 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, 52 | ) 53 | max_new_tokens: int = field( 54 | default=1024, 55 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, 56 | ) 57 | repetition_penalty: float = field( 58 | default=1.0, 59 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, 60 | ) 61 | length_penalty: float = field( 62 | default=1.0, 63 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, 64 | ) 65 | default_system: Optional[str] = field( 66 | default=None, 67 | metadata={"help": "Default system message to use in chat completion."}, 68 | ) 69 | skip_special_tokens: bool = field( 70 | default=True, 71 | metadata={"help": "Whether or not to remove special tokens in the decoding."}, 72 | ) 73 | 74 | def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]: 75 | args = asdict(self) 76 | if args.get("max_new_tokens", -1) > 0: 77 | args.pop("max_length", None) 78 | else: 79 | args.pop("max_new_tokens", None) 80 | 81 | if obey_generation_config: 82 | generation_config = GenerationConfig() 83 | for key in list(args.keys()): 84 | if not hasattr(generation_config, key): 85 | args.pop(key) 86 | 87 | return args 88 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/hparams/training_args.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import Literal, Optional, Union 4 | 5 | from transformers import Seq2SeqTrainingArguments 6 | from transformers.training_args import _convert_str_dict 7 | 8 | from ..extras.misc import use_ray 9 | 10 | 11 | @dataclass 12 | class RayArguments: 13 | r""" 14 | Arguments pertaining to the Ray training. 15 | """ 16 | 17 | ray_run_name: Optional[str] = field( 18 | default=None, 19 | metadata={"help": "The training results will be saved at `saves/ray_run_name`."}, 20 | ) 21 | ray_num_workers: int = field( 22 | default=1, 23 | metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, 24 | ) 25 | resources_per_worker: Union[dict, str] = field( 26 | default_factory=lambda: {"GPU": 1}, 27 | metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, 28 | ) 29 | placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field( 30 | default="PACK", 31 | metadata={"help": "The placement strategy for Ray training. Default is PACK."}, 32 | ) 33 | 34 | def __post_init__(self): 35 | self.use_ray = use_ray() 36 | if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): 37 | self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) 38 | 39 | 40 | @dataclass 41 | class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): 42 | r""" 43 | Arguments pertaining to the trainer. 44 | """ 45 | 46 | def __post_init__(self): 47 | Seq2SeqTrainingArguments.__post_init__(self) 48 | RayArguments.__post_init__(self) 49 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llamafactory.train.tuner import run_exp # use absolute import 16 | 17 | 18 | def launch(): 19 | run_exp() 20 | 21 | 22 | if __name__ == "__main__": 23 | launch() 24 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .loader import load_config, load_model, load_tokenizer 16 | from .model_utils.misc import find_all_linear_modules 17 | from .model_utils.quantization import QuantizationMethod 18 | from .model_utils.valuehead import load_valuehead_params 19 | 20 | 21 | __all__ = [ 22 | "QuantizationMethod", 23 | "load_config", 24 | "load_model", 25 | "load_tokenizer", 26 | "find_all_linear_modules", 27 | "load_valuehead_params", 28 | ] 29 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/model/model_utils/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available 18 | 19 | from ...extras import logging 20 | from ...extras.misc import check_version 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import PretrainedConfig 25 | 26 | from ...hparams import ModelArguments 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | def configure_attn_implementation( 33 | config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool 34 | ) -> None: 35 | if getattr(config, "model_type", None) == "gemma2" and is_trainable: 36 | if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": 37 | if is_flash_attn_2_available(): 38 | check_version("transformers>=4.42.4") 39 | check_version("flash_attn>=2.6.3") 40 | if model_args.flash_attn != "fa2": 41 | logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") 42 | model_args.flash_attn = "fa2" 43 | else: 44 | logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.") 45 | model_args.flash_attn = "disabled" 46 | elif model_args.flash_attn == "sdpa": 47 | logger.warning_rank0( 48 | "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." 49 | ) 50 | 51 | if model_args.flash_attn == "auto": 52 | return 53 | 54 | elif model_args.flash_attn == "disabled": 55 | requested_attn_implementation = "eager" 56 | 57 | elif model_args.flash_attn == "sdpa": 58 | if not is_torch_sdpa_available(): 59 | logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") 60 | return 61 | 62 | requested_attn_implementation = "sdpa" 63 | elif model_args.flash_attn == "fa2": 64 | if not is_flash_attn_2_available(): 65 | logger.warning_rank0("FlashAttention-2 is not installed.") 66 | return 67 | 68 | requested_attn_implementation = "flash_attention_2" 69 | else: 70 | raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}") 71 | 72 | if getattr(config, "model_type", None) == "internlm2": # special case for custom models 73 | setattr(config, "attn_implementation", requested_attn_implementation) 74 | else: 75 | setattr(config, "_attn_implementation", requested_attn_implementation) 76 | 77 | 78 | def print_attn_implementation(config: "PretrainedConfig") -> None: 79 | if getattr(config, "model_type", None) == "internlm2": # special case for custom models 80 | attn_implementation = getattr(config, "attn_implementation", None) 81 | else: 82 | attn_implementation = getattr(config, "_attn_implementation", None) 83 | 84 | if attn_implementation == "flash_attention_2": 85 | logger.info_rank0("Using FlashAttention-2 for faster training and inference.") 86 | elif attn_implementation == "sdpa": 87 | logger.info_rank0("Using torch SDPA for faster training and inference.") 88 | else: 89 | logger.info_rank0("Using vanilla attention implementation.") 90 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from contextlib import nullcontext 17 | from typing import TYPE_CHECKING 18 | 19 | import torch 20 | from transformers.integrations import is_deepspeed_zero3_enabled 21 | 22 | from ...extras import logging 23 | 24 | 25 | if TYPE_CHECKING: 26 | from transformers import PreTrainedModel, PreTrainedTokenizer 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: 33 | embedding_dim = embed_weight.size(1) 34 | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) 35 | noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) 36 | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) 37 | embed_weight[-num_new_tokens:] = avg_weight + noise_weight 38 | 39 | 40 | def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: 41 | r""" 42 | Resize token embeddings. 43 | """ 44 | if is_deepspeed_zero3_enabled(): 45 | import deepspeed # type: ignore 46 | 47 | params = [model.get_input_embeddings().weight] 48 | if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: 49 | params.append(model.get_output_embeddings().weight) 50 | 51 | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) 52 | else: 53 | context_maybe_zero3 = nullcontext() 54 | 55 | with context_maybe_zero3: 56 | current_embedding_size = model.get_input_embeddings().weight.size(0) 57 | 58 | if len(tokenizer) > current_embedding_size: 59 | if getattr(model, "quantization_method", None): 60 | raise ValueError("Cannot resize embedding layers of a quantized model.") 61 | 62 | if not isinstance(model.get_output_embeddings(), torch.nn.Linear): 63 | raise ValueError("Current model does not support resizing embedding layers.") 64 | 65 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) 66 | with context_maybe_zero3: 67 | new_embedding_size = model.get_input_embeddings().weight.size(0) 68 | num_new_tokens = new_embedding_size - current_embedding_size 69 | _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) 70 | _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) 71 | 72 | logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") 73 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/liger_kernel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import TYPE_CHECKING 17 | 18 | from ...extras import logging 19 | 20 | 21 | if TYPE_CHECKING: 22 | from transformers import PretrainedConfig 23 | 24 | from ...hparams import ModelArguments 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def apply_liger_kernel( 31 | config: "PretrainedConfig", 32 | model_args: "ModelArguments", 33 | is_trainable: bool, 34 | require_logits: bool, 35 | ) -> None: 36 | if not is_trainable or not model_args.enable_liger_kernel: 37 | return 38 | 39 | model_type = getattr(config, "model_type", None) 40 | if model_type == "gemma": 41 | from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel 42 | elif model_type == "gemma2": 43 | from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel 44 | elif model_type == "llama": 45 | from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel 46 | elif model_type == "mistral": 47 | from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel 48 | elif model_type == "mixtral": 49 | from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel 50 | elif model_type == "phi3": 51 | from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel 52 | elif model_type == "qwen2": 53 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel 54 | elif model_type == "qwen2_vl": 55 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel 56 | else: 57 | logger.warning_rank0("Current model does not support liger kernel.") 58 | return 59 | 60 | if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: 61 | logger.info_rank0("Current training stage does not support chunked cross entropy.") 62 | kwargs = {"fused_linear_cross_entropy": False} 63 | else: 64 | kwargs = {} 65 | 66 | apply_liger_kernel(**kwargs) 67 | logger.info_rank0("Liger kernel has been applied to the model.") 68 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, List 16 | 17 | from ...extras import logging 18 | from .visual import COMPOSITE_MODELS 19 | 20 | 21 | if TYPE_CHECKING: 22 | from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: 29 | r""" 30 | Finds all available modules to apply LoRA, GaLore or APOLLO. 31 | """ 32 | model_type = getattr(model.config, "model_type", None) 33 | forbidden_modules = {"lm_head"} 34 | if model_type == "chatglm": 35 | forbidden_modules.add("output_layer") 36 | elif model_type == "internlm2": 37 | forbidden_modules.add("output") 38 | 39 | if model_type in COMPOSITE_MODELS: 40 | forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key) 41 | 42 | if freeze_vision_tower and model_type in COMPOSITE_MODELS: 43 | forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys) 44 | 45 | module_names = set() 46 | for name, module in model.named_modules(): 47 | if any(forbidden_module in name for forbidden_module in forbidden_modules): 48 | continue 49 | 50 | if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: 51 | module_names.add(name.split(".")[-1]) 52 | 53 | logger.info_rank0("Found linear modules: {}".format(",".join(module_names))) 54 | return list(module_names) 55 | 56 | 57 | def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: 58 | r""" 59 | Finds the modules in the expanded blocks to apply lora. 60 | """ 61 | num_layers = getattr(model.config, "num_hidden_layers", None) 62 | if not num_layers: 63 | raise ValueError("Model was not supported.") 64 | 65 | if num_layers % num_layer_trainable != 0: 66 | raise ValueError( 67 | f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}." 68 | ) 69 | 70 | stride = num_layers // num_layer_trainable 71 | trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) 72 | trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids] 73 | module_names = [] 74 | for name, _ in model.named_modules(): 75 | if any(target_module in name for target_module in target_modules) and any( 76 | trainable_layer in name for trainable_layer in trainable_layers 77 | ): 78 | module_names.append(name) 79 | 80 | logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) 81 | return module_names 82 | 83 | 84 | def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): 85 | if "AutoConfig" in getattr(config, "auto_map", {}): 86 | config.__class__.register_for_auto_class() 87 | if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): 88 | model.__class__.register_for_auto_class() 89 | if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): 90 | tokenizer.__class__.register_for_auto_class() 91 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/mod.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ...extras.constants import MOD_SUPPORTED_MODELS 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers import PretrainedConfig, PreTrainedModel 22 | 23 | from ...hparams import ModelArguments 24 | 25 | 26 | def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": 27 | from MoD import AutoMoDModelForCausalLM 28 | 29 | return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) 30 | 31 | 32 | def convert_pretrained_model_to_mod( 33 | model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" 34 | ) -> "PreTrainedModel": 35 | from MoD import apply_mod_to_hf 36 | 37 | if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: 38 | raise ValueError("Current model is not supported by mixture-of-depth.") 39 | 40 | model = apply_mod_to_hf(model) 41 | model = model.to(model_args.compute_dtype) 42 | return model 43 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/moe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Sequence 16 | 17 | import torch 18 | from transformers.integrations import is_deepspeed_zero3_enabled 19 | 20 | from ...extras.misc import check_version 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import PretrainedConfig, PreTrainedModel 25 | 26 | from ...hparams import ModelArguments 27 | 28 | 29 | def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: 30 | check_version("deepspeed>=0.13.0") 31 | from deepspeed.utils import set_z3_leaf_modules # type: ignore 32 | 33 | set_z3_leaf_modules(model, leaf_modules) 34 | 35 | 36 | def add_z3_leaf_module(model: "PreTrainedModel") -> None: 37 | r""" 38 | Sets module as a leaf module to skip partitioning in deepspeed zero3. 39 | """ 40 | if not is_deepspeed_zero3_enabled(): 41 | return 42 | 43 | model_type = getattr(model.config, "model_type", None) 44 | if model_type == "dbrx": 45 | from transformers.models.dbrx.modeling_dbrx import DbrxFFN 46 | 47 | _set_z3_leaf_modules(model, [DbrxFFN]) 48 | 49 | if model_type == "jamba": 50 | from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock 51 | 52 | _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) 53 | 54 | if model_type == "jetmoe": 55 | from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE 56 | 57 | _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) 58 | 59 | if model_type == "mixtral": 60 | from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock 61 | 62 | _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) 63 | 64 | if model_type == "qwen2_moe": 65 | from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock 66 | 67 | _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) 68 | 69 | 70 | def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 71 | model_type = getattr(config, "model_type", None) 72 | if model_args.moe_aux_loss_coef is not None: 73 | if model_type in ["jamba", "mixtral", "qwen2_moe"]: 74 | setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) 75 | 76 | elif model_type == "deepseek": 77 | setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) 78 | 79 | elif model_type == "jetmoe": 80 | setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) 81 | 82 | if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: 83 | setattr(config, "output_router_logits", is_trainable) 84 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 LMSYS and the LlamaFactory team. 2 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 3 | # 4 | # This code is inspired by the LMSYS's FastChat library. 5 | # https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import math 20 | from typing import TYPE_CHECKING 21 | 22 | from ...extras import logging 23 | 24 | 25 | if TYPE_CHECKING: 26 | from transformers import PretrainedConfig 27 | 28 | from ...hparams import ModelArguments 29 | 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 35 | if model_args.rope_scaling is None: 36 | return 37 | 38 | if not hasattr(config, "rope_scaling"): 39 | logger.warning_rank0("Current model does not support RoPE scaling.") 40 | return 41 | 42 | rope_kwargs = {} 43 | if model_args.model_max_length is not None: 44 | if is_trainable and model_args.rope_scaling == "dynamic": 45 | logger.warning_rank0( 46 | "Dynamic NTK scaling may not work well with fine-tuning. " 47 | "See: https://github.com/huggingface/transformers/pull/24653" 48 | ) 49 | 50 | current_max_length = getattr(config, "max_position_embeddings", None) 51 | if current_max_length and model_args.model_max_length > current_max_length: 52 | logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") 53 | setattr(config, "max_position_embeddings", model_args.model_max_length) 54 | rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) 55 | else: 56 | logger.warning_rank0("Input length is smaller than max length. Consider increase input length.") 57 | rope_kwargs["factor"] = 1.0 58 | 59 | if model_args.rope_scaling == "dynamic": 60 | rope_kwargs["original_max_position_embeddings"] = current_max_length 61 | elif model_args.rope_scaling == "llama3": 62 | rope_kwargs["original_max_position_embeddings"] = current_max_length 63 | rope_kwargs["low_freq_factor"] = 1.0 64 | rope_kwargs["high_freq_factor"] = 4.0 65 | else: 66 | rope_kwargs["factor"] = 2.0 67 | 68 | setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs}) 69 | logger.info_rank0( 70 | f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}." 71 | ) 72 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/unsloth.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Any, Dict, Optional 16 | 17 | from ...extras import logging 18 | from ...extras.misc import get_current_device 19 | 20 | 21 | if TYPE_CHECKING: 22 | from transformers import PretrainedConfig, PreTrainedModel 23 | 24 | from ...hparams import ModelArguments 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def _get_unsloth_kwargs( 31 | config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" 32 | ) -> Dict[str, Any]: 33 | return { 34 | "model_name": model_name_or_path, 35 | "max_seq_length": model_args.model_max_length or 4096, 36 | "dtype": model_args.compute_dtype, 37 | "load_in_4bit": model_args.quantization_bit == 4, 38 | "token": model_args.hf_hub_token, 39 | "device_map": {"": get_current_device()}, 40 | "rope_scaling": getattr(config, "rope_scaling", None), 41 | "fix_tokenizer": False, 42 | "trust_remote_code": model_args.trust_remote_code, 43 | "use_gradient_checkpointing": "unsloth", 44 | } 45 | 46 | 47 | def load_unsloth_pretrained_model( 48 | config: "PretrainedConfig", model_args: "ModelArguments" 49 | ) -> Optional["PreTrainedModel"]: 50 | r""" 51 | Optionally loads pretrained model with unsloth. Used in training. 52 | """ 53 | from unsloth import FastLanguageModel 54 | 55 | unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) 56 | try: 57 | model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) 58 | except NotImplementedError: 59 | logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) 60 | model = None 61 | model_args.use_unsloth = False 62 | 63 | return model 64 | 65 | 66 | def get_unsloth_peft_model( 67 | model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] 68 | ) -> "PreTrainedModel": 69 | r""" 70 | Gets the peft model for the pretrained model with unsloth. Used in training. 71 | """ 72 | from unsloth import FastLanguageModel 73 | 74 | unsloth_peft_kwargs = { 75 | "model": model, 76 | "max_seq_length": model_args.model_max_length, 77 | "use_gradient_checkpointing": "unsloth", 78 | } 79 | return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) 80 | 81 | 82 | def load_unsloth_peft_model( 83 | config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool 84 | ) -> "PreTrainedModel": 85 | r""" 86 | Loads peft model with unsloth. Used in both training and inference. 87 | """ 88 | from unsloth import FastLanguageModel 89 | 90 | unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) 91 | try: 92 | if not is_trainable: 93 | unsloth_kwargs["use_gradient_checkpointing"] = False 94 | 95 | model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) 96 | except NotImplementedError: 97 | raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) 98 | 99 | if not is_trainable: 100 | FastLanguageModel.for_inference(model) 101 | 102 | return model 103 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/model/model_utils/valuehead.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | import torch 18 | from transformers.utils import cached_file 19 | 20 | from ...extras import logging 21 | from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME 22 | 23 | 24 | if TYPE_CHECKING: 25 | from transformers import PreTrainedModel 26 | 27 | from ...hparams import ModelArguments 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: 34 | r""" 35 | Loads value head parameters from Hugging Face Hub or local disk. 36 | 37 | Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. 38 | """ 39 | kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} 40 | err_text = "" 41 | 42 | try: 43 | from safetensors import safe_open 44 | 45 | vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) 46 | with safe_open(vhead_file, framework="pt", device="cpu") as f: 47 | return {key: f.get_tensor(key) for key in f.keys()} 48 | except Exception as err: 49 | err_text = str(err) 50 | 51 | try: 52 | vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) 53 | return torch.load(vhead_file, map_location="cpu") 54 | except Exception as err: 55 | err_text = str(err) 56 | 57 | logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.") 58 | logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.") 59 | return None 60 | 61 | 62 | def prepare_valuehead_model(model: "PreTrainedModel") -> None: 63 | if getattr(model.config, "model_type", None) == "llava": 64 | setattr(model, "lm_head", model.language_model.get_output_embeddings()) 65 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 66 | 67 | if getattr(model.config, "model_type", None) == "chatglm": 68 | setattr(model, "lm_head", model.transformer.output_layer) 69 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 70 | 71 | if getattr(model.config, "model_type", None) == "internlm2": 72 | setattr(model, "lm_head", model.output) 73 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 74 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/train/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_dpo 16 | 17 | 18 | __all__ = ["run_dpo"] 19 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/kto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_kto 16 | 17 | 18 | __all__ = ["run_kto"] 19 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_ppo 16 | 17 | 18 | __all__ = ["run_ppo"] 19 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/ppo/ppo_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from contextlib import nullcontext 17 | from typing import TYPE_CHECKING, Dict, List, Literal, Optional 18 | 19 | import torch 20 | from transformers.integrations import is_deepspeed_zero3_enabled 21 | 22 | from ...extras.packages import is_requests_available 23 | 24 | 25 | if is_requests_available(): 26 | import requests 27 | 28 | 29 | if TYPE_CHECKING: 30 | from transformers import PreTrainedModel 31 | from trl import AutoModelForCausalLMWithValueHead 32 | 33 | 34 | def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]: 35 | r""" 36 | Gets reward scores from the API server. 37 | """ 38 | headers = {"Content-Type": "application/json"} 39 | payload = {"model": "model", "messages": messages} 40 | response = requests.post(server_url, json=payload, headers=headers) 41 | rewards = json.loads(response.text)["scores"] 42 | return torch.Tensor(rewards) 43 | 44 | 45 | def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: 46 | r""" 47 | Replaces the default/reward modules in the model. The model is already unwrapped. 48 | """ 49 | v_head_layer = model.v_head.summary 50 | if is_deepspeed_zero3_enabled(): 51 | import deepspeed # type: ignore 52 | 53 | params = [v_head_layer.weight, v_head_layer.bias] 54 | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) 55 | else: 56 | context_maybe_zero3 = nullcontext() 57 | 58 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active 59 | with context_maybe_zero3: 60 | if target == "reward": # save default head temporarily 61 | setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone()) 62 | setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) 63 | 64 | device = v_head_layer.weight.device 65 | v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device) 66 | v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device) 67 | 68 | 69 | def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: 70 | r""" 71 | Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). 72 | """ 73 | layer_norm_params = {} 74 | for name, param in model.named_parameters(): 75 | if param.data.dtype == torch.float32: 76 | layer_norm_params[name] = param.data.detach().clone() 77 | param.data = param.data.to(model.config.torch_dtype) 78 | 79 | return layer_norm_params 80 | 81 | 82 | def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None: 83 | r""" 84 | Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). 85 | """ 86 | for name, param in model.named_parameters(): 87 | if name in layernorm_params: 88 | param.data = layernorm_params[name] 89 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/ppo/workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's TRL library. 4 | # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import TYPE_CHECKING, List, Optional 19 | 20 | from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer 21 | from ...extras.ploting import plot_loss 22 | from ...model import load_model, load_tokenizer 23 | from ..callbacks import fix_valuehead_checkpoint 24 | from ..trainer_utils import create_ref_model, create_reward_model 25 | from .trainer import CustomPPOTrainer 26 | 27 | 28 | if TYPE_CHECKING: 29 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 30 | 31 | from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments 32 | 33 | 34 | def run_ppo( 35 | model_args: "ModelArguments", 36 | data_args: "DataArguments", 37 | training_args: "Seq2SeqTrainingArguments", 38 | finetuning_args: "FinetuningArguments", 39 | generating_args: "GeneratingArguments", 40 | callbacks: Optional[List["TrainerCallback"]] = None, 41 | ): 42 | tokenizer_module = load_tokenizer(model_args) 43 | tokenizer = tokenizer_module["tokenizer"] 44 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 45 | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module) 46 | model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) 47 | 48 | tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training 49 | data_collator = MultiModalDataCollatorForSeq2Seq(template=template, model=model, **tokenizer_module) 50 | 51 | # Create reference model and reward model 52 | ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) 53 | reward_model = create_reward_model(model, model_args, finetuning_args) 54 | 55 | # Initialize our Trainer 56 | ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer( 57 | model_args=model_args, 58 | training_args=training_args, 59 | finetuning_args=finetuning_args, 60 | generating_args=generating_args, 61 | callbacks=callbacks, 62 | model=model, 63 | reward_model=reward_model, 64 | ref_model=ref_model, 65 | data_collator=data_collator, 66 | **dataset_module, 67 | **tokenizer_module, 68 | ) 69 | 70 | # Training 71 | if training_args.do_train: 72 | ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) 73 | ppo_trainer.save_model() 74 | if training_args.should_save: 75 | fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) 76 | 77 | ppo_trainer.save_state() # must be called after save_model to have a folder 78 | if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: 79 | plot_loss(training_args.output_dir, keys=["loss", "reward"]) 80 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/pt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_pt 16 | 17 | 18 | __all__ = ["run_pt"] 19 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/pt/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from types import MethodType 16 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 17 | 18 | import torch 19 | from transformers import Trainer 20 | from typing_extensions import override 21 | 22 | from ...extras.packages import is_transformers_version_greater_than 23 | from ..callbacks import SaveProcessorCallback 24 | from ..trainer_utils import create_custom_optimizer, create_custom_scheduler 25 | 26 | 27 | if TYPE_CHECKING: 28 | from transformers import PreTrainedModel, ProcessorMixin 29 | 30 | from ...hparams import FinetuningArguments 31 | 32 | 33 | class CustomTrainer(Trainer): 34 | r""" 35 | Inherits Trainer for custom optimizer. 36 | """ 37 | 38 | def __init__( 39 | self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs 40 | ) -> None: 41 | if is_transformers_version_greater_than("4.46"): 42 | kwargs["processing_class"] = kwargs.pop("tokenizer") 43 | 44 | super().__init__(**kwargs) 45 | self.finetuning_args = finetuning_args 46 | 47 | if processor is not None: 48 | self.add_callback(SaveProcessorCallback(processor)) 49 | 50 | if finetuning_args.use_badam: 51 | from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore 52 | 53 | self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) 54 | self.add_callback(BAdamCallback) 55 | 56 | @override 57 | def create_optimizer(self) -> "torch.optim.Optimizer": 58 | if self.optimizer is None: 59 | self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) 60 | return super().create_optimizer() 61 | 62 | @override 63 | def create_scheduler( 64 | self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None 65 | ) -> "torch.optim.lr_scheduler.LRScheduler": 66 | create_custom_scheduler(self.args, num_training_steps, optimizer) 67 | return super().create_scheduler(num_training_steps, optimizer) 68 | 69 | @override 70 | def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: 71 | if self.finetuning_args.disable_shuffling: 72 | return torch.utils.data.SequentialSampler(self.train_dataset) 73 | 74 | return super()._get_train_sampler() 75 | 76 | @override 77 | def compute_loss( 78 | self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs 79 | ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: 80 | r""" 81 | Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. 82 | 83 | It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged. 84 | """ 85 | loss = super().compute_loss(model, inputs, return_outputs, **kwargs) 86 | if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): 87 | if return_outputs: 88 | loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) 89 | else: 90 | loss = loss / self.args.gradient_accumulation_steps 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/pt/workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import math 19 | from typing import TYPE_CHECKING, List, Optional 20 | 21 | from transformers import DataCollatorForLanguageModeling 22 | 23 | from ...data import get_dataset, get_template_and_fix_tokenizer 24 | from ...extras.ploting import plot_loss 25 | from ...model import load_model, load_tokenizer 26 | from ..trainer_utils import create_modelcard_and_push 27 | from .trainer import CustomTrainer 28 | 29 | 30 | if TYPE_CHECKING: 31 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 32 | 33 | from ...hparams import DataArguments, FinetuningArguments, ModelArguments 34 | 35 | 36 | def run_pt( 37 | model_args: "ModelArguments", 38 | data_args: "DataArguments", 39 | training_args: "Seq2SeqTrainingArguments", 40 | finetuning_args: "FinetuningArguments", 41 | callbacks: Optional[List["TrainerCallback"]] = None, 42 | ): 43 | tokenizer_module = load_tokenizer(model_args) 44 | tokenizer = tokenizer_module["tokenizer"] 45 | template = get_template_and_fix_tokenizer(tokenizer, data_args) 46 | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) 47 | model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) 48 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 49 | 50 | # Initialize our Trainer 51 | trainer = CustomTrainer( 52 | model=model, 53 | args=training_args, 54 | finetuning_args=finetuning_args, 55 | data_collator=data_collator, 56 | callbacks=callbacks, 57 | **dataset_module, 58 | **tokenizer_module, 59 | ) 60 | 61 | # Training 62 | if training_args.do_train: 63 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 64 | trainer.save_model() 65 | trainer.log_metrics("train", train_result.metrics) 66 | trainer.save_metrics("train", train_result.metrics) 67 | trainer.save_state() 68 | if trainer.is_world_process_zero() and finetuning_args.plot_loss: 69 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 70 | 71 | # Evaluation 72 | if training_args.do_eval: 73 | metrics = trainer.evaluate(metric_key_prefix="eval") 74 | try: 75 | perplexity = math.exp(metrics["eval_loss"]) 76 | except OverflowError: 77 | perplexity = float("inf") 78 | 79 | metrics["perplexity"] = perplexity 80 | trainer.log_metrics("eval", metrics) 81 | trainer.save_metrics("eval", metrics) 82 | 83 | # Create model card 84 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) 85 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/rm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_rm 16 | 17 | 18 | __all__ = ["run_rm"] 19 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/rm/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import TYPE_CHECKING, Dict, Optional 17 | 18 | import numpy as np 19 | 20 | from ...extras.misc import numpify 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import EvalPrediction 25 | 26 | 27 | @dataclass 28 | class ComputeAccuracy: 29 | r""" 30 | Computes reward accuracy and supports `batch_eval_metrics`. 31 | """ 32 | 33 | def _dump(self) -> Optional[Dict[str, float]]: 34 | result = None 35 | if hasattr(self, "score_dict"): 36 | result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} 37 | 38 | self.score_dict = {"accuracy": []} 39 | return result 40 | 41 | def __post_init__(self): 42 | self._dump() 43 | 44 | def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: 45 | chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) 46 | if not chosen_scores.shape: 47 | self.score_dict["accuracy"].append(chosen_scores > rejected_scores) 48 | else: 49 | for i in range(len(chosen_scores)): 50 | self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) 51 | 52 | if compute_result: 53 | return self._dump() 54 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/train/sft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_sft 16 | 17 | 18 | __all__ = ["run_sft"] 19 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanhuajie/Reason-RFT/3180557dce2065e92754352dacf83eccb26ac032/train/stage_sft/llamafactory/webui/__init__.py -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .chatbot import create_chat_box 16 | from .eval import create_eval_tab 17 | from .export import create_export_tab 18 | from .infer import create_infer_tab 19 | from .top import create_top 20 | from .train import create_train_tab 21 | 22 | 23 | __all__ = [ 24 | "create_chat_box", 25 | "create_eval_tab", 26 | "create_export_tab", 27 | "create_infer_tab", 28 | "create_top", 29 | "create_train_tab", 30 | ] 31 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/components/chatbot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict, Tuple 16 | 17 | from ...data import Role 18 | from ...extras.packages import is_gradio_available 19 | from ..utils import check_json_schema 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_chat_box( 33 | engine: "Engine", visible: bool = False 34 | ) -> Tuple["Component", "Component", Dict[str, "Component"]]: 35 | with gr.Column(visible=visible) as chat_box: 36 | chatbot = gr.Chatbot(type="messages", show_copy_button=True) 37 | messages = gr.State([]) 38 | with gr.Row(): 39 | with gr.Column(scale=4): 40 | with gr.Row(): 41 | with gr.Column(): 42 | role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) 43 | system = gr.Textbox(show_label=False) 44 | tools = gr.Textbox(show_label=False, lines=3) 45 | 46 | with gr.Column() as mm_box: 47 | with gr.Tab("Image"): 48 | image = gr.Image(sources=["upload"], type="pil") 49 | 50 | with gr.Tab("Video"): 51 | video = gr.Video(sources=["upload"]) 52 | 53 | query = gr.Textbox(show_label=False, lines=8) 54 | submit_btn = gr.Button(variant="primary") 55 | 56 | with gr.Column(scale=1): 57 | max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) 58 | top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01) 59 | temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) 60 | clear_btn = gr.Button() 61 | 62 | tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) 63 | 64 | submit_btn.click( 65 | engine.chatter.append, 66 | [chatbot, messages, role, query], 67 | [chatbot, messages, query], 68 | ).then( 69 | engine.chatter.stream, 70 | [chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature], 71 | [chatbot, messages], 72 | ) 73 | clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) 74 | 75 | return ( 76 | chatbot, 77 | messages, 78 | dict( 79 | chat_box=chat_box, 80 | role=role, 81 | system=system, 82 | tools=tools, 83 | mm_box=mm_box, 84 | image=image, 85 | video=video, 86 | query=query, 87 | submit_btn=submit_btn, 88 | max_new_tokens=max_new_tokens, 89 | top_p=top_p, 90 | temperature=temperature, 91 | clear_btn=clear_btn, 92 | ), 93 | ) 94 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/components/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | from ...extras.packages import is_gradio_available 18 | from ..common import DEFAULT_DATA_DIR, list_datasets 19 | from .data import create_preview_box 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: 33 | input_elems = engine.manager.get_base_elems() 34 | elem_dict = dict() 35 | 36 | with gr.Row(): 37 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 38 | dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) 39 | preview_elems = create_preview_box(dataset_dir, dataset) 40 | 41 | input_elems.update({dataset_dir, dataset}) 42 | elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) 43 | 44 | with gr.Row(): 45 | cutoff_len = gr.Slider(minimum=4, maximum=131072, value=1024, step=1) 46 | max_samples = gr.Textbox(value="100000") 47 | batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1) 48 | predict = gr.Checkbox(value=True) 49 | 50 | input_elems.update({cutoff_len, max_samples, batch_size, predict}) 51 | elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict)) 52 | 53 | with gr.Row(): 54 | max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) 55 | top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01) 56 | temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) 57 | output_dir = gr.Textbox() 58 | 59 | input_elems.update({max_new_tokens, top_p, temperature, output_dir}) 60 | elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir)) 61 | 62 | with gr.Row(): 63 | cmd_preview_btn = gr.Button() 64 | start_btn = gr.Button(variant="primary") 65 | stop_btn = gr.Button(variant="stop") 66 | 67 | with gr.Row(): 68 | resume_btn = gr.Checkbox(visible=False, interactive=False) 69 | progress_bar = gr.Slider(visible=False, interactive=False) 70 | 71 | with gr.Row(): 72 | output_box = gr.Markdown() 73 | 74 | elem_dict.update( 75 | dict( 76 | cmd_preview_btn=cmd_preview_btn, 77 | start_btn=start_btn, 78 | stop_btn=stop_btn, 79 | resume_btn=resume_btn, 80 | progress_bar=progress_bar, 81 | output_box=output_box, 82 | ) 83 | ) 84 | output_elems = [output_box, progress_bar] 85 | 86 | cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None) 87 | start_btn.click(engine.runner.run_eval, input_elems, output_elems) 88 | stop_btn.click(engine.runner.set_abort) 89 | resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) 90 | 91 | dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False) 92 | 93 | return elem_dict 94 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/components/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | from ...extras.packages import is_gradio_available 18 | from ..common import get_visual 19 | from .chatbot import create_chat_box 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: 33 | input_elems = engine.manager.get_base_elems() 34 | elem_dict = dict() 35 | 36 | with gr.Row(): 37 | infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") 38 | infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") 39 | 40 | with gr.Row(): 41 | load_btn = gr.Button() 42 | unload_btn = gr.Button() 43 | 44 | info_box = gr.Textbox(show_label=False, interactive=False) 45 | 46 | input_elems.update({infer_backend, infer_dtype}) 47 | elem_dict.update( 48 | dict( 49 | infer_backend=infer_backend, 50 | infer_dtype=infer_dtype, 51 | load_btn=load_btn, 52 | unload_btn=unload_btn, 53 | info_box=info_box, 54 | ) 55 | ) 56 | 57 | chatbot, messages, chat_elems = create_chat_box(engine, visible=False) 58 | elem_dict.update(chat_elems) 59 | 60 | load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( 61 | lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]] 62 | ) 63 | 64 | unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( 65 | lambda: ([], []), outputs=[chatbot, messages] 66 | ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]) 67 | 68 | engine.manager.get_elem_by_id("top.model_name").change( 69 | lambda model_name: gr.Column(visible=get_visual(model_name)), 70 | [engine.manager.get_elem_by_id("top.model_name")], 71 | [chat_elems["mm_box"]], 72 | ) 73 | 74 | return elem_dict 75 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/components/top.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict 16 | 17 | from ...data import TEMPLATES 18 | from ...extras.constants import METHODS, SUPPORTED_MODELS 19 | from ...extras.packages import is_gradio_available 20 | from ..common import get_model_info, list_checkpoints, save_config 21 | from ..utils import can_quantize, can_quantize_to 22 | 23 | 24 | if is_gradio_available(): 25 | import gradio as gr 26 | 27 | 28 | if TYPE_CHECKING: 29 | from gradio.components import Component 30 | 31 | 32 | def create_top() -> Dict[str, "Component"]: 33 | with gr.Row(): 34 | lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1) 35 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] 36 | model_name = gr.Dropdown(choices=available_models, value=None, scale=3) 37 | model_path = gr.Textbox(scale=3) 38 | 39 | with gr.Row(): 40 | finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) 41 | checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) 42 | 43 | with gr.Row(): 44 | quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True) 45 | quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes") 46 | template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default") 47 | rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none") 48 | booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto") 49 | 50 | model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( 51 | list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False 52 | ) 53 | model_name.input(save_config, inputs=[lang, model_name], queue=False) 54 | model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) 55 | finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( 56 | list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False 57 | ) 58 | checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) 59 | quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False) 60 | 61 | return dict( 62 | lang=lang, 63 | model_name=model_name, 64 | model_path=model_path, 65 | finetuning_type=finetuning_type, 66 | checkpoint_path=checkpoint_path, 67 | quantization_bit=quantization_bit, 68 | quantization_method=quantization_method, 69 | template=template, 70 | rope_scaling=rope_scaling, 71 | booster=booster, 72 | ) 73 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/css.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CSS = r""" 16 | .duplicate-button { 17 | margin: auto !important; 18 | color: white !important; 19 | background: black !important; 20 | border-radius: 100vh !important; 21 | } 22 | 23 | .modal-box { 24 | position: fixed !important; 25 | top: 50%; 26 | left: 50%; 27 | transform: translate(-50%, -50%); /* center horizontally */ 28 | max-width: 1000px; 29 | max-height: 750px; 30 | overflow-y: auto; 31 | background-color: var(--input-background-fill); 32 | flex-wrap: nowrap !important; 33 | border: 2px solid black !important; 34 | z-index: 1000; 35 | padding: 10px; 36 | } 37 | 38 | .dark .modal-box { 39 | border: 2px solid white !important; 40 | } 41 | """ 42 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Any, Dict 16 | 17 | from .chatter import WebChatModel 18 | from .common import load_config 19 | from .locales import LOCALES 20 | from .manager import Manager 21 | from .runner import Runner 22 | from .utils import create_ds_config, get_time 23 | 24 | 25 | if TYPE_CHECKING: 26 | from gradio.components import Component 27 | 28 | 29 | class Engine: 30 | def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: 31 | self.demo_mode = demo_mode 32 | self.pure_chat = pure_chat 33 | self.manager = Manager() 34 | self.runner = Runner(self.manager, demo_mode) 35 | self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) 36 | if not demo_mode: 37 | create_ds_config() 38 | 39 | def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: 40 | r""" 41 | Gets the dict to update the components. 42 | """ 43 | output_dict: Dict["Component", "Component"] = {} 44 | for elem_id, elem_attr in input_dict.items(): 45 | elem = self.manager.get_elem_by_id(elem_id) 46 | output_dict[elem] = elem.__class__(**elem_attr) 47 | 48 | return output_dict 49 | 50 | def resume(self): 51 | user_config = load_config() if not self.demo_mode else {} 52 | lang = user_config.get("lang", None) or "en" 53 | 54 | init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} 55 | 56 | if not self.pure_chat: 57 | current_time = get_time() 58 | init_dict["train.current_time"] = {"value": current_time} 59 | init_dict["train.output_dir"] = {"value": f"train_{current_time}"} 60 | init_dict["train.config_path"] = {"value": f"{current_time}.yaml"} 61 | init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"} 62 | init_dict["infer.mm_box"] = {"visible": False} 63 | 64 | if user_config.get("last_model", None): 65 | init_dict["top.model_name"] = {"value": user_config["last_model"]} 66 | 67 | yield self._update_component(init_dict) 68 | 69 | if self.runner.running and not self.demo_mode and not self.pure_chat: 70 | yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()} 71 | if self.runner.do_train: 72 | yield self._update_component({"train.resume_btn": {"value": True}}) 73 | else: 74 | yield self._update_component({"eval.resume_btn": {"value": True}}) 75 | 76 | def change_lang(self, lang: str): 77 | return { 78 | elem: elem.__class__(**LOCALES[elem_name][lang]) 79 | for elem_name, elem in self.manager.get_elem_iter() 80 | if elem_name in LOCALES 81 | } 82 | -------------------------------------------------------------------------------- /train/stage_sft/llamafactory/webui/manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple 16 | 17 | 18 | if TYPE_CHECKING: 19 | from gradio.components import Component 20 | 21 | 22 | class Manager: 23 | def __init__(self) -> None: 24 | self._id_to_elem: Dict[str, "Component"] = {} 25 | self._elem_to_id: Dict["Component", str] = {} 26 | 27 | def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None: 28 | r""" 29 | Adds elements to manager. 30 | """ 31 | for elem_name, elem in elem_dict.items(): 32 | elem_id = f"{tab_name}.{elem_name}" 33 | self._id_to_elem[elem_id] = elem 34 | self._elem_to_id[elem] = elem_id 35 | 36 | def get_elem_list(self) -> List["Component"]: 37 | r""" 38 | Returns the list of all elements. 39 | """ 40 | return list(self._id_to_elem.values()) 41 | 42 | def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: 43 | r""" 44 | Returns an iterator over all elements with their names. 45 | """ 46 | for elem_id, elem in self._id_to_elem.items(): 47 | yield elem_id.split(".")[-1], elem 48 | 49 | def get_elem_by_id(self, elem_id: str) -> "Component": 50 | r""" 51 | Gets element by id. 52 | 53 | Example: top.lang, train.dataset 54 | """ 55 | return self._id_to_elem[elem_id] 56 | 57 | def get_id_by_elem(self, elem: "Component") -> str: 58 | r""" 59 | Gets id by element. 60 | """ 61 | return self._elem_to_id[elem] 62 | 63 | def get_base_elems(self) -> Set["Component"]: 64 | r""" 65 | Gets the base elements that are commonly used. 66 | """ 67 | return { 68 | self._id_to_elem["top.lang"], 69 | self._id_to_elem["top.model_name"], 70 | self._id_to_elem["top.model_path"], 71 | self._id_to_elem["top.finetuning_type"], 72 | self._id_to_elem["top.checkpoint_path"], 73 | self._id_to_elem["top.quantization_bit"], 74 | self._id_to_elem["top.quantization_method"], 75 | self._id_to_elem["top.template"], 76 | self._id_to_elem["top.rope_scaling"], 77 | self._id_to_elem["top.booster"], 78 | } 79 | -------------------------------------------------------------------------------- /train/stage_sft/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llamafactory.train.tuner import run_exp 16 | 17 | 18 | def main(): 19 | run_exp() 20 | 21 | 22 | def _mp_fn(index): 23 | # For xla_spawn (TPUs) 24 | run_exp() 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /train/stage_sft/webui.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from llamafactory.webui.interface import create_ui 18 | 19 | 20 | def main(): 21 | gradio_ipv6 = os.getenv("GRADIO_IPV6", "0").lower() in ["true", "1"] 22 | gradio_share = os.getenv("GRADIO_SHARE", "0").lower() in ["true", "1"] 23 | server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") 24 | create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /utils/convert_qwen2vl_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | 5 | 6 | def transform_item(item): 7 | """ 8 | Transform a single item into the desired format. 9 | """ 10 | # Ensure images are a list 11 | images = item.get("image", []) 12 | if isinstance(images, str): 13 | images = [images] 14 | 15 | # Transform conversations into messages 16 | messages = [ 17 | { 18 | "content": conversation["value"], 19 | "role": "user" if conversation["from"] == "human" else "assistant" 20 | } 21 | for conversation in item["conversations"] 22 | ] 23 | 24 | return { 25 | "messages": messages, 26 | "images": images 27 | } 28 | 29 | 30 | def process_json(input_file, output_file): 31 | """ 32 | Read the input JSON file, process each item, and write the transformed items to the output file. 33 | """ 34 | with open(input_file, "r", encoding="utf-8") as infile: 35 | data = json.load(infile) 36 | 37 | transformed_data = [] 38 | for item in tqdm(data, desc="Processing items"): 39 | transformed_data.append(transform_item(item)) 40 | # print(transformed_data[-1]) 41 | 42 | with open(output_file, "w", encoding="utf-8") as outfile: 43 | json.dump(transformed_data, outfile, indent=4, ensure_ascii=False) 44 | 45 | print(f"Transformation complete! Output saved to {output_file}") 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser(description="Transform JSON items into the desired format.") 50 | parser.add_argument("--input", default="/home/vlm/finetune_json/llava_v1_5_mix665k.json", help="Path to the input JSON file.") 51 | parser.add_argument("--output", default="/home/vlm/finetune_json/llava_v1_5_mix665k_qwen2vl_format.json", help="Path to the output JSON file.") 52 | args = parser.parse_args() 53 | 54 | process_json(args.input, args.output) 55 | --------------------------------------------------------------------------------