├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── accelerate ├── fsdp_full_config.yaml ├── fsdp_full_config_gemma3.yaml ├── fsdp_gradop_config.yaml ├── fsdp_gradop_config_gemma3.yaml ├── multigpu_dp_config.yaml ├── stage2_config.yaml ├── stage3_config.yaml └── tp_config.yaml ├── deepspeed_configs ├── ds_config_stage_2.json └── ds_config_stage_3.json ├── poetry.lock ├── pyproject.toml ├── scripts ├── inference │ ├── batched_generation.py │ ├── gradio_openai_chatbot.py │ ├── rm_rejection_sampling.py │ └── rm_scoring.py ├── model_training │ ├── classification.py │ ├── cpo.py │ ├── distill.py │ ├── dpo.py │ ├── gpo.py │ ├── orpo.py │ ├── rewards.py │ ├── sft.py │ └── smpo.py ├── post_training │ ├── merge_peft_adapters.py │ └── mergekit_example_config.yaml └── prompts_training │ ├── reward.py │ └── sft.py ├── src ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── attr_scheduling.py │ ├── efficiency_callback.py │ ├── generate_examples.py │ └── training_parameters_callback.py ├── collators │ ├── __init__.py │ └── completions_only.py ├── configs │ ├── __init__.py │ ├── additional │ │ ├── __init__.py │ │ ├── classification_args.py │ │ ├── common_script_args.py │ │ ├── cpo_args.py │ │ ├── dpo_args.py │ │ ├── gpo_args.py │ │ ├── orpo_args.py │ │ ├── reward_args.py │ │ ├── sft_args.py │ │ └── smpo_args.py │ ├── classificaion_config.py │ ├── gpo_config.py │ ├── prompts_optimization_comfig.py │ └── smpo_config.py ├── trainers │ ├── __init__.py │ ├── classification_trainer.py │ ├── gpo_trainer.py │ ├── prompts_optimization │ │ ├── __init__.py │ │ ├── prompts_reward_trainer.py │ │ ├── prompts_sft_trainer.py │ │ └── vq_prompts_tuner_module.py │ └── smpo_trainer.py └── utils │ ├── __init__.py │ ├── array_utils.py │ ├── datasets.py │ ├── embeddings_utils.py │ ├── logger.py │ ├── model_preparation.py │ └── yaml_args_parser.py └── training_configs ├── classification └── controllable-clf-qwen-1.5b-no-tags-full.yaml ├── preference ├── rpo-sigmoid-llama-3-lora-best-rs.yaml ├── simpo-llama-3.1-lora-best-rs.yaml ├── slic-llama-3-lora-best-rs.yaml ├── smpo-llama-3.1-lora-best-rs.yaml ├── smpo-mistral-nemo-lora-best-rs.yaml ├── smpo-phi4-lora-best-rs-v1.yaml ├── smpo-phi4-lora-best-rs-v12.yaml ├── smpo-phi4-lora-best-rs-v14.yaml ├── smpo-phi4-lora-best-rs-v2.yaml ├── smpo-phi4-lora-best-rs-v3.yaml ├── smpo-phi4-lora-best-rs-v4.yaml ├── smpo-phi4-lora-best-rs-v7.yaml ├── smpo-phi4-lora-best-rs-v9.yaml ├── smpo-qvikhr2.5-1.5b-lora-best-rs.yaml └── sppo-llama-3-lora-best-rs.yaml ├── prompts-tuning ├── controllable-rm-qwen-1.5b-no-init.yaml └── controllable-rm-qwen-1.5b.yaml ├── reward ├── controllable-rm-qwen-1.5b-no-tags-full.yaml ├── controllable-rm-qwen-1.5b-no-tags-lora.yaml ├── controllable-rm-qwen-3b-no-tags-full.yaml ├── controllable-rm-qwen-7b-no-tags-full.yaml ├── controllable-rm-qwen-7b-no-tags-lora.yaml ├── rm-llama-3-fsfairx-lora-arena.yaml ├── rm-llama-3.1-8b-it-lora-arena.yaml ├── rm-llama-3.1-8b-lora-arena.yaml └── rm-qwen-14b-v2.yaml └── sft ├── sft-gemma-2-2b-it-lora-ficbook.yaml ├── sft-llama-3.1-8b-it-full-Grandmaster.yaml ├── sft-llama-3.1-8b-it-lora-GRAG.yaml ├── sft-llama-3.1-8b-it-lora-GrandmasterRAG-v1.yaml ├── sft-mistral-nemo-12b-lora-GrandmasterRAG-v1.yaml ├── sft-phi4-full-GrandmasterRAG-v2.yaml ├── sft-phi4-lora-GrandmasterRAG-v4.yaml └── sft-yandex-lora-GrandmasterRAG.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # checkpoints 10 | checkpoints/* 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | data/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | .idea/ 168 | 169 | notebooks/ 170 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10.16 2 | -------------------------------------------------------------------------------- /accelerate/fsdp_full_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: true 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: FULL_SHARD 13 | fsdp_state_dict_type: SHARDED_STATE_DICT 14 | fsdp_sync_module_states: true 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | main_training_function: main 18 | num_machines: 1 19 | num_processes: 3 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false -------------------------------------------------------------------------------- /accelerate/fsdp_full_config_gemma3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_transformer_layer_cls_to_wrap: Gemma3DecoderLayer 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: false 13 | fsdp_sharding_strategy: FULL_SHARD 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | machine_rank: 0 18 | main_training_function: main 19 | num_machines: 1 20 | num_processes: 4 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false -------------------------------------------------------------------------------- /accelerate/fsdp_gradop_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: true 11 | fsdp_offload_params: False 12 | fsdp_sharding_strategy: SHARD_GRAD_OP 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_sync_module_states: true 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | main_training_function: main 18 | num_machines: 1 19 | num_processes: 4 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false -------------------------------------------------------------------------------- /accelerate/fsdp_gradop_config_gemma3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_transformer_layer_cls_to_wrap: Gemma3DecoderLayer 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: False 13 | fsdp_sharding_strategy: SHARD_GRAD_OP 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | machine_rank: 0 18 | main_training_function: main 19 | num_machines: 1 20 | num_processes: 4 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false -------------------------------------------------------------------------------- /accelerate/multigpu_dp_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 4 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false -------------------------------------------------------------------------------- /accelerate/stage2_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: deepspeed_configs/ds_config_stage_2.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | enable_cpu_affinity: false 9 | machine_rank: 0 10 | main_training_function: main 11 | num_machines: 1 12 | num_processes: 2 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /accelerate/stage3_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: deepspeed_configs/ds_config_stage_3.json 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | enable_cpu_affinity: false 9 | machine_rank: 0 10 | main_training_function: main 11 | num_machines: 1 12 | num_processes: 6 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /accelerate/tp_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: TP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | machine_rank: 0 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 4 10 | rdzv_backend: static 11 | same_network: true 12 | tp_config: 13 | tp_size: 4 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false -------------------------------------------------------------------------------- /deepspeed_configs/ds_config_stage_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 512 5 | }, 6 | "bf16": { 7 | "enabled": "auto" 8 | }, 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": "auto", 13 | "betas": "auto", 14 | "eps": "auto", 15 | "weight_decay": "auto" 16 | } 17 | }, 18 | "scheduler": { 19 | "type": "WarmupDecayLR", 20 | "params": { 21 | "warmup_min_lr": "auto", 22 | "warmup_max_lr": "auto", 23 | "warmup_num_steps": "auto", 24 | "total_num_steps": "auto" 25 | } 26 | }, 27 | "zero_optimization": { 28 | "stage": 2, 29 | "overlap_comm": true, 30 | "contiguous_gradients": true, 31 | "sub_group_size": 1e9, 32 | "reduce_bucket_size": "auto" 33 | }, 34 | "gradient_accumulation_steps": "auto", 35 | "gradient_clipping": "auto", 36 | "steps_per_print": 2, 37 | "train_batch_size": "auto", 38 | "train_micro_batch_size_per_gpu": "auto", 39 | "wall_clock_breakdown": false 40 | } 41 | -------------------------------------------------------------------------------- /deepspeed_configs/ds_config_stage_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 512 5 | }, 6 | "bf16": { 7 | "enabled": "auto" 8 | }, 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": "auto", 13 | "betas": "auto", 14 | "eps": "auto", 15 | "weight_decay": "auto" 16 | } 17 | }, 18 | "scheduler": { 19 | "type": "WarmupDecayLR", 20 | "params": { 21 | "warmup_min_lr": "auto", 22 | "warmup_max_lr": "auto", 23 | "warmup_num_steps": "auto", 24 | "total_num_steps": "auto" 25 | } 26 | }, 27 | 28 | "zero_optimization": { 29 | "stage": 3, 30 | "allgather_partitions": true, 31 | "allgather_bucket_size": 2e8, 32 | "overlap_comm": true, 33 | "reduce_scatter": true, 34 | "reduce_bucket_size": "auto", 35 | "contiguous_gradients": true, 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "stage3_gather_16bit_weights_on_model_save": true 41 | }, 42 | 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 2, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "llm_alignment_playground" 3 | version = "0.1.0" 4 | description = "Fine-Tuning Engine" 5 | authors = ["Sergei Bratchikov "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "3.10.*" 10 | jupyter = "^1.0.0" 11 | pandas = "^2.2.1" 12 | datasets = "^3.5.0" 13 | tokenizers = "==0.21.0" 14 | loguru = "^0.7.2" 15 | pydantic = "^2.6.4" 16 | pydantic-settings = "^2.2.1" 17 | wandb = "^0.19.6" 18 | evaluate = "^0.4.2" 19 | transformers = "^4.51.0" 20 | accelerate = "^1.6.0" 21 | peft = "^0.12.0" 22 | s3cmd = "^2.4.0" 23 | clearml = "^1.16.1" 24 | torch = "==2.5.1" 25 | deepspeed = {platform = 'linux', version = "==0.14.5"} 26 | sentence-transformers = "^3.0.1" 27 | trl = "0.16.0" 28 | vllm = {platform = 'linux', version="==0.6.5"} 29 | xformers = {platform = 'linux', version="==0.0.28.post3"} 30 | tabulate = "^0.9.0" 31 | orjson = "^3.10.15" 32 | gradio = "==4.43.0" 33 | dvc = "==3.53.1" 34 | liger-kernel = "^0.5.6" 35 | huggingface-hub = "^0.30.1" 36 | openai = "^1.44.1" 37 | bitsandbytes = "^0.45.2" 38 | galore-torch = "^1.0" 39 | apollo-torch = "^1.0.3" 40 | torchao = "^0.10.0" 41 | ruff = "^0.9.9" 42 | faiss-cpu = "^1.10.0" 43 | 44 | 45 | [build-system] 46 | requires = ["poetry-core"] 47 | build-backend = "poetry.core.masonry.api" 48 | -------------------------------------------------------------------------------- /scripts/inference/gradio_openai_chatbot.py: -------------------------------------------------------------------------------- 1 | # 2 | # Python script to start Gradio chat application for using served models using OpenAI API 3 | # VLLM: python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --served-model-name custom_model -tp 1 --max-model-len 8000 --dtype auto --api-key secret_token_228 --host localhost 4 | # 5 | 6 | import argparse 7 | 8 | import gradio as gr 9 | from openai import OpenAI 10 | 11 | # Argument parser setup 12 | parser = argparse.ArgumentParser( 13 | description='Chatbot Interface with Customizable Parameters') 14 | parser.add_argument('--model-url', 15 | type=str, 16 | default='http://localhost:8000/v1', 17 | help='Model URL') 18 | parser.add_argument('--api-key', 19 | type=str, 20 | default='secret_token_228', 21 | help='api key') 22 | parser.add_argument('-m', 23 | '--model', 24 | type=str, 25 | default='custom_model', 26 | help='Model name for the chatbot') 27 | parser.add_argument('--stop-token-ids', 28 | type=str, 29 | default='', 30 | help='Comma-separated stop token IDs') 31 | parser.add_argument("--host", type=str, default=None) 32 | parser.add_argument("--port", type=int, default=8730) 33 | parser.add_argument("--share", type=bool, default=True) 34 | 35 | # Parse the arguments 36 | args = parser.parse_args() 37 | 38 | # Set OpenAI's API key and API base to use vLLM's API server. 39 | openai_api_key = args.api_key 40 | openai_api_base = args.model_url 41 | 42 | # Create an OpenAI client to interact with the API server 43 | client = OpenAI( 44 | api_key=openai_api_key, 45 | base_url=openai_api_base, 46 | ) 47 | 48 | 49 | def predict(message: str, history: list, system_prompt: str, temp: float, top_p: float, top_k: int): 50 | # Convert chat history to OpenAI format 51 | history_openai_format = [] 52 | if system_prompt: 53 | history_openai_format.append({"role": "system", "content": system_prompt}) 54 | for human, assistant in history: 55 | history_openai_format.append({"role": "user", "content": human}) 56 | history_openai_format.append({ 57 | "role": "assistant", 58 | "content": assistant 59 | }) 60 | history_openai_format.append({"role": "user", "content": message}) 61 | 62 | # Create a chat completion request and send it to the API server 63 | stream = client.chat.completions.create( 64 | model=args.model, # Model name to use 65 | messages=history_openai_format, # Chat history 66 | temperature=temp, # Temperature for text generation 67 | top_p=top_p, 68 | stream=True, # Stream response 69 | extra_body={ 70 | 'top_k': top_k, 71 | 'repetition_penalty': 1, 72 | 'stop_token_ids': [ 73 | int(id.strip()) for id in args.stop_token_ids.split(',') if id.strip() 74 | ] if args.stop_token_ids else [] 75 | }) 76 | 77 | # Read and return generated text from response stream 78 | partial_message = "" 79 | for chunk in stream: 80 | partial_message += (chunk.choices[0].delta.content or "") 81 | yield partial_message 82 | 83 | 84 | # Create and launch a chat interface with Gradio 85 | with gr.Blocks() as demo: 86 | gr.ChatInterface( 87 | predict, 88 | fill_height=False, 89 | additional_inputs=[ 90 | gr.Textbox(label="System prompt", max_lines=2, render=False), 91 | gr.Slider(0, 1, step=0.1, label="Temperature", value=1.0, render=False), 92 | gr.Slider(0.3, 1, step=0.1, label="Top P", value=1.0, render=False), 93 | gr.Slider(10, 100, step=10, label="Top K", value=50, render=False) 94 | ], 95 | title='Custom model vLLM test' 96 | ) 97 | 98 | demo.queue().launch( 99 | server_name=args.host, 100 | server_port=args.port, 101 | share=args.share 102 | ) 103 | -------------------------------------------------------------------------------- /scripts/model_training/classification.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | from typing import List 7 | 8 | import torch 9 | from accelerate import PartialState 10 | from accelerate.logging import get_logger 11 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed 12 | from trl import ModelConfig, get_peft_config 13 | 14 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 15 | from src.configs.classificaion_config import ClassificationConfig 16 | from src.configs.additional.classification_args import CLFScriptArguments 17 | from src.trainers.classification_trainer import ClassificationTrainer 18 | from src.utils.datasets import load_datasets 19 | from src.utils.logger import setup_logging 20 | from src.utils.model_preparation import setup_model_and_tokenizer, unfreeze_modules_by_patterns 21 | from src.utils.yaml_args_parser import H4ArgumentParser 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 27 | 28 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 29 | 30 | 31 | def get_label_list(raw_dataset, split="train") -> List[str]: 32 | """Get the list of labels from a multi-label dataset""" 33 | 34 | if isinstance(raw_dataset[split]["label"][0], list): 35 | label_list = [label for sample in raw_dataset[split]["label"] for label in sample] 36 | label_list = list(set(label_list)) 37 | else: 38 | label_list = raw_dataset[split].unique("label") 39 | # we will treat the label list as a list of string instead of int, consistent with model.config.label2id 40 | label_list = [str(label) for label in label_list] 41 | return label_list 42 | 43 | 44 | def main(): 45 | parser = H4ArgumentParser((CLFScriptArguments, ClassificationConfig, ModelConfig)) 46 | args, classification_config, model_config = parser.parse() 47 | 48 | setup_logging(logger, classification_config) 49 | set_seed(classification_config.seed) # in case of new tokens added without initialize... 50 | 51 | os.environ["WANDB_PROJECT"] = args.project_name 52 | os.environ['CLEARML_PROJECT'] = args.project_name 53 | 54 | os.environ['WANDB_NAME'] = classification_config.run_name.split("/")[-1] 55 | os.environ['CLEARML_TASK'] = classification_config.run_name.split("/")[-1] 56 | 57 | ################ 58 | # Model & Tokenizer 59 | ################ 60 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 61 | model = AutoModelForSequenceClassification.from_pretrained( 62 | model_config.model_name_or_path, 63 | torch_dtype=torch.bfloat16 if classification_config.bf16 else torch.float16, 64 | attn_implementation=model_config.attn_implementation, 65 | num_labels=classification_config.num_labels 66 | ) 67 | 68 | setup_model_and_tokenizer(args, model, tokenizer, classification_config.max_length) 69 | 70 | if model_config.use_peft: 71 | for n, p in model.named_parameters(): 72 | p.requires_grad = False 73 | if model_config.lora_task_type != "SEQ_CLS": 74 | warnings.warn( 75 | "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" 76 | " Make sure to pass --lora_task_type SEQ_CLS when using this script." 77 | ) 78 | if args.unfreeze_layers_patterns: 79 | warnings.warn( 80 | "You can't use non-empty unfreeze_layers_patterns and peft together at this time, only peft config will be used" 81 | ) 82 | peft_config = get_peft_config(model_config) 83 | else: 84 | if args.unfreeze_layers_patterns: 85 | unfreeze_modules_by_patterns(model, args.unfreeze_layers_patterns) 86 | peft_config = None 87 | 88 | if PartialState().is_main_process: 89 | logger.info(f'Tokenizer: {tokenizer}') 90 | logger.info(f'Model config: {model.config}') 91 | logger.info(f'Model: {model}') 92 | 93 | ################ 94 | # Dataset 95 | ################ 96 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 97 | 98 | is_multi_label = False 99 | if ds["train"].features["label"].dtype == "list": # multi-label classification 100 | is_multi_label = True 101 | logger.info("Label type is list, doing multi-label classification") 102 | 103 | elif is_multi_label: 104 | model.config.problem_type = "multi_label_classification" 105 | logger.info("setting problem type to multi label classification") 106 | else: 107 | model.config.problem_type = "single_label_classification" 108 | logger.info("setting problem type to single label classification") 109 | 110 | # Trying to find the number of labels in a multi-label classification task 111 | # We have to deal with common cases that labels appear in the training set but not in the validation/test set. 112 | # So we build the label list from the union of labels in train/val/test. 113 | label_list = get_label_list(ds, split="train") 114 | for split in ["validation", "test"]: 115 | if split in ds: 116 | val_or_test_labels = get_label_list(ds, split=split) 117 | diff = set(val_or_test_labels).difference(set(label_list)) 118 | if len(diff) > 0: 119 | # add the labels that appear in val/test but not in train, throw a warning 120 | logger.warning( 121 | f"Labels {diff} in {split} set but not in training set, adding them to the label list" 122 | ) 123 | label_list += list(diff) 124 | # if label is -1, we throw a warning and remove it from the label list 125 | for label in label_list: 126 | if label == -1: 127 | logger.warning("Label -1 found in label list, removing it.") 128 | label_list.remove(label) 129 | 130 | label_list.sort() 131 | num_labels = len(label_list) 132 | if num_labels <= 1: 133 | raise ValueError("You need more than one label to do classification.") 134 | 135 | label_to_id = {v: i for i, v in enumerate(label_list)} 136 | # update config with label infos 137 | if model.config.label2id != label_to_id: 138 | logger.warning( 139 | "The label2id key in the model config.json is not equal to the label2id key of this " 140 | "run. You can ignore this if you are doing finetuning." 141 | ) 142 | model.config.label2id = label_to_id 143 | model.config.id2label = {id: label for label, id in label_to_id.items()} 144 | 145 | logger.info(f'Label2id Mapping: {str(label_to_id)}') 146 | 147 | def multi_labels_to_ids(labels: List[str]) -> List[float]: 148 | ids = [0.0] * len(label_to_id) # BCELoss requires float as target type 149 | for label in labels: 150 | ids[label_to_id[label]] = 1.0 151 | return ids 152 | 153 | def preprocess_function(example): 154 | text = tokenizer.apply_chat_template(example["prompt"], tokenize=False, add_generation_prompt=False) 155 | tokenized = tokenizer(text=text, truncation=True, max_length=classification_config.max_length) 156 | 157 | if label_to_id is not None and "label" in example: 158 | if is_multi_label: 159 | tokenized["label"] = multi_labels_to_ids(example["label"]) 160 | else: 161 | tokenized["label"] = label_to_id[str(example["label"])] if example["label"] != -1 else -1 162 | 163 | return tokenized 164 | 165 | # Preprocess the dataset and filter out examples that are longer than args.max_length 166 | with PartialState().local_main_process_first(): 167 | ds = ds.map( 168 | preprocess_function, 169 | batched=False, 170 | num_proc=DATASET_PROCESSING_THREADS, 171 | keep_in_memory=True, 172 | load_from_cache_file=True 173 | ) 174 | train_dataset = ds["train"] 175 | eval_dataset = ds["test"] 176 | 177 | if PartialState().is_main_process: 178 | logger.info('Example from train dataset:') 179 | logger.info(train_dataset[0]) 180 | logger.info('Example from test dataset:') 181 | logger.info(eval_dataset[0]) 182 | 183 | PartialState().wait_for_everyone() 184 | 185 | ################ 186 | # Training 187 | ################ 188 | trainer = ClassificationTrainer( 189 | model=model, 190 | processing_class=tokenizer, 191 | args=classification_config, 192 | train_dataset=train_dataset, 193 | eval_dataset=eval_dataset, 194 | peft_config=peft_config, 195 | callbacks=[ParameterStatsCallback], 196 | is_binary=len(label_to_id) == 2 197 | ) 198 | 199 | # train and save the model 200 | trainer.train() 201 | 202 | if trainer.is_fsdp_enabled: 203 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 204 | 205 | trainer.save_model(classification_config.output_dir) 206 | 207 | 208 | if __name__ == '__main__': 209 | main() 210 | -------------------------------------------------------------------------------- /scripts/model_training/cpo.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | from functools import partial 7 | 8 | import torch 9 | from accelerate import PartialState 10 | from accelerate.logging import get_logger 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 12 | from transformers.integrations import is_deepspeed_zero3_enabled 13 | from trl import ModelConfig, get_peft_config, CPOConfig, CPOTrainer 14 | 15 | from src.callbacks.generate_examples import GenerateExamplesCallback 16 | from src.configs.additional.cpo_args import CPOScriptArguments 17 | from src.utils.datasets import load_datasets, prepare_generative_row 18 | from src.utils.logger import setup_logging 19 | from src.utils.model_preparation import setup_model_and_tokenizer 20 | from src.utils.yaml_args_parser import H4ArgumentParser 21 | 22 | 23 | logger = get_logger(__name__) 24 | 25 | LOGGING_TASK_NAME = str(uuid.uuid4()) 26 | 27 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 28 | os.environ['WANDB_NAME'] = LOGGING_TASK_NAME 29 | os.environ['CLEARML_TASK'] = LOGGING_TASK_NAME 30 | 31 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 32 | 33 | 34 | def main(): 35 | parser = H4ArgumentParser((CPOScriptArguments, CPOConfig, ModelConfig)) 36 | args, cpo_config, model_config = parser.parse() 37 | 38 | setup_logging(logger, cpo_config) 39 | set_seed(cpo_config.seed) # in case of new tokens added without initialize... 40 | 41 | os.environ["WANDB_PROJECT"] = args.project_name 42 | os.environ['CLEARML_PROJECT'] = args.project_name 43 | 44 | ################ 45 | # Model & Tokenizer 46 | ################ 47 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 48 | model = AutoModelForCausalLM.from_pretrained( 49 | model_config.model_name_or_path, 50 | torch_dtype=torch.bfloat16 if cpo_config.bf16 else torch.float16, 51 | # max_position_embeddings=sft_config.max_seq_length, 52 | attn_implementation=model_config.attn_implementation 53 | ) 54 | 55 | for n, p in model.named_parameters(): 56 | p.requires_grad = not model_config.use_peft 57 | 58 | peft_config = get_peft_config(model_config) 59 | 60 | if model_config.lora_task_type != "CAUSAL_LM": 61 | warnings.warn( 62 | "You are using a `task_type` that is different than `CAUSAL_LM` for PEFT. This will lead to silent bugs" 63 | " Make sure to pass --lora_task_type CAUSAL_LM when using this script." 64 | ) 65 | 66 | setup_model_and_tokenizer(args, model, tokenizer) 67 | 68 | if PartialState().is_main_process: 69 | logger.info(f'Tokenizer: {tokenizer}') 70 | logger.info(f'Model config: {model.config}') 71 | 72 | ################ 73 | # Dataset 74 | ################ 75 | ds = load_datasets(args.dataset, args.test_size) 76 | generate_dataset = ds['test'] 77 | 78 | def apply_chat_templates(row): 79 | row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False) 80 | row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) 81 | row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) 82 | return row 83 | 84 | with PartialState().main_process_first(): 85 | ds = ds.map( 86 | apply_chat_templates, 87 | num_proc=DATASET_PROCESSING_THREADS, 88 | load_from_cache_file=True, 89 | ) 90 | generate_dataset = generate_dataset.map( 91 | partial(prepare_generative_row, tokenizer=tokenizer, max_length=cpo_config.max_prompt_length), 92 | num_proc=DATASET_PROCESSING_THREADS, 93 | load_from_cache_file=True 94 | ) 95 | 96 | train_dataset = ds["train"] 97 | eval_dataset = ds["test"] 98 | 99 | if PartialState().is_main_process: 100 | logger.info('Example from train dataset:') 101 | logger.info(train_dataset[0]) 102 | logger.info('Example from test dataset:') 103 | logger.info(eval_dataset[0]) 104 | logger.info('Example from gen dataset:') 105 | logger.info(generate_dataset[0]) 106 | 107 | generate_callback = GenerateExamplesCallback( 108 | preprocessed_dataset=generate_dataset, 109 | tokenizer=tokenizer, 110 | num_examples=args.num_gen_examples, 111 | is_deepspeed_zero3=is_deepspeed_zero3_enabled(), 112 | logger_backend=cpo_config.report_to[0] 113 | ) 114 | 115 | PartialState().wait_for_everyone() 116 | 117 | ################ 118 | # Training 119 | ################ 120 | trainer = CPOTrainer( 121 | model, 122 | args=cpo_config, 123 | train_dataset=train_dataset, 124 | eval_dataset=eval_dataset, 125 | tokenizer=tokenizer, 126 | peft_config=peft_config, 127 | callbacks=[generate_callback] if args.generate_eval_examples else [] 128 | ) 129 | 130 | # train and save the model 131 | trainer.train() 132 | 133 | if trainer.is_fsdp_enabled: 134 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 135 | 136 | trainer.save_model(cpo_config.output_dir) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() -------------------------------------------------------------------------------- /scripts/model_training/dpo.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | from functools import partial 7 | 8 | import torch 9 | from accelerate import PartialState 10 | from accelerate.logging import get_logger 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 12 | from transformers.integrations import is_deepspeed_zero3_enabled 13 | from trl import ModelConfig, get_peft_config, DPOConfig, DPOTrainer 14 | 15 | from src.callbacks.generate_examples import GenerateExamplesCallback 16 | from src.configs.additional.dpo_args import DPOScriptArguments 17 | from src.utils.datasets import load_datasets, prepare_generative_row 18 | from src.utils.logger import setup_logging 19 | from src.utils.model_preparation import setup_model_and_tokenizer 20 | from src.utils.yaml_args_parser import H4ArgumentParser 21 | 22 | 23 | logger = get_logger(__name__) 24 | 25 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 26 | 27 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 28 | 29 | 30 | def main(): 31 | parser = H4ArgumentParser((DPOScriptArguments, DPOConfig, ModelConfig)) 32 | args, dpo_config, model_config = parser.parse() 33 | 34 | setup_logging(logger, dpo_config) 35 | set_seed(dpo_config.seed) # in case of new tokens added without initialize... 36 | 37 | os.environ["WANDB_PROJECT"] = args.project_name 38 | os.environ['CLEARML_PROJECT'] = args.project_name 39 | 40 | os.environ['WANDB_NAME'] = dpo_config.run_name.split("/")[-1] 41 | os.environ['CLEARML_TASK'] = dpo_config.run_name.split("/")[-1] 42 | 43 | ################ 44 | # Model & Tokenizer 45 | ################ 46 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 47 | model = AutoModelForCausalLM.from_pretrained( 48 | model_config.model_name_or_path, 49 | torch_dtype=torch.bfloat16 if dpo_config.bf16 else torch.float16, 50 | attn_implementation=model_config.attn_implementation 51 | ) 52 | 53 | for n, p in model.named_parameters(): 54 | p.requires_grad = not model_config.use_peft 55 | 56 | peft_config = get_peft_config(model_config) 57 | if peft_config is None: 58 | model_ref = AutoModelForCausalLM.from_pretrained( 59 | model_config.model_name_or_path, 60 | torch_dtype=torch.bfloat16 if dpo_config.bf16 else torch.float16, 61 | attn_implementation=model_config.attn_implementation 62 | ) 63 | else: 64 | model_ref = None 65 | 66 | if model_config.lora_task_type != "CAUSAL_LM": 67 | warnings.warn( 68 | "You are using a `task_type` that is different than `CAUSAL_LM` for PEFT. This will lead to silent bugs" 69 | " Make sure to pass --lora_task_type CAUSAL_LM when using this script." 70 | ) 71 | 72 | setup_model_and_tokenizer(args, model, tokenizer) 73 | if model_ref: 74 | setup_model_and_tokenizer(args, model_ref, tokenizer) 75 | 76 | if PartialState().is_main_process: 77 | logger.info(f'Tokenizer: {tokenizer}') 78 | logger.info(f'Model config: {model.config}') 79 | 80 | ################ 81 | # Dataset 82 | ################ 83 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 84 | generate_dataset = ds['test'] 85 | 86 | def apply_chat_templates(row): 87 | row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False) 88 | row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) 89 | row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) 90 | return row 91 | 92 | with PartialState().main_process_first(): 93 | ds = ds.map( 94 | apply_chat_templates, 95 | num_proc=DATASET_PROCESSING_THREADS, 96 | load_from_cache_file=True, 97 | ) 98 | generate_dataset = generate_dataset.map( 99 | partial(prepare_generative_row, tokenizer=tokenizer, max_length=dpo_config.max_prompt_length), 100 | num_proc=DATASET_PROCESSING_THREADS, 101 | load_from_cache_file=True 102 | ) 103 | 104 | train_dataset = ds["train"] 105 | eval_dataset = ds["test"] 106 | 107 | if PartialState().is_main_process: 108 | logger.info('Example from train dataset:') 109 | logger.info(train_dataset[0]) 110 | logger.info('Example from test dataset:') 111 | logger.info(eval_dataset[0]) 112 | logger.info('Example from gen dataset:') 113 | logger.info(generate_dataset[0]) 114 | 115 | generate_callback = GenerateExamplesCallback( 116 | preprocessed_dataset=generate_dataset, 117 | tokenizer=tokenizer, 118 | num_examples=args.num_gen_examples, 119 | is_deepspeed_zero3=is_deepspeed_zero3_enabled(), 120 | logger_backend=dpo_config.report_to[0] 121 | ) 122 | 123 | PartialState().wait_for_everyone() 124 | 125 | ################ 126 | # Training 127 | ################ 128 | trainer = DPOTrainer( 129 | model, 130 | args=dpo_config, 131 | ref_model=model_ref, 132 | train_dataset=train_dataset, 133 | eval_dataset=eval_dataset, 134 | processing_class=tokenizer, 135 | peft_config=peft_config, 136 | callbacks=[generate_callback] if args.generate_eval_examples else [] 137 | ) 138 | 139 | # train and save the model 140 | trainer.train() 141 | 142 | if trainer.is_fsdp_enabled: 143 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 144 | 145 | trainer.save_model(dpo_config.output_dir) 146 | 147 | 148 | if __name__ == '__main__': 149 | main() -------------------------------------------------------------------------------- /scripts/model_training/gpo.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | from functools import partial 7 | 8 | import torch 9 | from accelerate import PartialState 10 | from accelerate.logging import get_logger 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 12 | from transformers.integrations import is_deepspeed_zero3_enabled 13 | from trl import ModelConfig, get_peft_config 14 | 15 | from src.callbacks.generate_examples import GenerateExamplesCallback 16 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 17 | from src.configs.additional.gpo_args import GPOScriptArguments 18 | from src.configs.gpo_config import GroupedPOConfig 19 | from src.trainers.gpo_trainer import GroupedPOTrainer 20 | from src.utils.datasets import load_datasets, prepare_generative_row 21 | from src.utils.logger import setup_logging 22 | from src.utils.model_preparation import setup_model_and_tokenizer, unfreeze_modules_by_patterns 23 | from src.utils.yaml_args_parser import H4ArgumentParser 24 | 25 | logger = get_logger(__name__) 26 | 27 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 28 | 29 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 30 | 31 | 32 | def main(): 33 | parser = H4ArgumentParser((GPOScriptArguments, GroupedPOConfig, ModelConfig)) 34 | args, gpo_config, model_config = parser.parse() 35 | 36 | setup_logging(logger, gpo_config) 37 | set_seed(gpo_config.seed) # in case of new tokens added without initialize... 38 | 39 | os.environ["WANDB_PROJECT"] = args.project_name 40 | os.environ['CLEARML_PROJECT'] = args.project_name 41 | os.environ['WANDB_NAME'] = gpo_config.run_name.split("/")[-1] 42 | os.environ['CLEARML_TASK'] = gpo_config.run_name.split("/")[-1] 43 | 44 | ################ 45 | # Model & Tokenizer 46 | ################ 47 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 48 | model = AutoModelForCausalLM.from_pretrained( 49 | model_config.model_name_or_path, 50 | torch_dtype=torch.bfloat16 if gpo_config.bf16 else torch.float16, 51 | attn_implementation=model_config.attn_implementation 52 | ) 53 | 54 | setup_model_and_tokenizer(args, model, tokenizer) 55 | 56 | if model_config.use_peft: 57 | for n, p in model.named_parameters(): 58 | p.requires_grad = False 59 | if model_config.lora_task_type != "CAUSAL_LM": 60 | warnings.warn( 61 | "You are using a `task_type` that is different than `CAUSAL_LM` for PEFT. This will lead to silent bugs" 62 | " Make sure to pass --lora_task_type CAUSAL_LM when using this script." 63 | ) 64 | if args.unfreeze_layers_patterns: 65 | warnings.warn( 66 | "You can't use non-empty unfreeze_layers_patterns and peft together at this time, only peft config will be used" 67 | ) 68 | peft_config = get_peft_config(model_config) 69 | else: 70 | if args.unfreeze_layers_patterns: 71 | unfreeze_modules_by_patterns(model, args.unfreeze_layers_patterns) 72 | peft_config = None 73 | 74 | if PartialState().is_main_process: 75 | logger.info(f'Tokenizer: {tokenizer}') 76 | logger.info(f'Model config: {model.config}') 77 | logger.info(f'Model: {model}') 78 | 79 | ################ 80 | # Dataset 81 | ################ 82 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 83 | generate_dataset = ds['test'] 84 | 85 | def apply_chat_templates(row): 86 | row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False) 87 | row["completions"] = [tokenizer.apply_chat_template(chosen, tokenize=False) for chosen in row["completions"]] 88 | return row 89 | 90 | with PartialState().main_process_first(): 91 | ds = ds.map( 92 | apply_chat_templates, 93 | num_proc=DATASET_PROCESSING_THREADS, 94 | keep_in_memory=True, 95 | load_from_cache_file=True 96 | ) 97 | generate_dataset = generate_dataset.map( 98 | partial(prepare_generative_row, tokenizer=tokenizer, max_length=gpo_config.max_prompt_length), 99 | num_proc=DATASET_PROCESSING_THREADS, 100 | keep_in_memory=True, 101 | load_from_cache_file=True 102 | ) 103 | 104 | train_dataset = ds["train"] 105 | eval_dataset = ds["test"] 106 | 107 | if PartialState().is_main_process: 108 | logger.info('Example from train dataset:') 109 | logger.info(train_dataset[0]) 110 | logger.info('Example from test dataset:') 111 | logger.info(eval_dataset[0]) 112 | logger.info('Example from gen dataset:') 113 | logger.info(generate_dataset[0]) 114 | 115 | generate_callback = GenerateExamplesCallback( 116 | preprocessed_dataset=generate_dataset, 117 | tokenizer=tokenizer, 118 | num_examples=args.num_gen_examples, 119 | is_deepspeed_zero3=is_deepspeed_zero3_enabled(), 120 | logger_backend=gpo_config.report_to[0] 121 | ) 122 | 123 | PartialState().wait_for_everyone() 124 | 125 | ################ 126 | # Training 127 | ################ 128 | trainer = GroupedPOTrainer( 129 | model, 130 | args=gpo_config, 131 | train_dataset=train_dataset, 132 | eval_dataset=eval_dataset, 133 | processing_class=tokenizer, 134 | peft_config=peft_config, 135 | callbacks=[generate_callback, ParameterStatsCallback] if args.generate_eval_examples else [ParameterStatsCallback] 136 | ) 137 | 138 | # train and save the model 139 | trainer.train() 140 | 141 | if trainer.is_fsdp_enabled: 142 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 143 | 144 | trainer.save_model(gpo_config.output_dir) 145 | 146 | 147 | if __name__ == '__main__': 148 | main() -------------------------------------------------------------------------------- /scripts/model_training/orpo.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | from functools import partial 7 | 8 | import torch 9 | from accelerate import PartialState 10 | from accelerate.logging import get_logger 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 12 | from transformers.integrations import is_deepspeed_zero3_enabled 13 | from trl import ModelConfig, get_peft_config, ORPOConfig, ORPOTrainer 14 | 15 | from src.callbacks.generate_examples import GenerateExamplesCallback 16 | from src.configs.additional.orpo_args import ORPOScriptArguments 17 | from src.utils.datasets import load_datasets, prepare_generative_row 18 | from src.utils.logger import setup_logging 19 | from src.utils.model_preparation import setup_model_and_tokenizer 20 | from src.utils.yaml_args_parser import H4ArgumentParser 21 | 22 | logger = get_logger(__name__) 23 | 24 | LOGGING_TASK_NAME = str(uuid.uuid4()) 25 | 26 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 27 | os.environ['WANDB_NAME'] = LOGGING_TASK_NAME 28 | os.environ['CLEARML_TASK'] = LOGGING_TASK_NAME 29 | 30 | 31 | def main(): 32 | parser = H4ArgumentParser((ORPOScriptArguments, ORPOConfig, ModelConfig)) 33 | args, orpo_config, model_config = parser.parse() 34 | 35 | setup_logging(logger, orpo_config) 36 | set_seed(orpo_config.seed) # in case of new tokens added without initialize... 37 | 38 | os.environ["WANDB_PROJECT"] = args.project_name 39 | os.environ['CLEARML_PROJECT'] = args.project_name 40 | 41 | ################ 42 | # Model & Tokenizer 43 | ################ 44 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 45 | model = AutoModelForCausalLM.from_pretrained( 46 | model_config.model_name_or_path, 47 | torch_dtype=torch.bfloat16 if orpo_config.bf16 else torch.float16, 48 | # max_position_embeddings=sft_config.max_seq_length, 49 | attn_implementation=model_config.attn_implementation 50 | ) 51 | 52 | for n, p in model.named_parameters(): 53 | p.requires_grad = not model_config.use_peft 54 | 55 | peft_config = get_peft_config(model_config) 56 | 57 | if model_config.lora_task_type != "CAUSAL_LM": 58 | warnings.warn( 59 | "You are using a `task_type` that is different than `CAUSAL_LM` for PEFT. This will lead to silent bugs" 60 | " Make sure to pass --lora_task_type CAUSAL_LM when using this script." 61 | ) 62 | 63 | setup_model_and_tokenizer(args, model, tokenizer) 64 | 65 | if PartialState().is_main_process: 66 | logger.info(f'Tokenizer: {tokenizer}') 67 | logger.info(f'Model config: {model.config}') 68 | 69 | ################ 70 | # Dataset 71 | ################ 72 | ds = load_datasets(args.dataset, args.test_size) 73 | generate_dataset = ds['test'] 74 | 75 | def apply_chat_templates(row): 76 | row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False) 77 | row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) 78 | row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) 79 | return row 80 | 81 | with PartialState().main_process_first(): 82 | ds = ds.map( 83 | apply_chat_templates, 84 | num_proc=multiprocessing.cpu_count(), 85 | load_from_cache_file=True, 86 | ) 87 | generate_dataset = generate_dataset.map( 88 | partial(prepare_generative_row, tokenizer=tokenizer, max_length=orpo_config.max_prompt_length), 89 | num_proc=multiprocessing.cpu_count(), 90 | load_from_cache_file=True 91 | ) 92 | 93 | train_dataset = ds["train"] 94 | eval_dataset = ds["test"] 95 | 96 | if PartialState().is_main_process: 97 | logger.info('Example from train dataset:') 98 | logger.info(train_dataset[0]) 99 | logger.info('Example from test dataset:') 100 | logger.info(eval_dataset[0]) 101 | logger.info('Example from gen dataset:') 102 | logger.info(generate_dataset[0]) 103 | 104 | generate_callback = GenerateExamplesCallback( 105 | preprocessed_dataset=generate_dataset, 106 | tokenizer=tokenizer, 107 | num_examples=args.num_gen_examples, 108 | is_deepspeed_zero3=is_deepspeed_zero3_enabled(), 109 | logger_backend=orpo_config.report_to[0] 110 | ) 111 | 112 | PartialState().wait_for_everyone() 113 | 114 | ################ 115 | # Training 116 | ################ 117 | trainer = ORPOTrainer( 118 | model, 119 | args=orpo_config, 120 | train_dataset=train_dataset, 121 | eval_dataset=eval_dataset, 122 | processing_class=tokenizer, 123 | peft_config=peft_config, 124 | callbacks=[generate_callback] if args.generate_eval_examples else [] 125 | ) 126 | 127 | # train and save the model 128 | trainer.train() 129 | 130 | if trainer.is_fsdp_enabled: 131 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 132 | 133 | trainer.save_model(orpo_config.output_dir) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() -------------------------------------------------------------------------------- /scripts/model_training/rewards.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import sys 4 | import random 5 | import uuid 6 | import warnings 7 | from dataclasses import dataclass 8 | 9 | import torch 10 | from accelerate import PartialState 11 | from accelerate.logging import get_logger 12 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed, HfArgumentParser 13 | from trl import RewardTrainer, RewardConfig, ModelConfig, get_peft_config 14 | 15 | from src.configs.additional.reward_args import RMScriptArguments 16 | from src.utils.logger import setup_logging 17 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 18 | from src.utils.datasets import load_datasets 19 | from src.utils.model_preparation import setup_model_and_tokenizer, unfreeze_modules_by_patterns 20 | 21 | 22 | logger = get_logger(__name__) 23 | 24 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 25 | 26 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 27 | 28 | 29 | def main(): 30 | parser = HfArgumentParser((RMScriptArguments, RewardConfig, ModelConfig)) 31 | args, reward_config, model_config = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 32 | 33 | setup_logging(logger, reward_config) 34 | set_seed(reward_config.seed) # in case of new tokens added without initialize... 35 | 36 | os.environ["WANDB_PROJECT"] = args.project_name 37 | os.environ['CLEARML_PROJECT'] = args.project_name 38 | 39 | os.environ['WANDB_NAME'] = reward_config.run_name.split("/")[-1] 40 | os.environ['CLEARML_TASK'] = reward_config.run_name.split("/")[-1] 41 | 42 | ################ 43 | # Model & Tokenizer 44 | ################ 45 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 46 | model = AutoModelForSequenceClassification.from_pretrained( 47 | model_config.model_name_or_path, 48 | torch_dtype=torch.bfloat16 if reward_config.bf16 else torch.float16, 49 | attn_implementation=model_config.attn_implementation, 50 | num_labels=1 51 | ) 52 | 53 | setup_model_and_tokenizer(args, model, tokenizer, reward_config.max_length) 54 | 55 | if model_config.use_peft: 56 | for n, p in model.named_parameters(): 57 | p.requires_grad = False 58 | if model_config.lora_task_type != "SEQ_CLS": 59 | warnings.warn( 60 | "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" 61 | " Make sure to pass --lora_task_type SEQ_CLS when using this script." 62 | ) 63 | if args.unfreeze_layers_patterns: 64 | warnings.warn( 65 | "You can't use non-empty unfreeze_layers_patterns and peft together at this time, only peft config will be used" 66 | ) 67 | peft_config = get_peft_config(model_config) 68 | else: 69 | if args.unfreeze_layers_patterns: 70 | unfreeze_modules_by_patterns(model, args.unfreeze_layers_patterns) 71 | peft_config = None 72 | 73 | if PartialState().is_main_process: 74 | logger.info(f'Tokenizer: {tokenizer}') 75 | logger.info(f'Model config: {model.config}') 76 | logger.info(f'Model: {model}') 77 | 78 | ################ 79 | # Dataset 80 | ################ 81 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 82 | 83 | def preprocess_function(examples): 84 | new_examples = { 85 | "input_ids_chosen": [], 86 | "attention_mask_chosen": [], 87 | "input_ids_rejected": [], 88 | "attention_mask_rejected": [], 89 | } 90 | for prompt, chosen, rejected in zip(examples["prompt"], examples["chosen"], examples["rejected"]): 91 | chosen = tokenizer.apply_chat_template(prompt + chosen, tokenize=False, add_generation_prompt=False) 92 | rejected = tokenizer.apply_chat_template(prompt + rejected, tokenize=False, add_generation_prompt=False) 93 | 94 | tokenized_chosen = tokenizer(text=chosen, truncation=True, max_length=reward_config.max_length) 95 | tokenized_rejected = tokenizer(text=rejected, truncation=True, max_length=reward_config.max_length) 96 | 97 | new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) 98 | new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) 99 | new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) 100 | new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) 101 | 102 | return new_examples 103 | 104 | # Preprocess the dataset and filter out examples that are longer than args.max_length 105 | with PartialState().local_main_process_first(): 106 | ds = ds.map( 107 | preprocess_function, 108 | batched=True, 109 | num_proc=DATASET_PROCESSING_THREADS, 110 | keep_in_memory=True, 111 | load_from_cache_file=True 112 | ) 113 | train_dataset = ds["train"] 114 | eval_dataset = ds["test"] 115 | 116 | if PartialState().is_main_process: 117 | logger.info('Example from train dataset:') 118 | logger.info(train_dataset[0]) 119 | logger.info('Example from test dataset:') 120 | logger.info(eval_dataset[0]) 121 | 122 | PartialState().wait_for_everyone() 123 | 124 | ################ 125 | # Training 126 | ################ 127 | trainer = RewardTrainer( 128 | model=model, 129 | processing_class=tokenizer, 130 | args=reward_config, 131 | train_dataset=train_dataset, 132 | eval_dataset=eval_dataset, 133 | peft_config=peft_config, 134 | callbacks=[ParameterStatsCallback] 135 | ) 136 | 137 | # train and save the model 138 | trainer.train() 139 | 140 | if trainer.is_fsdp_enabled: 141 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 142 | 143 | trainer.save_model(reward_config.output_dir) 144 | 145 | 146 | if __name__ == '__main__': 147 | assert len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"), "You must provide .yaml file with training config as argument" 148 | main() 149 | -------------------------------------------------------------------------------- /scripts/model_training/sft.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | 7 | import torch 8 | from accelerate import PartialState 9 | from accelerate.logging import get_logger 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 11 | from transformers.integrations import is_deepspeed_zero3_enabled 12 | from trl import SFTTrainer, SFTConfig, ModelConfig, get_peft_config 13 | 14 | from src.callbacks.generate_examples import GenerateExamplesCallback 15 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 16 | from src.collators.completions_only import DataCollatorForCompletionOnlyLM 17 | from src.configs.additional.sft_args import SFTScriptArguments 18 | from src.utils.datasets import load_datasets 19 | from src.utils.logger import setup_logging 20 | from src.utils.model_preparation import setup_model_and_tokenizer, unfreeze_modules_by_patterns 21 | from src.utils.yaml_args_parser import H4ArgumentParser 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 27 | 28 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 29 | 30 | 31 | def main(): 32 | parser = H4ArgumentParser((SFTScriptArguments, SFTConfig, ModelConfig)) 33 | args, sft_config, model_config = parser.parse() 34 | 35 | setup_logging(logger, sft_config) 36 | set_seed(sft_config.seed) # in case of new tokens added without initialize... 37 | 38 | os.environ["WANDB_PROJECT"] = args.project_name 39 | os.environ['CLEARML_PROJECT'] = args.project_name 40 | 41 | os.environ['WANDB_NAME'] = sft_config.run_name.split("/")[-1] 42 | os.environ['CLEARML_TASK'] = sft_config.run_name.split("/")[-1] 43 | 44 | ################ 45 | # Model & Tokenizer 46 | ################ 47 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 48 | model = AutoModelForCausalLM.from_pretrained( 49 | model_config.model_name_or_path, 50 | torch_dtype=torch.bfloat16 if sft_config.bf16 else torch.float16, 51 | # max_position_embeddings=sft_config.max_seq_length, 52 | attn_implementation=model_config.attn_implementation 53 | ) 54 | 55 | setup_model_and_tokenizer(args, model, tokenizer, sft_config.max_seq_length) 56 | 57 | if model_config.use_peft: 58 | for n, p in model.named_parameters(): 59 | p.requires_grad = False 60 | if model_config.lora_task_type != "CAUSAL_LM": 61 | warnings.warn( 62 | "You are using a `task_type` that is different than `CAUSAL_LM` for PEFT. This will lead to silent bugs" 63 | " Make sure to pass --lora_task_type CAUSAL_LM when using this script." 64 | ) 65 | if args.unfreeze_layers_patterns: 66 | warnings.warn( 67 | "You can't use non-empty unfreeze_layers_patterns and peft together at this time, only peft config will be used" 68 | ) 69 | peft_config = get_peft_config(model_config) 70 | else: 71 | if args.unfreeze_layers_patterns: 72 | unfreeze_modules_by_patterns(model, args.unfreeze_layers_patterns) 73 | peft_config = None 74 | 75 | if PartialState().is_main_process: 76 | logger.info(f'Tokenizer: {tokenizer}') 77 | logger.info(f'Model config: {model.config}') 78 | logger.info(f'Model: {model}') 79 | 80 | ################ 81 | # Dataset 82 | ################ 83 | def process_row(row, add_gen_prompt=False): 84 | system_message = [{'role': 'system', 'content': args.system_prompt}] if args.system_prompt else [] 85 | history = row[args.conversation_field] if not add_gen_prompt else row[args.conversation_field][:-1] 86 | if not args.model_support_system_role and history[0]["role"] == "system": 87 | if len(history) > 1 and history[1]["role"] == "user": 88 | # add sys prompt to first user message 89 | history[1]["content"] = history[0]["content"] + "\n" + history[1]["content"] 90 | history = history[1:] 91 | else: 92 | history[0]["role"] = "user" 93 | 94 | constructed_prompt = tokenizer.apply_chat_template( 95 | system_message + history, 96 | tokenize=False, 97 | add_generation_prompt=add_gen_prompt 98 | ) 99 | if tokenizer.bos_token is not None: 100 | if constructed_prompt.startswith(tokenizer.bos_token): # Remove extra bos token 101 | constructed_prompt = constructed_prompt[len(tokenizer.bos_token):] 102 | return tokenizer(constructed_prompt, truncation=True, padding=True, max_length=sft_config.max_seq_length) 103 | 104 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 105 | generate_dataset = ds['test'] 106 | 107 | signature_columns = ["input_ids", "labels", "attention_mask"] 108 | extra_columns = list(set(ds['train'].column_names) - set(signature_columns)) 109 | 110 | with PartialState().local_main_process_first(): 111 | ds = ds.map( 112 | process_row, 113 | num_proc=DATASET_PROCESSING_THREADS, 114 | keep_in_memory=True, 115 | load_from_cache_file=True, 116 | remove_columns=extra_columns 117 | ) 118 | generate_dataset = generate_dataset.map( 119 | lambda row: process_row(row, add_gen_prompt=True), 120 | num_proc=DATASET_PROCESSING_THREADS, 121 | keep_in_memory=True, 122 | load_from_cache_file=True, 123 | remove_columns=extra_columns 124 | ) 125 | 126 | train_dataset = ds["train"] 127 | eval_dataset = ds["test"] 128 | 129 | if PartialState().is_main_process: 130 | logger.info('Example from train dataset:') 131 | logger.info(train_dataset[0]) 132 | logger.info('Example from test dataset:') 133 | logger.info(eval_dataset[0]) 134 | logger.info('Example from gen dataset:') 135 | logger.info(generate_dataset[0]) 136 | 137 | collator = DataCollatorForCompletionOnlyLM( 138 | response_prompt_template=args.assistant_message_template, 139 | tokenizer=tokenizer 140 | ) if args.train_only_on_completions else None 141 | 142 | generate_callback = GenerateExamplesCallback( 143 | preprocessed_dataset=generate_dataset, 144 | tokenizer=tokenizer, 145 | num_examples=args.num_gen_examples, 146 | is_deepspeed_zero3=is_deepspeed_zero3_enabled(), 147 | logger_backend=sft_config.report_to[0] 148 | ) 149 | 150 | PartialState().wait_for_everyone() 151 | 152 | sft_config.dataset_kwargs = { 153 | "skip_prepare_dataset": True 154 | } 155 | 156 | ################ 157 | # Training 158 | ################ 159 | trainer = SFTTrainer( 160 | model, 161 | args=sft_config, 162 | train_dataset=train_dataset, 163 | eval_dataset=eval_dataset, 164 | processing_class=tokenizer, 165 | peft_config=peft_config, 166 | data_collator=collator, 167 | callbacks=[generate_callback, ParameterStatsCallback] if args.generate_eval_examples else [ParameterStatsCallback] 168 | ) 169 | 170 | # train and save the model 171 | trainer.train() 172 | 173 | if trainer.is_fsdp_enabled: 174 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 175 | 176 | trainer.save_model(sft_config.output_dir) 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /scripts/model_training/smpo.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | import warnings 6 | from functools import partial 7 | 8 | import torch 9 | from accelerate import PartialState 10 | from accelerate.logging import get_logger 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 12 | from transformers.integrations import is_deepspeed_zero3_enabled 13 | from trl import ModelConfig, get_peft_config 14 | 15 | from src.callbacks.generate_examples import GenerateExamplesCallback 16 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 17 | from src.configs.additional.smpo_args import SMPOScriptArguments 18 | from src.configs.smpo_config import SimpleMarginPOConfig 19 | from src.trainers.smpo_trainer import SimpleMarginPOTrainer 20 | from src.utils.datasets import load_datasets, prepare_generative_row 21 | from src.utils.logger import setup_logging 22 | from src.utils.model_preparation import setup_model_and_tokenizer, unfreeze_modules_by_patterns 23 | from src.utils.yaml_args_parser import H4ArgumentParser 24 | 25 | 26 | logger = get_logger(__name__) 27 | 28 | LOGGING_TASK_NAME = str(uuid.uuid4()) 29 | 30 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 31 | os.environ['WANDB_NAME'] = LOGGING_TASK_NAME 32 | os.environ['CLEARML_TASK'] = LOGGING_TASK_NAME 33 | 34 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 35 | 36 | 37 | def main(): 38 | parser = H4ArgumentParser((SMPOScriptArguments, SimpleMarginPOConfig, ModelConfig)) 39 | args, smpo_config, model_config = parser.parse() 40 | 41 | setup_logging(logger, smpo_config) 42 | set_seed(smpo_config.seed) # in case of new tokens added without initialize... 43 | 44 | os.environ["WANDB_PROJECT"] = args.project_name 45 | os.environ['CLEARML_PROJECT'] = args.project_name 46 | 47 | ################ 48 | # Model & Tokenizer 49 | ################ 50 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 51 | model = AutoModelForCausalLM.from_pretrained( 52 | model_config.model_name_or_path, 53 | torch_dtype=torch.bfloat16 if smpo_config.bf16 else torch.float16, 54 | attn_implementation=model_config.attn_implementation 55 | ) 56 | 57 | setup_model_and_tokenizer(args, model, tokenizer) 58 | 59 | if model_config.use_peft: 60 | for n, p in model.named_parameters(): 61 | p.requires_grad = False 62 | if model_config.lora_task_type != "CAUSAL_LM": 63 | warnings.warn( 64 | "You are using a `task_type` that is different than `CAUSAL_LM` for PEFT. This will lead to silent bugs" 65 | " Make sure to pass --lora_task_type CAUSAL_LM when using this script." 66 | ) 67 | if args.unfreeze_layers_patterns: 68 | warnings.warn( 69 | "You can't use non-empty unfreeze_layers_patterns and peft together at this time, only peft config will be used" 70 | ) 71 | peft_config = get_peft_config(model_config) 72 | else: 73 | if args.unfreeze_layers_patterns: 74 | unfreeze_modules_by_patterns(model, args.unfreeze_layers_patterns) 75 | peft_config = None 76 | 77 | if PartialState().is_main_process: 78 | logger.info(f'Tokenizer: {tokenizer}') 79 | logger.info(f'Model config: {model.config}') 80 | logger.info(f'Model: {model}') 81 | 82 | ################ 83 | # Dataset 84 | ################ 85 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 86 | generate_dataset = ds['test'] 87 | 88 | def apply_chat_templates(row): 89 | row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False) 90 | row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) 91 | row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) 92 | return row 93 | 94 | with PartialState().main_process_first(): 95 | ds = ds.map( 96 | apply_chat_templates, 97 | num_proc=DATASET_PROCESSING_THREADS, 98 | keep_in_memory=True, 99 | load_from_cache_file=True 100 | ) 101 | generate_dataset = generate_dataset.map( 102 | partial(prepare_generative_row, tokenizer=tokenizer, max_length=smpo_config.max_prompt_length), 103 | num_proc=DATASET_PROCESSING_THREADS, 104 | keep_in_memory=True, 105 | load_from_cache_file=True 106 | ) 107 | 108 | train_dataset = ds["train"] 109 | eval_dataset = ds["test"] 110 | 111 | if PartialState().is_main_process: 112 | logger.info('Example from train dataset:') 113 | logger.info(train_dataset[0]) 114 | logger.info('Example from test dataset:') 115 | logger.info(eval_dataset[0]) 116 | logger.info('Example from gen dataset:') 117 | logger.info(generate_dataset[0]) 118 | 119 | generate_callback = GenerateExamplesCallback( 120 | preprocessed_dataset=generate_dataset, 121 | tokenizer=tokenizer, 122 | num_examples=args.num_gen_examples, 123 | is_deepspeed_zero3=is_deepspeed_zero3_enabled(), 124 | logger_backend=smpo_config.report_to[0] 125 | ) 126 | 127 | PartialState().wait_for_everyone() 128 | 129 | ################ 130 | # Training 131 | ################ 132 | trainer = SimpleMarginPOTrainer( 133 | model, 134 | args=smpo_config, 135 | train_dataset=train_dataset, 136 | eval_dataset=eval_dataset, 137 | tokenizer=tokenizer, 138 | peft_config=peft_config, 139 | callbacks=[generate_callback, ParameterStatsCallback] if args.generate_eval_examples else [ParameterStatsCallback] 140 | ) 141 | 142 | # train and save the model 143 | trainer.train() 144 | 145 | if trainer.is_fsdp_enabled: 146 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 147 | 148 | trainer.save_model(smpo_config.output_dir) 149 | 150 | 151 | if __name__ == '__main__': 152 | main() -------------------------------------------------------------------------------- /scripts/post_training/merge_peft_adapters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from peft import AutoPeftModelForCausalLM, AutoPeftModelForSequenceClassification 5 | from transformers import AutoTokenizer 6 | import os 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Слияние LoRA адаптера с исходной моделью") 11 | parser.add_argument('--source', type=str, required=True, 12 | help="Путь к директории с адаптером и конфигурацией модели") 13 | parser.add_argument('--output', type=str, required=True, 14 | help="Выходная директория для сохранения модели с адаптером") 15 | parser.add_argument('--is_clf', action='store_true', 16 | help="Является ли модель AutoPeftModelForSequenceClassification или она AutoPeftModelForCausalLM") 17 | return parser.parse_args() 18 | 19 | 20 | def merge(source_path, output_path, is_clf): 21 | # Загружаем исходную модель и конфигурацию 22 | tokenizer = AutoTokenizer.from_pretrained(source_path) 23 | 24 | if not is_clf: 25 | adapter_model = AutoPeftModelForCausalLM.from_pretrained(source_path, torch_dtype=torch.bfloat16) 26 | else: 27 | adapter_model = AutoPeftModelForSequenceClassification.from_pretrained(source_path, torch_dtype=torch.bfloat16, num_labels=1) 28 | 29 | # Сохраняем адаптер 30 | adapter_save_path = os.path.join(output_path, 'original_adapter') 31 | os.makedirs(adapter_save_path, exist_ok=True) 32 | adapter_model.save_pretrained(adapter_save_path) 33 | 34 | # Сливем адаптер 35 | merged_model = adapter_model.merge_and_unload() 36 | 37 | # Сохраняем всю модель и токенизатор 38 | merged_model.save_pretrained(output_path) 39 | tokenizer.save_pretrained(output_path) 40 | 41 | 42 | if __name__ == "__main__": 43 | args = parse_args() 44 | merge(args.source, args.output, args.is_clf) 45 | -------------------------------------------------------------------------------- /scripts/post_training/mergekit_example_config.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | - model: Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT 3 | parameters: 4 | weight: 0.5 5 | - model: Vikhrmodels/Phikhr-14B-Instruct-R-25-12-24-SMPO-v9.1 6 | parameters: 7 | weight: 0.8 8 | - model: Vikhrmodels/Phikhr-14B-Instruct-R-25-12-24-SMPO-v7 9 | parameters: 10 | weight: 1.0 11 | 12 | merge_method: della_linear 13 | base_model: NyxKrage/Microsoft_Phi-4 14 | parameters: 15 | epsilon: 0.05 16 | lambda: 1 17 | density: 0.6 18 | normalize: true 19 | dtype: float16 20 | tokenizer_source: Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT -------------------------------------------------------------------------------- /scripts/prompts_training/reward.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import sys 5 | import uuid 6 | 7 | import torch 8 | from accelerate import PartialState 9 | from accelerate.logging import get_logger 10 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed, HfArgumentParser 11 | from trl import RewardConfig, ModelConfig 12 | 13 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 14 | from src.configs.additional.reward_args import RMScriptArguments 15 | from src.configs.prompts_optimization_comfig import PromptsOptimizationConfig 16 | from src.trainers.prompts_optimization.prompts_reward_trainer import PromptsRewardTrainer 17 | from src.utils.datasets import load_datasets 18 | from src.utils.logger import setup_logging 19 | from src.utils.model_preparation import setup_model_and_tokenizer 20 | 21 | 22 | logger = get_logger(__name__) 23 | 24 | LOGGING_TASK_NAME = str(uuid.uuid4()) 25 | 26 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 27 | os.environ['WANDB_NAME'] = LOGGING_TASK_NAME 28 | os.environ['CLEARML_TASK'] = LOGGING_TASK_NAME 29 | 30 | DATASET_PROCESSING_THREADS = min(multiprocessing.cpu_count() // 2, 16) 31 | 32 | 33 | def main(): 34 | parser = HfArgumentParser((RMScriptArguments, RewardConfig, ModelConfig, PromptsOptimizationConfig)) 35 | args, reward_config, model_config, prompts_config = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 36 | 37 | setup_logging(logger, reward_config) 38 | set_seed(reward_config.seed) # in case of new tokens added without initialize... 39 | 40 | os.environ["WANDB_PROJECT"] = args.project_name 41 | os.environ['CLEARML_PROJECT'] = args.project_name 42 | 43 | ################ 44 | # Model & Tokenizer 45 | ################ 46 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 47 | model = AutoModelForSequenceClassification.from_pretrained( 48 | model_config.model_name_or_path, 49 | torch_dtype=torch.bfloat16 if reward_config.bf16 else torch.float16, 50 | attn_implementation=model_config.attn_implementation, 51 | num_labels=1 52 | ) 53 | 54 | setup_model_and_tokenizer(args, model, tokenizer, reward_config.max_length) 55 | 56 | # # Замораживаем основную модель 57 | # for param in model.parameters(): 58 | # param.requires_grad = False 59 | 60 | if PartialState().is_main_process: 61 | logger.info(f'Tokenizer: {tokenizer}') 62 | logger.info(f'Model config: {model.config}') 63 | logger.info(f'Model: {model}') 64 | 65 | ################ 66 | # Dataset 67 | ################ 68 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 69 | 70 | def preprocess_function(examples): 71 | new_examples = { 72 | "input_ids_chosen": [], 73 | "attention_mask_chosen": [], 74 | "input_ids_rejected": [], 75 | "attention_mask_rejected": [], 76 | } 77 | for prompt, chosen, rejected in zip(examples["prompt"], examples["chosen"], examples["rejected"]): 78 | prompt = [x for x in prompt if x['role'] != prompts_config.inserted_chat_role] # needed only for prompts tuning 79 | chosen = tokenizer.apply_chat_template(prompt + chosen, tokenize=False, add_generation_prompt=False) 80 | rejected = tokenizer.apply_chat_template(prompt + rejected, tokenize=False, add_generation_prompt=False) 81 | 82 | tokenized_chosen = tokenizer(text=chosen, truncation=True, max_length=reward_config.max_length) 83 | tokenized_rejected = tokenizer(text=rejected, truncation=True, max_length=reward_config.max_length) 84 | 85 | new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) 86 | new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) 87 | new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) 88 | new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) 89 | 90 | return new_examples 91 | 92 | # Preprocess the dataset and filter out examples that are longer than args.max_length 93 | with PartialState().local_main_process_first(): 94 | ds = ds.map( 95 | preprocess_function, 96 | batched=True, 97 | num_proc=DATASET_PROCESSING_THREADS, 98 | keep_in_memory=True, 99 | load_from_cache_file=False 100 | ) 101 | train_dataset = ds["train"] 102 | eval_dataset = ds["test"] 103 | 104 | if PartialState().is_main_process: 105 | logger.info('Example from train dataset:') 106 | logger.info(train_dataset[0]) 107 | logger.info('Example from test dataset:') 108 | logger.info(eval_dataset[0]) 109 | 110 | PartialState().wait_for_everyone() 111 | 112 | ################ 113 | # Training 114 | ################ 115 | trainer = PromptsRewardTrainer( 116 | model=model, 117 | tokenizer=tokenizer, 118 | args=reward_config, 119 | prompt_args=prompts_config, 120 | train_dataset=train_dataset, 121 | eval_dataset=eval_dataset, 122 | peft_config=None, 123 | callbacks=[ParameterStatsCallback] 124 | ) 125 | 126 | # train and save the model 127 | trainer.train() 128 | 129 | if trainer.is_fsdp_enabled: 130 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 131 | 132 | trainer.save_model(reward_config.output_dir) 133 | 134 | 135 | if __name__ == '__main__': 136 | assert len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"), "You must provide .yaml file with training config as argument" 137 | main() 138 | -------------------------------------------------------------------------------- /scripts/prompts_training/sft.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import random 4 | import uuid 5 | 6 | import torch 7 | from accelerate import PartialState 8 | from accelerate.logging import get_logger 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed 10 | from trl import SFTConfig, ModelConfig 11 | 12 | from src.callbacks.training_parameters_callback import ParameterStatsCallback 13 | from src.collators.completions_only import DataCollatorForCompletionOnlyLM 14 | from src.configs.additional.sft_args import SFTScriptArguments 15 | from src.configs.prompts_optimization_comfig import PromptsOptimizationConfig 16 | from src.trainers.prompts_optimization.prompts_sft_trainer import PromptsSFTTrainer 17 | from src.utils.datasets import load_datasets 18 | from src.utils.logger import setup_logging 19 | from src.utils.model_preparation import setup_model_and_tokenizer 20 | from src.utils.yaml_args_parser import H4ArgumentParser 21 | 22 | 23 | logger = get_logger(__name__) 24 | 25 | LOGGING_TASK_NAME = str(uuid.uuid4()) 26 | 27 | os.environ['WANDB_RUN_ID'] = str(random.randint(100000, 999999)) 28 | os.environ['WANDB_NAME'] = LOGGING_TASK_NAME 29 | os.environ['CLEARML_TASK'] = LOGGING_TASK_NAME 30 | 31 | DATASET_PROCESSING_THREADS = multiprocessing.cpu_count() // 2 32 | 33 | 34 | def main(): 35 | parser = H4ArgumentParser((SFTScriptArguments, SFTConfig, ModelConfig, PromptsOptimizationConfig)) 36 | args, sft_config, model_config, prompts_config = parser.parse() 37 | 38 | setup_logging(logger, sft_config) 39 | set_seed(sft_config.seed) # in case of new tokens added without initialize... 40 | 41 | os.environ["WANDB_PROJECT"] = args.project_name 42 | os.environ['CLEARML_PROJECT'] = args.project_name 43 | 44 | ################ 45 | # Model & Tokenizer 46 | ################ 47 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) 48 | model = AutoModelForCausalLM.from_pretrained( 49 | model_config.model_name_or_path, 50 | torch_dtype=torch.bfloat16 if sft_config.bf16 else torch.float16, 51 | # max_position_embeddings=sft_config.max_seq_length, 52 | attn_implementation=model_config.attn_implementation 53 | ) 54 | if sft_config.use_liger: 55 | from liger_kernel.transformers import apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_qwen2 56 | apply_liger_kernel_to_llama( 57 | rope=False, 58 | swiglu=True, 59 | cross_entropy=False, 60 | fused_linear_cross_entropy=True, 61 | rms_norm=True 62 | ) 63 | apply_liger_kernel_to_mistral( 64 | rope=False, 65 | swiglu=True, 66 | cross_entropy=False, 67 | fused_linear_cross_entropy=True, 68 | rms_norm=True 69 | ) 70 | apply_liger_kernel_to_qwen2( 71 | rope=False, 72 | swiglu=True, 73 | cross_entropy=False, 74 | fused_linear_cross_entropy=True, 75 | rms_norm=True 76 | ) 77 | 78 | setup_model_and_tokenizer(args, model, tokenizer, sft_config.max_seq_length) 79 | 80 | if PartialState().is_main_process: 81 | logger.info(f'Tokenizer: {tokenizer}') 82 | logger.info(f'Model config: {model.config}') 83 | logger.info(f'Model: {model}') 84 | 85 | ################ 86 | # Dataset 87 | ################ 88 | def process_row(row, add_gen_prompt=False): 89 | system_message = [{'role': 'system', 'content': args.system_prompt}] if args.system_prompt else [] 90 | history = row[args.conversation_field] if not add_gen_prompt else row[args.conversation_field][:-1] 91 | history = [x for x in history if x['role'] != prompts_config.inserted_chat_role] # needed only for prompts tuning 92 | if not args.model_support_system_role and history[0]["role"] == "system": 93 | if len(history) > 1 and history[1]["role"] == "user": 94 | # add sys prompt to first user message 95 | history[1]["content"] = history[0]["content"] + "\n" + history[1]["content"] 96 | history = history[1:] 97 | else: 98 | history[0]["role"] = "user" 99 | 100 | constructed_prompt = tokenizer.apply_chat_template( 101 | system_message + history, 102 | tokenize=False, 103 | add_generation_prompt=add_gen_prompt 104 | ) 105 | if tokenizer.bos_token is not None: 106 | if constructed_prompt.startswith(tokenizer.bos_token): # Remove extra bos token 107 | constructed_prompt = constructed_prompt[len(tokenizer.bos_token):] 108 | return tokenizer(constructed_prompt, truncation=True, padding=True, max_length=sft_config.max_seq_length) 109 | 110 | ds = load_datasets(args.dataset, args.test_size, args.dataset_ratio) 111 | 112 | signature_columns = ["input_ids", "labels", "attention_mask"] 113 | extra_columns = list(set(ds['train'].column_names) - set(signature_columns)) 114 | 115 | with PartialState().local_main_process_first(): 116 | ds = ds.map( 117 | process_row, 118 | num_proc=DATASET_PROCESSING_THREADS, 119 | keep_in_memory=True, 120 | load_from_cache_file=True, 121 | remove_columns=extra_columns 122 | ) 123 | 124 | train_dataset = ds["train"] 125 | eval_dataset = ds["test"] 126 | 127 | if PartialState().is_main_process: 128 | logger.info('Example from train dataset:') 129 | logger.info(train_dataset[0]) 130 | logger.info('Example from test dataset:') 131 | logger.info(eval_dataset[0]) 132 | 133 | collator = DataCollatorForCompletionOnlyLM( 134 | response_prompt_template=args.assistant_message_template, 135 | tokenizer=tokenizer 136 | ) if args.train_only_on_completions else None 137 | 138 | PartialState().wait_for_everyone() 139 | 140 | sft_config.dataset_kwargs = { 141 | "skip_prepare_dataset": True 142 | } 143 | 144 | ################ 145 | # Training 146 | ################ 147 | trainer = PromptsSFTTrainer( 148 | model, 149 | args=sft_config, 150 | prompt_args=prompts_config, 151 | train_dataset=train_dataset, 152 | eval_dataset=eval_dataset, 153 | tokenizer=tokenizer, 154 | peft_config=None, 155 | data_collator=collator, 156 | callbacks=[ParameterStatsCallback] 157 | ) 158 | 159 | # train and save the model 160 | trainer.train() 161 | 162 | if trainer.is_fsdp_enabled: 163 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 164 | 165 | trainer.save_model(sft_config.output_dir) 166 | 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import datasets 5 | import transformers 6 | from transformers import TrainingArguments 7 | 8 | logging.basicConfig( 9 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 10 | datefmt="%m/%d/%Y %H:%M:%S", 11 | handlers=[logging.StreamHandler(sys.stdout)], 12 | ) 13 | 14 | datasets.utils.logging.set_verbosity(transformers.logging.INFO) 15 | transformers.logging.set_verbosity_info() 16 | transformers.logging.enable_default_handler() 17 | transformers.logging.enable_explicit_format() 18 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/callbacks/__init__.py -------------------------------------------------------------------------------- /src/callbacks/attr_scheduling.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Literal 3 | 4 | from transformers import TrainerCallback 5 | 6 | 7 | class VariableSchedulerCallback(TrainerCallback): 8 | """General-purpose variable scheduler callback with multiple schedule types.""" 9 | 10 | def __init__( 11 | self, 12 | attribute_name: str, 13 | initial_value: float, 14 | final_value: float, 15 | schedule_type: str = "cosine", 16 | warmup_steps: int = 0, 17 | cycle: bool = False, 18 | cycle_scale: float = 1.0, 19 | target: Literal["model", "trainer"] = "model", 20 | ): 21 | """ 22 | Args: 23 | attribute_name: Name of the attribute to update in the model 24 | initial_value: Starting value of the variable 25 | final_value: Target end value of the variable 26 | schedule_type: Type of schedule ('cosine', 'linear', 'exponential') 27 | warmup_steps: Number of steps to maintain initial value before starting schedule 28 | cycle: Whether to cycle the schedule (for cosine only) 29 | cycle_scale: Scale factor for cycling (number of cycles) 30 | """ 31 | self.attribute_name = attribute_name 32 | self.initial_value = initial_value 33 | self.final_value = final_value 34 | self.schedule_type = schedule_type 35 | self.warmup_steps = warmup_steps 36 | self.cycle = cycle 37 | self.cycle_scale = cycle_scale 38 | self.total_steps = None 39 | self.target = target 40 | 41 | if self.target not in ["model", "trainer"]: 42 | raise ValueError( 43 | f"Invalid target '{target}'. Must be 'model' or 'trainer'." 44 | ) 45 | 46 | def on_train_begin(self, args, state, control, **kwargs): 47 | self.total_steps = state.max_steps - self.warmup_steps 48 | if self.total_steps <= 0: 49 | raise ValueError("Total training steps must be greater than warmup steps") 50 | 51 | def on_step_begin(self, args, state, control, **kwargs): 52 | """Update variable at the beginning of each step""" 53 | current_step = state.global_step 54 | 55 | # Handle warmup period 56 | if current_step < self.warmup_steps: 57 | current_value = self.initial_value 58 | else: 59 | progress = (current_step - self.warmup_steps) / self.total_steps 60 | progress = min(progress, 1.0) # Clamp progress to 1.0 61 | 62 | if self.cycle and self.schedule_type == "cosine": 63 | progress = progress * self.cycle_scale % 1.0 64 | 65 | # Calculate current value based on schedule type 66 | if self.schedule_type == "cosine": 67 | current_value = self.final_value + 0.5 * ( 68 | self.initial_value - self.final_value 69 | ) * (1 + math.cos(math.pi * progress)) 70 | elif self.schedule_type == "linear": 71 | current_value = ( 72 | self.initial_value 73 | + (self.final_value - self.initial_value) * progress 74 | ) 75 | elif self.schedule_type == "exponential": 76 | ratio = progress 77 | current_value = ( 78 | self.initial_value 79 | * (self.final_value / self.initial_value) ** ratio 80 | ) 81 | else: 82 | raise ValueError(f"Unknown schedule type: {self.schedule_type}") 83 | 84 | # Get target object 85 | target_obj = kwargs.get(self.target) 86 | if not target_obj: 87 | raise ValueError(f"Could not find {self.target} in callback arguments") 88 | 89 | # Handle distributed/DataParallel models 90 | target_obj = target_obj.module if hasattr(target_obj, "module") else target_obj 91 | 92 | if hasattr(target_obj, self.attribute_name): 93 | setattr(target_obj, self.attribute_name, current_value) 94 | else: 95 | raise AttributeError( 96 | f"{type(target_obj)} does not have attribute {self.attribute_name}" 97 | ) 98 | 99 | def get_current_value(self, model): 100 | """Get current value of the scheduled variable""" 101 | model_obj = model.module if hasattr(model, "module") else model 102 | return getattr(model_obj, self.attribute_name) 103 | -------------------------------------------------------------------------------- /src/callbacks/generate_examples.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import shutil 5 | import tempfile 6 | 7 | import pandas as pd 8 | import torch 9 | import wandb 10 | from accelerate import PartialState 11 | from datasets import Dataset 12 | from tabulate import tabulate 13 | from tqdm import tqdm 14 | from transformers import ( 15 | TrainerCallback, 16 | TrainerState, 17 | TrainerControl, 18 | PreTrainedTokenizer, 19 | ) 20 | 21 | try: 22 | from clearml import Task 23 | except ImportError: 24 | pass 25 | 26 | 27 | def pretty_print_dataframe(df, max_prompt_length=200): 28 | formatted_df = df.copy() 29 | formatted_df["Prompt"] = formatted_df["Prompt"].apply( 30 | lambda x: ("..." + x[-max_prompt_length:]) if len(x) > max_prompt_length else x 31 | ) 32 | print(tabulate(formatted_df, headers="keys", tablefmt="simple", showindex=False)) 33 | 34 | 35 | def save_dataframe(df, process_index, temp_dir_path): 36 | if not os.path.exists(temp_dir_path): 37 | os.makedirs(temp_dir_path) 38 | filename = os.path.join(temp_dir_path, f"df_{process_index}.csv") 39 | df.to_csv(filename, index=False) 40 | 41 | 42 | class GenerateExamplesCallback(TrainerCallback): 43 | def __init__( 44 | self, 45 | preprocessed_dataset: Dataset, 46 | tokenizer: PreTrainedTokenizer, 47 | num_examples=5, 48 | max_new_tokens=256, 49 | is_deepspeed_zero3=False, 50 | logger_backend="wandb", 51 | ): 52 | """ 53 | :param preprocessed_dataset: Preprocessed datasets.Dataset after applying tokenizer 54 | :param tokenizer: Tokenizer for decoding 55 | :param num_examples: Number of examples to generate for each evaluation 56 | :param max_new_tokens: Max new tokens length of generated examples 57 | :param logger_backend: 'clearml' or 'wandb' for choosing the logging tool 58 | """ 59 | self.dataset = preprocessed_dataset 60 | self.tokenizer = tokenizer 61 | self.num_examples = min(num_examples, len(self.dataset)) 62 | self.max_new_tokens = max_new_tokens 63 | self.logger_backend = logger_backend 64 | self.is_deepspeed_zero3 = is_deepspeed_zero3 65 | 66 | temp_dir_base = tempfile.gettempdir() 67 | self.temp_dir_path = os.path.join( 68 | temp_dir_base, "callback_generate_examples_dir" 69 | ) 70 | 71 | sample_indices = random.choices(range(len(self.dataset)), k=self.num_examples) 72 | self.samples = self.dataset.select(sample_indices) 73 | 74 | def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs): 75 | model = kwargs["model"] 76 | model.eval() 77 | 78 | records = [] 79 | 80 | PartialState().wait_for_everyone() 81 | 82 | with PartialState().split_between_processes(self.samples) as samples: 83 | for example in tqdm(samples, desc="Generating examples"): 84 | input_ids = ( 85 | torch.tensor(example["input_ids"]).unsqueeze(0).to(model.device) 86 | ) 87 | attention_mask = ( 88 | torch.tensor(example["attention_mask"]) 89 | .unsqueeze(0) 90 | .to(model.device) 91 | ) 92 | 93 | with torch.no_grad(): 94 | outputs = model.generate( 95 | input_ids, 96 | attention_mask=attention_mask, 97 | use_cache=not self.is_deepspeed_zero3, 98 | synced_gpus=self.is_deepspeed_zero3, 99 | max_new_tokens=self.max_new_tokens, 100 | ) 101 | 102 | prompt_text = self.tokenizer.decode( 103 | input_ids.squeeze(), skip_special_tokens=False 104 | ) 105 | completion_text = self.tokenizer.decode( 106 | outputs.squeeze(), skip_special_tokens=False 107 | )[len(prompt_text) :] 108 | 109 | pred_dict = {"Prompt": prompt_text, "Completion": completion_text} 110 | if "chosen" in example.keys(): 111 | pred_dict["Chosen"] = example["chosen"][0]["content"] 112 | if "rejected" in example.keys(): 113 | pred_dict["Rejected"] = example["rejected"][0]["content"] 114 | records.append(pred_dict) 115 | 116 | # Сохраняем на каждом процессе свою версию 117 | save_dataframe( 118 | pd.DataFrame(records), PartialState().process_index, self.temp_dir_path 119 | ) 120 | 121 | PartialState().wait_for_everyone() 122 | 123 | # печатаем и логируем только на основном потоке 124 | if PartialState().is_main_process: 125 | # Читаем все файлы DataFrame из временной папки и объединяем их 126 | all_files = glob.glob(os.path.join(self.temp_dir_path, "df_*.csv")) 127 | df_list = [pd.read_csv(file) for file in all_files] 128 | combined_df = pd.concat(df_list, ignore_index=True) 129 | 130 | # После объединения можно удалять временную папку 131 | shutil.rmtree(self.temp_dir_path) 132 | 133 | if self.logger_backend == "clearml": 134 | task = Task.current_task() 135 | if task: 136 | logger = task.get_logger() 137 | logger.report_table( 138 | "Generated Text Samples", 139 | "DataFrame", 140 | iteration=state.global_step, 141 | table_plot=combined_df, 142 | ) 143 | elif self.logger_backend == "wandb": 144 | wandb.log( 145 | { 146 | f"eval/generated_text_{state.global_step}": wandb.Table( 147 | dataframe=combined_df 148 | ) 149 | } 150 | ) 151 | 152 | pretty_print_dataframe(combined_df.sample(3)) 153 | -------------------------------------------------------------------------------- /src/callbacks/training_parameters_callback.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from accelerate import PartialState 4 | from accelerate.utils import gather_object 5 | from transformers import ( 6 | TrainerCallback, 7 | ) 8 | from transformers.utils import logging 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | def count_model_parameters(model): 14 | """ 15 | Returns a tuple (total_params, trainable_params), 16 | where total_params is the total number of parameters, 17 | and trainable_params is the number of parameters with requires_grad=True. 18 | """ 19 | total_params = sum(p.numel() for p in model.parameters()) 20 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | return total_params, trainable_params 22 | 23 | 24 | def normalize_module_name(module_name): 25 | """ 26 | Converts the module name into a normalized (grouping) form, 27 | replacing numeric indices with the "X" character. Example: 28 | "model.layers.0.self_attn" -> "model.layers.X.self_attn" 29 | """ 30 | if not module_name: 31 | return "" 32 | normalized = re.sub(r"\b\d+\b", "X", module_name) 33 | return normalized 34 | 35 | 36 | def compute_module_trainable_stats(model): 37 | stats = {} 38 | for param_name, param in model.named_parameters(): 39 | total = param.numel() 40 | trainable = param.numel() if param.requires_grad else 0 41 | norm_name = normalize_module_name(param_name) 42 | if norm_name in stats: 43 | prev_total, prev_trainable = stats[norm_name] 44 | stats[norm_name] = (prev_total + total, prev_trainable + trainable) 45 | else: 46 | stats[norm_name] = (total, trainable) 47 | return stats 48 | 49 | 50 | class ParameterStatsCallback(TrainerCallback): 51 | """ 52 | Callback for Trainer that logs the following before training begins: 53 | - Total number of parameters 54 | - Number of trainable parameters and the percentage of trainable parameters 55 | - List of "deduplicated" modules with the percentage of trainable parameters per group 56 | """ 57 | 58 | def on_train_begin(self, args, state, control, **kwargs): 59 | model = kwargs.get("model", None) 60 | if model is None: 61 | trainer = kwargs.get("trainer", None) 62 | if trainer is not None: 63 | model = trainer.model 64 | 65 | total_params, trainable_params = count_model_parameters(model) 66 | percent = 100 * trainable_params / total_params if total_params > 0 else 0 67 | 68 | print( 69 | f"\n===== Model view inside Trainer (process: {PartialState().process_index}) =====" 70 | ) 71 | print(model) 72 | 73 | if PartialState().is_main_process: 74 | print(f"Total number of parameters : {total_params:,}") 75 | print(f"Number of trainable parameters : {trainable_params:,}") 76 | print(f"Percentage of trainable parameters: {percent:.2f}%\n") 77 | 78 | print( 79 | f"\n===== Model parameter statistics (process: {PartialState().process_index}) =====" 80 | ) 81 | 82 | module_stats = compute_module_trainable_stats(model) 83 | module_stats_percent = [] 84 | for mod_name, (mod_total, mod_trainable) in module_stats.items(): 85 | mod_percent = 100 * mod_trainable / mod_total if mod_total > 0 else 0 86 | module_stats_percent.append( 87 | (mod_name, mod_total, mod_trainable, mod_percent) 88 | ) 89 | 90 | module_stats_percent.sort(key=lambda x: x[3], reverse=True) 91 | 92 | print(f"List of module groups with normalized names:") 93 | for mod_name, mod_total, mod_trainable, mod_percent in module_stats_percent: 94 | print( 95 | f" {mod_name:30s} - trainable: {mod_trainable:,} / {mod_total:,} ({mod_percent:.2f}%)" 96 | ) 97 | print("========================================\n") 98 | -------------------------------------------------------------------------------- /src/collators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/collators/__init__.py -------------------------------------------------------------------------------- /src/collators/completions_only.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union, List, Any, Dict 3 | 4 | import torch 5 | from transformers import DataCollatorForLanguageModeling 6 | 7 | from src.utils.array_utils import filter_indices 8 | 9 | 10 | class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): 11 | """ 12 | Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' 13 | when they do not come from the assistant. This ensure that the loss is only 14 | calculated on the completion made by the assistant. 15 | 16 | Args: 17 | response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like 18 | '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response 19 | differently if it does not have proper context. 20 | mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying 21 | `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present 22 | for flexibility and backwards-compatibility. 23 | ignore_index (`int`, *optional*, defaults to `-100`): 24 | The index to use to ignore the initial tokens with 25 | """ 26 | 27 | def __init__( 28 | self, 29 | response_prompt_template: Union[str, List[int]], 30 | *args, 31 | mlm: bool = False, 32 | ignore_index: int = -100, 33 | **kwargs, 34 | ): 35 | super().__init__(*args, mlm=mlm, **kwargs) 36 | 37 | self.response_prompt_template = response_prompt_template 38 | 39 | if isinstance(response_prompt_template, str): 40 | self.response_token_ids = self.tokenizer.encode( 41 | self.response_prompt_template, add_special_tokens=False 42 | ) 43 | else: 44 | self.response_token_ids = self.response_prompt_template 45 | 46 | self.eos_token_id = self.tokenizer.eos_token_id 47 | 48 | if not self.mlm and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: 49 | warnings.warn( 50 | "The pad_token_id and eos_token_id values of this tokenizer are identical. " 51 | "If you are planning for multi-turn training, " 52 | "it can result in the model continuously generating questions and answers without eos token. " 53 | "To avoid this, set the pad_token_id to a different value." 54 | ) 55 | 56 | self.ignore_index = ignore_index 57 | 58 | def torch_call( 59 | self, examples: List[Union[List[int], Any, Dict[str, Any]]] 60 | ) -> Dict[str, Any]: 61 | batch = super().torch_call(examples) 62 | 63 | for i in range(len(examples)): 64 | response_token_ids_start_indexes = [] 65 | eos_token_indexes = [] 66 | 67 | for idx in torch.where(batch["labels"][i] == self.response_token_ids[0])[0]: 68 | if ( 69 | self.response_token_ids 70 | == batch["labels"][i][ 71 | idx : idx + len(self.response_token_ids) 72 | ].tolist() 73 | ): 74 | response_token_ids_start_indexes.append(idx.item()) 75 | 76 | for idx in torch.where(batch["labels"][i] == self.eos_token_id)[0]: 77 | eos_token_indexes.append(idx.item()) 78 | 79 | eos_token_indexes = filter_indices( 80 | response_token_ids_start_indexes, eos_token_indexes 81 | ) 82 | 83 | if not response_token_ids_start_indexes or not eos_token_indexes: 84 | warnings.warn( 85 | f"Could not find response key `{self.response_prompt_template}` in the " 86 | f"following instance: {self.tokenizer.decode(batch['input_ids'][i])} " 87 | f"This instance will be ignored in loss calculation. " 88 | f"Note, if this happens often, consider increasing the `max_seq_length`." 89 | ) 90 | batch["labels"][i, :] = self.ignore_index 91 | else: 92 | new_labels = torch.full_like(batch["labels"][i], self.ignore_index).to( 93 | device=batch["labels"][i].device 94 | ) 95 | 96 | for start, end in zip( 97 | response_token_ids_start_indexes, eos_token_indexes 98 | ): 99 | new_labels[start : end + 1] = batch["labels"][i, start : end + 1] 100 | 101 | batch["labels"][i] = new_labels 102 | 103 | return batch 104 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/configs/__init__.py -------------------------------------------------------------------------------- /src/configs/additional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/configs/additional/__init__.py -------------------------------------------------------------------------------- /src/configs/additional/classification_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class CLFScriptArguments(CommonScriptArguments): 8 | def __post_init__(self): 9 | self.project_name = ( 10 | "classification" 11 | if self.project_name == "default-project" 12 | else self.project_name 13 | ) 14 | -------------------------------------------------------------------------------- /src/configs/additional/common_script_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | 5 | @dataclass 6 | class CommonScriptArguments: 7 | dataset: str | List[str] = field( 8 | default="/path/to/dataset", 9 | metadata={ 10 | "help": "The name on HF or path to jsonl file of the dataset to use. Can be a list of paths." 11 | }, 12 | ) 13 | dataset_ratio: float | None = field( 14 | default=None, 15 | metadata={ 16 | "help": "How much of dataset should we take. Each ratio should be between 0 and 1" 17 | }, 18 | ) 19 | test_size: float = field( 20 | default=None, 21 | metadata={ 22 | "help": "Test set split proportion (like 0.05). If dataset already contain test split leave empty" 23 | }, 24 | ) 25 | project_name: str | None = field( 26 | default="default-project", 27 | metadata={"help": "Name of logging project (wandb or clearml)"}, 28 | ) 29 | pad_token: str | None = field(default=None, metadata={"help": "Special pad token"}) 30 | bos_token: str | None = field(default=None, metadata={"help": "Special bos token"}) 31 | eos_token: str | None = field(default=None, metadata={"help": "Special eos token"}) 32 | chat_template: str | None = field( 33 | default="{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", 34 | metadata={"help": "Chat template for the model"}, 35 | ) 36 | force_chat_template: bool = field( 37 | default=False, 38 | metadata={"help": "Force custom chat template from chat_template argument"}, 39 | ) 40 | added_special_tokens: List[str] | None = field( 41 | default=None, metadata={"help": "Additional special tokens"} 42 | ) 43 | unfreeze_layers_patterns: List[str] | None = field( 44 | default=None, 45 | metadata={"help": "Patterns of layer names needed to be unfreeze for learning"}, 46 | ) 47 | -------------------------------------------------------------------------------- /src/configs/additional/cpo_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class CPOScriptArguments(CommonScriptArguments): 8 | generate_eval_examples: bool | None = field( 9 | default=True, metadata={"help": "Do generate examples on eval"} 10 | ) 11 | num_gen_examples: int | None = field( 12 | default=50, metadata={"help": "Number of examples to generate on eval phase"} 13 | ) 14 | 15 | def __post_init__(self): 16 | self.project_name = ( 17 | "cpo-tuning" 18 | if self.project_name == "default-project" 19 | else self.project_name 20 | ) 21 | -------------------------------------------------------------------------------- /src/configs/additional/dpo_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class DPOScriptArguments(CommonScriptArguments): 8 | generate_eval_examples: bool | None = field( 9 | default=True, metadata={"help": "Do generate examples on eval"} 10 | ) 11 | num_gen_examples: int | None = field( 12 | default=50, metadata={"help": "Number of examples to generate on eval phase"} 13 | ) 14 | 15 | def __post_init__(self): 16 | self.project_name = ( 17 | "dpo-tuning" 18 | if self.project_name == "default-project" 19 | else self.project_name 20 | ) 21 | -------------------------------------------------------------------------------- /src/configs/additional/gpo_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class GPOScriptArguments(CommonScriptArguments): 8 | generate_eval_examples: bool | None = field( 9 | default=True, metadata={"help": "Do generate examples on eval"} 10 | ) 11 | num_gen_examples: int | None = field( 12 | default=50, metadata={"help": "Number of examples to generate on eval phase"} 13 | ) 14 | 15 | def __post_init__(self): 16 | self.project_name = ( 17 | "gpo-tuning" 18 | if self.project_name == "default-project" 19 | else self.project_name 20 | ) 21 | -------------------------------------------------------------------------------- /src/configs/additional/orpo_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class ORPOScriptArguments(CommonScriptArguments): 8 | generate_eval_examples: bool | None = field( 9 | default=True, metadata={"help": "Do generate examples on eval"} 10 | ) 11 | num_gen_examples: int | None = field( 12 | default=50, metadata={"help": "Number of examples to generate on eval phase"} 13 | ) 14 | 15 | def __post_init__(self): 16 | self.project_name = ( 17 | "orpo-tuning" 18 | if self.project_name == "default-project" 19 | else self.project_name 20 | ) 21 | -------------------------------------------------------------------------------- /src/configs/additional/reward_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class RMScriptArguments(CommonScriptArguments): 8 | def __post_init__(self): 9 | self.project_name = ( 10 | "reward-modeling" 11 | if self.project_name == "default-project" 12 | else self.project_name 13 | ) 14 | -------------------------------------------------------------------------------- /src/configs/additional/sft_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class SFTScriptArguments(CommonScriptArguments): 8 | conversation_field: str | None = field( 9 | default="prompt", 10 | metadata={ 11 | "help": "Field in dataset with conversations (in list of dicts format)" 12 | }, 13 | ) 14 | system_prompt: str | None = field( 15 | default=None, 16 | metadata={ 17 | "help": "Will use system prompt if there is no one in dialogue, set to None to disable" 18 | }, 19 | ) 20 | train_only_on_completions: bool | None = field( 21 | default=True, metadata={"help": "Do train only on completions or not"} 22 | ) 23 | generate_eval_examples: bool | None = field( 24 | default=True, metadata={"help": "Do generate examples on eval"} 25 | ) 26 | assistant_message_template: str | None = field( 27 | default="<|start_header_id|>assistant<|end_header_id|>\n\n", 28 | metadata={ 29 | "help": "Assistant message template for the training only on completions" 30 | }, 31 | ) 32 | num_gen_examples: int | None = field( 33 | default=50, metadata={"help": "Number of examples to generate on eval phase"} 34 | ) 35 | model_support_system_role: bool | None = field( 36 | default=True, 37 | metadata={ 38 | "help": "Flag that indicates if model have support for system prompt. If not, will use user for setting system prompt" 39 | }, 40 | ) 41 | 42 | def __post_init__(self): 43 | self.project_name = ( 44 | "sft-tuning" 45 | if self.project_name == "default-project" 46 | else self.project_name 47 | ) 48 | -------------------------------------------------------------------------------- /src/configs/additional/smpo_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | 6 | @dataclass 7 | class SMPOScriptArguments(CommonScriptArguments): 8 | generate_eval_examples: bool | None = field( 9 | default=True, metadata={"help": "Do generate examples on eval"} 10 | ) 11 | num_gen_examples: int | None = field( 12 | default=50, metadata={"help": "Number of examples to generate on eval phase"} 13 | ) 14 | 15 | def __post_init__(self): 16 | self.project_name = ( 17 | "smpo-tuning" 18 | if self.project_name == "default-project" 19 | else self.project_name 20 | ) 21 | -------------------------------------------------------------------------------- /src/configs/classificaion_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | from transformers import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ClassificationConfig(TrainingArguments): 9 | r""" 10 | Configuration class for the [`ClassificationTrainer`]. 11 | 12 | Using [`~transformers.HfArgumentParser`] we can turn this class into 13 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 14 | command line. 15 | 16 | Parameters: 17 | max_length (`int` or `None`, *optional*, defaults to `1024`): 18 | Maximum length of the sequences in the batch, filters out entries that exceed the limit. 19 | disable_dropout (`bool`, *optional*, defaults to `True`): 20 | Whether to disable dropout in the model. 21 | dataset_num_proc (`int`, *optional*, defaults to `None`): 22 | Number of processes to use for processing the dataset. 23 | remove_unused_columns (`bool`, *optional*, defaults to `False`): 24 | Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if 25 | the dataset is pretokenized. 26 | """ 27 | 28 | max_length: Optional[int] = field( 29 | default=1024, 30 | metadata={ 31 | "help": "Maximum length of the sequences in the batch, filters out entries that exceed the limit." 32 | }, 33 | ) 34 | num_labels: Optional[int] = field( 35 | default=2, 36 | metadata={"help": "Number of classes used in dataset labels"}, 37 | ) 38 | disable_dropout: bool = field( 39 | default=True, 40 | metadata={"help": "Whether to disable dropout in the model."}, 41 | ) 42 | dataset_num_proc: Optional[int] = field( 43 | default=None, 44 | metadata={"help": "Number of processes to use for processing the dataset."}, 45 | ) 46 | remove_unused_columns: bool = field( 47 | default=False, 48 | metadata={ 49 | "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " 50 | "if the dataset is pretokenized." 51 | }, 52 | ) 53 | -------------------------------------------------------------------------------- /src/configs/gpo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Literal, Optional 3 | from transformers import TrainingArguments 4 | 5 | 6 | @dataclass 7 | class GroupedPOConfig(TrainingArguments): 8 | r""" 9 | GroupedPOConfig collects all training arguments related to the [`GroupedPOTrainer`] class. 10 | 11 | Using [`HfArgumentParser`] we can turn this class into 12 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 13 | command line. 14 | 15 | Parameters: 16 | max_prompt_length (`int`, defaults to `None`): 17 | The maximum length of the prompt. This argument is required if you want to use the default data collator. 18 | lower_clip_percentile (`Optional[float]`, defaults to 0.02): 19 | Lower percentile of token log probs value allowed for PO loss calculation for rejected completions. Works like winsorizing. Recommended range [0.01, 0.05] 20 | min_log_prob (`Optional[float]`, defaults to -2.3): 21 | Lowest possible token log prob value allowed in rejected completions. Will clip all log probs, works after percentile winsorizing. 22 | upper_clip_percentile (`Optional[float]`, defaults to `None`): 23 | Upper percentile of token log probs value allowed for PO loss calculation for chosen completions. Works like winsorizing. Recommended range [0.95, 0.99] 24 | label_pad_token_id (`int`, defaults to `-100`): 25 | The label pad token id. This argument is required if you want to use the default data collator. 26 | padding_value (`int`, defaults to `None`): 27 | The padding value if it is different to the tokenizer's pad_token_id. 28 | disable_dropout (`bool`, defaults to `True`): 29 | Whether or not to disable dropouts in `model`. 30 | model_init_kwargs (`Optional[Dict]`, *optional*): 31 | Dict of Optional kwargs to pass when instantiating the model from a string 32 | dataset_num_proc (`Optional[int]`, *optional*): 33 | The number of workers to use to tokenize the data. Defaults to None. 34 | """ 35 | 36 | max_prompt_length: Optional[int] = 512 37 | max_completion_length: Optional[int] = 1024 38 | 39 | kl_beta: float = 0.0 40 | # lower_clip_percentile: Optional[float] = 0.02 41 | # upper_clip_percentile: Optional[float] = None 42 | # min_log_prob: Optional[float] = -2.3 43 | 44 | disable_dropout: bool = True 45 | label_pad_token_id: int = -100 46 | padding_value: int = None 47 | 48 | model_init_kwargs: Optional[Dict] = None 49 | 50 | dataset_num_proc: Optional[int] = None 51 | -------------------------------------------------------------------------------- /src/configs/prompts_optimization_comfig.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, List 3 | 4 | 5 | @dataclass 6 | class PromptsOptimizationConfig: 7 | num_prompts: Optional[int] = field( 8 | default=3, 9 | metadata={"help": "Number of optimized prompts."}, 10 | ) 11 | prompt_len: Optional[int] = field( 12 | default=128, 13 | metadata={"help": "Prompt length of each prompt"}, 14 | ) 15 | dissim_coef: Optional[float] = field( 16 | default=0.3, 17 | metadata={"help": "Used in aux loss for prompts similarity penalty"}, 18 | ) 19 | special_token_coef: Optional[float] = field( 20 | default=0.8, 21 | metadata={ 22 | "help": "Used in aux loss for penalty of using forbidden (special) tokens" 23 | }, 24 | ) 25 | gumbel_temp: Optional[float] = field( 26 | default=0.5, 27 | metadata={"help": "Temperature for gumbel softmax trick"}, 28 | ) 29 | gumbel_noise_scale: Optional[float] = field( 30 | default=0.05, 31 | metadata={"help": "Multiplier of added gumbel noise inside softmax"}, 32 | ) 33 | forbidden_token_ids: Optional[List[int]] = field( 34 | default=None, 35 | metadata={"help": "List of ids of forbidden tokens in created prompts"}, 36 | ) 37 | inserted_chat_role: str = field( 38 | default="system", 39 | metadata={"help": "Chat role used for templating of created prompts insertion"}, 40 | ) 41 | fused_forward: bool = field( 42 | default=True, 43 | metadata={ 44 | "help": "Use full in-batch forward, instead of for loop, memory usage increase." 45 | }, 46 | ) 47 | init_prompt: Optional[str] = field( 48 | default=None, 49 | metadata={"help": "Prompt to init optimization from"}, 50 | ) 51 | -------------------------------------------------------------------------------- /src/configs/smpo_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Literal, Optional 3 | from transformers import TrainingArguments 4 | 5 | 6 | @dataclass 7 | class SimpleMarginPOConfig(TrainingArguments): 8 | r""" 9 | SimpleMarginPOConfig collects all training arguments related to the [`MarginPOTrainer`] class. 10 | 11 | Using [`HfArgumentParser`] we can turn this class into 12 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 13 | command line. 14 | 15 | Parameters: 16 | max_length (`int`, defaults to `None`): 17 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 18 | max_prompt_length (`int`, defaults to `None`): 19 | The maximum length of the prompt. This argument is required if you want to use the default data collator. 20 | max_target_length (`int`, defaults to `None`): 21 | The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. 22 | beta (`float`, defaults to 1.2): 23 | The beta factor in SimpleMarginPO loss. 24 | target_margin (`float`, defaults to 0.35): 25 | The target reward margin in SimpleMarginPO loss. Can be zero for sigmoid and hinge losses. 26 | chosen_sft_ratio (`float`, defaults to 0.8): 27 | SFT loss balance weight between chosen and rejected, used in the SimpleMarginPO loss (1.0 will use maximum of chosen loss and zero of rejected loss). 28 | loss_type (`str`, defaults to `smooth_lower_bound`): 29 | The type of loss to use. This argument is required if you want to use the default data collator. 30 | use_margin_schedule (`float`, defaults to `True`): 31 | The margin will gradually increase (linear schedule) from near zero to the target value during training. 32 | lower_clip_percentile (`Optional[float]`, defaults to 0.02): 33 | Lower percentile of token log probs value allowed for PO loss calculation for rejected completions. Works like winsorizing. Recommended range [0.01, 0.05] 34 | min_log_prob (`Optional[float]`, defaults to -2.3): 35 | Lowest possible token log prob value allowed in rejected completions. Will clip all log probs, works after percentile winsorizing. 36 | upper_clip_percentile (`Optional[float]`, defaults to `None`): 37 | Upper percentile of token log probs value allowed for PO loss calculation for chosen completions. Works like winsorizing. Recommended range [0.95, 0.99] 38 | label_pad_token_id (`int`, defaults to `-100`): 39 | The label pad token id. This argument is required if you want to use the default data collator. 40 | padding_value (`int`, defaults to `None`): 41 | The padding value if it is different to the tokenizer's pad_token_id. 42 | truncation_mode (`str`, defaults to `keep_end`): 43 | The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. 44 | is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): 45 | If no model is provided, we need to know if the model_init returns an encoder-decoder. 46 | disable_dropout (`bool`, defaults to `True`): 47 | Whether or not to disable dropouts in `model`. 48 | model_init_kwargs (`Optional[Dict]`, *optional*): 49 | Dict of Optional kwargs to pass when instantiating the model from a string 50 | dataset_num_proc (`Optional[int]`, *optional*): 51 | The number of workers to use to tokenize the data. Defaults to None. 52 | """ 53 | 54 | max_length: Optional[int] = None 55 | max_prompt_length: Optional[int] = None 56 | max_completion_length: Optional[int] = None 57 | max_target_length: Optional[int] = None 58 | 59 | beta: float = 1.2 60 | target_margin: float = 0.35 61 | chosen_sft_ratio: float = 0.8 62 | lower_clip_percentile: Optional[float] = 0.02 63 | upper_clip_percentile: Optional[float] = None 64 | min_log_prob: Optional[float] = -2.3 65 | loss_type: Literal["sigmoid", "hinge", "ipo", "smooth_lower_bound"] = ( 66 | "smooth_lower_bound" 67 | ) 68 | use_margin_schedule: bool = True 69 | 70 | disable_dropout: bool = True 71 | label_pad_token_id: int = -100 72 | padding_value: int = None 73 | truncation_mode: str = "keep_end" 74 | is_encoder_decoder: Optional[bool] = None 75 | 76 | model_init_kwargs: Optional[Dict] = None 77 | 78 | dataset_num_proc: Optional[int] = None 79 | -------------------------------------------------------------------------------- /src/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/trainers/__init__.py -------------------------------------------------------------------------------- /src/trainers/prompts_optimization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/trainers/prompts_optimization/__init__.py -------------------------------------------------------------------------------- /src/trainers/prompts_optimization/prompts_reward_trainer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Union, Any, Optional, Dict, Literal 3 | 4 | import pandas as pd 5 | import torch 6 | from torch import nn as nn 7 | from transformers import ( 8 | PreTrainedModel, 9 | PreTrainedTokenizerBase, 10 | Trainer, 11 | ) 12 | from trl import RewardTrainer, RewardConfig 13 | from trl.trainer.utils import log_table_to_comet_experiment, print_rich_table 14 | 15 | from src.callbacks.attr_scheduling import VariableSchedulerCallback 16 | from src.configs.prompts_optimization_comfig import PromptsOptimizationConfig 17 | from src.trainers.prompts_optimization.vq_prompts_tuner_module import ( 18 | PromptCodebookTuner, 19 | ) 20 | 21 | 22 | class PromptsRewardTrainer(RewardTrainer): 23 | def __init__( 24 | self, 25 | model: Union[PreTrainedModel, nn.Module], 26 | args: RewardConfig, 27 | prompt_args: PromptsOptimizationConfig, 28 | tokenizer: PreTrainedTokenizerBase, 29 | **kwargs, 30 | ): 31 | self.prompt_args = prompt_args 32 | 33 | # Wrap the model with PromptCodebookTuner 34 | tuned_model = PromptCodebookTuner( 35 | model=model, 36 | tokenizer=tokenizer, 37 | num_prompts=prompt_args.num_prompts, 38 | prompt_len=prompt_args.prompt_len, 39 | forbidden_token_ids=prompt_args.forbidden_token_ids, 40 | dissim_coef=prompt_args.dissim_coef, 41 | special_token_coef=prompt_args.special_token_coef, 42 | role=prompt_args.inserted_chat_role, 43 | init_prompt=prompt_args.init_prompt, 44 | fused_forward=prompt_args.fused_forward, 45 | gumbel_temp=prompt_args.gumbel_temp, 46 | gumbel_noise_scale=prompt_args.gumbel_noise_scale, 47 | ) 48 | 49 | # Initialize the parent RewardTrainer with the tuned_model and other parameters 50 | super().__init__(model=tuned_model, args=args, tokenizer=tokenizer, **kwargs) 51 | 52 | # Initialize stored metrics 53 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 54 | 55 | self.log_codebook_prompts(no_gumbel=False) 56 | 57 | self.add_callback( 58 | VariableSchedulerCallback( 59 | attribute_name="gumbel_temp", 60 | initial_value=prompt_args.gumbel_temp, 61 | final_value=0.005, 62 | schedule_type="cosine", 63 | ) 64 | ) 65 | 66 | def store_metrics( 67 | self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" 68 | ) -> None: 69 | """Store metrics for later logging.""" 70 | for key, value in metrics.items(): 71 | self._stored_metrics[train_eval][key].append(value) 72 | 73 | def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: 74 | """Log metrics, including stored ones.""" 75 | # Determine if we're in train or eval mode 76 | train_eval = "train" if "loss" in logs else "eval" 77 | 78 | # Add stored metrics to the logs 79 | for key, values in self._stored_metrics[train_eval].items(): 80 | if values: 81 | logs[key] = torch.stack(values).mean().item() 82 | # Clear stored metrics after logging 83 | self._stored_metrics[train_eval].clear() 84 | 85 | # Call the original log method 86 | super().log(logs, start_time) 87 | 88 | def compute_loss( 89 | self, 90 | model: Union[PreTrainedModel, nn.Module], 91 | inputs: dict[str, Union[torch.Tensor, Any]], 92 | return_outputs: bool = False, 93 | **kwargs, 94 | ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: 95 | # Forward passes and loss calculation as before 96 | outputs_chosen = model( 97 | input_ids=inputs["input_ids_chosen"], 98 | attention_mask=inputs.get("attention_mask_chosen", None), 99 | ) 100 | logits_chosen = outputs_chosen["logits"].squeeze( 101 | -1 102 | ) # (num_prompts, batch_size) 103 | 104 | outputs_rejected = model( 105 | input_ids=inputs["input_ids_rejected"], 106 | attention_mask=inputs.get("attention_mask_rejected", None), 107 | cached_noise=outputs_chosen["cached_noise"], 108 | ) 109 | logits_rejected = outputs_rejected["logits"].squeeze( 110 | -1 111 | ) # (num_prompts, batch_size) 112 | 113 | diff = logits_chosen - logits_rejected 114 | if "margin" in inputs: 115 | margin = inputs["margin"].unsqueeze(0) 116 | diff -= margin 117 | 118 | loss_per_element = -nn.functional.logsigmoid(diff) 119 | loss_per_prompt = loss_per_element.mean(dim=1) 120 | total_pairwise_loss = loss_per_prompt.mean() 121 | 122 | # Center rewards regularization 123 | if self.args.center_rewards_coefficient is not None: 124 | sum_sq = (logits_chosen + logits_rejected) ** 2 125 | center_loss = sum_sq.mean() 126 | total_pairwise_loss += self.args.center_rewards_coefficient * center_loss 127 | 128 | # Combine with auxiliary loss 129 | aux_loss = outputs_chosen.get("aux_loss", 0.0) 130 | total_loss = total_pairwise_loss + aux_loss 131 | 132 | # Calculate accuracy per prompt 133 | accuracy_per_prompt = (diff > 0).float().mean(dim=1) 134 | 135 | # Collect metrics 136 | metrics = {} 137 | for i in range(self.prompt_args.num_prompts): 138 | metrics[f"loss_prompt_{i}"] = loss_per_prompt[i].detach().cpu() 139 | metrics[f"accuracy_prompt_{i}"] = accuracy_per_prompt[i].detach().cpu() 140 | metrics["aux_loss"] = aux_loss.detach().cpu() 141 | metrics["gumbel_temp"] = torch.tensor( 142 | self.accelerator.unwrap_model(model).gumbel_temp 143 | ) 144 | metrics["gumbel_noise_scale"] = ( 145 | self.accelerator.unwrap_model(model).gumbel_noise_scale.data.detach().cpu() 146 | ) 147 | metrics["mean_pairwise_loss"] = total_pairwise_loss.detach().cpu() 148 | metrics["mean_accuracy"] = accuracy_per_prompt.mean().detach().cpu() 149 | 150 | # Store metrics based on current phase 151 | train_eval = "eval" if return_outputs else "train" 152 | self.store_metrics(metrics, train_eval=train_eval) 153 | 154 | if return_outputs: 155 | rewards_chosen_avg = logits_chosen[0].unsqueeze( 156 | -1 157 | ) # eval only first prompt 158 | rewards_rejected_avg = logits_rejected[0].unsqueeze(-1) 159 | return total_loss, { 160 | "rewards_chosen": rewards_chosen_avg, 161 | "rewards_rejected": rewards_rejected_avg, 162 | "aux_loss": aux_loss, 163 | } 164 | 165 | return total_loss 166 | 167 | def prediction_step( 168 | self, 169 | model: Union[PreTrainedModel, nn.Module], 170 | inputs: dict[str, Union[torch.Tensor, Any]], 171 | prediction_loss_only: bool, 172 | ignore_keys: Optional[list[str]] = None, 173 | ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 174 | # Ensure aux_loss is ignored in logits processing 175 | if ignore_keys is None: 176 | ignore_keys = [] 177 | ignore_keys.append("aux_loss") 178 | return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) 179 | 180 | def evaluate(self, *args, **kwargs): 181 | # num_print_samples = kwargs.pop("num_print_samples", 4) 182 | # self.visualize_samples(num_print_samples) 183 | self.log_codebook_prompts() # Log codebook prompts during evaluation 184 | return Trainer.evaluate(self, *args, **kwargs) 185 | 186 | def log_codebook_prompts(self, no_gumbel=True): 187 | """Log current codebook prompts to console and logging services.""" 188 | if self.accelerator.process_index != 0: 189 | return # Only log from main process 190 | 191 | # Retrieve current codebook prompts 192 | prompts_info = self.model.get_codebook_tokens( 193 | return_strings=True, no_gumbel=no_gumbel 194 | ) 195 | prompts = prompts_info["prompts"] 196 | tokens = prompts_info["tokens"] 197 | 198 | # Create DataFrame for logging 199 | df = pd.DataFrame( 200 | {"Prompt Index": range(len(prompts)), "Prompt": prompts, "Tokens": tokens} 201 | ) 202 | 203 | # Print to console 204 | print(f"\nCurrent Codebook Prompts (no_gumbel: {no_gumbel}):") 205 | print_rich_table(df) 206 | 207 | # Log to WandB 208 | if "wandb" in self.args.report_to: 209 | import wandb 210 | 211 | if wandb.run is not None: 212 | wandb.log({"codebook_prompts": wandb.Table(dataframe=df)}) 213 | 214 | # Log to Comet.ml 215 | if "comet_ml" in self.args.report_to: 216 | log_table_to_comet_experiment( 217 | name="codebook_prompts.csv", 218 | table=df, 219 | ) 220 | -------------------------------------------------------------------------------- /src/trainers/prompts_optimization/prompts_sft_trainer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Union, Any, Optional, Dict, Literal 3 | 4 | import pandas as pd 5 | import torch 6 | from torch import nn as nn 7 | from transformers import ( 8 | PreTrainedModel, 9 | PreTrainedTokenizerBase, 10 | Trainer, 11 | ) 12 | from trl import RewardTrainer, RewardConfig, SFTTrainer 13 | from trl.trainer.utils import log_table_to_comet_experiment, print_rich_table 14 | 15 | from src.callbacks.attr_scheduling import VariableSchedulerCallback 16 | from src.configs.prompts_optimization_comfig import PromptsOptimizationConfig 17 | from src.trainers.prompts_optimization.vq_prompts_tuner_module import ( 18 | PromptCodebookTuner, 19 | ) 20 | 21 | 22 | class PromptsSFTTrainer(SFTTrainer): 23 | def __init__( 24 | self, 25 | model: Union[PreTrainedModel, nn.Module], 26 | args: RewardConfig, 27 | prompt_args: PromptsOptimizationConfig, 28 | tokenizer: PreTrainedTokenizerBase, 29 | **kwargs, 30 | ): 31 | self.prompt_args = prompt_args 32 | 33 | # Wrap the model with PromptCodebookTuner 34 | tuned_model = PromptCodebookTuner( 35 | model=model, 36 | tokenizer=tokenizer, 37 | num_prompts=prompt_args.num_prompts, 38 | prompt_len=prompt_args.prompt_len, 39 | forbidden_token_ids=prompt_args.forbidden_token_ids, 40 | dissim_coef=prompt_args.dissim_coef, 41 | special_token_coef=prompt_args.special_token_coef, 42 | role=prompt_args.inserted_chat_role, 43 | init_prompt=prompt_args.init_prompt, 44 | fused_forward=prompt_args.fused_forward, 45 | gumbel_temp=prompt_args.gumbel_temp, 46 | gumbel_noise_scale=prompt_args.gumbel_noise_scale, 47 | ) 48 | # Initialize the parent RewardTrainer with the tuned_model and other parameters 49 | super().__init__(model=tuned_model, args=args, tokenizer=tokenizer, **kwargs) 50 | 51 | # Initialize stored metrics 52 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 53 | 54 | self.log_codebook_prompts(no_gumbel=False) 55 | 56 | self.add_callback( 57 | VariableSchedulerCallback( 58 | attribute_name="gumbel_temp", 59 | initial_value=prompt_args.gumbel_temp, 60 | final_value=0.005, 61 | schedule_type="cosine", 62 | ) 63 | ) 64 | 65 | def store_metrics( 66 | self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" 67 | ) -> None: 68 | """Store metrics for later logging.""" 69 | for key, value in metrics.items(): 70 | self._stored_metrics[train_eval][key].append(value) 71 | 72 | def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: 73 | """Log metrics, including stored ones.""" 74 | # Determine if we're in train or eval mode 75 | train_eval = "train" if "loss" in logs else "eval" 76 | 77 | # Add stored metrics to the logs 78 | for key, values in self._stored_metrics[train_eval].items(): 79 | if values: 80 | logs[key] = torch.stack(values).mean().item() 81 | # Clear stored metrics after logging 82 | self._stored_metrics[train_eval].clear() 83 | 84 | # Call the original log method 85 | super().log(logs, start_time) 86 | 87 | def compute_loss( 88 | self, 89 | model: Union[PreTrainedModel, nn.Module], 90 | inputs: dict[str, Union[torch.Tensor, Any]], 91 | return_outputs: bool = False, 92 | **kwargs, 93 | ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: 94 | # Forward passes and loss calculation as before 95 | 96 | outputs = model(**inputs) 97 | 98 | loss_per_prompt = outputs["losses"] 99 | total_sft_loss = loss_per_prompt.mean() 100 | 101 | # Combine with auxiliary loss 102 | aux_loss = outputs.get("aux_loss", 0.0) 103 | total_loss = total_sft_loss + aux_loss 104 | 105 | # Collect metrics 106 | metrics = {} 107 | for i in range(self.prompt_args.num_prompts): 108 | metrics[f"loss_prompt_{i}"] = loss_per_prompt[i].detach().cpu() 109 | metrics["aux_loss"] = aux_loss.detach().cpu() 110 | metrics["gumbel_temp"] = torch.tensor( 111 | self.accelerator.unwrap_model(model).gumbel_temp 112 | ) 113 | metrics["gumbel_noise_scale"] = ( 114 | self.accelerator.unwrap_model(model).gumbel_noise_scale.data.detach().cpu() 115 | ) 116 | metrics["mean_sft_loss"] = total_sft_loss.detach().cpu() 117 | 118 | # Store metrics based on current phase 119 | train_eval = "eval" if return_outputs else "train" 120 | self.store_metrics(metrics, train_eval=train_eval) 121 | 122 | if return_outputs: 123 | return total_loss, { 124 | "logits": outputs["logits"].mean(dim=0), 125 | "aux_loss": aux_loss, 126 | } 127 | 128 | return total_loss 129 | 130 | def prediction_step( 131 | self, 132 | model: Union[PreTrainedModel, nn.Module], 133 | inputs: dict[str, Union[torch.Tensor, Any]], 134 | prediction_loss_only: bool, 135 | ignore_keys: Optional[list[str]] = None, 136 | ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 137 | # Ensure aux_loss is ignored in logits processing 138 | if ignore_keys is None: 139 | ignore_keys = [] 140 | ignore_keys.append("aux_loss") 141 | return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) 142 | 143 | def evaluate(self, *args, **kwargs): 144 | # num_print_samples = kwargs.pop("num_print_samples", 4) 145 | # self.visualize_samples(num_print_samples) 146 | self.log_codebook_prompts() # Log codebook prompts during evaluation 147 | return Trainer.evaluate(self, *args, **kwargs) 148 | 149 | def log_codebook_prompts(self, no_gumbel=True): 150 | """Log current codebook prompts to console and logging services.""" 151 | if self.accelerator.process_index != 0: 152 | return # Only log from main process 153 | 154 | # Retrieve current codebook prompts 155 | prompts_info = self.model.get_codebook_tokens( 156 | return_strings=True, no_gumbel=no_gumbel 157 | ) 158 | prompts = prompts_info["prompts"] 159 | tokens = prompts_info["tokens"] 160 | 161 | # Create DataFrame for logging 162 | df = pd.DataFrame( 163 | {"Prompt Index": range(len(prompts)), "Prompt": prompts, "Tokens": tokens} 164 | ) 165 | 166 | # Print to console 167 | print(f"\nCurrent Codebook Prompts (no_gumbel: {no_gumbel}):") 168 | print_rich_table(df) 169 | 170 | # Log to WandB 171 | if "wandb" in self.args.report_to: 172 | import wandb 173 | 174 | if wandb.run is not None: 175 | wandb.log({"codebook_prompts": wandb.Table(dataframe=df)}) 176 | 177 | # Log to Comet.ml 178 | if "comet_ml" in self.args.report_to: 179 | log_table_to_comet_experiment( 180 | name="codebook_prompts.csv", 181 | table=df, 182 | ) 183 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikhrModels/effective_llm_alignment/c48ca2905ac75f82dea4ba17a9e6995e88b7007f/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/array_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def filter_indices(a, b): 5 | """ 6 | Фильтрует массив b, оставляя только те элементы, которые следуют сразу после элементов массива a. 7 | """ 8 | filtered_b = [] 9 | a_len = len(a) 10 | b_len = len(b) 11 | 12 | # Указатель для массива b 13 | j = 0 14 | 15 | for i in range(a_len): 16 | # Ищем индекс элемента из a в b 17 | while j < b_len and b[j] <= a[i]: 18 | j += 1 19 | # Если следующий элемент в b существует, добавляем его в отфильтрованный массив 20 | if j < b_len: 21 | filtered_b.append(b[j]) 22 | j += 1 # Переходим к следующему элементу в b 23 | 24 | return filtered_b 25 | 26 | 27 | def find_occurrences_v3(arr, subarr): 28 | """ 29 | Находит все позиции вхождения подмассива subarr в массиве arr. 30 | 31 | Параметры: 32 | arr (numpy.ndarray): Исходный массив, в котором выполняется поиск. 33 | subarr (numpy.ndarray): Подмассив, который нужно найти в arr. 34 | 35 | Возвращает: 36 | list: Список индексов, где начинается вхождение subarr в arr. 37 | """ 38 | 39 | # Длина подмассива 40 | m = len(subarr) 41 | 42 | # Выполняем свертку исходного массива с подмассивом 43 | conv_result = np.convolve(arr, subarr[::-1], mode="valid") 44 | 45 | # Вычисляем сумму элементов подмассива 46 | subarr_sum = np.sum(subarr) 47 | 48 | # Выполняем свертку исходного массива с вектором из единиц такой же длины, как подмассив 49 | window_sum = np.convolve(arr, np.ones(m), mode="valid") 50 | 51 | # Индексы, где произведение совпадает 52 | positions = np.where( 53 | (conv_result == subarr_sum * window_sum) & (window_sum == subarr_sum) 54 | )[0] 55 | 56 | return positions.tolist() 57 | -------------------------------------------------------------------------------- /src/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | from accelerate import PartialState 4 | from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets 5 | 6 | 7 | def _load_dataset_from_path(path: str, test_size: float | None) -> DatasetDict: 8 | if path.endswith("jsonl"): 9 | dataset = load_dataset("json", data_files=path) 10 | else: 11 | dataset = load_dataset(path) 12 | if test_size is not None: 13 | dataset = dataset["train"].train_test_split( 14 | test_size, seed=42, load_from_cache_file=True 15 | ) 16 | return dataset 17 | 18 | 19 | def _get_subset_from_dataset(dataset: Dataset, dataset_ratio: float | None) -> Dataset: 20 | indices = np.random.choice( 21 | range(len(dataset)), int(dataset_ratio * len(dataset)), replace=False 22 | ) 23 | dataset = dataset.select(indices) 24 | return dataset 25 | 26 | 27 | def _get_subset_from_dataset_dict( 28 | dataset: DatasetDict, dataset_ratio: float | None 29 | ) -> DatasetDict: 30 | dataset["train"] = _get_subset_from_dataset(dataset["train"], dataset_ratio) 31 | dataset["test"] = _get_subset_from_dataset(dataset["test"], dataset_ratio) 32 | return dataset 33 | 34 | 35 | def load_datasets( 36 | path: str | list, test_size: float | None, dataset_ratio: float | list | None 37 | ): 38 | with PartialState().local_main_process_first(): 39 | if dataset_ratio is None: 40 | warnings.warn( 41 | "You haven't set dataset ratio for your datasets. Assuming that it's 1 for all datasets." 42 | ) 43 | dataset_ratio = [1] * len(path) if isinstance(path, list) else 1 44 | if isinstance(path, list) and not isinstance(dataset_ratio, list): 45 | raise ValueError("You shold pass dataset ratio for all of your datasets.") 46 | if not isinstance(path, list) and isinstance(dataset_ratio, list): 47 | raise ValueError("You shold pass datasets for all of your dataset ratios.") 48 | if ( 49 | isinstance(path, list) 50 | and isinstance(dataset_ratio, list) 51 | and len(path) != len(dataset_ratio) 52 | ): 53 | raise ValueError( 54 | f"You have set {len(path)} datasets and {len(dataset_ratio)} dataset ratios, but it should be equal." 55 | ) 56 | if isinstance(path, list): 57 | all_datasets = [_load_dataset_from_path(d, test_size) for d in path] 58 | truncated_datasets = [ 59 | _get_subset_from_dataset_dict(d, ratio) 60 | for d, ratio in zip(all_datasets, dataset_ratio) 61 | ] 62 | ds = DatasetDict() 63 | ds["train"] = concatenate_datasets([d["train"] for d in truncated_datasets]) 64 | ds["test"] = concatenate_datasets([d["test"] for d in truncated_datasets]) 65 | else: 66 | ds = _load_dataset_from_path(path, test_size) 67 | ds = _get_subset_from_dataset_dict(ds, dataset_ratio) 68 | return ds 69 | 70 | 71 | def prepare_generative_row(row, tokenizer, max_length): 72 | constructed_prompt = tokenizer.apply_chat_template( 73 | row["prompt"], tokenize=False, add_generation_prompt=True 74 | ) 75 | return tokenizer( 76 | constructed_prompt, 77 | truncation=True, 78 | padding=True, 79 | max_length=max_length, 80 | add_special_tokens=False, 81 | ) 82 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import datasets 5 | import transformers 6 | from transformers import TrainingArguments 7 | 8 | 9 | def setup_logging(logger, training_args: TrainingArguments): 10 | if training_args.should_log: 11 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 12 | transformers.logging.set_verbosity_info() 13 | else: 14 | transformers.logging.set_verbosity_error() 15 | log_level = training_args.get_process_log_level() 16 | logger.setLevel(log_level) 17 | -------------------------------------------------------------------------------- /src/utils/model_preparation.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from src.configs.additional.common_script_args import CommonScriptArguments 4 | 5 | try: 6 | import deepspeed 7 | except Exception as e: 8 | print(e) 9 | 10 | import torch 11 | from accelerate import Accelerator 12 | from torch import nn 13 | from transformers import PreTrainedModel, PreTrainedTokenizer 14 | 15 | 16 | def setup_model_and_tokenizer( 17 | args: CommonScriptArguments, 18 | model: PreTrainedModel, 19 | tokenizer: PreTrainedTokenizer, 20 | max_seq_len: int = None, 21 | ): 22 | if max_seq_len is not None: 23 | tokenizer.model_max_length = max_seq_len 24 | if tokenizer.eos_token != args.eos_token: 25 | tokenizer.eos_token = args.eos_token 26 | model.config.eos_token_id = tokenizer.eos_token_id 27 | if model.generation_config: 28 | model.generation_config.eos_token_id = tokenizer.eos_token_id 29 | if ( 30 | tokenizer.bos_token is None or args.bos_token is not None 31 | ) and tokenizer.bos_token != args.bos_token: 32 | tokenizer.bos_token = args.bos_token 33 | model.config.bos_token_id = tokenizer.bos_token_id 34 | if model.generation_config: 35 | model.generation_config.bos_token_id = tokenizer.bos_token_id 36 | if tokenizer.pad_token != args.pad_token: 37 | tokenizer.pad_token = args.pad_token 38 | model.config.pad_token_id = tokenizer.pad_token_id 39 | if model.generation_config: 40 | model.generation_config.pad_token_id = tokenizer.pad_token_id 41 | if tokenizer.chat_template is None or ( 42 | args.chat_template is not None and args.force_chat_template 43 | ): 44 | tokenizer.chat_template = args.chat_template 45 | if args.added_special_tokens is not None: 46 | tokenizer.add_special_tokens( 47 | {"additional_special_tokens": args.added_special_tokens} 48 | ) 49 | model.resize_token_embeddings(len(tokenizer)) 50 | 51 | 52 | def unfreeze_modules_by_patterns(model, patterns): 53 | """ 54 | Замораживает все параметры модели, затем размораживает те модули, 55 | полное имя которых соответствует хотя бы одному паттерну из списка. 56 | 57 | Аргументы: 58 | model: torch.nn.Module – модель (например, экземпляр Qwen2ForSequenceClassification). 59 | patterns: список строк – шаблоны для имен модулей (поддерживаются подстановочные знаки * и ?). 60 | 61 | Пример паттернов: 62 | ["*.mlp.up_proj", "score", "model.layers.0.self_attn.*"] 63 | """ 64 | import fnmatch 65 | 66 | # Сначала замораживаем все параметры 67 | for param in model.parameters(): 68 | param.requires_grad = False 69 | 70 | # Проходим по всем модулям модели. model.named_modules() возвращает пары (имя_модуля, модуль). 71 | for module_name, module in model.named_modules(): 72 | # Для каждого модуля проверяем, удовлетворяет ли его имя какому-нибудь паттерну 73 | for pattern in patterns: 74 | if fnmatch.fnmatch(module_name, pattern): 75 | # Если совпадение найдено – размораживаем все параметры этого модуля 76 | for param in module.parameters(): 77 | param.requires_grad = True 78 | # Если наш модуль уже разморожен – переходим к следующему 79 | break 80 | 81 | 82 | def prepare_ref_model_for_deepspeed( 83 | model: PreTrainedModel | nn.Module, accelerator: Accelerator 84 | ) -> PreTrainedModel | nn.Module: 85 | deepspeed_plugin = accelerator.state.deepspeed_plugin 86 | config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) 87 | if model is not None: 88 | if hasattr(model, "config"): 89 | hidden_size: int | None = ( # type: ignore 90 | max(model.config.hidden_sizes) # type: ignore 91 | if getattr(model.config, "hidden_sizes", None) # type: ignore 92 | else getattr(model.config, "hidden_size", None) # type: ignore 93 | ) 94 | 95 | if ( 96 | hidden_size is not None 97 | and config_kwargs["zero_optimization"]["stage"] == 3 98 | ): 99 | config_kwargs.update( 100 | { 101 | "zero_optimization.reduce_bucket_size": hidden_size 102 | * hidden_size, 103 | "zero_optimization.stage3_param_persistence_threshold": 10 104 | * hidden_size, 105 | "zero_optimization.stage3_prefetch_bucket_size": 0.9 106 | * hidden_size 107 | * hidden_size, 108 | } 109 | ) 110 | 111 | if config_kwargs["zero_optimization"]["stage"] != 3: 112 | config_kwargs["zero_optimization"]["stage"] = 0 113 | # if not "offload_optimizer" in config_kwargs['zero_optimization']: 114 | # config_kwargs['zero_optimization']['offload_optimizer'] = { 115 | # "device": "cpu", 116 | # "pin_memory": True 117 | # } 118 | if "offload_param" in config_kwargs["zero_optimization"]: 119 | del config_kwargs["zero_optimization"]["offload_param"] 120 | 121 | config_kwargs["optimizer"] = {"type": None} 122 | 123 | model, *_ = deepspeed.initialize(model=model, config=config_kwargs) 124 | model.eval() 125 | return model 126 | 127 | 128 | def peft_module_casting_to_bf16(model): 129 | from peft.tuners.tuners_utils import BaseTunerLayer 130 | 131 | for name, module in model.named_modules(): 132 | if isinstance(module, BaseTunerLayer): 133 | module = module.to(torch.bfloat16) 134 | elif isinstance(module, torch.nn.LayerNorm) or "norm" in name: 135 | module = module.to(torch.float32) 136 | elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): 137 | if hasattr(module, "weight"): 138 | if module.weight.dtype == torch.float32: 139 | module = module.to(torch.bfloat16) 140 | -------------------------------------------------------------------------------- /src/utils/yaml_args_parser.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import sys 4 | from dataclasses import dataclass, field 5 | from typing import Any, Dict, List, NewType, Optional, Tuple 6 | 7 | from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser 8 | 9 | import trl 10 | 11 | 12 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 13 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 14 | 15 | 16 | DataClassType = NewType("DataClassType", Any) 17 | 18 | 19 | class H4ArgumentParser(HfArgumentParser): 20 | def parse_yaml_and_args( 21 | self, yaml_arg: str, other_args: Optional[List[str]] = None 22 | ) -> List[dataclass]: 23 | """ 24 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. 25 | 26 | Args: 27 | yaml_arg (`str`): 28 | The path to the config file used 29 | other_args (`List[str]`, *optional`): 30 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. 31 | 32 | Returns: 33 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line 34 | """ 35 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) 36 | 37 | outputs = [] 38 | # strip other args list into dict of key-value pairs 39 | other_args = { 40 | arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args 41 | } 42 | used_args = {} 43 | 44 | # overwrite the default/loaded value with the value provided to the command line 45 | # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 46 | for data_yaml, data_class in zip(arg_list, self.dataclass_types): 47 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} 48 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} 49 | for arg, val in other_args.items(): 50 | # add only if in keys 51 | 52 | if arg in keys: 53 | base_type = data_yaml.__dataclass_fields__[arg].type 54 | inputs[arg] = val 55 | 56 | # cast type for ints, floats (default to strings) 57 | if base_type in [int, float]: 58 | inputs[arg] = base_type(val) 59 | 60 | if base_type == List[str]: 61 | inputs[arg] = [str(v) for v in val.split(",")] 62 | 63 | # bool of a non-empty string is True, so we manually check for bools 64 | if base_type is bool: 65 | if val in ["true", "True"]: 66 | inputs[arg] = True 67 | else: 68 | inputs[arg] = False 69 | 70 | # add to used-args so we can check if double add 71 | if arg not in used_args: 72 | used_args[arg] = val 73 | else: 74 | raise ValueError( 75 | f"Duplicate argument provided: {arg}, may cause unexpected behavior" 76 | ) 77 | 78 | obj = data_class(**inputs) 79 | outputs.append(obj) 80 | 81 | return outputs 82 | 83 | def parse(self) -> DataClassType | Tuple[DataClassType]: 84 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 85 | # If we pass only one argument to the script and it's the path to a YAML file, 86 | # let's parse it to get our arguments. 87 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) 88 | # parse command line args and yaml file 89 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): 90 | output = self.parse_yaml_and_args( 91 | os.path.abspath(sys.argv[1]), sys.argv[2:] 92 | ) 93 | # parse command line args only 94 | else: 95 | output = self.parse_args_into_dataclasses() 96 | 97 | if len(output) == 1: 98 | output = output[0] 99 | return output 100 | -------------------------------------------------------------------------------- /training_configs/classification/controllable-clf-qwen-1.5b-no-tags-full.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" 2 | dataset: "Vikhrmodels/controllable_clf_dataset_no_tags" 3 | per_device_train_batch_size: 4 4 | per_device_eval_batch_size: 6 5 | num_train_epochs: 1 6 | num_labels: 2 7 | save_strategy: "steps" 8 | evaluation_strategy: "steps" 9 | save_steps: 125 10 | eval_steps: 125 11 | save_only_model: True 12 | save_total_limit: 10 13 | learning_rate: 0.00003 14 | weight_decay: 0.05 15 | lr_scheduler_type: "cosine" 16 | warmup_ratio: 0.05 17 | gradient_accumulation_steps: 16 18 | gradient_checkpointing: True 19 | logging_steps: 1 20 | remove_unused_columns: True 21 | dataloader_num_workers: 2 22 | max_length: 16384 23 | attn_implementation: "sdpa" 24 | torch_compile: False 25 | run_name: "controllable-clf-qwen-1.5b-no-tags-full-no-emb" 26 | output_dir: "checkpoints/controllable-clf-qwen-1.5b-no-tags-full-no-emb" 27 | report_to: "wandb" 28 | bf16: True 29 | fp16: False 30 | seed: 42 31 | logging_first_step: True 32 | unfreeze_layers_patterns: 33 | - "score" 34 | - "*.layers.*" 35 | use_peft: False 36 | lora_task_type: SEQ_CLS 37 | lora_target_modules: 38 | - "k_proj" 39 | - "v_proj" 40 | - "q_proj" 41 | - "o_proj" 42 | - "gate_proj" 43 | - "up_proj" 44 | - "down_proj" 45 | lora_modules_to_save: 46 | - "score" 47 | lora_r: 128 48 | lora_alpha: 128 49 | use_rslora: True 50 | pad_token: "<|image_pad|>" 51 | eos_token: "<|im_end|>" 52 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 53 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/preference/rpo-sigmoid-llama-3-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Llama3.1-8B-Instruct-R-01-09-24" 2 | dataset: "data/vikhr-llama-3.1-rm4-scored-answers-rs7-nw2-t0.8-pref.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.000005 12 | lr_scheduler_type: cosine 13 | loss_type: "sigmoid" 14 | gradient_accumulation_steps: 8 15 | gradient_checkpointing: True 16 | logging_steps: 1 17 | remove_unused_columns: True 18 | dataloader_num_workers: 2 19 | dataset_num_proc: 10 20 | max_length: 8192 21 | max_prompt_length: 4096 22 | save_only_model: True 23 | generate_eval_examples: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "rpo-sigmoid-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.1-rpa-1.0" 28 | output_dir: "models/rpo-sigmoid-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.1-rpa-1.0" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 0.1 32 | rpo_alpha: 1.0 33 | reference_free: True 34 | bf16: False 35 | fp16: True 36 | seed: 42 37 | logging_first_step: True 38 | use_peft: True 39 | lora_task_type: CAUSAL_LM 40 | lora_target_modules: 41 | - "k_proj" 42 | - "v_proj" 43 | - "q_proj" 44 | - "o_proj" 45 | - "gate_proj" 46 | - "up_proj" 47 | - "down_proj" 48 | - "lm_head" 49 | lora_r: 64 50 | lora_alpha: 64 51 | pad_token: "<|reserved_special_token_0|>" 52 | eos_token: "<|eot_id|>" 53 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 54 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/simpo-llama-3.1-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Llama3.1-8B-Instruct-R-01-09-24" 2 | dataset: "data/vikhr-llama-3.1-rm4-scored-answers-rs7-nw2-t0.8-pref.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 3 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.000008 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 10 19 | max_length: 8192 20 | max_prompt_length: 4096 21 | save_only_model: True 22 | generate_eval_examples: True 23 | test_size: 0.05 24 | evaluation_strategy: "steps" 25 | eval_steps: 50 26 | run_name: "simpo-smooth_ipo-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.6-gbr0.8-sft-1.0" 27 | output_dir: "models/simpo-smooth_ipo-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.6-gbr0.8-sft-1.0" 28 | warmup_steps: 10 29 | report_to: "wandb" 30 | beta: 0.6 31 | gamma_beta_ratio: 0.8 32 | sft_weight: 1.0 33 | loss_type: "smooth_ipo" 34 | bf16: False 35 | fp16: True 36 | seed: 42 37 | logging_first_step: True 38 | use_peft: True 39 | lora_task_type: CAUSAL_LM 40 | lora_target_modules: 41 | - "k_proj" 42 | - "v_proj" 43 | - "q_proj" 44 | - "o_proj" 45 | - "gate_proj" 46 | - "up_proj" 47 | - "down_proj" 48 | - "lm_head" 49 | lora_r: 64 50 | lora_alpha: 64 51 | pad_token: "<|reserved_special_token_0|>" 52 | eos_token: "<|eot_id|>" 53 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 54 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/slic-llama-3-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Llama3.1-8B-Instruct-R-01-09-24" 2 | dataset: "data/vikhr-llama-3.1-rm4-scored-answers-rs7-nw2-t0.8-pref.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.000005 12 | lr_scheduler_type: cosine 13 | loss_type: "hinge" 14 | gradient_accumulation_steps: 8 15 | gradient_checkpointing: True 16 | logging_steps: 1 17 | remove_unused_columns: True 18 | dataloader_num_workers: 2 19 | dataset_num_proc: 10 20 | max_length: 8192 21 | max_prompt_length: 4096 22 | save_only_model: True 23 | generate_eval_examples: True 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "slic-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.1" 28 | output_dir: "models/slic-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.1" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 0.1 32 | reference_free: True 33 | bf16: False 34 | fp16: True 35 | seed: 42 36 | logging_first_step: True 37 | use_peft: True 38 | lora_task_type: CAUSAL_LM 39 | lora_target_modules: 40 | - "k_proj" 41 | - "v_proj" 42 | - "q_proj" 43 | - "o_proj" 44 | - "gate_proj" 45 | - "up_proj" 46 | - "down_proj" 47 | - "lm_head" 48 | lora_r: 64 49 | lora_alpha: 64 50 | pad_token: "<|reserved_special_token_0|>" 51 | eos_token: "<|eot_id|>" 52 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 53 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-llama-3.1-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Llama3.1-8B-Instruct-R-01-09-24" 2 | dataset: "data/vikhr-llama-3.1-rm4-scored-answers-rs7-nw2-t0.8-pref-fixed.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 10 19 | max_length: 8192 20 | max_prompt_length: 4096 21 | save_only_model: True 22 | generate_eval_examples: True 23 | test_size: 0.05 24 | evaluation_strategy: "steps" 25 | eval_steps: 50 26 | run_name: "smpo-smooth_lower_bound-top-rs-llama-3.1-01-09-24-lora-96-qkvogud-b1.1-mm0.4-sft0.75-fixed" 27 | output_dir: "models/smpo-smooth_lower_bound-top-rs-llama-3.1-01-09-24-lora-96-qkvogud-b1.1-mm0.4-sft0.75-fixed" 28 | warmup_steps: 10 29 | report_to: "wandb" 30 | beta: 1.1 31 | margin_min: 0.4 32 | margin_delta: 0.2 33 | chosen_sft_ratio: 0.75 34 | loss_type: "smooth_lower_bound" 35 | bf16: False 36 | fp16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "k_proj" 43 | - "v_proj" 44 | - "q_proj" 45 | - "o_proj" 46 | - "gate_proj" 47 | - "up_proj" 48 | - "down_proj" 49 | lora_r: 96 50 | lora_alpha: 96 51 | pad_token: "<|reserved_special_token_0|>" 52 | eos_token: "<|eot_id|>" 53 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 54 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-mistral-nemo-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Nemo-12B-Instruct-R-05-09-24" 2 | dataset: "data/vikhr-nemo-rm4-scored-answers-rs7-nw2-t0.8-pref.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 10 19 | max_length: 8192 20 | max_prompt_length: 4096 21 | save_only_model: True 22 | generate_eval_examples: True 23 | test_size: 0.05 24 | evaluation_strategy: "steps" 25 | eval_steps: 50 26 | run_name: "smpo-smooth_lower_bound-top-rs-nemo-05-09-24-lora-96-qkvogud-b1.2-mm0.4-sft0.75" 27 | output_dir: "models/smpo-smooth_lower_bound-top-rs-nemo-05-09-24-lora-96-qkvogud-b1.2-mm0.4-sft0.75" 28 | warmup_steps: 20 29 | report_to: "wandb" 30 | beta: 1.2 31 | margin_min: 0.4 32 | margin_delta: 0.2 33 | chosen_sft_ratio: 0.75 34 | loss_type: "smooth_lower_bound" 35 | bf16: False 36 | fp16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "k_proj" 43 | - "v_proj" 44 | - "q_proj" 45 | - "o_proj" 46 | - "gate_proj" 47 | - "up_proj" 48 | - "down_proj" 49 | lora_r: 96 50 | lora_alpha: 96 51 | pad_token: "" 52 | eos_token: "" 53 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 54 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v1.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "Vikhrmodels/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 60 10 | save_total_limit: 10 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 60 27 | run_name: "smpo-hinge-phikr-14b-19-12-24-lora-196-b0.7-mm0.6-sft0.8" 28 | output_dir: "checkpoints/smpo-hinge-phikr-14b-19-12-24-lora-196-b0.7-mm0.6-sft0.8" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 0.7 32 | margin_min: 0.6 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 0.8 35 | loss_type: "hinge" 36 | bf16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "qkv_proj" 43 | - "o_proj" 44 | - "gate_up_proj" 45 | - "down_proj" 46 | - "lm_head" 47 | lora_r: 256 48 | lora_alpha: 256 49 | pad_token: "<|dummy_0|>" 50 | eos_token: "<|im_end|>" 51 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 52 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v12.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "hivaze/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 2 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 12 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 16 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.1-mm0.3-sft1.0-ltp0.02-mlp2.3-2" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.1-mm0.3-sft1.0-ltp0.02-mlp2.3-2" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.1 32 | margin_min: 0.3 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 1.0 35 | loss_type: "smooth_lower_bound" 36 | lower_clip_percentile: 0.02 37 | min_log_prob: -2.3 38 | bf16: True 39 | seed: 42 40 | logging_first_step: True 41 | use_peft: True 42 | lora_task_type: CAUSAL_LM 43 | lora_target_modules: 44 | - "qkv_proj" 45 | - "o_proj" 46 | - "gate_up_proj" 47 | - "down_proj" 48 | - "lm_head" 49 | lora_r: 256 50 | lora_alpha: 256 51 | pad_token: "<|dummy_0|>" 52 | eos_token: "<|im_end|>" 53 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 54 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v14.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "hivaze/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 1 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 12 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 16 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.1-mm0.5-sft1.0-ltp0.02-mlp2.3" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.1-mm0.5-sft1.0-ltp0.02-mlp2.3" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.1 32 | margin_min: 0.5 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 1.0 35 | loss_type: "smooth_lower_bound" 36 | lower_clip_percentile: 0.02 37 | min_log_prob: -2.3 38 | bf16: True 39 | seed: 42 40 | logging_first_step: True 41 | use_peft: True 42 | lora_task_type: CAUSAL_LM 43 | lora_target_modules: 44 | - "qkv_proj" 45 | - "o_proj" 46 | - "gate_up_proj" 47 | - "down_proj" 48 | lora_r: 256 49 | lora_alpha: 256 50 | pad_token: "<|dummy_0|>" 51 | eos_token: "<|im_end|>" 52 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 53 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v2.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "Vikhrmodels/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 60 10 | save_total_limit: 10 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 60 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.25-mm0.5-sft0.8" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.25-mm0.5-sft0.8" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.25 32 | margin_min: 0.5 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 0.8 35 | loss_type: "smooth_lower_bound" 36 | bf16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "qkv_proj" 43 | - "o_proj" 44 | - "gate_up_proj" 45 | - "down_proj" 46 | - "lm_head" 47 | lora_r: 256 48 | lora_alpha: 256 49 | pad_token: "<|dummy_0|>" 50 | eos_token: "<|im_end|>" 51 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 52 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v3.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "hivaze/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 60 10 | save_total_limit: 10 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 60 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.2-mm0.35-sft0.85" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.2-mm0.35-sft0.85" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.2 32 | margin_min: 0.35 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 0.85 35 | loss_type: "smooth_lower_bound" 36 | bf16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "qkv_proj" 43 | - "o_proj" 44 | - "gate_up_proj" 45 | - "down_proj" 46 | - "lm_head" 47 | lora_r: 256 48 | lora_alpha: 256 49 | pad_token: "<|dummy_0|>" 50 | eos_token: "<|im_end|>" 51 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 52 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v4.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "hivaze/phi4-sft-po-dataset-v2" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 10 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.3-mm0.4-sft0.8" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.3-mm0.4-sft0.8" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.3 32 | max_grad_norm: 0.5 33 | margin_min: 0.4 34 | margin_delta: 0.2 35 | chosen_sft_ratio: 0.8 36 | loss_type: "smooth_lower_bound" 37 | bf16: True 38 | seed: 42 39 | logging_first_step: True 40 | use_peft: True 41 | lora_task_type: CAUSAL_LM 42 | lora_target_modules: 43 | - "qkv_proj" 44 | - "o_proj" 45 | - "gate_up_proj" 46 | - "down_proj" 47 | - "lm_head" 48 | lora_r: 256 49 | lora_alpha: 256 50 | pad_token: "<|dummy_0|>" 51 | eos_token: "<|im_end|>" 52 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 53 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v7.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "hivaze/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 60 10 | save_total_limit: 10 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 60 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.0-mm0.4-sft0.75-1e" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.0-mm0.4-sft0.75-1e" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.0 32 | margin_min: 0.4 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 0.75 35 | loss_type: "smooth_lower_bound" 36 | bf16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "qkv_proj" 43 | - "o_proj" 44 | - "gate_up_proj" 45 | - "down_proj" 46 | lora_r: 256 47 | lora_alpha: 256 48 | pad_token: "<|dummy_0|>" 49 | eos_token: "<|im_end|>" 50 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 51 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-phi4-lora-best-rs-v9.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Phikhr-14B-Instruct-R-19-12-24-SFT" 2 | dataset: "hivaze/phi4-sft-po-dataset" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 12 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 20 19 | max_length: 10240 20 | max_prompt_length: 5120 21 | save_only_model: True 22 | generate_eval_examples: True 23 | generate_during_eval: False 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.2-mm0.45-sft0.85-ltp0.025" 28 | output_dir: "checkpoints/smpo-smooth_lower_bound-phikr-14b-19-12-24-lora-196-b1.2-mm0.45-sft0.85-ltp0.025" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 1.2 32 | margin_min: 0.45 33 | margin_delta: 0.2 34 | chosen_sft_ratio: 0.85 35 | loss_type: "smooth_lower_bound" 36 | lower_trim_percentile: 0.025 37 | bf16: True 38 | seed: 42 39 | logging_first_step: True 40 | use_peft: True 41 | lora_task_type: CAUSAL_LM 42 | lora_target_modules: 43 | - "qkv_proj" 44 | - "o_proj" 45 | - "gate_up_proj" 46 | - "down_proj" 47 | lora_r: 256 48 | lora_alpha: 256 49 | pad_token: "<|dummy_0|>" 50 | eos_token: "<|im_end|>" 51 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 52 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/preference/smpo-qvikhr2.5-1.5b-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Qwen-2.5-1.5B-Instruct" 2 | dataset: "filtered_dataset.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 1.2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.00002 12 | lr_scheduler_type: cosine 13 | gradient_accumulation_steps: 4 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: True 17 | dataloader_num_workers: 2 18 | dataset_num_proc: 10 19 | max_length: 8192 20 | max_prompt_length: 4096 21 | save_only_model: True 22 | generate_eval_examples: False 23 | test_size: 0.05 24 | evaluation_strategy: "steps" 25 | eval_steps: 50 26 | run_name: "smpo-qvikhr2.5-1.5b-lora-128-b1.6-mm0.5-sft0.90-new-epochs-1.2" 27 | output_dir: "smpo-qvikhr2.5-1.5b-lora-128-b1.6-mm0.5-sft0.90-new-epochs-1.2" 28 | warmup_steps: 20 29 | report_to: "wandb" 30 | beta: 1.6 31 | margin_min: 0.5 32 | margin_delta: 0.2 33 | chosen_sft_ratio: 0.90 34 | loss_type: "smooth_lower_bound" 35 | bf16: False 36 | fp16: True 37 | seed: 42 38 | logging_first_step: True 39 | use_peft: True 40 | lora_task_type: CAUSAL_LM 41 | lora_target_modules: 42 | - "k_proj" 43 | - "v_proj" 44 | - "q_proj" 45 | - "o_proj" 46 | - "gate_proj" 47 | - "up_proj" 48 | - "down_proj" 49 | lora_r: 128 50 | lora_alpha: 128 51 | pad_token: "<|endoftext|>" 52 | eos_token: "<|im_end|>" 53 | chat_template: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" 54 | force_chat_template: False 55 | -------------------------------------------------------------------------------- /training_configs/preference/sppo-llama-3-lora-best-rs.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/Vikhr-Llama3.1-8B-Instruct-R-01-09-24" 2 | dataset: "data/vikhr-llama-3.1-rm4-scored-answers-rs7-nw2-t0.8-pref.jsonl" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | log_level: "info" 7 | attn_implementation: "sdpa" 8 | save_strategy: "steps" 9 | save_steps: 50 10 | save_total_limit: 20 11 | learning_rate: 0.000005 12 | lr_scheduler_type: cosine 13 | loss_type: "sppo_hard" 14 | gradient_accumulation_steps: 8 15 | gradient_checkpointing: True 16 | logging_steps: 1 17 | remove_unused_columns: True 18 | dataloader_num_workers: 2 19 | dataset_num_proc: 10 20 | max_length: 8192 21 | max_prompt_length: 4096 22 | save_only_model: True 23 | generate_eval_examples: True 24 | test_size: 0.05 25 | evaluation_strategy: "steps" 26 | eval_steps: 50 27 | run_name: "sppo-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.1" 28 | output_dir: "models/sppo-top-rs-llama-3.1-01-09-24-lora-64-qkvogudlm-b0.1" 29 | warmup_steps: 10 30 | report_to: "wandb" 31 | beta: 0.1 32 | reference_free: True 33 | bf16: False 34 | fp16: True 35 | seed: 42 36 | logging_first_step: True 37 | use_peft: True 38 | lora_task_type: CAUSAL_LM 39 | lora_target_modules: 40 | - "k_proj" 41 | - "v_proj" 42 | - "q_proj" 43 | - "o_proj" 44 | - "gate_proj" 45 | - "up_proj" 46 | - "down_proj" 47 | - "lm_head" 48 | lora_r: 64 49 | lora_alpha: 64 50 | pad_token: "<|reserved_special_token_0|>" 51 | eos_token: "<|eot_id|>" 52 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 53 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/prompts-tuning/controllable-rm-qwen-1.5b-no-init.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/ControllableRM-Qwen2.5-1.5B" 2 | dataset: "Vikhrmodels/rm_training_dataset_16_12_24" 3 | dataset_ratio: 0.03 4 | test_size: 0.1 5 | per_device_train_batch_size: 8 6 | per_device_eval_batch_size: 16 7 | num_train_epochs: 5 8 | save_strategy: "epoch" 9 | evaluation_strategy: "epoch" 10 | num_prompts: 1 11 | dissim_coef: 0.0 12 | gumbel_temp: 1.0 13 | gumbel_noise_scale: 0.1 14 | prompt_len: 96 15 | fused_forward: True 16 | save_steps: 360 17 | eval_steps: 60 18 | save_only_model: True 19 | save_total_limit: 10 20 | learning_rate: 0.001 21 | weight_decay: 0.01 22 | lr_scheduler_type: "constant" 23 | warmup_ratio: 0.01 24 | center_rewards_coefficient: 0.01 25 | gradient_accumulation_steps: 8 26 | gradient_checkpointing: True 27 | logging_steps: 1 28 | remove_unused_columns: False 29 | dataloader_num_workers: 2 30 | max_length: 16384 31 | attn_implementation: "sdpa" 32 | torch_compile: False 33 | run_name: "controllable-rm-qwen-1.5b-prompts-optimization-np1-dc0.0-pl96-pi-gumbel-dynamic-no-init" 34 | output_dir: "checkpoints/controllable-rm-qwen-1.5b-prompts-optimization-np1-dc0.0-pl96-pi-gumbel-dynamic-no-init" 35 | report_to: "wandb" 36 | bf16: True 37 | fp16: False 38 | seed: 42 39 | logging_first_step: True 40 | pad_token: "<|endoftext|>" 41 | eos_token: "<|im_end|>" 42 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 43 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/prompts-tuning/controllable-rm-qwen-1.5b.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Vikhrmodels/ControllableRM-Qwen2.5-1.5B" 2 | dataset: "Vikhrmodels/rm_training_dataset_16_12_24" 3 | dataset_ratio: 0.03 4 | test_size: 0.1 5 | per_device_train_batch_size: 8 6 | per_device_eval_batch_size: 16 7 | num_train_epochs: 5 8 | save_strategy: "epoch" 9 | evaluation_strategy: "epoch" 10 | num_prompts: 1 11 | dissim_coef: 0.0 12 | gumbel_temp: 1.0 13 | gumbel_noise_scale: 0.05 14 | prompt_len: 96 15 | fused_forward: True 16 | init_prompt: "The answer must be good, safe, high-quality, and relevant to the user's question. The answer should provide well-researched, factual information while avoiding harmful, biased, or misleading content. Additionally, it should prioritize user safety by avoiding inappropriate or dangerous advice and maintaining ethical standards." 17 | save_steps: 360 18 | eval_steps: 60 19 | save_only_model: True 20 | save_total_limit: 10 21 | learning_rate: 0.001 22 | weight_decay: 0.01 23 | lr_scheduler_type: "constant" 24 | warmup_ratio: 0.01 25 | center_rewards_coefficient: 0.01 26 | gradient_accumulation_steps: 8 27 | gradient_checkpointing: True 28 | logging_steps: 1 29 | remove_unused_columns: False 30 | dataloader_num_workers: 2 31 | max_length: 16384 32 | attn_implementation: "sdpa" 33 | torch_compile: False 34 | run_name: "controllable-rm-qwen-1.5b-prompts-optimization-np1-dc0.0-pl96-pi-gumbel-dynamic" 35 | output_dir: "checkpoints/controllable-rm-qwen-1.5b-prompts-optimization-np1-dc0.0-pl96-pi-gumbel-dynamic" 36 | report_to: "wandb" 37 | bf16: True 38 | fp16: False 39 | seed: 42 40 | logging_first_step: True 41 | pad_token: "<|endoftext|>" 42 | eos_token: "<|im_end|>" 43 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 44 | force_chat_template: False -------------------------------------------------------------------------------- /training_configs/reward/controllable-rm-qwen-1.5b-no-tags-full.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" 2 | dataset: "Vikhrmodels/controllable_rm_dataset_no_tags" 3 | per_device_train_batch_size: 2 4 | per_device_eval_batch_size: 6 5 | num_train_epochs: 1 6 | save_strategy: "steps" 7 | evaluation_strategy: "steps" 8 | save_steps: 125 9 | eval_steps: 125 10 | save_only_model: True 11 | save_total_limit: 10 12 | learning_rate: 0.00003 13 | weight_decay: 0.05 14 | lr_scheduler_type: "cosine" 15 | warmup_ratio: 0.05 16 | center_rewards_coefficient: 0.01 17 | gradient_accumulation_steps: 16 18 | gradient_checkpointing: True 19 | logging_steps: 1 20 | remove_unused_columns: True 21 | dataloader_num_workers: 2 22 | max_length: 16384 23 | attn_implementation: "sdpa" 24 | torch_compile: False 25 | run_name: "controllable-rm-qwen-1.5b-no-tags-full-no-emb" 26 | output_dir: "checkpoints/controllable-rm-qwen-1.5b-no-tags-full-no-emb" 27 | report_to: "wandb" 28 | bf16: True 29 | fp16: False 30 | seed: 42 31 | logging_first_step: True 32 | unfreeze_layers_patterns: 33 | - "score" 34 | - "*.layers.*" 35 | use_peft: False 36 | lora_task_type: SEQ_CLS 37 | lora_target_modules: 38 | - "k_proj" 39 | - "v_proj" 40 | - "q_proj" 41 | - "o_proj" 42 | - "gate_proj" 43 | - "up_proj" 44 | - "down_proj" 45 | lora_modules_to_save: 46 | - "score" 47 | lora_r: 128 48 | lora_alpha: 128 49 | use_rslora: True 50 | pad_token: "<|image_pad|>" 51 | eos_token: "<|im_end|>" 52 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 53 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/controllable-rm-qwen-1.5b-no-tags-lora.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" 2 | dataset: "Vikhrmodels/controllable_rm_dataset_no_tags" 3 | per_device_train_batch_size: 2 4 | per_device_eval_batch_size: 6 5 | num_train_epochs: 1 6 | save_strategy: "steps" 7 | evaluation_strategy: "steps" 8 | save_steps: 125 9 | eval_steps: 125 10 | save_total_limit: 10 11 | learning_rate: 0.0002 12 | weight_decay: 0.00 13 | lr_scheduler_type: "cosine" 14 | warmup_ratio: 0.05 15 | center_rewards_coefficient: 0.01 16 | gradient_accumulation_steps: 16 17 | gradient_checkpointing: True 18 | logging_steps: 1 19 | remove_unused_columns: True 20 | dataloader_num_workers: 2 21 | max_length: 16384 22 | attn_implementation: "sdpa" 23 | torch_compile: False 24 | run_name: "controllable-rm-qwen-1.5b-no-tags-lora-r32-32-all" 25 | output_dir: "checkpoints/controllable-rm-qwen-1.5b-no-tags-lora-r32-32-all" 26 | report_to: "clearml" 27 | bf16: True 28 | fp16: False 29 | seed: 42 30 | logging_first_step: True 31 | use_peft: True 32 | lora_task_type: SEQ_CLS 33 | lora_target_modules: 34 | - "k_proj" 35 | - "v_proj" 36 | - "q_proj" 37 | - "o_proj" 38 | - "gate_proj" 39 | - "up_proj" 40 | - "down_proj" 41 | lora_modules_to_save: 42 | - "score" 43 | lora_r: 32 44 | lora_alpha: 32 45 | use_rslora: False 46 | pad_token: "<|image_pad|>" 47 | eos_token: "<|im_end|>" 48 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 49 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/controllable-rm-qwen-3b-no-tags-full.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-3B-Instruct" 2 | dataset: "Vikhrmodels/controllable_rm_dataset_no_tags" 3 | per_device_train_batch_size: 2 4 | per_device_eval_batch_size: 6 5 | num_train_epochs: 1 6 | save_strategy: "steps" 7 | evaluation_strategy: "steps" 8 | save_steps: 125 9 | eval_steps: 125 10 | save_total_limit: 10 11 | save_only_model: True 12 | optim: "adamw_8bit" 13 | learning_rate: 0.00003 14 | weight_decay: 0.05 15 | lr_scheduler_type: "cosine" 16 | warmup_ratio: 0.05 17 | center_rewards_coefficient: 0.01 18 | gradient_accumulation_steps: 16 19 | gradient_checkpointing: True 20 | logging_steps: 1 21 | remove_unused_columns: True 22 | dataloader_num_workers: 2 23 | max_length: 16384 24 | attn_implementation: "sdpa" 25 | torch_compile: False 26 | run_name: "controllable-rm-qwen-3b-no-tags-full-no-emb-adamw_8bit" 27 | output_dir: "checkpoints/controllable-rm-qwen-3b-no-tags-full-no-emb-adamw_8bit" 28 | report_to: "wandb" 29 | bf16: True 30 | fp16: False 31 | seed: 42 32 | logging_first_step: True 33 | unfreeze_layers_patterns: 34 | - "score" 35 | - "*.layers.*" 36 | use_peft: False 37 | lora_task_type: SEQ_CLS 38 | lora_target_modules: 39 | - "k_proj" 40 | - "v_proj" 41 | - "q_proj" 42 | - "o_proj" 43 | - "gate_proj" 44 | - "up_proj" 45 | - "down_proj" 46 | lora_modules_to_save: 47 | - "score" 48 | lora_r: 128 49 | lora_alpha: 128 50 | use_rslora: True 51 | pad_token: "<|image_pad|>" 52 | eos_token: "<|im_end|>" 53 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 54 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/controllable-rm-qwen-7b-no-tags-full.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-7B-Instruct" 2 | dataset: "Vikhrmodels/controllable_rm_dataset_no_tags" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 2 5 | num_train_epochs: 1 6 | evaluation_strategy: "steps" 7 | eval_steps: 125 8 | save_strategy: "steps" 9 | save_steps: 125 10 | save_total_limit: 10 11 | save_only_model: True 12 | optim: "adamw_8bit" 13 | learning_rate: 0.00003 14 | weight_decay: 0.05 15 | lr_scheduler_type: "cosine" 16 | warmup_ratio: 0.05 17 | center_rewards_coefficient: 0.01 18 | gradient_accumulation_steps: 32 19 | gradient_checkpointing: True 20 | logging_steps: 1 21 | remove_unused_columns: True 22 | dataloader_num_workers: 2 23 | max_length: 16384 24 | attn_implementation: "sdpa" 25 | torch_compile: False 26 | run_name: "controllable-rm-qwen-7b-no-tags-full-no-emb-adamw_8bit" 27 | output_dir: "checkpoints/controllable-rm-qwen-7b-no-tags-full-no-emb-adamw_8bit" 28 | report_to: "wandb" 29 | bf16: True 30 | fp16: False 31 | seed: 42 32 | logging_first_step: True 33 | unfreeze_layers_patterns: 34 | - "score" 35 | - "*.layers.*" 36 | use_peft: False 37 | lora_task_type: SEQ_CLS 38 | lora_target_modules: 39 | - "k_proj" 40 | - "v_proj" 41 | - "q_proj" 42 | - "o_proj" 43 | - "gate_proj" 44 | - "up_proj" 45 | - "down_proj" 46 | lora_modules_to_save: 47 | - "score" 48 | lora_r: 128 49 | lora_alpha: 128 50 | use_rslora: True 51 | pad_token: "<|image_pad|>" 52 | eos_token: "<|im_end|>" 53 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 54 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/controllable-rm-qwen-7b-no-tags-lora.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-7B-Instruct" 2 | dataset: "Vikhrmodels/controllable_rm_dataset_no_tags" 3 | per_device_train_batch_size: 2 4 | per_device_eval_batch_size: 3 5 | num_train_epochs: 1 6 | save_strategy: "steps" 7 | evaluation_strategy: "steps" 8 | save_steps: 125 9 | eval_steps: 125 10 | save_total_limit: 10 11 | learning_rate: 0.0001 12 | lr_scheduler_type: "cosine" 13 | warmup_ratio: 0.05 14 | center_rewards_coefficient: 0.01 15 | gradient_accumulation_steps: 16 16 | gradient_checkpointing: True 17 | logging_steps: 1 18 | remove_unused_columns: True 19 | dataloader_num_workers: 2 20 | max_length: 16384 21 | attn_implementation: "sdpa" 22 | torch_compile: False 23 | run_name: "controllable-rm-qwen-7b-no-tags-rslora-r256-256-all-2" 24 | output_dir: "checkpoints/controllable-rm-qwen-7b-no-tags-rslora-r256-256-all-2" 25 | report_to: "clearml" 26 | bf16: True 27 | fp16: False 28 | seed: 42 29 | logging_first_step: True 30 | use_peft: True 31 | lora_task_type: SEQ_CLS 32 | lora_target_modules: 33 | - "k_proj" 34 | - "v_proj" 35 | - "q_proj" 36 | - "o_proj" 37 | - "gate_proj" 38 | - "up_proj" 39 | - "down_proj" 40 | lora_modules_to_save: 41 | - "score" 42 | lora_r: 256 43 | lora_alpha: 256 44 | use_rslora: True 45 | pad_token: "<|image_pad|>" 46 | eos_token: "<|im_end|>" 47 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 48 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/rm-llama-3-fsfairx-lora-arena.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "sfairXC/FsfairX-LLaMA3-RM-v0.1" 2 | dataset: "Vikhrmodels/ru-arena-general-rankings" 3 | per_device_train_batch_size: 4 4 | per_device_eval_batch_size: 4 5 | num_train_epochs: 3 6 | save_strategy: "steps" 7 | save_steps: 100 8 | save_total_limit: 6 9 | learning_rate: 0.0004 10 | gradient_accumulation_steps: 4 11 | gradient_checkpointing: True 12 | logging_steps: 1 13 | remove_unused_columns: True 14 | dataloader_num_workers: 2 15 | max_length: 8192 16 | center_rewards_coefficient: 0.01 17 | test_size: 0.05 18 | evaluation_strategy: "steps" 19 | eval_steps: 50 20 | run_name: "rm-arena-llama-3-fsfairx-lora-32-qkvougd-rc-0.01" 21 | output_dir: "/mnt/models/rm-arena-llama-3-fsfairx-lora-32-qkvougd-0.01" 22 | warmup_steps: 20 23 | report_to: "wandb" 24 | bf16: True 25 | seed: 42 26 | logging_first_step: True 27 | use_peft: True 28 | lora_task_type: SEQ_CLS 29 | lora_target_modules: 30 | - "k_proj" 31 | - "v_proj" 32 | - "q_proj" 33 | - "o_proj" 34 | - "up_proj" 35 | - "gate_proj" 36 | - "down_proj" 37 | lora_modules_to_save: 38 | - "score" 39 | lora_r: 32 40 | lora_alpha: 32 41 | pad_token: "<|reserved_special_token_0|>" 42 | eos_token: "<|eot_id|>" 43 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 44 | force_chat_template: True 45 | -------------------------------------------------------------------------------- /training_configs/reward/rm-llama-3.1-8b-it-lora-arena.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "unsloth/Meta-Llama-3.1-8B-Instruct" 2 | dataset: "Vikhrmodels/ru-arena-general-rankings" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 2 6 | save_strategy: "steps" 7 | save_steps: 100 8 | save_total_limit: 6 9 | learning_rate: 0.0004 10 | gradient_accumulation_steps: 8 11 | gradient_checkpointing: True 12 | logging_steps: 1 13 | remove_unused_columns: True 14 | dataloader_num_workers: 2 15 | max_length: 8192 16 | test_size: 0.05 17 | evaluation_strategy: "steps" 18 | eval_steps: 100 19 | run_name: "rm-arena-llama-3.1-it-unsloth-lora-32-qkvogud" 20 | output_dir: "/home/models/rm-arena-llama-3.1-it-unsloth-lora-32-qkvogud" 21 | warmup_steps: 20 22 | report_to: "wandb" 23 | bf16: True 24 | seed: 42 25 | logging_first_step: True 26 | use_peft: True 27 | lora_task_type: SEQ_CLS 28 | lora_target_modules: 29 | - "k_proj" 30 | - "v_proj" 31 | - "q_proj" 32 | - "o_proj" 33 | - "gate_proj" 34 | - "up_proj" 35 | - "down_proj" 36 | lora_modules_to_save: 37 | - "score" 38 | lora_r: 32 39 | lora_alpha: 32 40 | pad_token: "<|reserved_special_token_0|>" 41 | eos_token: "<|eot_id|>" 42 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 43 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/rm-llama-3.1-8b-lora-arena.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "unsloth/Meta-Llama-3.1-8B" 2 | dataset: "Vikhrmodels/ru-arena-general-rankings" 3 | per_device_train_batch_size: 1 4 | per_device_eval_batch_size: 1 5 | num_train_epochs: 1 6 | save_strategy: "steps" 7 | save_steps: 100 8 | save_total_limit: 6 9 | learning_rate: 0.0004 10 | gradient_accumulation_steps: 8 11 | gradient_checkpointing: True 12 | logging_steps: 1 13 | remove_unused_columns: True 14 | dataloader_num_workers: 2 15 | max_length: 8192 16 | test_size: 0.05 17 | evaluation_strategy: "steps" 18 | eval_steps: 100 19 | run_name: "rm-arena-llama-3.1-unsloth-lora-32-qkvogud" 20 | output_dir: "/home/models/rm-arena-llama-3.1-unsloth-lora-32-qkvogud" 21 | warmup_steps: 20 22 | report_to: "wandb" 23 | bf16: True 24 | seed: 42 25 | logging_first_step: True 26 | use_peft: True 27 | lora_task_type: SEQ_CLS 28 | lora_target_modules: 29 | - "k_proj" 30 | - "v_proj" 31 | - "q_proj" 32 | - "o_proj" 33 | - "gate_proj" 34 | - "up_proj" 35 | - "down_proj" 36 | lora_modules_to_save: 37 | - "score" 38 | lora_r: 32 39 | lora_alpha: 32 40 | pad_token: "<|reserved_special_token_0|>" 41 | eos_token: "<|eot_id|>" 42 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 43 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/reward/rm-qwen-14b-v2.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "Qwen/Qwen2.5-14B-Instruct" 2 | dataset: "Vikhrmodels/rm_training_dataset_16_12_24" 3 | per_device_train_batch_size: 2 4 | per_device_eval_batch_size: 2 5 | num_train_epochs: 1 6 | save_strategy: "steps" 7 | save_steps: 250 8 | save_total_limit: 6 9 | learning_rate: 0.0002 10 | center_rewards_coefficient: 0.01 11 | gradient_accumulation_steps: 8 12 | gradient_checkpointing: True 13 | logging_steps: 1 14 | remove_unused_columns: True 15 | dataloader_num_workers: 2 16 | max_length: 16384 17 | attn_implementation: "sdpa" 18 | test_size: 0.05 19 | evaluation_strategy: "steps" 20 | eval_steps: 250 21 | run_name: "rm-qwen2.5-14b-lora-128-all" 22 | output_dir: "checkpoints/rm-qwen2.5-14b-lora-128-all" 23 | warmup_steps: 20 24 | report_to: "wandb" 25 | bf16: True 26 | seed: 42 27 | logging_first_step: True 28 | use_peft: True 29 | lora_task_type: SEQ_CLS 30 | lora_target_modules: 31 | - "k_proj" 32 | - "v_proj" 33 | - "q_proj" 34 | - "o_proj" 35 | - "gate_proj" 36 | - "up_proj" 37 | - "down_proj" 38 | lora_modules_to_save: 39 | - "score" 40 | lora_r: 128 41 | lora_alpha: 128 42 | pad_token: "<|image_pad|>" 43 | eos_token: "<|im_end|>" 44 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_end|>\n' }}{% endif %}" 45 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/sft/sft-gemma-2-2b-it-lora-ficbook.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "unsloth/gemma-2-2b-it" 2 | dataset: 3 | - "qklent/ficbook_no_url_default_sys_prompt" 4 | - "qklent/roleplay_ru_gusev_preprocessed" 5 | - "Vikhrmodels/GrandMaster-PRO-MAX" 6 | dataset_ratio: 7 | - 1 8 | - 1 9 | - 0.1 10 | train_only_on_completions: False 11 | per_device_train_batch_size: 2 12 | per_device_eval_batch_size: 2 13 | num_train_epochs: 1 14 | save_strategy: "steps" 15 | save_steps: 2000 16 | save_total_limit: 3 17 | learning_rate: 0.000002 18 | lr_scheduler_type: "cosine" 19 | gradient_accumulation_steps: 4 20 | gradient_checkpointing: True 21 | logging_steps: 10 22 | remove_unused_columns: True 23 | dataloader_num_workers: 4 24 | test_size: 0.005 25 | generate_eval_examples: True 26 | num_gen_examples: 20 27 | evaluation_strategy: "steps" 28 | eval_steps: 2000 29 | run_name: "sft-ficbook-gemma-2-2b-it-unsloth-lora-16-64-all-proj" 30 | output_dir: "/mnt/storage/ficbook_models/gemma-2-2b-zero-16-64-all-proj" 31 | warmup_ratio: 0.03 32 | report_to: "wandb" 33 | conversation_field: "conversation" 34 | bf16: false 35 | fp16: true 36 | seed: 42 37 | max_seq_length: 2048 38 | logging_first_step: False 39 | use_peft: True 40 | lora_target_modules: 41 | - "k_proj" 42 | - "v_proj" 43 | - "q_proj" 44 | - "o_proj" 45 | - "gate_proj" 46 | - "up_proj" 47 | - "down_proj" 48 | lora_r: 16 49 | lora_alpha: 64 50 | assistant_message_template: "<|im_start|>assistant\n" 51 | pad_token: "<|reserved_special_token_0|>" 52 | eos_token: "<|im_end|>" 53 | chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}{% else %}{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" 54 | force_chat_template: True 55 | model_support_system_role: False 56 | attn_implementation: "eager" -------------------------------------------------------------------------------- /training_configs/sft/sft-llama-3.1-8b-it-full-Grandmaster.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "unsloth/Meta-Llama-3.1-8B-Instruct" 2 | dataset: "Vikhrmodels/GrandMaster-PRO-MAX" 3 | train_only_on_completions: True 4 | per_device_train_batch_size: 1 5 | per_device_eval_batch_size: 1 6 | num_train_epochs: 1 7 | save_strategy: "steps" 8 | save_steps: 300 9 | save_total_limit: 3 10 | learning_rate: 0.00004 11 | gradient_accumulation_steps: 8 12 | gradient_checkpointing: True 13 | logging_steps: 1 14 | remove_unused_columns: True 15 | dataloader_num_workers: 2 16 | generate_eval_examples: False 17 | evaluation_strategy: "steps" 18 | eval_steps: 300 19 | run_name: "sft-grndm-llama-3.1-unsloth-full" 20 | output_dir: "/home/models/sft-grndm-llama-3.1-unsloth-full" 21 | warmup_steps: 20 22 | report_to: "wandb" 23 | conversation_field: "conversation" 24 | bf16: True 25 | seed: 42 26 | logging_first_step: True 27 | use_peft: False 28 | assistant_message_template: "<|start_header_id|>assistant<|end_header_id|>" 29 | pad_token: "<|reserved_special_token_0|>" 30 | eos_token: "<|eot_id|>" 31 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 32 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/sft/sft-llama-3.1-8b-it-lora-GRAG.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "unsloth/Meta-Llama-3.1-8B-Instruct" 2 | dataset: "Vikhrmodels/Grounded-RAG-RU-v2" 3 | train_only_on_completions: True 4 | per_device_train_batch_size: 1 5 | per_device_eval_batch_size: 1 6 | num_train_epochs: 1 7 | save_strategy: "steps" 8 | save_steps: 300 9 | save_total_limit: 6 10 | learning_rate: 0.00004 11 | gradient_accumulation_steps: 8 12 | gradient_checkpointing: True 13 | logging_steps: 1 14 | remove_unused_columns: True 15 | dataloader_num_workers: 2 16 | generate_eval_examples: True 17 | evaluation_strategy: "steps" 18 | eval_steps: 300 19 | run_name: "sft-grag-llama-3.1-unsloth-lora-32-qkvo" 20 | output_dir: "/home/models/sft-grag-llama-3.1-unsloth-lora-32-qkvo" 21 | warmup_steps: 20 22 | report_to: "wandb" 23 | conversation_field: "conversation" 24 | bf16: True 25 | seed: 42 26 | logging_first_step: True 27 | use_peft: True 28 | lora_target_modules: 29 | - "k_proj" 30 | - "v_proj" 31 | - "q_proj" 32 | - "o_proj" 33 | lora_r: 32 34 | lora_alpha: 32 35 | assistant_message_template: "<|start_header_id|>assistant<|end_header_id|>" 36 | pad_token: "<|reserved_special_token_0|>" 37 | eos_token: "<|eot_id|>" 38 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 39 | force_chat_template: True -------------------------------------------------------------------------------- /training_configs/sft/sft-llama-3.1-8b-it-lora-GrandmasterRAG-v1.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "unsloth/Meta-Llama-3.1-8B-Instruct" 2 | dataset: 3 | - "Vikhrmodels/GrandMaster-PRO-MAX" 4 | - "Vikhrmodels/Grounded-RAG-RU-v2" 5 | train_only_on_completions: True 6 | per_device_train_batch_size: 1 7 | per_device_eval_batch_size: 1 8 | num_train_epochs: 1 9 | save_strategy: "steps" 10 | save_steps: 400 11 | save_total_limit: 6 12 | learning_rate: 0.00004 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: False 17 | dataloader_num_workers: 2 18 | save_only_model: True 19 | generate_eval_examples: True 20 | use_liger: True 21 | max_seq_length: 16000 22 | evaluation_strategy: "steps" 23 | eval_steps: 400 24 | run_name: "sft-grndmrag-llama-3.1-unsloth-lora-256-qkvogudlm-v1" 25 | output_dir: "/home/models/sft-grndmrag-llama-3.1-unsloth-lora-256-qkvogudlm-v1" 26 | warmup_steps: 20 27 | report_to: "wandb" 28 | conversation_field: "conversation" 29 | bf16: True 30 | seed: 42 31 | logging_first_step: True 32 | use_peft: True 33 | lora_target_modules: 34 | - "k_proj" 35 | - "v_proj" 36 | - "q_proj" 37 | - "o_proj" 38 | - "gate_proj" 39 | - "up_proj" 40 | - "down_proj" 41 | - "lm_head" 42 | lora_r: 256 43 | lora_alpha: 256 44 | assistant_message_template: "<|start_header_id|>assistant<|end_header_id|>\n\n" 45 | pad_token: "<|reserved_special_token_0|>" 46 | eos_token: "<|eot_id|>" 47 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 48 | force_chat_template: True 49 | -------------------------------------------------------------------------------- /training_configs/sft/sft-mistral-nemo-12b-lora-GrandmasterRAG-v1.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "mistralai/Mistral-Nemo-Instruct-2407" 2 | dataset: 3 | - "Vikhrmodels/GrandMaster-PRO-MAX" 4 | - "Vikhrmodels/Grounded-RAG-RU-v2" 5 | train_only_on_completions: True 6 | per_device_train_batch_size: 1 7 | per_device_eval_batch_size: 1 8 | num_train_epochs: 1 9 | save_strategy: "steps" 10 | save_steps: 400 11 | save_total_limit: 6 12 | learning_rate: 0.00004 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: False 17 | dataloader_num_workers: 2 18 | save_only_model: True 19 | generate_eval_examples: True 20 | use_liger: True 21 | max_seq_length: 16000 22 | evaluation_strategy: "steps" 23 | eval_steps: 400 24 | run_name: "sft-grndmrag-mistral-nemo-lora-256-qkvogudlm-v1" 25 | output_dir: "/mnt/models/sft-grndmrag-mistral-nemo-lora-256-qkvogudlm-v1" 26 | warmup_steps: 20 27 | report_to: "wandb" 28 | conversation_field: "conversation" 29 | bf16: True 30 | seed: 42 31 | logging_first_step: True 32 | use_peft: True 33 | lora_target_modules: 34 | - "k_proj" 35 | - "v_proj" 36 | - "q_proj" 37 | - "o_proj" 38 | - "gate_proj" 39 | - "up_proj" 40 | - "down_proj" 41 | - "lm_head" 42 | lora_r: 256 43 | lora_alpha: 256 44 | assistant_message_template: "<|start_header_id|>assistant<|end_header_id|>\n\n" 45 | pad_token: "" 46 | eos_token: "" 47 | chat_template: "{{ bos_token }}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 48 | force_chat_template: True 49 | added_special_tokens: 50 | - "<|start_header_id|>" 51 | - "<|end_header_id|>" 52 | -------------------------------------------------------------------------------- /training_configs/sft/sft-phi4-full-GrandmasterRAG-v2.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "checkpoints/phi-4-original/phi-4/" 2 | dataset: 3 | - "Vikhrmodels/GrandMaster-PRO-MAX" 4 | - "Vikhrmodels/Grounded-RAG-RU-v2" 5 | train_only_on_completions: True 6 | per_device_train_batch_size: 1 7 | per_device_eval_batch_size: 1 8 | num_train_epochs: 1 9 | save_strategy: "steps" 10 | save_steps: 400 11 | save_total_limit: 6 12 | learning_rate: 0.00004 13 | gradient_accumulation_steps: 8 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: False 17 | dataloader_num_workers: 2 18 | save_only_model: True 19 | generate_eval_examples: True 20 | use_liger: True 21 | max_seq_length: 16384 22 | attn_implementation: "sdpa" 23 | evaluation_strategy: "steps" 24 | eval_steps: 400 25 | run_name: "sft-grndmrag-phi4-full-v2" 26 | output_dir: "checkpoints/sft-grndmrag-phi4-full-v2" 27 | warmup_steps: 20 28 | report_to: "wandb" 29 | conversation_field: "conversation" 30 | bf16: True 31 | seed: 42 32 | logging_first_step: True 33 | use_peft: False 34 | assistant_message_template: "<|im_start|>assistant<|im_sep|>" 35 | pad_token: "<|dummy_0|>" 36 | eos_token: "<|im_end|>" 37 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 38 | force_chat_template: True 39 | -------------------------------------------------------------------------------- /training_configs/sft/sft-phi4-lora-GrandmasterRAG-v4.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "checkpoints/phi4-original/phi-4/" 2 | dataset: 3 | - "Vikhrmodels/GrandMaster-PRO-MAX" 4 | - "Vikhrmodels/Grounded-RAG-RU-v2" 5 | train_only_on_completions: True 6 | per_device_train_batch_size: 1 7 | per_device_eval_batch_size: 1 8 | num_train_epochs: 1 9 | save_strategy: "steps" 10 | save_steps: 400 11 | save_total_limit: 6 12 | learning_rate: 0.00004 13 | gradient_accumulation_steps: 16 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: False 17 | dataloader_num_workers: 2 18 | save_only_model: True 19 | generate_eval_examples: True 20 | use_liger: True 21 | max_seq_length: 16384 22 | attn_implementation: "sdpa" 23 | evaluation_strategy: "steps" 24 | eval_steps: 400 25 | run_name: "sft-grndmrag-phi4-lora-384-qkvogudlm-v4" 26 | output_dir: "checkpoints/sft-grndmrag-phi4-lora-384-qkvogudlm-v4" 27 | warmup_steps: 20 28 | report_to: "wandb" 29 | conversation_field: "conversation" 30 | bf16: True 31 | seed: 42 32 | logging_first_step: True 33 | use_peft: True 34 | lora_target_modules: 35 | - "qkv_proj" 36 | - "o_proj" 37 | - "gate_up_proj" 38 | - "down_proj" 39 | - "lm_head" 40 | lora_r: 384 41 | lora_alpha: 384 42 | assistant_message_template: "<|im_start|>assistant<|im_sep|>" 43 | pad_token: "<|dummy_0|>" 44 | eos_token: "<|im_end|>" 45 | chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|im_start|>' + message['role'] + '<|im_sep|>'+ message['content'] | trim + '<|im_end|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}" 46 | force_chat_template: True 47 | -------------------------------------------------------------------------------- /training_configs/sft/sft-yandex-lora-GrandmasterRAG.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "yandex/YandexGPT-5-Lite-8B-pretrain" 2 | dataset: 3 | - "Vikhrmodels/GrandMaster-PRO-MAX" 4 | - "Vikhrmodels/Grounded-RAG-RU-v2" 5 | train_only_on_completions: True 6 | per_device_train_batch_size: 1 7 | per_device_eval_batch_size: 1 8 | num_train_epochs: 1.5 9 | save_strategy: "steps" 10 | save_steps: 400 11 | save_total_limit: 6 12 | learning_rate: 0.000005 13 | gradient_accumulation_steps: 16 14 | gradient_checkpointing: True 15 | logging_steps: 1 16 | remove_unused_columns: False 17 | dataloader_num_workers: 40 18 | save_only_model: True 19 | generate_eval_examples: True 20 | use_liger: True 21 | max_seq_length: 16384 22 | attn_implementation: "sdpa" 23 | evaluation_strategy: "steps" 24 | eval_steps: 500 25 | run_name: "sft-grndmrag-YandexGPT-lora-256-qkvogudlm-v1" 26 | output_dir: "checkpoints/sft-grndmrag-YandexGPT-lora-256-qkvogudlm-v1" 27 | warmup_steps: 20 28 | report_to: "wandb" 29 | conversation_field: "conversation" 30 | bf16: True 31 | seed: 42 32 | logging_first_step: True 33 | use_peft: True 34 | lora_target_modules: 35 | - "qkv_proj" 36 | - "o_proj" 37 | - "gate_up_proj" 38 | - "down_proj" 39 | - "lm_head" 40 | lora_r: 256 41 | lora_alpha: 256 42 | pad_token: "[SPEC_TOKEN_1001]" 43 | assistant_message_template: "assistant\n" 44 | eos_token: "" 45 | chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'' + message['role'] + '\n' + message['content'] + '' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'assistant\n' }}{% endif %}" 46 | force_chat_template: True 47 | --------------------------------------------------------------------------------