├── .gitignore ├── LICENSE ├── README.md ├── conda-recipe.yaml ├── framework-colorblindfriendly.jpg ├── mcts_rl ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── dpo │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── main.py │ │ └── trainer.py │ ├── mcts │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── main.py │ │ ├── mcts │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── embedding_retrieve.py │ │ │ ├── mcts.py │ │ │ ├── search_config.py │ │ │ └── world_model.py │ │ └── trainer.py │ └── ppo │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── main.py │ │ └── trainer.py ├── configs │ ├── __init__.py │ ├── constants.py │ ├── deepspeed_config.py │ ├── ds_eval_config_template.json │ ├── ds_train_config_template.json │ └── fsdp_config.json ├── datasets │ ├── __init__.py │ ├── base.py │ ├── preference.py │ ├── prompt_only.py │ ├── raw │ │ ├── __init__.py │ │ ├── alpaca.py │ │ ├── aqua.py │ │ ├── arithmo.py │ │ ├── exam.py │ │ ├── firefly.py │ │ ├── gsm8k.py │ │ ├── hh_rlhf.py │ │ ├── math.py │ │ ├── math_qa.py │ │ ├── mcq.py │ │ ├── mcq_for_eval.py │ │ ├── mcq_pairs.py │ │ ├── moss.py │ │ ├── prm800k.py │ │ ├── qa_feedback.py │ │ └── safe_rlhf.py │ ├── safety_preference.py │ ├── supervised.py │ └── utils.py ├── finetune │ ├── __init__.py │ ├── __main__.py │ ├── deepspeed.py │ ├── huggingface.py │ ├── main.py │ └── trainer.py ├── logger.py ├── models │ ├── __init__.py │ ├── normalizer.py │ ├── pretrained.py │ └── score_model │ │ ├── __init__.py │ │ ├── bloom │ │ ├── __init__.py │ │ └── modeling_bloom.py │ │ ├── gpt2 │ │ ├── __init__.py │ │ └── modeling_gpt2.py │ │ ├── gpt_neo │ │ ├── __init__.py │ │ └── modeling_gpt_neo.py │ │ ├── gpt_neox │ │ ├── __init__.py │ │ └── modeling_gpt_neox.py │ │ ├── gptj │ │ ├── __init__.py │ │ └── modeling_gptj.py │ │ ├── llama │ │ ├── __init__.py │ │ └── modeling_llama.py │ │ ├── mistral │ │ ├── __init__.py │ │ └── modeling_mistral.py │ │ ├── open_llama │ │ ├── __init__.py │ │ └── modeling_open_llama.py │ │ └── opt │ │ ├── __init__.py │ │ └── modeling_opt.py ├── trainers │ ├── __init__.py │ ├── base.py │ ├── rl_trainer.py │ ├── supervised_trainer.py │ └── tsrl_trainer.py ├── utils.py └── version.py ├── requirements.txt ├── scripts ├── eval │ ├── mctseval_math.sh │ └── mctseval_sqa.sh ├── mcts_csr.sh ├── mcts_mathqa.sh ├── mcts_mathqa2.sh └── mcts_mathqa_llama3.sh └── visualize.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning 2 | 3 | This repository contains code and analysis for the paper: [Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning](https://arxiv.org/abs/2405.00451). 4 | Below is the framework of our proposed method. 5 | 6 | ![Model Framework](framework-colorblindfriendly.jpg) 7 | 8 | #### Environment Setup 9 | 10 | ```sh 11 | conda env create --file conda-recipe.yaml 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | #### Dataset Download 16 | 17 | * Arithmo: [akjindal53244/Arithmo-Data](https://huggingface.co/datasets/akjindal53244/Arithmo-Data) 18 | 19 | * GSM8K: [openai/grade-school-math](https://github.com/openai/grade-school-math/tree/master/grade_school_math/data) 20 | 21 | * MATH: [hendrycks/math](https://github.com/hendrycks/math/) 22 | 23 | * ARC: [AI2 Reasoning Challenge](https://paperswithcode.com/dataset/arc) 24 | 25 | * AI2S: [AI2 Science Questions](http://data.allenai.org/ai2-science-questions) 26 | 27 | * OBQA: [Openbook QA](https://allenai.org/data/open-book-qa) 28 | 29 | * CSQA: [tau/commonsense_qa](https://huggingface.co/datasets/tau/commonsense_qa) 30 | 31 | * SciQ: [SciQ Dataset](https://allenai.org/data/sciq) 32 | 33 | 34 | #### Run MCTS-DPO 35 | 36 | Our main code include `./mcts_rl/algorithms/mcts` and `./mcts_rl/trainers/tsrl_trainer.py` 37 | 38 | To run MCTS-DPO for MathQA on Mistral (SFT): 39 | ```sh 40 | bash scripts/mcts_mathqa.sh 41 | ``` 42 | 43 | To run MCTS-DPO for CSR on Mistral (SFT): 44 | ```sh 45 | bash scripts/mcts_csr.sh 46 | ``` 47 | 48 | ## Citation 49 | 50 | ``` 51 | @article{xie2024monte, 52 | title={Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning}, 53 | author={Xie, Yuxi and Goyal, Anirudh and Zheng, Wenyue and Kan, Min-Yen and Lillicrap, Timothy P and Kawaguchi, Kenji and Shieh, Michael}, 54 | journal={arXiv preprint arXiv:2405.00451}, 55 | year={2024} 56 | } 57 | ``` 58 | 59 | --- 60 | This repository is adapted from the code of the works [Safe-RLHF](https://github.com/PKU-Alignment/safe-rlhf). 61 | -------------------------------------------------------------------------------- /conda-recipe.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Create virtual environment with command: 3 | # 4 | # $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml 5 | # 6 | 7 | name: mcts-dpo 8 | channels: 9 | - huggingface 10 | - pytorch 11 | - nvidia/label/cuda-11.7.1 12 | - defaults 13 | - conda-forge 14 | dependencies: 15 | - python = 3.10 16 | - pip 17 | 18 | - pytorch::pytorch >= 1.13 19 | - pytorch::pytorch-mutex =*=*cuda* 20 | - transformers >= 4.29.0 21 | - datasets 22 | - tokenizers >= 0.13.3 23 | - sentencepiece 24 | - tensorboard 25 | - wandb 26 | - pip: 27 | - accelerate 28 | 29 | - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 30 | 31 | - optree 32 | - scipy 33 | - nvitop 34 | - matplotlib-base 35 | - rich 36 | - tqdm 37 | - typing-extensions 38 | - ipdb 39 | - jsonlines 40 | - func_timeout 41 | -------------------------------------------------------------------------------- /framework-colorblindfriendly.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiXie/MCTS-DPO/fb23437f502189b21e0824e7f6ad9c822cc8e091/framework-colorblindfriendly.jpg -------------------------------------------------------------------------------- /mcts_rl/__init__.py: -------------------------------------------------------------------------------- 1 | from mcts_rl import algorithms, configs, datasets, models, trainers, utils 2 | from mcts_rl.algorithms import * # noqa: F403 3 | from mcts_rl.configs import * # noqa: F403 4 | from mcts_rl.datasets import * # noqa: F403 5 | from mcts_rl.models import * # noqa: F403 6 | from mcts_rl.trainers import * # noqa: F403 7 | from mcts_rl.utils import * # noqa: F403 8 | from mcts_rl.version import __version__ 9 | 10 | 11 | __all__ = [ 12 | *algorithms.__all__, 13 | *configs.__all__, 14 | *datasets.__all__, 15 | *models.__all__, 16 | *trainers.__all__, 17 | *utils.__all__, 18 | ] 19 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """RL algorithms for RLHF.""" 16 | 17 | from mcts_rl.algorithms.dpo import DPOTrainer 18 | from mcts_rl.algorithms.ppo import PPOTrainer 19 | from mcts_rl.algorithms.mcts import MCTSTrainer 20 | 21 | 22 | __all__ = ['PPOTrainer', 'DPOTrainer', 'MCTSTrainer'] 23 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The Direct Preference Optimization (DPO) algorithm.""" 16 | 17 | from mcts_rl.algorithms.dpo.trainer import DPOTrainer 18 | 19 | 20 | __all__ = ['DPOTrainer'] 21 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/dpo/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The main training script to run the DPO algorithm.""" 16 | 17 | import sys 18 | 19 | from mcts_rl.algorithms.dpo.main import main 20 | 21 | 22 | if __name__ == '__main__': 23 | sys.exit(main()) 24 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | """RL with MCTS algorithm.""" 2 | 3 | from mcts_rl.algorithms.mcts.trainer import MCTSTrainer 4 | 5 | 6 | __all__ = ['MCTSTrainer'] 7 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/mcts/__main__.py: -------------------------------------------------------------------------------- 1 | """The main training script to train RLHF using PPO algorithm.""" 2 | 3 | import sys 4 | 5 | from mcts_rl.algorithms.mcts.main import main 6 | 7 | 8 | if __name__ == '__main__': 9 | sys.exit(main()) 10 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/mcts/mcts/__init__.py: -------------------------------------------------------------------------------- 1 | from mcts_rl.algorithms.mcts.mcts.base import TreeConstructor 2 | from mcts_rl.algorithms.mcts.mcts.mcts import MCTS, MCTSNode, MCTSResult, MCTSConfig 3 | from mcts_rl.algorithms.mcts.mcts.world_model import StepLMWorldModel, LMExample 4 | from mcts_rl.algorithms.mcts.mcts.search_config import StepLMConfig, SearchArgs -------------------------------------------------------------------------------- /mcts_rl/algorithms/mcts/mcts/base.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar, Protocol 2 | from abc import ABC, abstractmethod 3 | 4 | State = TypeVar("State") 5 | Action = TypeVar("Action") 6 | Example = TypeVar("Example") 7 | Trace = tuple[list[State], list[Action]] 8 | Args = TypeVar("Args") 9 | 10 | 11 | class WorldModel(ABC, Generic[State, Action, Example]): 12 | def __init__(self) -> None: 13 | self.example = None 14 | 15 | @abstractmethod 16 | def init_state(self) -> State: ... 17 | 18 | @abstractmethod 19 | def step(self, state: State, action: Action) -> State: ... 20 | 21 | @abstractmethod 22 | def is_terminal(self, state: State) -> bool: ... 23 | 24 | def update_example(self, example: Example) -> None: 25 | self.example = example 26 | 27 | 28 | class SearchConfig(ABC, Generic[State, Action, Example]): 29 | def __init__(self) -> None: 30 | self.example = None 31 | 32 | @abstractmethod 33 | def get_actions(self, state: State) -> list[Action]: ... 34 | 35 | @abstractmethod 36 | def reward(self, state, action, **kwargs) -> tuple[float, dict]: ... 37 | 38 | @abstractmethod 39 | def get_values(self, state: State, action: Action) -> list[tuple[float, bool]]: ... 40 | 41 | def update_example(self, example: Example) -> None: 42 | self.example = example 43 | 44 | 45 | class HasTerminalStateAndTrace(Protocol[State]): 46 | terminal_state: State 47 | trace: Trace 48 | 49 | 50 | class SearchAlgorithm(ABC): 51 | def __init__(self, **kwargs): ... 52 | 53 | @abstractmethod 54 | def __call__(self, world_model: WorldModel, search_config: SearchConfig, **kwargs) -> HasTerminalStateAndTrace: ... 55 | 56 | 57 | class TreeConstructor(ABC, Generic[State, Action, Example]): 58 | def __init__(self, 59 | world_model: WorldModel[State, Action, Example], 60 | search_config: SearchConfig[State, Action, Example], 61 | search_algo: SearchAlgorithm) -> None: 62 | self.world_model = world_model 63 | self.search_config = search_config 64 | self.search_algo = search_algo 65 | 66 | def __call__(self, example: Example, node=None, **kwargs) -> HasTerminalStateAndTrace[State]: 67 | self.world_model.update_example(example) 68 | self.search_config.update_example(example) 69 | return self.search_algo(self.world_model, 70 | self.search_config, 71 | root_node=node, 72 | **kwargs) 73 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/mcts/mcts/embedding_retrieve.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import torch 4 | import torch.distributed as dist 5 | from bert_score.utils import greedy_cos_idf 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | 9 | def get_embs_masks(seq_ids: list[list[int]], seq_embs: list[torch.Tensor], 10 | special_ids: list = [1, 2, 32000]): 11 | seq_ids = [ids[:len(embs)] for ids, embs in zip(seq_ids, seq_embs)] 12 | seq_embs = pad_sequence(seq_embs, batch_first=True, padding_value=2.0) 13 | lengths = torch.tensor([len(seq) for seq in seq_ids], dtype=torch.long) 14 | seq_masks = torch.arange(max(lengths), dtype=torch.long).expand(len(lengths), max(lengths)) 15 | seq_masks = seq_masks < lengths.unsqueeze(1) 16 | seq_idfs = pad_sequence([ 17 | torch.tensor([float(_id not in special_ids) for _id in seq]) 18 | for seq in seq_ids], batch_first=True, padding_value=0.0) 19 | return seq_embs, seq_masks.to(seq_embs.device), seq_idfs.to(seq_embs.device) 20 | 21 | 22 | def filter_with_similarity(candidates, tokenizer): 23 | sequences = [tokenizer.decode(x[0]) for x in candidates] 24 | special_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id] 25 | 26 | seq_ids, seq_embs = [c[0][:len(c[2])] for c in candidates], [c[2] for c in candidates] 27 | seq_embs, seq_masks, seq_idfs = get_embs_masks(seq_ids, seq_embs, special_ids=special_ids) 28 | 29 | scores = {'P': {}, 'R': {}, 'F': {}} 30 | for i, rst in enumerate(candidates): 31 | P, R, F = check_match(special_ids, 32 | seq_embs, seq_masks, seq_idfs, 33 | rst[0][:len(rst[2])], rst[2]) 34 | scores['P'][i], scores['R'][i], scores['F'][i] = P, R, F 35 | max_f = max(x.item() for xid, x in enumerate(F) if xid != i) 36 | if max_f > .9: 37 | import ipdb; ipdb.set_trace() 38 | F_scores = list(dict(sorted(scores['F'].items(), key=lambda x: x[0])).values()) 39 | 40 | return F_scores 41 | 42 | 43 | def check_match(keys_embs: torch.Tensor, keys_masks: torch.BoolTensor, keys_idfs: torch.Tensor, 44 | query: list[int], query_emb: torch.Tensor, 45 | special_ids: list = [1, 2, 32000]): 46 | query_embs = torch.stack([query_emb for _ in keys_embs], dim=0) 47 | query_masks = torch.ones(query_embs.size()[:-1]).bool().to(query_embs.device) 48 | query_idfs = torch.stack([ 49 | torch.tensor([float(_id not in special_ids) for _id in query[:query_emb.size(0)]]) 50 | for _ in keys_idfs], dim=0).to(query_embs.device) 51 | 52 | P, R, F = greedy_cos_idf(keys_embs.float(), keys_masks, keys_idfs, 53 | query_embs.float(), query_masks, query_idfs) 54 | return P, R, F -------------------------------------------------------------------------------- /mcts_rl/algorithms/mcts/mcts/world_model.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/maitrix-org/llm-reasoners/blob/main/examples/RAP/gsm8k/world_model.py 2 | 3 | from typing import NamedTuple, TypedDict 4 | 5 | import torch 6 | from transformers import GenerationConfig, PreTrainedTokenizerBase 7 | 8 | from mcts_rl.algorithms.mcts.mcts.base import WorldModel 9 | 10 | 11 | class StepSubResult(NamedTuple): 12 | next_step_ids: torch.LongTensor 13 | log_probs: torch.Tensor 14 | 15 | 16 | StepLMState = list[StepSubResult] 17 | StepLMAction = torch.LongTensor 18 | 19 | 20 | class WorldModelArgs(NamedTuple): 21 | base_tokenizer: PreTrainedTokenizerBase 22 | generation_config: GenerationConfig 23 | stop_tokens: list[str] = [] 24 | 25 | 26 | class LMExample(TypedDict): 27 | input_ids: torch.LongTensor # (L,) 28 | attention_mask: torch.BoolTensor # (L,) 29 | 30 | 31 | class StepLMWorldModel(WorldModel[StepLMState, StepLMAction, LMExample]): 32 | def __init__(self, 33 | max_length: int, 34 | base_tokenizer: PreTrainedTokenizerBase, 35 | generation_config: GenerationConfig, 36 | stop_tokens=[]) -> None: 37 | super().__init__() 38 | self.base_tokenizer = base_tokenizer 39 | self.generation_config = generation_config 40 | self.max_tokens_num = max_length 41 | self.stop_tokens = list(set( 42 | stop_tokens + [self.base_tokenizer.decode([self.generation_config.eos_token_id])] 43 | )) 44 | 45 | def init_state(self) -> list: 46 | return [] 47 | 48 | def step(self, state: StepLMState, action: StepLMAction, log_probs: torch.Tensor) -> StepLMState: 49 | state = state.copy() 50 | state.append(StepSubResult(action, log_probs)) 51 | return state 52 | 53 | def is_terminal(self, state: StepLMState) -> bool: 54 | input_length = self.example['attention_mask'].nonzero()[-1].item() + 1 55 | sum_tokens_num = sum(x.next_step_ids.size(0) for x in state) + input_length 56 | 57 | if sum_tokens_num >= self.max_tokens_num - 5: 58 | return True 59 | elif state[-1].next_step_ids.eq(self.base_tokenizer.eos_token_id).sum(): 60 | return True 61 | elif state[-1].next_step_ids.eq(self.base_tokenizer.convert_tokens_to_ids("<|eot_id|>")).sum(): 62 | return True 63 | elif self.base_tokenizer.decode(state[-1].next_step_ids).count('QUESTION: '): 64 | return True 65 | else: 66 | return False 67 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """RLHF with PPO algorithm.""" 16 | 17 | from mcts_rl.algorithms.ppo.trainer import PPOTrainer 18 | 19 | 20 | __all__ = ['PPOTrainer'] 21 | -------------------------------------------------------------------------------- /mcts_rl/algorithms/ppo/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The main training script to train RLHF using PPO algorithm.""" 16 | 17 | import sys 18 | 19 | from mcts_rl.algorithms.ppo.main import main 20 | 21 | 22 | if __name__ == '__main__': 23 | sys.exit(main()) 24 | -------------------------------------------------------------------------------- /mcts_rl/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Configurations and constants.""" 16 | 17 | from mcts_rl.configs import constants 18 | from mcts_rl.configs.constants import * # noqa: F403 19 | from mcts_rl.configs.deepspeed_config import ( 20 | TEMPLATE_DIR, 21 | get_deepspeed_eval_config, 22 | get_deepspeed_train_config, 23 | ) 24 | 25 | 26 | __all__ = [ 27 | *constants.__all__, 28 | 'TEMPLATE_DIR', 29 | 'get_deepspeed_eval_config', 30 | 'get_deepspeed_train_config', 31 | ] 32 | -------------------------------------------------------------------------------- /mcts_rl/configs/deepspeed_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DeepSpeed configuration for training and evaluation.""" 16 | 17 | from __future__ import annotations 18 | 19 | import json 20 | import pathlib 21 | from typing import Any, Literal 22 | 23 | import torch.distributed as dist 24 | 25 | 26 | __all__ = ['TEMPLATE_DIR', 'get_deepspeed_train_config', 'get_deepspeed_eval_config'] 27 | 28 | 29 | TEMPLATE_DIR = pathlib.Path(__file__).absolute().parent 30 | TRAIN_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_train_config_template.json' 31 | EVAL_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_eval_config_template.json' 32 | 33 | 34 | def get_deepspeed_train_config( 35 | *, 36 | micro_batch_size_per_gpu: int = 16, 37 | gradient_accumulation_steps: int = 1, 38 | stage: int = 3, 39 | offload: Literal['none', 'parameter', 'optimizer', 'all'] = 'none', 40 | enable_hybrid_engine: bool = False, 41 | max_length: int = 512, 42 | fp16: bool = False, 43 | bf16: bool = False, 44 | ) -> dict[str, Any]: 45 | """Get the DeepSpeed config for training. 46 | 47 | Args: 48 | micro_batch_size_per_gpu (int, optional): The micro batch size per GPU. Defaults to 16. 49 | gradient_accumulation_steps (int, optional): The number of gradient accumulation steps. 50 | Defaults to 1. 51 | stage (int, optional): The stage of ZeRO. Defaults to 3. 52 | offload (Literal['none', 'parameter', 'optimizer', 'all'], optional): The offload mode. 53 | enable_hybrid_engine (bool, optional): Whether to enable the DeepSpeed hybrid engine. 54 | Defaults to False. 55 | max_length (int, optional): The maximum length of the input sequence. Defaults to 512. 56 | fp16 (bool, optional): Whether to use FP16 precision. Defaults to False. 57 | bf16 (bool, optional): Whether to use BF16 precision. Defaults to False. 58 | 59 | Returns: 60 | The DeepSpeed config for training. 61 | """ 62 | assert offload in {'none', 'parameter', 'optimizer', 'all'} 63 | 64 | with TRAIN_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f: 65 | train_config = json.load(f) 66 | 67 | word_size = dist.get_world_size() if dist.is_initialized() else 1 68 | train_batch_size = micro_batch_size_per_gpu * word_size * gradient_accumulation_steps 69 | 70 | train_config['train_batch_size'] = train_batch_size 71 | train_config['train_micro_batch_size_per_gpu'] = micro_batch_size_per_gpu 72 | train_config['gradient_accumulation_steps'] = gradient_accumulation_steps 73 | train_config['zero_optimization']['stage'] = stage 74 | if offload in {'parameter', 'all'}: 75 | train_config['zero_optimization'].setdefault('offload_param', {}) 76 | train_config['zero_optimization']['offload_param']['device'] = 'cpu' 77 | if offload in {'optimizer', 'all'}: 78 | train_config['zero_optimization'].setdefault('offload_optimizer', {}) 79 | train_config['zero_optimization']['offload_optimizer']['device'] = 'cpu' 80 | train_config['hybrid_engine']['enabled'] = enable_hybrid_engine 81 | train_config['hybrid_engine']['max_out_tokens'] = max_length 82 | if fp16 or 'fp16' in train_config: 83 | train_config.setdefault('fp16', {}) 84 | train_config['fp16']['enabled'] = fp16 85 | if bf16 or 'bf16' in train_config: 86 | train_config.setdefault('bf16', {}) 87 | train_config['bf16']['enabled'] = bf16 88 | return train_config 89 | 90 | 91 | def get_deepspeed_eval_config( 92 | *, 93 | stage: int = 3, 94 | offload: Literal['none', 'parameter', 'optimizer', 'all'] = 'none', 95 | fp16: bool = False, 96 | bf16: bool = False, 97 | ) -> dict[str, Any]: 98 | """Get the DeepSpeed config for evaluation. 99 | 100 | Args: 101 | stage (int, optional): The stage of ZeRO. Defaults to 3. 102 | offload (Literal['none', 'parameter', 'optimizer', 'all'], optional): The offload mode. 103 | fp16 (bool, optional): Whether to use FP16 precision. Defaults to False. 104 | bf16 (bool, optional): Whether to use BF16 precision. Defaults to False. 105 | 106 | Returns: 107 | The DeepSpeed config for evaluation. 108 | """ 109 | assert offload in {'none', 'parameter', 'optimizer', 'all'} 110 | 111 | with EVAL_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f: 112 | eval_config = json.load(f) 113 | 114 | if stage in {1, 2}: 115 | # The evaluation config only works for ZeRO stage 0 and ZeRO stage 3 116 | stage = 0 117 | 118 | eval_config['train_batch_size'] = None 119 | eval_config['train_micro_batch_size_per_gpu'] = 1 120 | eval_config['gradient_accumulation_steps'] = 1 121 | eval_config['zero_optimization']['stage'] = stage 122 | if offload in {'parameter', 'all'}: 123 | eval_config['zero_optimization'].setdefault('offload_param', {}) 124 | eval_config['zero_optimization']['offload_param']['device'] = 'cpu' 125 | if fp16 or 'fp16' in eval_config: 126 | eval_config.setdefault('fp16', {}) 127 | eval_config['fp16']['enabled'] = fp16 128 | if bf16 or 'bf16' in eval_config: 129 | eval_config.setdefault('bf16', {}) 130 | eval_config['bf16']['enabled'] = bf16 131 | return eval_config 132 | -------------------------------------------------------------------------------- /mcts_rl/configs/ds_eval_config_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": null, 3 | "train_micro_batch_size_per_gpu": 1, 4 | "gradient_accumulation_steps": 1, 5 | "steps_per_print": 10, 6 | "zero_optimization": { 7 | "stage": 3, 8 | "offload_param": { 9 | "device": "none" 10 | }, 11 | "param_persistence_threshold": 1e4, 12 | "memory_efficient_linear": false 13 | }, 14 | "gradient_clipping": 1.0, 15 | "prescale_gradients": false, 16 | "wall_clock_breakdown": false 17 | } 18 | -------------------------------------------------------------------------------- /mcts_rl/configs/ds_train_config_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "train_micro_batch_size_per_gpu": 16, 4 | "gradient_accumulation_steps": "auto", 5 | "steps_per_print": 10, 6 | "zero_optimization": { 7 | "stage": 2, 8 | "offload_param": { 9 | "device": "none" 10 | }, 11 | "offload_optimizer": { 12 | "device": "none" 13 | }, 14 | "param_persistence_threshold": 1e4, 15 | "max_live_parameters": 3e7, 16 | "prefetch_bucket_size": 3e7, 17 | "memory_efficient_linear": false, 18 | "gather_16bit_weights_on_model_save": true, 19 | "overlap_comm": true, 20 | "contiguous_gradients": true, 21 | "sub_group_size": 1e9, 22 | "reduce_bucket_size": "auto" 23 | }, 24 | "gradient_clipping": 1.0, 25 | "prescale_gradients": false, 26 | "wall_clock_breakdown": false, 27 | "hybrid_engine": { 28 | "enabled": true, 29 | "max_out_tokens": 512, 30 | "inference_tp_size": 1, 31 | "release_inference_cache": false, 32 | "pin_parameters": true, 33 | "tp_gather_partition_size": 8 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /mcts_rl/configs/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer" 3 | } 4 | -------------------------------------------------------------------------------- /mcts_rl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Dataset classes.""" 16 | 17 | from __future__ import annotations 18 | 19 | from typing import Dict 20 | 21 | import torch 22 | from torch.utils.data import Dataset 23 | 24 | from mcts_rl.datasets import raw 25 | from mcts_rl.datasets.base import ( 26 | CollatorBase, 27 | RawDataset, 28 | RawSample, 29 | TokenizedDataset, 30 | parse_dataset, 31 | ) 32 | from mcts_rl.datasets.preference import ( 33 | PreferenceBatch, 34 | PreferenceCollator, 35 | PreferenceDataset, 36 | PreferenceSample, 37 | ) 38 | from mcts_rl.datasets.prompt_only import ( 39 | PromptOnlyBatch, 40 | PromptOnlyCollator, 41 | PromptOnlyDataset, 42 | PromptOnlySample, 43 | PromptOnlyPostDataset, 44 | PromptOnlyPostCollator, 45 | PromptOnlyPostSample, 46 | PromptOnlyPostBatch, 47 | ) 48 | from mcts_rl.datasets.raw import * # noqa: F403 49 | from mcts_rl.datasets.safety_preference import ( 50 | SafetyPreferenceBatch, 51 | SafetyPreferenceCollator, 52 | SafetyPreferenceDataset, 53 | SafetyPreferenceSample, 54 | ) 55 | from mcts_rl.datasets.supervised import ( 56 | SupervisedBatch, 57 | SupervisedCollator, 58 | SupervisedDataset, 59 | SupervisedSample, 60 | ) 61 | 62 | 63 | __all__ = [ 64 | 'DummyDataset', 65 | 'parse_dataset', 66 | 'RawDataset', 67 | 'RawSample', 68 | 'TokenizedDataset', 69 | 'CollatorBase', 70 | 'PreferenceDataset', 71 | 'PreferenceSample', 72 | 'PreferenceBatch', 73 | 'PreferenceCollator', 74 | 'PromptOnlyDataset', 75 | 'PromptOnlyCollator', 76 | 'PromptOnlySample', 77 | 'PromptOnlyBatch', 78 | 'PromptOnlyPostDataset', 79 | 'PromptOnlyPostCollator', 80 | 'PromptOnlyPostSample', 81 | 'PromptOnlyPostBatch', 82 | 'SafetyPreferenceDataset', 83 | 'SafetyPreferenceCollator', 84 | 'SafetyPreferenceSample', 85 | 'SafetyPreferenceBatch', 86 | 'SupervisedDataset', 87 | 'SupervisedCollator', 88 | 'SupervisedSample', 89 | 'SupervisedBatch', 90 | *raw.__all__, 91 | ] 92 | 93 | 94 | class DummyDataset(Dataset[Dict[str, torch.Tensor]]): 95 | def __init__(self, length: int) -> None: 96 | self.length = length 97 | 98 | def __len__(self) -> int: 99 | return self.length 100 | 101 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]: 102 | return {} 103 | -------------------------------------------------------------------------------- /mcts_rl/datasets/preference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Dataset class for preference training.""" 16 | 17 | from __future__ import annotations 18 | 19 | from typing import Callable 20 | from typing_extensions import TypedDict # Python 3.10+ 21 | 22 | import torch 23 | 24 | from mcts_rl.datasets.base import CollatorBase, RawSample, TokenizedDataset 25 | from mcts_rl.datasets.utils import format_prompt, right_padding 26 | 27 | 28 | __all__ = [ 29 | 'PreferenceDataset', 30 | 'PreferenceCollator', 31 | 'PreferenceSample', 32 | 'PreferenceBatch', 33 | ] 34 | 35 | 36 | class PreferenceSample(TypedDict, total=True): 37 | better_input_ids: torch.LongTensor # size = (L,) 38 | worse_input_ids: torch.LongTensor # size = (L,) 39 | 40 | 41 | class PreferenceBatch(TypedDict, total=True): 42 | better_input_ids: torch.LongTensor # size = (B, L) 43 | better_attention_mask: torch.BoolTensor # size = (B, L) 44 | 45 | worse_input_ids: torch.LongTensor # size = (B, L) 46 | worse_attention_mask: torch.BoolTensor # size = (B, L) 47 | 48 | 49 | class PreferenceDataset(TokenizedDataset): 50 | def preprocess(self, raw_sample: RawSample) -> PreferenceSample: 51 | prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token, use_mcq=self.use_mcq) 52 | better_answer = raw_sample['answer'] 53 | worse_answer = raw_sample['other_answer'] 54 | better = raw_sample['better'] 55 | if not better: 56 | better_answer, worse_answer = worse_answer, better_answer 57 | 58 | better_input_ids = self.tokenize(prompt + better_answer + self.tokenizer.eos_token) 59 | worse_input_ids = self.tokenize(prompt + worse_answer + self.tokenizer.eos_token) 60 | if ( 61 | better_input_ids.size() == worse_input_ids.size() 62 | and torch.all(torch.eq(better_input_ids, worse_input_ids)).item() 63 | ): 64 | raise ValueError( 65 | 'Two responses get the same `input_ids` after tokenization.\n\n' 66 | f'Prompt: {prompt}\n\n' 67 | f'Better answer: {better_answer}\n\n' 68 | f'Worse answer: {worse_answer}', 69 | ) 70 | return { 71 | 'better_input_ids': better_input_ids, # size = (L,) 72 | 'worse_input_ids': worse_input_ids, # size = (L,) 73 | } 74 | 75 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: 76 | return PreferenceCollator(self.tokenizer.pad_token_id) 77 | 78 | 79 | class PreferenceCollator(CollatorBase): 80 | def __call__(self, samples: list[PreferenceSample]) -> PreferenceBatch: 81 | input_ids = [sample['better_input_ids'] for sample in samples] + [ 82 | sample['worse_input_ids'] for sample in samples 83 | ] # size = (2 * B, L) 84 | attention_mask = [ 85 | input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids 86 | ] # size = (2 * B, L) 87 | 88 | input_ids = right_padding(input_ids, padding_value=self.pad_token_id) # size = (2 * B, L) 89 | attention_mask = right_padding(attention_mask, padding_value=0) # size = (2 * B, L) 90 | 91 | ( 92 | better_input_ids, # size = (B, L) 93 | worse_input_ids, # size = (B, L) 94 | ) = input_ids.chunk(chunks=2, dim=0) 95 | ( 96 | better_attention_mask, # size = (B, L) 97 | worse_attention_mask, # size = (B, L) 98 | ) = attention_mask.chunk(chunks=2, dim=0) 99 | 100 | return { 101 | 'better_input_ids': better_input_ids, # size = (B, L) 102 | 'better_attention_mask': better_attention_mask, # size = (B, L) 103 | 'worse_input_ids': worse_input_ids, # size = (B, L) 104 | 'worse_attention_mask': worse_attention_mask, # size = (B, L) 105 | } 106 | -------------------------------------------------------------------------------- /mcts_rl/datasets/prompt_only.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Callable, Hashable 19 | from typing_extensions import TypedDict # Python 3.10+ 20 | 21 | import torch 22 | from torch.utils.data import Dataset, Subset 23 | 24 | from mcts_rl.datasets.base import CollatorBase, RawSample, RawSamplePost, TokenizedDataset 25 | from mcts_rl.datasets.utils import format_prompt, left_padding 26 | 27 | 28 | __all__ = [ 29 | 'PromptOnlyDataset', 30 | 'PromptOnlyCollator', 31 | 'PromptOnlySample', 32 | 'PromptOnlyBatch', 33 | 'PromptOnlyPostDataset', 34 | 'PromptOnlyPostCollator', 35 | 'PromptOnlyPostSample', 36 | 'PromptOnlyPostBatch', 37 | ] 38 | 39 | 40 | class PromptOnlySample(TypedDict, total=True): 41 | input_ids: torch.LongTensor # size = (L,) 42 | 43 | 44 | class PromptOnlyBatch(TypedDict, total=True): 45 | input_ids: torch.LongTensor # size = (B, L) 46 | attention_mask: torch.BoolTensor # size = (B, L) 47 | 48 | 49 | class PromptOnlyDataset(TokenizedDataset): 50 | def preprocess(self, raw_sample: RawSample) -> PromptOnlySample: 51 | try: 52 | prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token, 53 | use_mcq=self.use_mcq, few_shot=self.few_shot, model_type=self.model_type) 54 | except: 55 | import ipdb; ipdb.set_trace() 56 | input_ids = self.tokenize(prompt) 57 | return { 58 | 'input_ids': input_ids, # size = (L,) 59 | 'answer': raw_sample.get('final_answer', ''), # str 60 | 'reasoning': raw_sample.get('answer', ''), 61 | 'answer_content': raw_sample.get('final_answer_content', raw_sample['final_answer'] if 'final_answer' in raw_sample else '') 62 | } 63 | 64 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: 65 | return PromptOnlyCollator(self.tokenizer.pad_token_id) 66 | 67 | def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSample]: 68 | """Merge multiple raw datasets into one dataset and remove duplicates.""" 69 | 70 | def to_hashable(raw_sample: RawSample) -> Hashable: 71 | input = raw_sample['input'] # pylint: disable=redefined-builtin 72 | return input if isinstance(input, str) else tuple(input) 73 | 74 | merged = super()._merge_raw_datasets(seed) 75 | inputs = {to_hashable(merged[i]): i for i in range(len(merged)) if isinstance(merged[i]['input'], str) or len(merged[i]['input']) == 1} 76 | return Subset(merged, sorted(inputs.values())) 77 | 78 | 79 | class PromptOnlyCollator(CollatorBase): 80 | def __call__(self, samples: list[PromptOnlySample]) -> PromptOnlyBatch: 81 | input_ids = [sample['input_ids'] for sample in samples] 82 | attention_mask = [ 83 | input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids 84 | ] 85 | 86 | input_ids = left_padding(input_ids, padding_value=self.pad_token_id) 87 | attention_mask = left_padding(attention_mask, padding_value=0) 88 | return { 89 | 'input_ids': input_ids, # size = (B, L) 90 | 'attention_mask': attention_mask, # size = (B, L) 91 | 'answer': [sample['answer'] for sample in samples], 92 | 'reasoning': [sample['reasoning'] for sample in samples], 93 | 'answer_content': [sample['answer_content'] for sample in samples], 94 | } 95 | 96 | 97 | class PromptOnlyPostSample(TypedDict, total=True): 98 | prompts_list: list[torch.LongTensor] 99 | input_ids_list: list[torch.LongTensor] 100 | answer: str 101 | answer_content: str 102 | 103 | 104 | class PromptOnlyPostBatch(TypedDict, total=True): 105 | prompts_list: list[torch.LongTensor] 106 | input_ids_list: list[torch.LongTensor] 107 | answer: list[str] 108 | answer_content: list[str] 109 | 110 | 111 | class PromptOnlyPostDataset(TokenizedDataset): 112 | def preprocess(self, raw_sample: RawSamplePost) -> PromptOnlyPostSample: 113 | return { 114 | 'prompts_list': [raw_sample['prompt']], 115 | 'input_ids_list': raw_sample['input_ids_list'], 116 | 'answer': raw_sample['final_answer'], # str 117 | 'answer_content': raw_sample.get('final_answer_content', raw_sample['final_answer']) 118 | } 119 | 120 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: 121 | return PromptOnlyPostCollator(self.tokenizer.pad_token_id) 122 | 123 | def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSamplePost]: 124 | """Merge multiple raw datasets into one dataset and remove duplicates.""" 125 | 126 | def to_hashable(raw_sample: RawSamplePost) -> Hashable: 127 | input = raw_sample['prompt'] # pylint: disable=redefined-builtin 128 | return input if isinstance(input, str) else tuple(input.tolist()) 129 | 130 | merged = super()._merge_raw_datasets(seed) 131 | inputs = {to_hashable(merged[i]): i for i in range(len(merged))} 132 | return Subset(merged, sorted(inputs.values())) 133 | 134 | 135 | class PromptOnlyPostCollator(CollatorBase): 136 | def __call__(self, samples: list[PromptOnlyPostSample]) -> PromptOnlyPostBatch: 137 | prompts_list = [sample['prompts_list'] for sample in samples] 138 | input_ids_list = [sample['input_ids_list'] for sample in samples] 139 | attention_mask_list = [[ 140 | input_ids.not_equal(self.pad_token_id) for input_ids in sample['input_ids_list'] 141 | ] for sample in samples] 142 | 143 | return { 144 | 'prompts_list': prompts_list, 145 | 'input_ids_list': input_ids_list, 146 | 'attention_mask_list': attention_mask_list, 147 | 'answer': [sample['answer'] for sample in samples], 148 | 'answer_content': [sample['answer_content'] for sample in samples], 149 | } 150 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/__init__.py: -------------------------------------------------------------------------------- 1 | """Raw datasets.""" 2 | 3 | from mcts_rl.datasets.raw.alpaca import AlpacaDataset 4 | from mcts_rl.datasets.raw.firefly import FireflyDataset 5 | from mcts_rl.datasets.raw.hh_rlhf import ( 6 | HhRLHFDialogueDataset, 7 | HhRLHFHarmlessDialogueDataset, 8 | HhRLHFHelpfulDialogueDataset, 9 | ) 10 | from mcts_rl.datasets.raw.moss import MOSS002SFT, MOSS003SFT 11 | from mcts_rl.datasets.raw.safe_rlhf import ( 12 | SafeRLHF10KTrainDataset, 13 | SafeRLHFDataset, 14 | SafeRLHFTestDataset, 15 | SafeRLHFTrainDataset, 16 | ) 17 | from mcts_rl.datasets.raw.prm800k import ( 18 | PRM800KDataset, 19 | PRM800KTestDataset, 20 | PRM800KTrainDataset, 21 | ) 22 | from mcts_rl.datasets.raw.mcq import ( 23 | MCQDataset, 24 | SQATestDataset, 25 | SQATrainDataset, 26 | CSRTestDataset, 27 | CSRTrainDataset, 28 | SciQTestDataset, 29 | NLITestDataset, 30 | MCQTestDataset, 31 | MCQTrainDataset, 32 | ) 33 | from mcts_rl.datasets.raw.math import ( 34 | MATHDataset, 35 | MATHTestDataset, 36 | MATHTrainDataset, 37 | MATHSFTTrainDataset, 38 | MATHSFTTestDataset, 39 | ) 40 | from mcts_rl.datasets.raw.gsm8k import ( 41 | GSM8KDataset, 42 | GSM8KTestDataset, 43 | GSM8KTrainDataset, 44 | GSM8KPoTTestDataset, 45 | GSM8KPoTTrainDataset, 46 | GSM8KSFTTrainDataset, 47 | GSM8KSFTTestDataset, 48 | ) 49 | from mcts_rl.datasets.raw.arithmo import ( 50 | ArithmoDataset, 51 | ArithmoTestDataset, 52 | ArithmoTrainDataset, 53 | ArithmoMATHTrainDataset, 54 | ArithmoMCQTrainDataset, 55 | ArithmoCodeTrainDataset, 56 | ) 57 | from mcts_rl.datasets.raw.qa_feedback import ( 58 | QAFBDataset, 59 | QAFBTestDataset, 60 | QAFBTrainDataset, 61 | ) 62 | from mcts_rl.datasets.raw.mcq_pairs import ( 63 | MCQPreferenceDataset, 64 | SQAPreferenceTestDataset, 65 | SQAPreferenceTrainDataset, 66 | CSRPreferenceTestDataset, 67 | CSRPreferenceTrainDataset, 68 | GSMPreferenceTrainDataset, 69 | GSMPreferenceTestDataset, 70 | ) 71 | from mcts_rl.datasets.raw.mcq_for_eval import ( 72 | MCQEvalDataset, 73 | SQAEvalTestDataset, 74 | SQAEvalTrainDataset, 75 | CSREvalTestDataset, 76 | CSREvalTrainDataset, 77 | GSMEvalTestDataset, 78 | GSMEvalTrainDataset, 79 | ) 80 | from mcts_rl.datasets.raw.math_qa import ( 81 | MathQADataset, 82 | MathQATestDataset, 83 | MathQATrainDataset, 84 | MathQACodeTestDataset, 85 | MathQACodeTrainDataset, 86 | MathQAAllTrainDataset, 87 | MathQAAllTestDataset, 88 | MathQASFTTrainDataset, 89 | ) 90 | from mcts_rl.datasets.raw.aqua import ( 91 | AQuADataset, 92 | AQuAPoTTestDataset, 93 | AQuATestDataset, 94 | ) 95 | from mcts_rl.datasets.raw.exam import ( 96 | ExamTestDataset, 97 | ExamDataset, 98 | ) 99 | 100 | 101 | __all__ = [ 102 | 'AlpacaDataset', 103 | 'FireflyDataset', 104 | 'HhRLHFDialogueDataset', 105 | 'HhRLHFHarmlessDialogueDataset', 106 | 'HhRLHFHelpfulDialogueDataset', 107 | 'MOSS002SFT', 108 | 'MOSS003SFT', 109 | 'SafeRLHFDataset', 110 | 'SafeRLHFTrainDataset', 111 | 'SafeRLHFTestDataset', 112 | 'SafeRLHF10KTrainDataset', 113 | 'PRM800KDataset', 114 | 'PRM800KTrainDataset', 115 | 'PRM800KTestDataset', 116 | 'MCQDataset', 117 | 'SQATestDataset', 118 | 'SQATrainDataset', 119 | 'CSRTestDataset', 120 | 'CSRTrainDataset', 121 | 'SciQTestDataset', 122 | 'NLITestDataset', 123 | 'MATHDataset', 124 | 'MATHTrainDataset', 125 | 'MATHTestDataset', 126 | 'MathQAAllTrainDataset', 127 | 'GSM8KDataset', 128 | 'GSM8KTestDataset', 129 | 'GSM8KTrainDataset', 130 | 'GSM8KPoTTestDataset', 131 | 'GSM8KPoTTrainDataset', 132 | 'ArithmoDataset', 133 | 'ArithmoTestDataset', 134 | 'ArithmoTrainDataset', 135 | 'ArithmoMATHTrainDataset', 136 | 'ArithmoMCQTrainDataset', 137 | 'ArithmoCodeTrainDataset', 138 | 'QAFBDataset', 139 | 'QAFBTestDataset', 140 | 'QAFBTrainDataset', 141 | 'MCQPreferenceDataset', 142 | 'SQAPreferenceTestDataset', 143 | 'SQAPreferenceTrainDataset', 144 | 'CSRPreferenceTestDataset', 145 | 'CSRPreferenceTrainDataset', 146 | 'MCQTestDataset', 147 | 'MCQTrainDataset', 148 | 'GSMPreferenceTrainDataset', 149 | 'GSMPreferenceTestDataset', 150 | 'MCQEvalDataset', 151 | 'SQAEvalTestDataset', 152 | 'SQAEvalTrainDataset', 153 | 'CSREvalTestDataset', 154 | 'CSREvalTrainDataset', 155 | 'GSMEvalTestDataset', 156 | 'GSMEvalTrainDataset', 157 | 'MathQADataset', 158 | 'MathQATestDataset', 159 | 'MathQATrainDataset', 160 | 'MathQAAllTestDataset', 161 | 'MathQACodeTestDataset', 162 | 'MathQACodeTrainDataset', 163 | 'AQuADataset', 164 | 'AQuAPoTTestDataset', 165 | 'AQuATestDataset', 166 | 'ExamTestDataset', 167 | 'ExamDataset', 168 | 'MathQASFTTrainDataset', 169 | 'MATHSFTTrainDataset', 170 | 'MATHSFTTestDataset', 171 | 'GSM8KSFTTrainDataset', 172 | 'GSM8KSFTTestDataset', 173 | ] 174 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/alpaca.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Stanford Alpaca dataset for supervised instruction fine-tuning.""" 16 | 17 | from __future__ import annotations 18 | 19 | from datasets import load_dataset 20 | from mcts_rl.datasets.base import RawDataset, RawSample 21 | 22 | 23 | __all__ = ['AlpacaDataset'] 24 | 25 | 26 | class AlpacaDataset(RawDataset): 27 | NAME: str = 'alpaca' 28 | ALIASES: tuple[str, ...] = ('stanford-alpaca',) 29 | 30 | def __init__(self, path: str | None = None) -> None: 31 | alpaca = load_dataset(path or 'tatsu-lab/alpaca', split='train') 32 | self.data = [] 33 | # for data in load_dataset('McGill-NLP/feedbackQA', split='train'): 34 | # question = data['question'] 35 | # answer = data['answer'].replace('\n', ' ') 36 | # comment = '\n'.join([f'{r}: {e}' for r, e in zip(data['feedback']['rating'], data['feedback']['explanation'])]) 37 | # if ('Excellent' in comment or 'Acceptable' in comment) and 'Bad' not in comment: 38 | # self.data.append({'instruction': question, 'input': '', 'output': answer}) 39 | # self.data += list(alpaca)[:len(self.data) * 7] 40 | safe_data = {} 41 | for data in load_dataset('PKU-Alignment/PKU-SafeRLHF', split='train'): 42 | question = data['prompt'] 43 | idx = data['better_response_id'] 44 | if data[f'is_response_{idx}_safe']: 45 | answer = data[f'response_{idx}'] 46 | if question not in safe_data or data['safer_response_id'] == idx: 47 | safe_data[question] = {'instruction': question, 'input': '', 'output': answer} 48 | self.data = list(safe_data.values()) + list(alpaca)[:len(self.data) * 7] 49 | 50 | def __getitem__(self, index: int) -> RawSample: 51 | data = self.data[index] 52 | input = ( # pylint: disable=redefined-builtin 53 | ' '.join((data['instruction'], data['input'])) if data['input'] else data['instruction'] 54 | ) 55 | answer = data['output'] 56 | return RawSample(input=input, answer=answer) 57 | 58 | def __len__(self) -> int: 59 | return len(self.data) 60 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/aqua.py: -------------------------------------------------------------------------------- 1 | """MATH datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import ClassVar 7 | 8 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 9 | 10 | 11 | __all__ = [ 12 | 'AQuADataset', 13 | 'AQuATestDataset', 14 | 'AQuAPoTTestDataset', 15 | ] 16 | 17 | DATA_DIR = "path_to_dataset_folder" 18 | 19 | 20 | class AQuADataset(RawDataset): 21 | SPLIT: ClassVar[str] 22 | PTYPE: ClassVar[str] 23 | 24 | def __init__(self) -> None: 25 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'auqa/aqua_{self.SPLIT}.jsonl')) 26 | 27 | def __getitem__(self, index: int) -> RawSample: 28 | data = self.data[index] 29 | question = data['question'] + '\nAnswer Choices: (' + ' ('.join(data['options']) 30 | prompt = question + '\nWrite a Python program to solve this.' if self.PTYPE == 'pot' else question 31 | return RawSample( 32 | input=prompt, 33 | answer=data['rationale'], 34 | final_answer=data.get('correct', None), 35 | final_answer_content=data.get('correct', None), 36 | ) 37 | 38 | def __len__(self) -> int: 39 | return len(self.data) 40 | 41 | 42 | class AQuATestDataset(AQuADataset): 43 | NAME: str = 'AQuA/test' 44 | SPLIT: str = 'test' 45 | PTYPE: str = 'cot' 46 | 47 | 48 | class AQuAPoTTestDataset(AQuADataset): 49 | NAME: str = 'AQuACode/test' 50 | SPLIT: str = 'test' 51 | PTYPE: str = 'pot' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/arithmo.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | import os 5 | import regex 6 | from typing import ClassVar 7 | 8 | from datasets import load_dataset 9 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 10 | 11 | 12 | __all__ = [ 13 | 'ArithmoDataset', 14 | 'ArithmoTrainDataset', 15 | 'ArithmoTestDataset', 16 | 'ArithmoMATHTrainDataset', 17 | 'ArithmoMCQTrainDataset', 18 | 'ArithmoCodeTrainDataset', 19 | ] 20 | 21 | DATA_DIR = "path_to_dataset_folder" 22 | 23 | class ArithmoDataset(RawDataset): 24 | SPLIT: ClassVar[str] 25 | PATH: ClassVar[str] 26 | TYPE: ClassVar[str] 27 | 28 | def __init__(self, path: str | None = None) -> None: 29 | try: 30 | self.data = load_dataset(path or self.PATH, split=self.SPLIT) 31 | except: 32 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'arithmo/{self.SPLIT}.jsonl')) 33 | if self.TYPE == 'math': 34 | self.data = [dt for dt in self.data if ' answer is' in dt['answer']] 35 | elif self.TYPE == 'mcq': 36 | self.data = [ 37 | dt for dt in self.data if ' answer is' in dt['answer'] and not dt['answer'].startswith('The answer is') \ 38 | and 'answer choices' in dt['question'].lower() 39 | ] 40 | elif self.TYPE == 'code': 41 | self.data = [dt for dt in self.data if regex.search(r'print\(.+\)', dt['answer'])] 42 | 43 | def __getitem__(self, index: int) -> RawSample: 44 | data = self.data[index] 45 | return RawSample( 46 | input=data['question'], 47 | answer=data['answer'], 48 | ) 49 | 50 | def __len__(self) -> int: 51 | return len(self.data) 52 | 53 | 54 | class ArithmoTrainDataset(ArithmoDataset): 55 | NAME: str = 'Arithmo/train' 56 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train',) 57 | PATH: str = 'akjindal53244/Arithmo-Data' 58 | SPLIT: str = 'train' 59 | TYPE: str = 'all' 60 | 61 | 62 | class ArithmoTestDataset(ArithmoDataset): 63 | NAME: str = 'Arithmo/test' 64 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/test',) 65 | PATH: str = 'akjindal53244/Arithmo-Data' 66 | SPLIT: str = 'test' 67 | TYPE: str = 'all' 68 | 69 | 70 | class ArithmoMATHTrainDataset(ArithmoDataset): 71 | NAME: str = 'ArithmoMATH/train' 72 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train/mathqa',) 73 | PATH: str = 'akjindal53244/Arithmo-Data' 74 | SPLIT: str = 'train' 75 | TYPE: str = 'math' 76 | 77 | 78 | class ArithmoMCQTrainDataset(ArithmoDataset): 79 | NAME: str = 'ArithmoMCQ/train' 80 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train/mcq',) 81 | PATH: str = 'akjindal53244/Arithmo-Data' 82 | SPLIT: str = 'train' 83 | TYPE: str = 'mcq' 84 | 85 | 86 | class ArithmoCodeTrainDataset(ArithmoDataset): 87 | NAME: str = 'ArithmoCode/train' 88 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train/code',) 89 | PATH: str = 'akjindal53244/Arithmo-Data' 90 | SPLIT: str = 'train' 91 | TYPE: str = 'code' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/exam.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | from typing import ClassVar 5 | 6 | from datasets import load_dataset 7 | from mcts_rl.datasets.base import RawDataset, RawSample 8 | 9 | 10 | __all__ = [ 11 | 'ExamDataset', 12 | ] 13 | 14 | class ExamDataset(RawDataset): 15 | SPLIT: ClassVar[str] 16 | PATH: ClassVar[str] 17 | 18 | def __init__(self, path: str | None = None) -> None: 19 | self.data = load_dataset(path or self.PATH, split=self.SPLIT) 20 | 21 | def __getitem__(self, index: int) -> RawSample: 22 | data = self.data[index] 23 | return RawSample( 24 | input=data['Question'], 25 | answer=None, 26 | ) 27 | 28 | def __len__(self) -> int: 29 | return len(self.data) 30 | 31 | 32 | class ExamTestDataset(ExamDataset): 33 | NAME: str = 'Exam/test' 34 | PATH: str = 'keirp/hungarian_national_hs_finals_exam' 35 | SPLIT: str = 'test' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/firefly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Firefly (流萤) dataset for supervised instruction fine-tuning.""" 16 | 17 | from __future__ import annotations 18 | 19 | from datasets import load_dataset 20 | from mcts_rl.datasets.base import RawDataset, RawSample 21 | 22 | 23 | __all__ = ['FireflyDataset'] 24 | 25 | 26 | class FireflyDataset(RawDataset): 27 | NAME: str = 'firefly' 28 | 29 | def __init__(self, path: str | None = None) -> None: 30 | self.data = load_dataset(path or 'YeungNLP/firefly-train-1.1M', split='train') 31 | 32 | def __getitem__(self, index: int) -> RawSample: 33 | data = self.data[index] 34 | return RawSample(input=data['input'], answer=data['target']) 35 | 36 | def __len__(self) -> int: 37 | return len(self.data) 38 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/gsm8k.py: -------------------------------------------------------------------------------- 1 | """MATH datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import ClassVar 7 | from datasets import load_dataset 8 | 9 | from mcts_rl.datasets.base import RawDataset, RawSample 10 | from mcts_rl.utils import extract_answer, list_to_dict, get_math_data 11 | 12 | 13 | __all__ = [ 14 | 'GSM8KDataset', 15 | 'GSM8KTrainDataset', 16 | 'GSM8KTestDataset', 17 | 'GSM8KPoTTrainDataset', 18 | 'GSM8KPoTTestDataset', 19 | 'GSM8KSFTTrainDataset', 20 | 'GSM8KSFTTestDataset', 21 | ] 22 | 23 | 24 | class GSM8KDataset(RawDataset): 25 | SPLIT: ClassVar[str] 26 | PTYPE: ClassVar[str] 27 | DTYPE: ClassVar[str] 28 | 29 | def __init__(self) -> None: 30 | if self.PTYPE != 'pot': 31 | self.data = load_dataset('openai/gsm8k', 'main', split=self.SPLIT, trust_remote_code=True) 32 | else: 33 | raise ValueError('Do not Support PoT for now.') 34 | if self.DTYPE == 'arithmo': 35 | gsm_dict = list_to_dict(self.data) 36 | arithmo_dict = list_to_dict(get_math_data(load_dataset('akjindal53244/Arithmo-Data', split=self.SPLIT))) 37 | arithmo = {k:v for k, v in arithmo_dict.items() if k in gsm_dict} 38 | self.data = [vv for v in arithmo.values() for vv in v] 39 | # self.data = get_arithmo_data(arithmo) 40 | 41 | def __getitem__(self, index: int) -> RawSample: 42 | data = self.data[index] 43 | prompt = data['problem'] if 'problem' in data else data['question'] 44 | prompt = prompt + '\nWrite a Python program to solve this.' if self.PTYPE == 'pot' else prompt 45 | solution = data['solution'] if 'solution' in data else data['answer'] 46 | answer = extract_answer(solution) 47 | if self.DTYPE == 'default': 48 | solution = f'{solution}\nThe answer is {answer}' 49 | return RawSample( 50 | input=prompt, 51 | answer=solution, 52 | final_answer=answer, 53 | final_answer_content=answer, 54 | ) 55 | 56 | def __len__(self) -> int: 57 | return len(self.data) 58 | 59 | 60 | class GSM8KSFTTrainDataset(GSM8KDataset): 61 | NAME: str = 'GSM8KSFT/train' 62 | SPLIT: str = 'train' 63 | PTYPE: str = 'cot' 64 | DTYPE: str = 'arithmo' 65 | 66 | 67 | class GSM8KSFTTestDataset(GSM8KDataset): 68 | NAME: str = 'GSM8KSFT/test' 69 | SPLIT: str = 'test' 70 | PTYPE: str = 'cot' 71 | DTYPE: str = 'arithmo' 72 | 73 | 74 | class GSM8KTrainDataset(GSM8KDataset): 75 | NAME: str = 'GSM8K/train' 76 | SPLIT: str = 'train' 77 | PTYPE: str = 'cot' 78 | DTYPE: str = 'default' 79 | 80 | 81 | class GSM8KTestDataset(GSM8KDataset): 82 | NAME: str = 'GSM8K/test' 83 | SPLIT: str = 'test' 84 | PTYPE: str = 'cot' 85 | DTYPE: str = 'default' 86 | 87 | 88 | class GSM8KPoTTrainDataset(GSM8KDataset): 89 | NAME: str = 'GSM8KCode/train' 90 | SPLIT: str = 'train' 91 | PTYPE: str = 'pot' 92 | DTYPE: str = 'default' 93 | 94 | 95 | class GSM8KPoTTestDataset(GSM8KDataset): 96 | NAME: str = 'GSM8KCode/test' 97 | SPLIT: str = 'test' 98 | PTYPE: str = 'pot' 99 | DTYPE: str = 'default' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/hh_rlhf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helpful and Harmless Dialogue Datasets from Anthropic.""" 16 | 17 | from __future__ import annotations 18 | 19 | from typing import ClassVar 20 | 21 | from datasets import load_dataset 22 | from mcts_rl.datasets.base import RawDataset, RawSample 23 | 24 | 25 | __all__ = [ 26 | 'HhRLHFDialogueDataset', 27 | 'HhRLHFHarmlessDialogueDataset', 28 | 'HhRLHFHelpfulDialogueDataset', 29 | 'HhRLHFPreferenceDataset', 30 | 'HhRLHFHarmlessPreferenceTrainDataset', 31 | 'HhRLHFHarmlessPreferenceTestDataset', 32 | 'HhRLHFHelpfulPreferenceTrainDataset', 33 | 'HhRLHFHelpfulPreferenceTestDataset', 34 | ] 35 | 36 | 37 | class HhRLHFDialogueDataset(RawDataset): 38 | NAME: ClassVar[str] = 'hh-rlhf-dialogue' 39 | ALIASES: tuple[str, ...] = ('hh-dialogue',) 40 | DATA_DIR: ClassVar[str | None] = None 41 | 42 | def __init__(self, path: str | None = None) -> None: 43 | self.data = load_dataset( 44 | path or 'PKU-Alignment/processed-hh-rlhf', 45 | data_dir=self.DATA_DIR, 46 | split='train', 47 | ) 48 | 49 | def __getitem__(self, index: int) -> RawSample: 50 | data = self.data[index] 51 | dialogue = [content['text'] for content in data['context']] 52 | dialogue.append(data['chosen']['text']) 53 | return RawSample(dialogue=dialogue) 54 | 55 | def __len__(self) -> int: 56 | return len(self.data) 57 | 58 | 59 | class HhRLHFHarmlessDialogueDataset(HhRLHFDialogueDataset): 60 | NAME: str = 'hh-rlhf-harmless-dialogue' 61 | ALIASES: tuple[str, ...] = ( 62 | 'hh-rlhf-dialogue/harmless-base', 63 | 'hh-harmless-dialogue', 64 | 'hh-dialogue/harmless-base', 65 | ) 66 | DATA_DIR: str = 'harmless-base' 67 | 68 | 69 | class HhRLHFHelpfulDialogueDataset(HhRLHFDialogueDataset): 70 | NAME: str = 'hh-rlhf-helpful-dialogue' 71 | ALIASES: tuple[str, ...] = ( 72 | 'hh-rlhf-dialogue/helpful-base', 73 | 'hh-helpful-dialogue', 74 | 'hh-dialogue/helpful-base', 75 | ) 76 | DATA_DIR: str = 'helpful-base' 77 | 78 | 79 | class HhRLHFPreferenceDataset(RawDataset): 80 | NAME: ClassVar[str] = 'hh-rlhf-preference' 81 | ALIASES: tuple[str, ...] = ('hh-preference',) 82 | DATA_DIR: ClassVar[str | None] = None 83 | SPLIT: ClassVar[str] 84 | 85 | def __init__(self, path: str | None = None) -> None: 86 | self.data = load_dataset( 87 | path or 'PKU-Alignment/processed-hh-rlhf', 88 | data_dir=self.DATA_DIR, 89 | split=self.SPLIT, 90 | ) 91 | 92 | def __getitem__(self, index: int) -> RawSample: 93 | data = self.data[index] 94 | dialogue = [content['text'] for content in data['context']] 95 | answer = data['chosen']['text'] 96 | other_answer = data['rejected']['text'] 97 | 98 | return RawSample( 99 | input=dialogue, 100 | answer=answer, 101 | other_answer=other_answer, 102 | better=True, 103 | ) 104 | 105 | def __len__(self) -> int: 106 | return len(self.data) 107 | 108 | 109 | class HhRLHFHarmlessPreferenceTrainDataset(HhRLHFPreferenceDataset): 110 | NAME: str = 'hh-rlhf-harmless-preference/train' 111 | ALIASES: tuple[str, ...] = ( 112 | 'hh-rlhf-preference/harmless-base/train', 113 | 'hh-harmless-preference/train', 114 | 'hh-preference/harmless-base/train', 115 | ) 116 | DATA_DIR: str = 'harmless-base' 117 | SPLIT: str = 'train' 118 | 119 | 120 | class HhRLHFHarmlessPreferenceTestDataset(HhRLHFPreferenceDataset): 121 | NAME: str = 'hh-rlhf-harmless-preference/test' 122 | ALIASES: tuple[str, ...] = ( 123 | 'hh-rlhf-preference/harmless-base/test', 124 | 'hh-harmless-preference/test', 125 | 'hh-preference/harmless-base/test', 126 | ) 127 | DATA_DIR: str = 'harmless-base' 128 | SPLIT: str = 'test' 129 | 130 | 131 | class HhRLHFHelpfulPreferenceTrainDataset(HhRLHFPreferenceDataset): 132 | NAME: str = 'hh-rlhf-helpful-preference/train' 133 | ALIASES: tuple[str, ...] = ( 134 | 'hh-rlhf-preference/helpful-base/train', 135 | 'hh-helpful-preference/train', 136 | 'hh-preference/helpful-base/train', 137 | ) 138 | DATA_DIR: str = 'helpful-base' 139 | SPLIT: str = 'train' 140 | 141 | 142 | class HhRLHFHelpfulPreferenceTestDataset(HhRLHFPreferenceDataset): 143 | NAME: str = 'hh-rlhf-helpful-preference/test' 144 | ALIASES: tuple[str, ...] = ( 145 | 'hh-rlhf-preference/helpful-base/test', 146 | 'hh-helpful-preference/test', 147 | 'hh-preference/helpful-base/test', 148 | ) 149 | DATA_DIR: str = 'helpful-base' 150 | SPLIT: str = 'test' 151 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/math.py: -------------------------------------------------------------------------------- 1 | """MATH datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import ClassVar 7 | from datasets import load_dataset 8 | 9 | from mcts_rl.datasets.base import RawDataset, RawSample 10 | from mcts_rl.utils import extract_answer, get_math_data, list_to_dict, get_arithmo_data 11 | 12 | 13 | __all__ = [ 14 | 'MATHDataset', 15 | 'MATHTrainDataset', 16 | 'MATHTestDataset', 17 | 'MATHSFTTrainDataset', 18 | 'MATHSFTTestDataset', 19 | ] 20 | 21 | 22 | class MATHDataset(RawDataset): 23 | SPLIT: ClassVar[str] 24 | DTYPE: ClassVar[str] 25 | 26 | def __init__(self) -> None: 27 | self.data = load_dataset('hendrycks/competition_math', split=self.SPLIT, trust_remote_code=True) 28 | if self.DTYPE == 'arithmo': 29 | math_dict = list_to_dict(self.data) 30 | arithmo_dict = list_to_dict(get_math_data(load_dataset('akjindal53244/Arithmo-Data', split=self.SPLIT))) 31 | arithmo = {k:v for k, v in arithmo_dict.items() if k in math_dict} 32 | self.data = [vv for v in arithmo.values() for vv in v] 33 | # self.data = get_arithmo_data(arithmo) 34 | 35 | 36 | def __getitem__(self, index: int) -> RawSample: 37 | data = self.data[index] 38 | solution = data['solution'] 39 | answer = extract_answer(solution) 40 | if self.DTYPE == 'default': 41 | solution = f'{solution}\nThe answer is {answer}' 42 | return RawSample( 43 | input=data['problem'] if 'problem' in data else data['question'], 44 | answer=solution, 45 | final_answer=answer, 46 | final_answer_content=answer, 47 | ) 48 | 49 | def __len__(self) -> int: 50 | return len(self.data) 51 | 52 | 53 | class MATHTrainDataset(MATHDataset): 54 | NAME: str = 'MATH/train' 55 | SPLIT: str = 'train' 56 | DTYPE: str = 'default' 57 | 58 | 59 | class MATHTestDataset(MATHDataset): 60 | NAME: str = 'MATH/test' 61 | SPLIT: str = 'test' 62 | DTYPE: str = 'default' 63 | 64 | 65 | class MATHSFTTrainDataset(MATHDataset): 66 | NAME: str = 'MATHSFT/train' 67 | SPLIT: str = 'train' 68 | DTYPE: str = 'arithmo' 69 | 70 | 71 | class MATHSFTTestDataset(MATHDataset): 72 | NAME: str = 'MATHSFT/test' 73 | SPLIT: str = 'test' 74 | DTYPE: str = 'arithmo' 75 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/math_qa.py: -------------------------------------------------------------------------------- 1 | """MATH datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import ClassVar 7 | 8 | from datasets import load_dataset 9 | from mcts_rl.utils import get_math_data, get_arithmo_data, list_to_dict, tqdm 10 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 11 | 12 | 13 | __all__ = [ 14 | 'MathQADataset', 15 | 'MathQATrainDataset', 16 | 'MathQATestDataset', 17 | 'MathQACodeTrainDataset', 18 | 'MathQACodeTestDataset', 19 | 'MathQASFTTrainDataset', 20 | ] 21 | 22 | DATA_DIR = "path_to_dataset_folder" 23 | 24 | 25 | class MathQADataset(RawDataset): 26 | SPLIT: ClassVar[str] 27 | TYPE: ClassVar[str] 28 | 29 | def __init__(self) -> None: 30 | if self.TYPE == 'pot': 31 | raise ValueError('Do not Support PoT for now.') 32 | ## PoT data 33 | gsm8k = jsonlines_load(os.path.join(DATA_DIR, f'gsm8k/gsm8k_{self.SPLIT}.jsonl')) 34 | raw_arithmo = jsonlines_load(os.path.join(DATA_DIR, f'arithmo/arithmo_code_{self.SPLIT}.jsonl')) 35 | arithmo = [] 36 | for dt in raw_arithmo: 37 | prompt = dt['question'] if 'question' in dt else dt['problem'] 38 | if all(not prompt.strip().startswith(x['question'].strip()) for x in gsm8k): 39 | arithmo.append(dt) 40 | for i, dt in enumerate(gsm8k): 41 | gsm8k[i]['question'] = dt['question'] + ' Write a Python program to solve this.' 42 | self.data = gsm8k 43 | elif self.TYPE == 'all': 44 | raise ValueError('Do not Support PoT for now.') 45 | ## CoT + PoT data 46 | gsm8k = jsonlines_load(os.path.join(DATA_DIR, f'gsm8k/gsm8k_{self.SPLIT}.jsonl')) 47 | math = jsonlines_load(os.path.join(DATA_DIR, f'math/math_{self.SPLIT}.jsonl')) 48 | self.data = gsm8k + math 49 | if self.SPLIT == 'train': 50 | raw_arithmo = jsonlines_load(os.path.join(DATA_DIR, f'arithmo/arithmo_code_{self.SPLIT}.jsonl')) 51 | for dt in raw_arithmo[::-1]: 52 | prompt = dt['question'] if 'question' in dt else dt['problem'] 53 | if any(prompt.strip().startswith(x['question'].strip()) for x in gsm8k): 54 | self.data.append(dt) 55 | else: 56 | for i, dt in enumerate(gsm8k[:]): 57 | dt['question'] = dt['question'] + ' Write a Python program to solve this.' 58 | self.data.append(dt) 59 | else: 60 | gsm8k = load_dataset('openai/gsm8k', 'main', split=self.SPLIT, trust_remote_code=True) 61 | math = load_dataset('hendrycks/competition_math', split=self.SPLIT, trust_remote_code=True) 62 | try: 63 | arithmo = get_math_data(load_dataset('akjindal53244/Arithmo-Data', split=self.SPLIT)) 64 | except: 65 | arithmo = get_math_data(jsonlines_load(os.path.join(DATA_DIR, 'arithmo/train.jsonl'))) 66 | if self.TYPE == 'sft': 67 | arithmo, gsm8k, math = list_to_dict(arithmo), list_to_dict(gsm8k), list_to_dict(math) 68 | ## use the corresponding training data seen in SFT 69 | mathqa_dict = {k:v for k,v in arithmo.items() if k in math or k in gsm8k} 70 | self.data = [vv for v in mathqa_dict.values() for vv in v] 71 | # self.data = get_arithmo_data(mathqa_dict) 72 | else: 73 | self.data = gsm8k + math 74 | 75 | def __getitem__(self, index: int) -> RawSample: 76 | data = self.data[index] 77 | prompt = data['question'] if 'question' in data else data['problem'] 78 | return RawSample( 79 | input=prompt, 80 | answer=data['solution'], 81 | final_answer=data.get('answer', None), 82 | final_answer_content=data.get('answer_content', data.get('answer', None)), 83 | ) 84 | 85 | def __len__(self) -> int: 86 | return len(self.data) 87 | 88 | 89 | class MathQASFTTrainDataset(MathQADataset): 90 | NAME: str = 'MathQASFT/train' 91 | SPLIT: str = 'train' 92 | TYPE: str = 'sft' 93 | 94 | 95 | class MathQATrainDataset(MathQADataset): 96 | NAME: str = 'MathQA/train' 97 | SPLIT: str = 'train' 98 | TYPE: str = 'cot' 99 | 100 | 101 | class MathQAAllTrainDataset(MathQADataset): 102 | NAME: str = 'MathQAAll/train' 103 | SPLIT: str = 'train' 104 | TYPE: str = 'all' 105 | 106 | 107 | class MathQAAllTestDataset(MathQADataset): 108 | NAME: str = 'MathQAAll/test' 109 | SPLIT: str = 'test' 110 | TYPE: str = 'all' 111 | 112 | 113 | class MathQATestDataset(MathQADataset): 114 | NAME: str = 'MathQA/test' 115 | SPLIT: str = 'test' 116 | TYPE: str = 'cot' 117 | 118 | 119 | class MathQACodeTrainDataset(MathQADataset): 120 | NAME: str = 'MathQACode/train' 121 | SPLIT: str = 'train' 122 | TYPE: str = 'pot' 123 | 124 | 125 | class MathQACodeTestDataset(MathQADataset): 126 | NAME: str = 'MathQACode/test' 127 | SPLIT: str = 'test' 128 | TYPE: str = 'pot' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/mcq.py: -------------------------------------------------------------------------------- 1 | """CSR Datasets""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import ClassVar 7 | 8 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 9 | 10 | 11 | __all__ = ['MCQDataset'] 12 | 13 | DATA_DIR = "path_to_dataset_folder" 14 | 15 | 16 | class MCQDataset(RawDataset): 17 | SPLIT: ClassVar[str] 18 | DTYPE: ClassVar[str] 19 | 20 | def __init__(self, path: str | None = None) -> None: 21 | if self.DTYPE == 'all': 22 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'csr/mcq_{self.SPLIT}.jsonl')) 23 | else: 24 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'csr/mcq_{self.DTYPE}_{self.SPLIT}.jsonl')) 25 | 26 | def __getitem__(self, index: int) -> RawSample: 27 | data = self.data[index] 28 | question = data['question'] 29 | return RawSample(input=question, final_answer=data['answer'], 30 | final_answer_content=data.get('answer_content', data['answer'])) 31 | 32 | def __len__(self) -> int: 33 | return len(self.data) 34 | 35 | 36 | class MCQTrainDataset(MCQDataset): 37 | NAME: str = 'MCQ/train' 38 | DTYPE: str = 'all' 39 | SPLIT: str = 'train' 40 | 41 | 42 | class MCQTestDataset(MCQDataset): 43 | NAME: str = 'MCQ/test' 44 | DTYPE: str = 'all' 45 | SPLIT: str = 'test' 46 | 47 | 48 | class SQATrainDataset(MCQDataset): 49 | NAME: str = 'SQA/train' 50 | DTYPE: str = 'sqa' 51 | SPLIT: str = 'train' 52 | 53 | 54 | class CSRTrainDataset(MCQDataset): 55 | NAME: str = 'CSR/train' 56 | DTYPE: str = 'csqa' 57 | SPLIT: str = 'train' 58 | 59 | 60 | class SQATestDataset(MCQDataset): 61 | NAME: str = 'SQA/test' 62 | DTYPE: str = 'sqa' 63 | SPLIT: str = 'fulltest' 64 | 65 | 66 | class CSRTestDataset(MCQDataset): 67 | NAME: str = 'CSR/test' 68 | DTYPE: str = 'csqa' 69 | SPLIT: str = 'test' 70 | 71 | 72 | class SciQTestDataset(MCQDataset): 73 | NAME: str = 'SciQ/test' 74 | DTYPE: str = 'sciq' 75 | SPLIT: str = 'test' 76 | 77 | 78 | class NLITestDataset(MCQDataset): 79 | NAME: str = 'NLI/test' 80 | DTYPE: str = 'nli' 81 | SPLIT: str = 'test' 82 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/mcq_for_eval.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | import os 5 | from typing import ClassVar 6 | 7 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 8 | from mcts_rl.configs.constants import HINTED_EVAL_PROMPT 9 | 10 | 11 | __all__ = [ 12 | 'MCQEvalDataset', 13 | ] 14 | 15 | DATA_DIR = "path_to_dataset_folder" 16 | 17 | 18 | class MCQEvalDataset(RawDataset): 19 | SPLIT: ClassVar[str] 20 | DTYPE: ClassVar[str] 21 | 22 | def __init__(self, path: str | None = None) -> None: 23 | data = jsonlines_load(os.path.join(DATA_DIR, 'arithmo/stepwise_generations.jsonl')) 24 | self.data = [] 25 | for dt in data: 26 | prompt = dt['question'] 27 | solution = dt['solution'] 28 | eval_prompt = HINTED_EVAL_PROMPT.format( 29 | input=f'{prompt}\n\n', 30 | solution=dt['answer'], 31 | prompt=solution, 32 | ).replace('\n\nANSWER: The answer is', '').strip() 33 | self.data.append({ 34 | 'question': eval_prompt, 35 | 'answer': '', 36 | }) 37 | 38 | 39 | def __getitem__(self, index: int) -> RawSample: 40 | data = self.data[index] 41 | return RawSample( 42 | input=data['question'], 43 | final_answer=data['answer'], 44 | ) 45 | 46 | def __len__(self) -> int: 47 | return len(self.data) 48 | 49 | 50 | class SQAEvalTrainDataset(MCQEvalDataset): 51 | NAME: str = 'SQAEval/train' 52 | DTYPE: str = 'sqa_all' 53 | SPLIT: str = 'train' 54 | 55 | 56 | class SQAEvalTestDataset(MCQEvalDataset): 57 | NAME: str = 'SQAEval/test' 58 | DTYPE: str = 'sqa' 59 | SPLIT: str = 'train' 60 | 61 | 62 | class CSREvalTrainDataset(MCQEvalDataset): 63 | NAME: str = 'CSREval/train' 64 | DTYPE: str = 'csr_all' 65 | SPLIT: str = 'train' 66 | 67 | 68 | class CSREvalTestDataset(MCQEvalDataset): 69 | NAME: str = 'CSREval/test' 70 | DTYPE: str = 'csr' 71 | SPLIT: str = 'test' 72 | 73 | 74 | class GSMEvalTrainDataset(MCQEvalDataset): 75 | NAME: str = 'GSMEval/train' 76 | DTYPE: str = 'gsm_all' 77 | SPLIT: str = 'train' 78 | 79 | 80 | class GSMEvalTestDataset(MCQEvalDataset): 81 | NAME: str = 'GSMEval/test' 82 | DTYPE: str = 'gsm' 83 | SPLIT: str = 'train' 84 | 85 | 86 | class GSMEvalTestDataset(MCQEvalDataset): 87 | NAME: str = 'arithmo/test' 88 | DTYPE: str = 'arithmo' 89 | SPLIT: str = 'train' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/mcq_pairs.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | import os 5 | from typing import ClassVar 6 | 7 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 8 | 9 | 10 | __all__ = [ 11 | 'MCQPreferenceDataset', 12 | ] 13 | 14 | DATA_DIR = "path_to_dataset_folder" 15 | 16 | 17 | class MCQPreferenceDataset(RawDataset): 18 | SPLIT: ClassVar[str] 19 | DTYPE: ClassVar[str] 20 | 21 | def __init__(self, path: str | None = None) -> None: 22 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'{self.DTYPE}_pairs_{self.SPLIT}.jsonl')) 23 | 24 | def __getitem__(self, index: int) -> RawSample: 25 | data = self.data[index] 26 | return RawSample( 27 | input=data['prompt'].replace('QUESTION: ', ''), 28 | answer=f"\n{data['response_0']}", 29 | other_answer=f"\n{data['response_1']}", 30 | better=True, 31 | is_safe=bool(data['is_response_0_correct']), 32 | is_other_safe=bool(data['is_response_1_correct']), 33 | ) 34 | 35 | def __len__(self) -> int: 36 | return len(self.data) 37 | 38 | 39 | class SQAPreferenceTrainDataset(MCQPreferenceDataset): 40 | NAME: str = 'SQAPreference/train' 41 | DTYPE: str = 'sqa_all' 42 | SPLIT: str = 'train' 43 | 44 | 45 | class SQAPreferenceTestDataset(MCQPreferenceDataset): 46 | NAME: str = 'SQAPreference/test' 47 | DTYPE: str = 'sqa' 48 | SPLIT: str = 'train' 49 | 50 | 51 | class CSRPreferenceTrainDataset(MCQPreferenceDataset): 52 | NAME: str = 'CSRPreference/train' 53 | DTYPE: str = 'csr_all' 54 | SPLIT: str = 'train' 55 | 56 | 57 | class CSRPreferenceTestDataset(MCQPreferenceDataset): 58 | NAME: str = 'CSRPreference/test' 59 | DTYPE: str = 'csr' 60 | SPLIT: str = 'test' 61 | 62 | 63 | class GSMPreferenceTrainDataset(MCQPreferenceDataset): 64 | NAME: str = 'GSMPreference/train' 65 | DTYPE: str = 'gsm' 66 | SPLIT: str = 'train' 67 | 68 | 69 | class GSMPreferenceTestDataset(MCQPreferenceDataset): 70 | NAME: str = 'GSMPreference/test' 71 | DTYPE: str = 'gsm' 72 | SPLIT: str = 'test' -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/moss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """MOSS datasets for supervised instruction fine-tuning.""" 16 | 17 | from __future__ import annotations 18 | 19 | import json 20 | import pathlib 21 | import re 22 | import zipfile 23 | 24 | from mcts_rl.datasets.base import RawDataset, RawSample 25 | 26 | 27 | __all__ = ['MOSS002SFT', 'MOSS003SFT'] 28 | 29 | 30 | class MOSS002SFT(RawDataset): 31 | NAME: str = 'moss-002-sft' 32 | PATTERN = re.compile( 33 | r""" 34 | ^ 35 | \[(?P(?PHuman)|(?PMOSS))\]: 36 | \s* 37 | (?P.*?) 38 | \s* 39 | (?(human)|) 40 | \s* 41 | """, 42 | flags=re.DOTALL | re.VERBOSE, 43 | ) 44 | 45 | def __init__(self, path: str | None = None) -> None: 46 | if path is None: # fnlp/moss-002-sft-data cannot load with `load_dataset` 47 | raise ValueError('moss-002-sft dataset requires a local path to the dataset.') 48 | 49 | path = pathlib.Path(path).expanduser().absolute() 50 | if not path.exists(): 51 | raise ValueError('moss-002-sft dataset path does not exist.') 52 | if not path.is_dir(): 53 | raise ValueError('moss-002-sft dataset path is not a directory.') 54 | 55 | data_files = sorted(path.glob('*.json')) 56 | self.data = [] 57 | for file in data_files: 58 | with file.open(mode='rt', encoding='utf-8') as f: 59 | self.data.extend(json.load(f)) 60 | 61 | def __getitem__(self, index: int) -> RawSample: 62 | data = self.data[index] 63 | plain_text = data['plain_text'].strip() 64 | if not plain_text.startswith(('[Human]:', '[MOSS]:')): 65 | raise ValueError(f'Invalid plain text: {plain_text}') 66 | 67 | dialogue = [] 68 | text = plain_text 69 | while len(text) > 0: 70 | match = self.PATTERN.match(text) 71 | if match is None: 72 | raise ValueError(f'Invalid plain text: {plain_text}') 73 | if (match.group('human') is not None and len(dialogue) % 2 != 0) or ( 74 | match.group('assistant') is not None and len(dialogue) % 2 != 1 75 | ): 76 | raise ValueError(f'Invalid plain text: {plain_text}') 77 | dialogue.append(match.group('value')) 78 | text = text[match.end() :] 79 | 80 | return RawSample(dialogue=dialogue) 81 | 82 | def __len__(self) -> int: 83 | return len(self.data) 84 | 85 | 86 | class MOSS003SFT(RawDataset): 87 | NAME: str = 'moss-003-sft' 88 | 89 | def __init__(self, path: str | None = None) -> None: 90 | if path is None: # fnlp/moss-003-sft-data cannot load with `load_dataset` 91 | raise ValueError('moss-003-sft dataset requires a local path to the dataset.') 92 | 93 | path = pathlib.Path(path).expanduser().absolute() 94 | if not path.exists(): 95 | raise ValueError('moss-003-sft dataset path does not exist.') 96 | if not path.is_dir(): 97 | raise ValueError('moss-003-sft dataset path is not a directory.') 98 | 99 | data_file = path / 'moss-003-sft-no-tools.jsonl' 100 | archive_file = path / 'moss-003-sft-no-tools.jsonl.zip' 101 | 102 | if not data_file.exists(): 103 | if not archive_file.exists(): 104 | raise ValueError('moss-003-sft dataset requires a local path to the dataset.') 105 | with zipfile.ZipFile(archive_file, mode='r') as archive: 106 | archive.extractall(path) 107 | 108 | self.data = [] 109 | with data_file.open(mode='rt', encoding='utf-8') as f: 110 | for line in f: 111 | self.data.append(json.loads(line)) 112 | 113 | def __getitem__(self, index: int) -> RawSample: 114 | data = self.data[index] 115 | num_turns = data['num_turns'] 116 | chat = data['chat'] 117 | dialogue = [] 118 | for i in range(1, num_turns + 1): 119 | turn = chat[f'turn_{i}'] 120 | dialogue.append(turn['Human'].replace('<|Human|>:', '').replace('', '').strip()) 121 | dialogue.append(turn['MOSS'].replace('<|MOSS|>', '').replace('', '').strip()) 122 | return RawSample(dialogue=dialogue) 123 | 124 | def __len__(self) -> int: 125 | return len(self.data) 126 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/prm800k.py: -------------------------------------------------------------------------------- 1 | """PRM800K preference datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import ClassVar 7 | 8 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load 9 | 10 | 11 | __all__ = [ 12 | 'PRM800KDataset', 13 | 'PRM800KTrainDataset', 14 | 'PRM800KTestDataset', 15 | ] 16 | 17 | DATA_DIR = "path_to_dataset_folder" 18 | 19 | 20 | class PRM800KDataset(RawDataset): 21 | SPLIT: ClassVar[str] 22 | PATH: ClassVar[str] 23 | 24 | def __init__(self, path: str | None = None) -> None: 25 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'prm800k/preference_prm_{self.SPLIT}.jsonl')) 26 | 27 | def __getitem__(self, index: int) -> RawSample: 28 | data = self.data[index] 29 | prompt = data['prompt'] 30 | # prompt = f'{prompt}\n\nANSWER: {data["solution"]["solution"]}\nThe answer is {data["solution"]["answer"]}\n\nQUESTION: {prompt}' 31 | # from mcts_rl.utils import extract_answer, math_equal 32 | # if not math_equal(extract_answer(data.get('generation-base', 'None')), data.get('answer', None)): 33 | # prompt = '{}\n\nANSWER: {}\n\nREVISION REQUEST: Please revise the above answer to get the correct answer {}.'.format( 34 | # prompt, data.get('generation-base', 'None'), data.get('answer', 'None'), 35 | # ) 36 | # import ipdb; ipdb.set_trace() 37 | return RawSample( 38 | input=prompt, 39 | # answer=data['response_0'], 40 | answer=data['solution']['solution'], 41 | other_answer=data['response_1'], 42 | better=int(data['better_response_id']) == 0, 43 | safer=int(data['better_response_id']) == 0, 44 | is_safe=bool(data['is_response_0_correct_answer']), 45 | is_other_safe=bool(data['is_response_1_correct_answer']), 46 | final_answer=data.get('answer', None), 47 | ) 48 | 49 | def __len__(self) -> int: 50 | return len(self.data) 51 | 52 | 53 | class PRM800KTrainDataset(PRM800KDataset): 54 | NAME: str = 'PRM800K/train' 55 | ALIASES: tuple[str, ...] = ('OpenAI/PRM800K/train',) 56 | PATH: str = 'OpenAI/PRM800K' 57 | SPLIT: str = 'train' 58 | 59 | 60 | class PRM800KTestDataset(PRM800KDataset): 61 | NAME: str = 'PRM800K/test' 62 | ALIASES: tuple[str, ...] = ('OpenAI/PRM800K/test',) 63 | PATH: str = 'OpenAI/PRM800K' 64 | SPLIT: str = 'test' 65 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/qa_feedback.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | from typing import ClassVar 5 | 6 | from datasets import load_dataset 7 | from mcts_rl.datasets.base import RawDataset, RawSample 8 | from mcts_rl.configs.constants import QA_EVAL_PROMPT, PROMPT_ASSISTANT 9 | 10 | 11 | __all__ = [ 12 | 'QAFBDataset', 13 | 'QAFBTrainDataset', 14 | 'QAFBTestDataset', 15 | ] 16 | 17 | 18 | class QAFBDataset(RawDataset): 19 | SPLIT: ClassVar[str] 20 | PATH: ClassVar[str] 21 | 22 | def __init__(self, path: str | None = None) -> None: 23 | self.data = load_dataset(path or self.PATH, split=self.SPLIT) 24 | 25 | def __getitem__(self, index: int) -> RawSample: 26 | data = self.data[index] 27 | question = data['question'] 28 | # question = QA_EVAL_PROMPT.format(input=question, prompt=PROMPT_ASSISTANT + f' {data["answer"]}').rstrip('ASSISTANT: ').rstrip() 29 | return RawSample( 30 | input=question, 31 | final_answer=data['answer'], 32 | final_answer_content='\n'.join([f'{r}: {e}' for r, e in zip(data['feedback']['rating'], data['feedback']['explanation'])]) 33 | ) 34 | 35 | def __len__(self) -> int: 36 | return len(self.data) 37 | 38 | 39 | class QAFBTrainDataset(QAFBDataset): 40 | NAME: str = 'FeedbackQA/train' 41 | PATH: str = 'McGill-NLP/feedbackQA' 42 | SPLIT: str = 'train' 43 | 44 | 45 | class QAFBTestDataset(QAFBDataset): 46 | NAME: str = 'FeedbackQA/test' 47 | PATH: str = 'McGill-NLP/feedbackQA' 48 | SPLIT: str = 'test' 49 | -------------------------------------------------------------------------------- /mcts_rl/datasets/raw/safe_rlhf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Safe-RLHF preference datasets.""" 16 | 17 | from __future__ import annotations 18 | 19 | from typing import ClassVar 20 | 21 | from datasets import load_dataset 22 | from mcts_rl.datasets.base import RawDataset, RawSample 23 | 24 | 25 | __all__ = [ 26 | 'SafeRLHFDataset', 27 | 'SafeRLHFTrainDataset', 28 | 'SafeRLHFTestDataset', 29 | 'SafeRLHF10KTrainDataset', 30 | ] 31 | 32 | 33 | class SafeRLHFDataset(RawDataset): 34 | SPLIT: ClassVar[str] 35 | PATH: ClassVar[str] 36 | 37 | def __init__(self, path: str | None = None) -> None: 38 | self.data = load_dataset(path or self.PATH, split=self.SPLIT) 39 | 40 | def __getitem__(self, index: int) -> RawSample: 41 | data = self.data[index] 42 | return RawSample( 43 | input=data['prompt'], 44 | answer=data['response_0'], 45 | other_answer=data['response_1'], 46 | better=int(data['better_response_id']) == 0, 47 | safer=int(data['safer_response_id']) == 0, 48 | is_safe=bool(data['is_response_0_safe']), 49 | is_other_safe=bool(data['is_response_1_safe']), 50 | ) 51 | 52 | def __len__(self) -> int: 53 | return len(self.data) 54 | 55 | 56 | class SafeRLHFTrainDataset(SafeRLHFDataset): 57 | NAME: str = 'PKU-SafeRLHF/train' 58 | ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF/train',) 59 | PATH: str = 'PKU-Alignment/PKU-SafeRLHF' 60 | SPLIT: str = 'train' 61 | 62 | 63 | class SafeRLHFTestDataset(SafeRLHFDataset): 64 | NAME: str = 'PKU-SafeRLHF/test' 65 | ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF/test',) 66 | PATH: str = 'PKU-Alignment/PKU-SafeRLHF' 67 | SPLIT: str = 'test' 68 | 69 | 70 | class SafeRLHF10KTrainDataset(SafeRLHFDataset): 71 | NAME: str = 'PKU-SafeRLHF-10K/train' 72 | ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-10K/train',) 73 | PATH: str = 'PKU-Alignment/PKU-SafeRLHF-10K' 74 | SPLIT: str = 'train' 75 | -------------------------------------------------------------------------------- /mcts_rl/datasets/safety_preference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Callable 19 | from typing_extensions import TypedDict # Python 3.10+ 20 | 21 | import torch 22 | 23 | from mcts_rl.datasets.base import CollatorBase, RawSample, TokenizedDataset 24 | from mcts_rl.datasets.utils import format_prompt, right_padding 25 | 26 | 27 | __all__ = [ 28 | 'SafetyPreferenceDataset', 29 | 'SafetyPreferenceCollator', 30 | 'SafetyPreferenceSample', 31 | 'SafetyPreferenceBatch', 32 | ] 33 | 34 | 35 | class SafetyPreferenceSample(TypedDict, total=True): 36 | safer_input_ids: torch.LongTensor # size = (L,) 37 | # +1 for safe / -1 for unsafe 38 | safer_sign: torch.LongTensor # size = (L,) 39 | 40 | unsafer_input_ids: torch.LongTensor # size = (L,) 41 | # +1 for safe / -1 for unsafe 42 | unsafer_sign: torch.LongTensor # size = (L,) 43 | 44 | 45 | class SafetyPreferenceBatch(TypedDict, total=True): 46 | safer_input_ids: torch.LongTensor # size = (B, L) 47 | safer_attention_mask: torch.BoolTensor # size = (B, L) 48 | # +1 for safe / -1 for unsafe 49 | safer_safety_sign: torch.LongTensor # size = (B,) 50 | 51 | unsafer_input_ids: torch.LongTensor # size = (B, L) 52 | unsafer_attention_mask: torch.BoolTensor # size = (B, L) 53 | # +1 for safe / -1 for unsafe 54 | unsafer_safety_sign: torch.LongTensor # size = (B,) 55 | 56 | 57 | class SafetyPreferenceDataset(TokenizedDataset): 58 | def preprocess(self, raw_sample: RawSample) -> SafetyPreferenceSample: 59 | prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token) 60 | answer = raw_sample['answer'] 61 | other_answer = raw_sample['other_answer'] 62 | safer = raw_sample['safer'] 63 | is_safe = raw_sample['is_safe'] 64 | is_other_safe = raw_sample['is_other_safe'] 65 | 66 | safer_answer, unsafer_answer = answer, other_answer 67 | safer_sign, unsafer_sign = ( # +1 for safe / -1 for unsafe 68 | 2 * int(is_safe) - 1, 69 | 2 * int(is_other_safe) - 1, 70 | ) 71 | if not safer: 72 | safer_answer, unsafer_answer = unsafer_answer, safer_answer 73 | safer_sign, unsafer_sign = unsafer_sign, safer_sign 74 | 75 | if safer_sign < unsafer_sign: 76 | raise ValueError( 77 | 'The safer answer is not safer than the unsafer answer.\n\n' 78 | f'Prompt: {prompt}\n\n' 79 | f'Safer answer (labeled as unsafe): {safer_answer}\n\n' 80 | f'Unsafer answer (labeled as safe): {unsafer_answer}', 81 | ) 82 | 83 | # size = (L,) 84 | safer_input_ids = self.tokenize(prompt + safer_answer + self.tokenizer.eos_token) 85 | unsafer_input_ids = self.tokenize(prompt + unsafer_answer + self.tokenizer.eos_token) 86 | if ( 87 | safer_input_ids.size() == unsafer_input_ids.size() 88 | and torch.all(torch.eq(safer_input_ids, unsafer_input_ids)).item() 89 | ): 90 | raise ValueError( 91 | 'Two responses get the same `input_ids` after tokenization.\n\n' 92 | f'Prompt: {prompt}\n\n' 93 | f'Safer answer: {safer_answer}\n\n' 94 | f'Unsafer answer: {unsafer_answer}', 95 | ) 96 | return { 97 | 'safer_input_ids': safer_input_ids, # size = (L,) 98 | 'safer_sign': torch.tensor(safer_sign), # size = () 99 | 'unsafer_input_ids': unsafer_input_ids, # size = (L,) 100 | 'unsafer_sign': torch.tensor(unsafer_sign), # size = () 101 | } 102 | 103 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: 104 | return SafetyPreferenceCollator(self.tokenizer.pad_token_id) 105 | 106 | 107 | class SafetyPreferenceCollator(CollatorBase): 108 | def __call__(self, samples: list[SafetyPreferenceSample]) -> SafetyPreferenceBatch: 109 | input_ids = [sample['safer_input_ids'] for sample in samples] + [ 110 | sample['unsafer_input_ids'] for sample in samples 111 | ] 112 | attention_mask = [ 113 | input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids 114 | ] 115 | safety_sign = [sample['safer_sign'] for sample in samples] + [ 116 | sample['unsafer_sign'] for sample in samples 117 | ] 118 | 119 | # size = (2 * B, L) 120 | input_ids = right_padding(input_ids, padding_value=self.pad_token_id) 121 | attention_mask = right_padding(attention_mask, padding_value=0) 122 | # size = (2 * B,) 123 | safety_sign = torch.tensor(safety_sign, dtype=torch.long) 124 | 125 | # size = (B, L) 126 | safer_input_ids, unsafer_input_ids = input_ids.chunk(chunks=2, dim=0) 127 | safer_attention_mask, unsafer_attention_mask = attention_mask.chunk(chunks=2, dim=0) 128 | # size = (B,) 129 | safer_safety_sign, unsafer_safety_sign = safety_sign.chunk(chunks=2, dim=0) 130 | return { 131 | 'safer_input_ids': safer_input_ids, # size = (B, L) 132 | 'safer_attention_mask': safer_attention_mask, # size = (B, L) 133 | 'safer_safety_sign': safer_safety_sign, # size = (B,) 134 | 'unsafer_input_ids': unsafer_input_ids, # size = (B, L) 135 | 'unsafer_attention_mask': unsafer_attention_mask, # size = (B, L) 136 | 'unsafer_safety_sign': unsafer_safety_sign, # size = (B,) 137 | } 138 | -------------------------------------------------------------------------------- /mcts_rl/datasets/supervised.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Callable 19 | from typing_extensions import TypedDict # Python 3.10+ 20 | 21 | import torch 22 | 23 | from mcts_rl.configs import IGNORE_INDEX, PROMPT_ASSISTANT, PROMPT_BEGIN, PROMPT_USER 24 | from mcts_rl.datasets.base import CollatorBase, RawSample, TokenizedDataset 25 | from mcts_rl.datasets.utils import format_prompt, right_padding 26 | 27 | 28 | __all__ = [ 29 | 'SupervisedDataset', 30 | 'SupervisedCollator', 31 | 'SupervisedSample', 32 | 'SupervisedBatch', 33 | ] 34 | 35 | 36 | class SupervisedSample(TypedDict, total=True): 37 | input_ids: torch.LongTensor # size = (L,) 38 | labels: torch.LongTensor # size = (L,) 39 | 40 | 41 | class SupervisedBatch(TypedDict, total=True): 42 | input_ids: torch.LongTensor # size = (B, L) 43 | labels: torch.LongTensor # size = (B, L) 44 | attention_mask: torch.BoolTensor # size = (B, L) 45 | 46 | 47 | class SupervisedDataset(TokenizedDataset): 48 | def preprocess(self, raw_sample: RawSample) -> SupervisedSample: 49 | if raw_sample.get('input') is None and raw_sample.get('dialogue') is None: 50 | raise ValueError('Either `input` or `dialogue` must be provided.') 51 | if raw_sample.get('input') is not None and raw_sample.get('dialogue') is not None: 52 | raise ValueError('At most one of `input` and `dialogue` can be provided.') 53 | 54 | if raw_sample.get('input') is not None: 55 | input = raw_sample['input'] # pylint: disable=redefined-builtin 56 | if not isinstance(input, str): 57 | raise ValueError(f'Unsupported type of `input`: {type(input)}. Expected: str.') 58 | prompt = format_prompt(input=input, eos_token=self.tokenizer.eos_token) 59 | answer = raw_sample['answer'] 60 | if len(raw_sample) == 2 and PROMPT_ASSISTANT.endswith('think step by step.'): 61 | prompt = prompt.split(PROMPT_ASSISTANT)[0] + 'ANSWER: ' 62 | if raw_sample.get('final_answer', None) is not None: 63 | text = f'{prompt}{answer}\nThe answer is: {raw_sample["final_answer"]}' + self.tokenizer.eos_token 64 | else: 65 | text = prompt + answer + self.tokenizer.eos_token 66 | input_ids = self.tokenize(text) 67 | labels = input_ids.clone() 68 | # Mask non-assistant input 69 | labels[: len(self.tokenize(prompt))] = IGNORE_INDEX 70 | return {'input_ids': input_ids, 'labels': labels} 71 | 72 | dialogue = raw_sample['dialogue'] # is not None 73 | text = PROMPT_BEGIN 74 | offsets = [0] 75 | input_ids = torch.empty(0, dtype=torch.long) 76 | for i, line in enumerate(dialogue): 77 | if i % 2 == 0: 78 | # User input 79 | text += PROMPT_USER.format(input=line) + PROMPT_ASSISTANT 80 | else: 81 | # Assistant input 82 | text += line + self.tokenizer.eos_token 83 | input_ids = self.tokenize(text) 84 | offsets.append(len(input_ids)) 85 | 86 | labels = input_ids.clone() 87 | # Mask non-assistant input 88 | for begin, end in zip(offsets[::2], offsets[1::2]): 89 | labels[begin:end] = IGNORE_INDEX 90 | 91 | return { 92 | 'input_ids': input_ids, # size = (L,) 93 | 'labels': labels, # size = (L,) 94 | } 95 | 96 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: 97 | return SupervisedCollator(self.tokenizer.pad_token_id) 98 | 99 | 100 | class SupervisedCollator(CollatorBase): 101 | def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch: 102 | input_ids = right_padding( 103 | [sample['input_ids'] for sample in samples], 104 | padding_value=self.pad_token_id, 105 | ) 106 | labels = right_padding( 107 | [sample['labels'] for sample in samples], 108 | padding_value=IGNORE_INDEX, 109 | ) 110 | attention_mask = input_ids.ne(self.pad_token_id) 111 | return { 112 | 'input_ids': input_ids, # size = (B, L) 113 | 'labels': labels, # size = (B, L) 114 | 'attention_mask': attention_mask, # size = (B, L) 115 | } 116 | -------------------------------------------------------------------------------- /mcts_rl/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | import random 19 | random.seed(0) 20 | 21 | import torch 22 | from torch.nn.utils.rnn import pad_sequence 23 | from torch.types import Number 24 | 25 | from mcts_rl.configs import ( 26 | PROMPT_BEGIN, PROMPT_USER, 27 | PROMPT_ASSISTANT, PROMPT_ASSISTANT_MCQ, 28 | SQA_PROMPT, 29 | LLAMA3_PROMPT_USER, 30 | LLAMA3_PROMPT_ASSISTANT, 31 | LLAMA3_PROMPT_ASSISTANT_MCQ, 32 | ) 33 | 34 | 35 | def format_prompt( 36 | input: str | list[str], # pylint: disable=redefined-builtin 37 | eos_token: str, 38 | use_mcq: bool = False, 39 | few_shot: bool = False, 40 | model_type: str = 'mistral', 41 | ) -> str: 42 | if isinstance(input, str): 43 | input = [input] 44 | elif not isinstance(input, list): 45 | raise ValueError(f'Unsupported type of `input`: {type(input)}. Expected: str or list[str].') 46 | 47 | if len(input) % 2 != 1: 48 | raise ValueError( 49 | 'The length of `input` must be odd, while `input` must end at the user question.', 50 | ) 51 | 52 | if 'USER:' in PROMPT_USER: 53 | buffer = [PROMPT_BEGIN] 54 | elif few_shot: 55 | # buffer = [GSM8K_PROMPT] 56 | buffer = [SQA_PROMPT] 57 | else: 58 | # exp = random.choice(GSM8K_EXP) 59 | # buffer = [PROMPT_USER.format(input=exp['Q']) + PROMPT_ASSISTANT + ' ' + exp['A'] + DEFAULT_EOS_TOKEN + '\n\n'] 60 | # buffer = ['At the end of your answer output #### {final answer}.\n\n'] 61 | buffer = [] 62 | 63 | for i, line in enumerate(input): 64 | if i % 2 == 0: 65 | # User input 66 | if model_type == 'llama3': 67 | buffer.extend((LLAMA3_PROMPT_USER.format(input=line), 68 | LLAMA3_PROMPT_ASSISTANT_MCQ if use_mcq else LLAMA3_PROMPT_ASSISTANT)) 69 | else: 70 | buffer.extend((PROMPT_USER.format(input=line), 71 | PROMPT_ASSISTANT_MCQ if use_mcq else PROMPT_ASSISTANT)) 72 | else: 73 | # Assistant response 74 | buffer.extend((line, eos_token)) 75 | return ''.join(buffer) 76 | 77 | 78 | def right_padding(sequences: list[torch.Tensor], padding_value: Number) -> torch.Tensor: 79 | return pad_sequence(sequences, batch_first=True, padding_value=padding_value) 80 | 81 | 82 | def left_padding(sequences: list[torch.Tensor], padding_value: Number) -> torch.Tensor: 83 | return right_padding( 84 | [seq.flip(0) for seq in sequences], 85 | padding_value=padding_value, 86 | ).flip(1) 87 | -------------------------------------------------------------------------------- /mcts_rl/finetune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Supervised Fine-Tuning (SFT).""" 16 | 17 | from mcts_rl.finetune.trainer import SupervisedFinetuneTrainer 18 | 19 | 20 | __all__ = ['SupervisedFinetuneTrainer'] 21 | -------------------------------------------------------------------------------- /mcts_rl/finetune/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The main training script to supervised finetune a model.""" 16 | 17 | import sys 18 | 19 | from mcts_rl.finetune.main import main 20 | 21 | 22 | if __name__ == '__main__': 23 | sys.exit(main()) 24 | -------------------------------------------------------------------------------- /mcts_rl/finetune/huggingface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The main training script to supervised finetune a model using Hugging Face Transformers Trainer.""" 16 | 17 | import argparse 18 | from dataclasses import dataclass, field 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import transformers 22 | from transformers.training_args import OptimizerNames 23 | 24 | from mcts_rl.datasets import SupervisedDataset, parse_dataset 25 | from mcts_rl.models import load_pretrained_models 26 | 27 | 28 | @dataclass 29 | class ModelArguments: 30 | """Arguments for models.""" 31 | 32 | model_name_or_path: str 33 | 34 | 35 | @dataclass 36 | class DataArguments: 37 | """Arguments for datasets.""" 38 | 39 | datasets: List[parse_dataset] = field( 40 | default=None, 41 | metadata={'help': 'Path to the training data.'}, 42 | ) 43 | 44 | 45 | @dataclass 46 | class TrainingArguments(transformers.TrainingArguments): 47 | """Arguments for the training loop.""" 48 | 49 | cache_dir: Optional[str] = field(default=None) 50 | optim: Union[OptimizerNames, str] = field( 51 | default=OptimizerNames.ADAMW_TORCH, 52 | metadata={'help': 'The optimizer to use.'}, 53 | ) 54 | model_max_length: int = field( 55 | default=512, 56 | metadata={ 57 | 'help': 'Maximum sequence length. Sequences will be right padded (and possibly truncated).', 58 | }, 59 | ) 60 | 61 | 62 | def parse_arguments() -> Tuple[argparse.Namespace, argparse.Namespace, argparse.Namespace]: 63 | """Parse the command-line arguments.""" 64 | parser = transformers.HfArgumentParser([TrainingArguments, ModelArguments, DataArguments]) 65 | # pylint: disable-next=unbalanced-tuple-unpacking 66 | training_args, model_args, data_args = parser.parse_args_into_dataclasses() 67 | return training_args, model_args, data_args 68 | 69 | 70 | def main() -> None: 71 | """Main training routine.""" 72 | # pylint: disable=no-member 73 | training_args, model_args, data_args = parse_arguments() 74 | 75 | model, tokenizer = load_pretrained_models( 76 | model_args.model_name_or_path, 77 | model_max_length=training_args.model_max_length, 78 | padding_side='right', 79 | cache_dir=training_args.cache_dir, 80 | trust_remote_code=True, 81 | ) 82 | 83 | train_dataset = SupervisedDataset( 84 | data_args.datasets, 85 | tokenizer=tokenizer, 86 | seed=training_args.seed, 87 | ) 88 | data_collator = train_dataset.get_collator() 89 | 90 | trainer = transformers.Trainer( 91 | model=model, 92 | tokenizer=tokenizer, 93 | args=training_args, 94 | train_dataset=train_dataset, 95 | data_collator=data_collator, 96 | ) 97 | trainer.train() 98 | trainer.save_state() 99 | trainer.save_model() 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /mcts_rl/finetune/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The main training script to supervised finetune a model.""" 16 | 17 | from mcts_rl.finetune.deepspeed import main 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /mcts_rl/finetune/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Trainer class for supervised finetuning.""" 16 | 17 | from __future__ import annotations 18 | 19 | from typing import Any 20 | from tqdm import tqdm 21 | 22 | import torch 23 | import torch.distributed as dist 24 | from transformers import AutoModelForCausalLM 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from mcts_rl.datasets import SupervisedDataset 28 | from mcts_rl.trainers import SupervisedTrainer 29 | from mcts_rl.utils import ( 30 | get_all_reduce_mean, 31 | is_main_process, 32 | to_device, 33 | ) 34 | 35 | 36 | class SupervisedFinetuneTrainer(SupervisedTrainer): 37 | """Trainer class for supervised finetuning.""" 38 | 39 | TRAINING_TYPE = 'sft' 40 | DATASET_TYPE = SupervisedDataset 41 | MODEL_TYPE = AutoModelForCausalLM 42 | 43 | def loss( 44 | self, 45 | input_ids: torch.LongTensor, # size = (B, L) 46 | labels: torch.LongTensor, # size = (B, L) 47 | attention_mask: torch.BoolTensor, # size = (B, L) 48 | ) -> dict[str, torch.Tensor]: 49 | """Loss function for supervised finetuning.""" 50 | outputs: CausalLMOutputWithPast = self.model( 51 | input_ids=input_ids, 52 | attention_mask=attention_mask, 53 | labels=labels, 54 | ) 55 | return { 56 | 'loss': outputs.loss, 57 | } 58 | 59 | def eval( 60 | self, 61 | ) -> dict[str, Any]: 62 | if self.eval_dataloader is None: 63 | return {} 64 | 65 | self.set_eval() 66 | eval_dataloader = tqdm( 67 | self.eval_dataloader, 68 | desc='Evaluating', 69 | disable=not is_main_process(), 70 | position=1, 71 | leave=False, 72 | ) 73 | 74 | losses, batch = [], None 75 | for batch in eval_dataloader: 76 | batch = to_device(batch, self.args.device) 77 | with torch.no_grad(): 78 | loss = self.loss( 79 | input_ids=batch['input_ids'], 80 | labels=batch['labels'], 81 | attention_mask=batch['attention_mask'], 82 | )['loss'] 83 | losses.extend([loss]) 84 | 85 | if batch is None: 86 | self.logger.print('WARNING: `eval_dataloader` is empty.') 87 | return {} 88 | 89 | losses = torch.stack(losses, dim=0) 90 | if is_main_process(): 91 | gathered_losses = [torch.empty_like(losses) for _ in range(dist.get_world_size())] 92 | else: 93 | gathered_losses = [] 94 | dist.gather(losses, gathered_losses, dst=0) 95 | if is_main_process(): 96 | losses = torch.cat(gathered_losses, dim=0) 97 | 98 | self.set_train() 99 | 100 | return { 101 | 'eval/loss': losses.mean().item(), 102 | } 103 | 104 | def train_step( 105 | self, 106 | input_ids: torch.LongTensor, # size = (B, L) 107 | labels: torch.LongTensor, # size = (B, L) 108 | attention_mask: torch.BoolTensor, # size = (B, L) 109 | ) -> dict[str, Any]: 110 | """Performs a single training step. 111 | 112 | Args: 113 | input_ids (torch.LongTensor): input ids for causal inputs to complete with. 114 | labels (torch.LongTensor): labels for the full sequence. 115 | attention_mask (torch.BoolTensor): attention mask for the labels. 116 | 117 | Returns: 118 | dict[str, Any]: training loss, learning rate 119 | """ 120 | loss = self.loss( 121 | input_ids=input_ids, 122 | labels=labels, 123 | attention_mask=attention_mask, 124 | )['loss'] 125 | self.model.backward(loss) 126 | self.model.step() 127 | 128 | loss = get_all_reduce_mean(loss) 129 | 130 | return { 131 | 'train/loss': loss.item(), 132 | 'train/lr': self.model.optimizer.param_groups[0]['lr'], 133 | } 134 | -------------------------------------------------------------------------------- /mcts_rl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for Hugging Face auto-models.""" 16 | 17 | from mcts_rl.models.pretrained import load_pretrained_models 18 | from mcts_rl.models.score_model import AutoModelForScore, ScoreModelOutput 19 | 20 | 21 | __all__ = ['load_pretrained_models', 'AutoModelForScore', 'ScoreModelOutput'] 22 | -------------------------------------------------------------------------------- /mcts_rl/models/normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Normalizer for score models.""" 16 | 17 | from __future__ import annotations 18 | 19 | from abc import abstractmethod 20 | from typing import Any, Literal 21 | 22 | import torch 23 | from torch import nn 24 | from torch.types import Number 25 | 26 | 27 | NormalizeFunction = Literal['affine', 'scale', 'translate', 'identity'] 28 | NormalizerType = Literal['RunningMeanStd', 'ExponentialMovingAverage'] 29 | 30 | 31 | class Normalizer(nn.Module): 32 | """Normalize input to have zero mean and unit variance.""" 33 | 34 | mean: torch.Tensor 35 | var: torch.Tensor 36 | count: torch.LongTensor 37 | normalize_function: NormalizeFunction 38 | 39 | def __init__( 40 | self, 41 | normalize_function: NormalizeFunction, 42 | shape: tuple[int, ...], 43 | device: torch.device | str | None = None, 44 | ) -> None: 45 | """Initialize.""" 46 | super().__init__() 47 | if normalize_function not in {'affine', 'scale', 'translate', 'identity'}: 48 | raise ValueError( 49 | f'Invalid normalization function type: {normalize_function}. ', 50 | 'Expected one of "affine", "scale", "translate", "identity".', 51 | ) 52 | self.normalize_function = normalize_function 53 | self.register_buffer('mean', torch.zeros(shape, device=device)) 54 | self.register_buffer('var', torch.ones(shape, device=device)) 55 | self.register_buffer('count', torch.zeros(1, dtype=torch.long, device=device)) 56 | 57 | @abstractmethod 58 | def update(self, data: torch.Tensor) -> None: 59 | """Update mean and variance.""" 60 | raise NotImplementedError 61 | 62 | @property 63 | def std(self) -> torch.Tensor: 64 | """Return standard deviation.""" 65 | return self.var.sqrt() 66 | 67 | def set_mean_var( 68 | self, 69 | mean: torch.Tensor | list[float] | tuple[float, ...] | None, 70 | var: torch.Tensor | list[float] | tuple[float, ...] | None, 71 | ) -> None: 72 | """Set mean and variance.""" 73 | mean = ( 74 | torch.as_tensor(mean, dtype=self.mean.dtype, device=self.mean.device) 75 | if mean is not None 76 | else self.mean 77 | ) 78 | var = ( 79 | torch.as_tensor(var, dtype=self.var.dtype, device=self.var.device) 80 | if var is not None 81 | else self.var 82 | ) 83 | 84 | assert mean.shape == self.mean.shape 85 | assert var.shape == self.var.shape 86 | 87 | self.mean = mean 88 | self.var = var 89 | 90 | def forward( 91 | self, 92 | data: torch.Tensor, 93 | epsilon: Number = 1e-8, 94 | ) -> torch.Tensor: 95 | """Update and normalize input.""" 96 | if self.training: 97 | self.update(data) 98 | return self.normalize(data, epsilon=epsilon) 99 | 100 | def normalize( 101 | self, 102 | data: torch.Tensor, 103 | epsilon: Number = 1e-8, 104 | ) -> torch.Tensor: 105 | """Normalize input.""" 106 | if self.normalize_function == 'affine': 107 | return (data - self.mean.detach()) / (self.std.detach() + epsilon) 108 | if self.normalize_function == 'scale': 109 | return data / (self.std.detach() + epsilon) 110 | if self.normalize_function == 'translate': 111 | return data - self.mean.detach() 112 | if self.normalize_function == 'identity': 113 | return data 114 | raise ValueError( 115 | f'Invalid normalization function type: {self.normalize_function}. ', 116 | 'Expected one of "affine", "scale", "translate", "identity".', 117 | ) 118 | 119 | @classmethod 120 | def instantiate( 121 | cls, 122 | normalizer_type: NormalizerType | None, 123 | normalize_function: NormalizeFunction, 124 | shape: tuple[int, ...], 125 | device: torch.device | str | None = None, 126 | **kwargs: Any, 127 | ) -> Normalizer: 128 | """Get a normalizer.""" 129 | if normalizer_type == 'RunningMeanStd': 130 | return RunningMeanStd( 131 | normalize_function, 132 | shape=shape, 133 | device=device, 134 | ) 135 | if normalizer_type == 'ExponentialMovingAverage': 136 | return ExponentialMovingAverage( 137 | normalize_function, 138 | shape=shape, 139 | device=device, 140 | **kwargs, 141 | ) 142 | if normalizer_type is None: 143 | return IdentityNormalizer( 144 | normalize_function, 145 | shape=shape, 146 | device=device, 147 | ) 148 | raise ValueError( 149 | f'Invalid normalization function type: {normalizer_type}. ' 150 | 'Expected one of "RunningMeanStd", "ExponentialMovingAverage".', 151 | ) 152 | 153 | 154 | class RunningMeanStd(Normalizer): 155 | """Running mean and standard deviation.""" 156 | 157 | def update(self, data: torch.Tensor) -> None: 158 | """Update mean and variance.""" 159 | batch_mean = data.mean(dim=0) 160 | batch_var = data.var(dim=0) 161 | batch_count = data.size(0) 162 | 163 | delta = batch_mean - self.mean 164 | total_count = self.count + batch_count 165 | 166 | new_mean = self.mean + delta * batch_count / total_count 167 | m_a = self.var * self.count 168 | m_b = batch_var * batch_count 169 | m2 = ( # pylint: disable=invalid-name 170 | m_a + m_b + torch.square(delta) * (self.count * batch_count / total_count) 171 | ) 172 | new_var = m2 / total_count 173 | 174 | self.mean = new_mean 175 | self.var = new_var 176 | self.count = total_count 177 | 178 | 179 | class ExponentialMovingAverage(Normalizer): 180 | """Exponential moving average.""" 181 | 182 | def __init__( 183 | self, 184 | normalize_function: NormalizeFunction, 185 | shape: tuple[int, ...], 186 | device: torch.device | str | None = None, 187 | momentum: float = 0.9, 188 | ) -> None: 189 | super().__init__(normalize_function, shape=shape, device=device) 190 | self.momentum = momentum 191 | 192 | def update(self, data: torch.Tensor) -> None: 193 | """Update mean and variance.""" 194 | batch_mean = data.mean(dim=0) 195 | batch_var = data.var(dim=0) 196 | batch_count = data.size(0) 197 | 198 | self.mean = self.momentum * self.mean + (1.0 - self.momentum) * batch_mean 199 | self.var = self.momentum * self.var + (1.0 - self.momentum) * batch_var 200 | self.count += batch_count # pylint: disable=no-member 201 | 202 | 203 | class IdentityNormalizer(Normalizer): 204 | """Identity normalizer.""" 205 | 206 | def update(self, data: torch.Tensor) -> None: 207 | """Update mean and variance.""" 208 | self.count += data.size(0) # pylint: disable=no-member 209 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Auto-models for score models.""" 16 | 17 | from __future__ import annotations 18 | 19 | import functools 20 | import importlib 21 | from collections import OrderedDict 22 | from dataclasses import dataclass 23 | from typing import Any 24 | 25 | import torch 26 | import torch.nn as nn 27 | import transformers.models.auto as auto_module 28 | from torch import distributed as dist 29 | from transformers import PretrainedConfig 30 | from transformers.models.auto.auto_factory import ( 31 | _BaseAutoModelClass, 32 | _LazyAutoMapping, 33 | auto_class_update, 34 | getattribute_from_module, 35 | ) 36 | from transformers.models.auto.configuration_auto import ( 37 | CONFIG_MAPPING_NAMES, 38 | model_type_to_module_name, 39 | ) 40 | from transformers.utils.generic import ModelOutput 41 | 42 | from mcts_rl.models.normalizer import NormalizeFunction, Normalizer 43 | 44 | 45 | class _LazyAutoMappingInSafeRLHF(_LazyAutoMapping): 46 | def _load_attr_from_module(self, model_type: str, attr: str) -> Any: 47 | module_name = model_type_to_module_name(model_type) 48 | if module_name not in self._modules: 49 | self._modules[module_name] = importlib.import_module( 50 | f'.{module_name}', 51 | 'mcts_rl.models.score_model', 52 | ) 53 | return getattribute_from_module(self._modules[module_name], attr) 54 | 55 | 56 | MODEL_FOR_SCORE_MAPPING_NAMES: OrderedDict[str, str] = OrderedDict( 57 | [ 58 | # Score model mapping 59 | ('llama', 'LlamaModelForScore'), 60 | ('bloom', 'BloomModelForScore'), 61 | ('open_llama', 'OpenLlamaForScore'), 62 | ('opt', 'OPTForScore'), 63 | ('gpt_neo', 'GPTNeoForScore'), 64 | ('gptj', 'GPTJForScore'), 65 | ('gpt2', 'GPT2ForScore'), 66 | ('gpt_neox', 'GPTNeoXForScore'), 67 | ('mistral', 'MistralModelForScore'), 68 | ], 69 | ) 70 | MODEL_FOR_SCORE_MAPPING: OrderedDict[str, Any] = _LazyAutoMappingInSafeRLHF( 71 | CONFIG_MAPPING_NAMES, 72 | MODEL_FOR_SCORE_MAPPING_NAMES, 73 | ) 74 | 75 | 76 | @functools.partial(auto_class_update, head_doc='score model') 77 | class AutoModelForScore(_BaseAutoModelClass): 78 | _model_mapping: OrderedDict[str, Any] = MODEL_FOR_SCORE_MAPPING 79 | 80 | 81 | setattr(auto_module, 'MODEL_FOR_SCORE_MAPPING', MODEL_FOR_SCORE_MAPPING) # noqa: B010 82 | setattr(auto_module, AutoModelForScore.__name__, AutoModelForScore) 83 | 84 | 85 | @dataclass 86 | class ScoreModelOutput(ModelOutput): 87 | """ 88 | Output of the score model. 89 | 90 | Args: 91 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim, sequence_length)`): 92 | Prediction scores of the score model. 93 | end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim)`): 94 | Prediction scores of the end of the sequence. 95 | """ 96 | 97 | scores: torch.Tensor | None = None # size = (B, L, D) 98 | end_scores: torch.Tensor | None = None # size = (B, D) 99 | 100 | 101 | class ScoreModelMixin: 102 | """Base class for score models.""" 103 | 104 | score_head: nn.Linear 105 | normalizer: Normalizer 106 | do_normalize: bool = False 107 | normalize_function: NormalizeFunction = 'affine' 108 | _initialized: bool = False 109 | 110 | def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: Any) -> None: 111 | """Initialize the score head.""" 112 | if self._initialized: 113 | return 114 | 115 | config.score_dim = kwargs.pop('score_dim', getattr(config, 'score_dim', 1)) 116 | config.bias = kwargs.pop('bias', getattr(config, 'bias', False)) 117 | 118 | config.score_type = kwargs.pop('score_type', getattr(config, 'score_type', 'reward')) 119 | if config.score_type == 'reward': 120 | self.normalize_function = 'affine' 121 | elif config.score_type == 'cost': 122 | self.normalize_function = 'scale' 123 | elif config.score_type == 'critic': 124 | self.normalize_function = 'identity' 125 | else: 126 | raise ValueError( 127 | f"Invalid score type: {config.score_type}. Expected one of 'reward', 'cost', or 'critic'.", 128 | ) 129 | 130 | config.do_normalize = kwargs.pop( 131 | 'do_normalize', 132 | getattr(config, 'do_normalize', False), 133 | ) 134 | self.do_normalize = config.do_normalize 135 | 136 | config.normalizer_type = kwargs.pop( 137 | 'normalizer_type', 138 | getattr(config, 'normalizer_type', None), 139 | ) 140 | if config.normalizer_type not in {'RunningMeanStd', 'ExponentialMovingAverage', None}: 141 | raise ValueError( 142 | f'Invalid norm type: {config.normalizer_type}.' 143 | "Expected one of 'RunningMeadStd', 'ExponentialMovingAverage', or None.", 144 | ) 145 | if config.normalizer_type == 'ExponentialMovingAverage': 146 | config.momentum = kwargs.pop('momentum', getattr(config, 'momentum', None)) 147 | momentum = getattr(config, 'momentum', None) 148 | 149 | self.score_head = nn.Linear(hidden_size, config.score_dim, bias=config.bias) 150 | self.normalizer = Normalizer.instantiate( 151 | normalizer_type=config.normalizer_type, 152 | normalize_function=self.normalize_function, 153 | shape=(config.score_dim,), 154 | momentum=momentum, 155 | ) 156 | 157 | mean = getattr(config, 'mean', None) 158 | var = getattr(config, 'var', None) 159 | self.normalizer.set_mean_var(mean, var) 160 | 161 | self._initialized = True 162 | 163 | def get_score( 164 | self, 165 | hidden_state: torch.Tensor, # size = (B, L, E) 166 | attention_mask: torch.BoolTensor, # size = (B, L) 167 | return_dict: bool | None = None, 168 | ) -> ScoreModelOutput: 169 | """Forward pass of the score model.""" 170 | scores = self.score_head(hidden_state) # size = (B, L, D) 171 | 172 | end_score = [] 173 | for i in range(hidden_state.size(0)): 174 | end_index = attention_mask[i].nonzero()[-1].item() 175 | end_score.append(scores[i, end_index]) # size = (D,) 176 | end_score = torch.stack(end_score, dim=0) # size = (B, D) 177 | 178 | if self.training: 179 | if dist.is_initialized(): 180 | gathered_end_score_list = [ 181 | torch.zeros_like(end_score) for _ in range(dist.get_world_size()) 182 | ] 183 | dist.all_gather(gathered_end_score_list, end_score) 184 | gathered_end_score = torch.cat(gathered_end_score_list, dim=0) 185 | self.normalizer.update(gathered_end_score) 186 | else: 187 | self.normalizer.update(end_score) 188 | self.config.mean = self.normalizer.mean.tolist() 189 | self.config.var = self.normalizer.var.tolist() 190 | 191 | if self.do_normalize: 192 | scores = self.normalizer.normalize(scores) 193 | end_score = self.normalizer.normalize(end_score) 194 | 195 | if not return_dict: 196 | return scores, end_score 197 | 198 | return ScoreModelOutput( 199 | scores=scores, # size = (B, L, D) 200 | end_scores=end_score, # size = (B, D) 201 | ) 202 | 203 | def set_normalize(self, mode: bool = True) -> None: 204 | if self.do_normalize == mode: 205 | return 206 | 207 | self.do_normalize = self.config.do_normalize = mode 208 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/bloom/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.bloom.modeling_bloom import BloomModelForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/bloom/modeling_bloom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | import warnings 19 | from typing import Any, ClassVar 20 | 21 | import torch 22 | from transformers import BloomConfig, BloomModel, BloomPreTrainedModel, PreTrainedModel 23 | from transformers.models.bloom.modeling_bloom import ( 24 | _CHECKPOINT_FOR_DOC, 25 | _CONFIG_FOR_DOC, 26 | BLOOM_INPUTS_DOCSTRING, 27 | ) 28 | from transformers.utils.doc import add_code_sample_docstrings, add_start_docstrings_to_model_forward 29 | 30 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 31 | 32 | 33 | class BloomModelForScore(ScoreModelMixin, BloomPreTrainedModel): 34 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ 35 | 'h.*.self_attention.scale_mask_softmax.causal_mask', 36 | 'lm_head.weight', 37 | ] 38 | 39 | def __init__(self, config: BloomConfig, **kwargs: Any) -> None: 40 | super().__init__(config) 41 | self.transformer = BloomModel(config) 42 | 43 | config.architectures = [self.__class__.__name__] 44 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 45 | 46 | # Initialize weights and apply final processing 47 | self.post_init() 48 | 49 | def get_output_embeddings(self) -> None: 50 | return None 51 | 52 | def set_decoder(self, decoder: PreTrainedModel) -> None: 53 | self.transformer = decoder 54 | 55 | def get_decoder(self) -> PreTrainedModel: 56 | return self.transformer 57 | 58 | @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) 59 | @add_code_sample_docstrings( 60 | checkpoint=_CHECKPOINT_FOR_DOC, 61 | output_type=ScoreModelOutput, 62 | config_class=_CONFIG_FOR_DOC, 63 | ) 64 | def forward( # pylint: disable=too-many-arguments 65 | self, 66 | input_ids: torch.LongTensor | None = None, 67 | past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...] | None = None, 68 | attention_mask: torch.Tensor | None = None, 69 | head_mask: torch.Tensor | None = None, 70 | inputs_embeds: torch.Tensor | None = None, 71 | use_cache: bool | None = None, 72 | output_attentions: bool | None = None, 73 | output_hidden_states: bool | None = None, 74 | return_dict: bool | None = None, 75 | **deprecated_arguments: dict[str, Any], 76 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 77 | """ 78 | Args: 79 | 80 | Returns: 81 | 82 | Examples: 83 | 84 | ```python 85 | >>> from safe_rlhf.models.llama.modeling_llama import LlamaModelForScore 86 | >>> from transformers import LlamaTokenizer 87 | 88 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 89 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 90 | 91 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 92 | >>> inputs = tokenizer(prompt, return_tensors="pt") 93 | 94 | # got score 95 | >>> outputs = model(**inputs) 96 | >>> scores = outputs.scores 97 | >>> scores 98 | tensor([[[0.0000]]]) 99 | ``` 100 | """ 101 | assert attention_mask is not None 102 | if deprecated_arguments.pop('position_ids', False) is not False: 103 | # `position_ids` could have been `torch.Tensor` or `None` 104 | # so defaulting pop to `False` allows to detect if users were passing explicitly `None` 105 | warnings.warn( 106 | '`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' 107 | 'You can safely ignore passing `position_ids`.', 108 | FutureWarning, 109 | stacklevel=1, 110 | ) 111 | if len(deprecated_arguments) > 0: 112 | raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') 113 | 114 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 115 | 116 | transformer_outputs = self.transformer( 117 | input_ids, 118 | past_key_values=past_key_values, 119 | attention_mask=attention_mask, 120 | head_mask=head_mask, 121 | inputs_embeds=inputs_embeds, 122 | use_cache=use_cache, 123 | output_attentions=output_attentions, 124 | output_hidden_states=output_hidden_states, 125 | return_dict=return_dict, 126 | ) 127 | hidden_states = transformer_outputs[0] # size = (B, L, E) 128 | return self.get_score( 129 | hidden_states, 130 | attention_mask=attention_mask, 131 | return_dict=return_dict, 132 | ) 133 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gpt2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.gpt2.modeling_gpt2 import GPT2ForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gpt2/modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | import warnings 19 | from typing import Any, ClassVar 20 | 21 | import torch 22 | from transformers import GPT2Model, GPT2PreTrainedModel, PreTrainedModel 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.models.gpt2.modeling_gpt2 import ( 25 | DEPARALLELIZE_DOCSTRING, 26 | GPT2_INPUTS_DOCSTRING, 27 | GPT2_START_DOCSTRING, 28 | PARALLELIZE_DOCSTRING, 29 | ) 30 | from transformers.utils.doc import add_start_docstrings, add_start_docstrings_to_model_forward 31 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 32 | 33 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 34 | 35 | 36 | @add_start_docstrings( 37 | """ 38 | The GPT2 Model transformer with a score head on top (linear layer with weights tied to the input 39 | embeddings). 40 | """, 41 | GPT2_START_DOCSTRING, 42 | ) 43 | class GPT2ForScore(ScoreModelMixin, GPT2PreTrainedModel): 44 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ 45 | 'attn.masked_bias', 46 | 'attn.bias', 47 | 'lm_head.weight', 48 | ] 49 | 50 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 51 | super().__init__(config) 52 | self.transformer = GPT2Model(config) 53 | 54 | config.architectures = [self.__class__.__name__] 55 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 56 | 57 | # Model parallel 58 | self.model_parallel = False 59 | self.device_map = None 60 | 61 | # Initialize weights and apply final processing 62 | self.post_init() 63 | 64 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 65 | def parallelize(self, device_map: str | None = None) -> None: 66 | warnings.warn( 67 | '`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load' 68 | " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" 69 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" 70 | " 0, 'transformer.h.1': 1, ...}", 71 | FutureWarning, 72 | stacklevel=1, 73 | ) 74 | self.device_map = ( 75 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 76 | if device_map is None 77 | else device_map 78 | ) 79 | assert_device_map(self.device_map, len(self.transformer.h)) 80 | self.transformer.parallelize(self.device_map) 81 | self.score_head = self.score_head.to(self.transformer.first_device) 82 | self.model_parallel = True 83 | 84 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 85 | def deparallelize(self) -> None: 86 | warnings.warn( 87 | 'Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.', 88 | FutureWarning, 89 | stacklevel=1, 90 | ) 91 | self.transformer.deparallelize() 92 | self.transformer = self.transformer.to('cpu') 93 | self.score_head = self.score_head.to('cpu') 94 | self.model_parallel = False 95 | torch.cuda.empty_cache() 96 | 97 | def get_output_embeddings(self) -> None: 98 | return None 99 | 100 | def set_decoder(self, decoder: PreTrainedModel) -> None: 101 | self.transformer = decoder 102 | 103 | def get_decoder(self) -> PreTrainedModel: 104 | return self.transformer 105 | 106 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 107 | def forward( # pylint: disable=too-many-arguments 108 | self, 109 | input_ids: torch.LongTensor | None = None, 110 | attention_mask: torch.FloatTensor | None = None, 111 | past_key_values: tuple[tuple[torch.Tensor]] | None = None, 112 | token_type_ids: torch.LongTensor | None = None, 113 | position_ids: torch.LongTensor | None = None, 114 | head_mask: torch.FloatTensor | None = None, 115 | inputs_embeds: torch.FloatTensor | None = None, 116 | encoder_hidden_states: torch.Tensor | None = None, 117 | encoder_attention_mask: torch.FloatTensor | None = None, 118 | use_cache: bool | None = None, 119 | output_attentions: bool | None = None, 120 | output_hidden_states: bool | None = None, 121 | return_dict: bool | None = None, 122 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 123 | """ 124 | Args: 125 | 126 | Returns: 127 | 128 | Examples: 129 | 130 | ```python 131 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore 132 | >>> from transformers import LlamaTokenizer 133 | 134 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 135 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 136 | 137 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 138 | >>> inputs = tokenizer(prompt, return_tensors="pt") 139 | 140 | # got score 141 | >>> outputs = model(**inputs) 142 | >>> scores = outputs.scores 143 | >>> scores 144 | tensor([[[0.0000]]]) 145 | ``` 146 | """ 147 | assert attention_mask is not None 148 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 149 | 150 | transformer_outputs = self.transformer( 151 | input_ids, 152 | past_key_values=past_key_values, 153 | attention_mask=attention_mask, 154 | token_type_ids=token_type_ids, 155 | position_ids=position_ids, 156 | head_mask=head_mask, 157 | inputs_embeds=inputs_embeds, 158 | encoder_hidden_states=encoder_hidden_states, 159 | encoder_attention_mask=encoder_attention_mask, 160 | use_cache=use_cache, 161 | output_attentions=output_attentions, 162 | output_hidden_states=output_hidden_states, 163 | return_dict=return_dict, 164 | ) 165 | hidden_states = transformer_outputs[0] # size = (B, L, E) 166 | 167 | # Set device for model parallelism 168 | if self.model_parallel: 169 | torch.cuda.set_device(self.transformer.first_device) 170 | hidden_states = hidden_states.to(self.lm_head.weight.device) 171 | 172 | return self.get_score( 173 | hidden_states, 174 | attention_mask=attention_mask, 175 | return_dict=return_dict, 176 | ) 177 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gpt_neo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.gpt_neo.modeling_gpt_neo import GPTNeoForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gpt_neo/modeling_gpt_neo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, ClassVar 19 | 20 | import torch 21 | from transformers import GPTNeoModel, GPTNeoPreTrainedModel, PretrainedConfig, PreTrainedModel 22 | from transformers.models.gpt_neo.modeling_gpt_neo import ( 23 | GPT_NEO_INPUTS_DOCSTRING, 24 | GPT_NEO_START_DOCSTRING, 25 | ) 26 | from transformers.utils.doc import add_start_docstrings, add_start_docstrings_to_model_forward 27 | 28 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 29 | 30 | 31 | @add_start_docstrings( 32 | """ 33 | The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input 34 | embeddings). 35 | """, 36 | GPT_NEO_START_DOCSTRING, 37 | ) 38 | class GPTNeoForScore(ScoreModelMixin, GPTNeoPreTrainedModel): 39 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ 40 | r'h\.\d+\.attn\.masked_bias', 41 | r'lm_head.weight', 42 | r'h\.\d+\.attn\.attention\.bias', 43 | ] 44 | _keys_to_ignore_on_save: ClassVar[list[str]] = [r'lm_head.weight'] 45 | 46 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 47 | super().__init__(config) 48 | self.transformer = GPTNeoModel(config) 49 | 50 | config.architectures = [self.__class__.__name__] 51 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 52 | 53 | # Initialize weights and apply final processing 54 | self.post_init() 55 | 56 | def get_output_embeddings(self) -> None: 57 | return None 58 | 59 | def set_decoder(self, decoder: PreTrainedModel) -> None: 60 | self.transformer = decoder 61 | 62 | def get_decoder(self) -> PreTrainedModel: 63 | return self.transformer 64 | 65 | @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) 66 | def forward( # pylint: disable=too-many-arguments 67 | self, 68 | input_ids: torch.Tensor | None = None, 69 | past_key_values: tuple[torch.FloatTensor] | None = None, 70 | attention_mask: torch.Tensor | None = None, 71 | token_type_ids: torch.Tensor | None = None, 72 | position_ids: torch.Tensor | None = None, 73 | head_mask: torch.Tensor | None = None, 74 | inputs_embeds: torch.Tensor | None = None, 75 | use_cache: bool | None = None, 76 | output_attentions: bool | None = None, 77 | output_hidden_states: bool | None = None, 78 | return_dict: bool | None = None, 79 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 80 | r""" 81 | Args: 82 | 83 | Returns: 84 | 85 | Examples: 86 | 87 | ```python 88 | >>> from safe_rlhf.models.llama.modeling_llama import LlamaModelForScore 89 | >>> from transformers import LlamaTokenizer 90 | 91 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 92 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 93 | 94 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 95 | >>> inputs = tokenizer(prompt, return_tensors="pt") 96 | 97 | # got score 98 | >>> outputs = model(**inputs) 99 | >>> scores = outputs.scores 100 | >>> scores 101 | tensor([[[0.0000]]]) 102 | ``` 103 | """ 104 | assert attention_mask is not None 105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 106 | 107 | outputs = self.transformer( 108 | input_ids, 109 | past_key_values=past_key_values, 110 | attention_mask=attention_mask, 111 | token_type_ids=token_type_ids, 112 | position_ids=position_ids, 113 | head_mask=head_mask, 114 | inputs_embeds=inputs_embeds, 115 | use_cache=use_cache, 116 | output_attentions=output_attentions, 117 | output_hidden_states=output_hidden_states, 118 | return_dict=return_dict, 119 | ) 120 | hidden_states = outputs[0] # size = (B, L, E) 121 | return self.get_score( 122 | hidden_states, 123 | attention_mask=attention_mask, 124 | return_dict=return_dict, 125 | ) 126 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gpt_neox/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.gpt_neox.modeling_gpt_neox import GPTNeoXForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gpt_neox/modeling_gpt_neox.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, ClassVar 19 | 20 | import torch 21 | from transformers import GPTNeoXModel, LlamaPreTrainedModel, PreTrainedModel 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.models.gpt_neox.modeling_gpt_neox import ( 24 | _CONFIG_FOR_DOC, 25 | GPT_NEOX_INPUTS_DOCSTRING, 26 | ) 27 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings 28 | 29 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 30 | 31 | 32 | class GPTNeoXForScore(ScoreModelMixin, LlamaPreTrainedModel): 33 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ 34 | r'position_ids', 35 | r'predictions.decoder.bias', 36 | ] 37 | 38 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 39 | super().__init__(config) 40 | self.gpt_neox = GPTNeoXModel(config) 41 | 42 | config.architectures = [self.__class__.__name__] 43 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 44 | 45 | # Initialize weights and apply final processing 46 | self.post_init() 47 | 48 | def get_output_embeddings(self) -> None: 49 | return None 50 | 51 | def set_decoder(self, decoder: PreTrainedModel) -> None: 52 | self.gpt_neox = decoder 53 | 54 | def get_decoder(self) -> PreTrainedModel: 55 | return self.gpt_neox 56 | 57 | @add_start_docstrings_to_model_forward( 58 | GPT_NEOX_INPUTS_DOCSTRING.format('batch_size, sequence_length'), 59 | ) 60 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC) 61 | def forward( # pylint: disable=too-many-arguments 62 | self, 63 | input_ids: torch.LongTensor, 64 | attention_mask: torch.Tensor, 65 | position_ids: torch.LongTensor | None = None, 66 | inputs_embeds: torch.FloatTensor | None = None, 67 | head_mask: torch.FloatTensor | None = None, 68 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None, 69 | use_cache: bool | None = None, 70 | output_attentions: bool | None = None, 71 | output_hidden_states: bool | None = None, 72 | return_dict: bool | None = None, 73 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 74 | """ 75 | Args: 76 | 77 | Returns: 78 | 79 | Examples: 80 | 81 | ```python 82 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore 83 | >>> from transformers import LlamaTokenizer 84 | 85 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 86 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 87 | 88 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 89 | >>> inputs = tokenizer(prompt, return_tensors="pt") 90 | 91 | # got score 92 | >>> outputs = model(**inputs) 93 | >>> scores = outputs.scores 94 | >>> scores 95 | tensor([[[0.0000]]]) 96 | ``` 97 | """ 98 | assert attention_mask is not None 99 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 100 | 101 | outputs = self.gpt_neox( 102 | input_ids, 103 | attention_mask=attention_mask, 104 | position_ids=position_ids, 105 | head_mask=head_mask, 106 | inputs_embeds=inputs_embeds, 107 | past_key_values=past_key_values, 108 | use_cache=use_cache, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | return_dict=return_dict, 112 | ) 113 | hidden_states = outputs[0] # size = (B, L, E) 114 | return self.get_score( 115 | hidden_states, 116 | attention_mask=attention_mask, 117 | return_dict=return_dict, 118 | ) 119 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gptj/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.gptj.modeling_gptj import GPTJForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/gptj/modeling_gptj.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | import warnings 19 | from typing import Any, ClassVar 20 | 21 | import torch 22 | from transformers import GPTJModel, GPTJPreTrainedModel, PreTrainedModel 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.models.gptj.modeling_gptj import ( 25 | DEPARALLELIZE_DOCSTRING, 26 | GPTJ_INPUTS_DOCSTRING, 27 | GPTJ_START_DOCSTRING, 28 | PARALLELIZE_DOCSTRING, 29 | ) 30 | from transformers.utils.doc import add_start_docstrings, add_start_docstrings_to_model_forward 31 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 32 | 33 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 34 | 35 | 36 | @add_start_docstrings( 37 | """ 38 | The GPT-J Model transformer with a score head on top. 39 | """, 40 | GPTJ_START_DOCSTRING, 41 | ) 42 | class GPTJForScore(ScoreModelMixin, GPTJPreTrainedModel): 43 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ 44 | r'h\.\d+\.attn\.masked_bias', 45 | r'h\.\d+\.attn\.bias', 46 | ] 47 | 48 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 49 | super().__init__(config) 50 | self.transformer = GPTJModel(config) 51 | 52 | config.architectures = [self.__class__.__name__] 53 | self.init_score_head(config, hidden_size=config.n_embd, **kwargs) 54 | 55 | # Initialize weights and apply final processing 56 | self.post_init() 57 | 58 | self.device_map: dict[Any, Any] 59 | self.model_parallel: bool 60 | 61 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 62 | def parallelize(self, device_map: str | None = None) -> None: 63 | warnings.warn( 64 | '`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load' 65 | " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" 66 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" 67 | " 0, 'transformer.h.1': 1, ...}", 68 | FutureWarning, 69 | stacklevel=1, 70 | ) 71 | self.device_map = ( 72 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 73 | if device_map is None 74 | else device_map 75 | ) 76 | assert_device_map(self.device_map, len(self.transformer.h)) 77 | self.transformer.parallelize(self.device_map) 78 | self.score_head = self.score_head.to(self.transformer.first_device) 79 | self.model_parallel = True 80 | 81 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 82 | def deparallelize(self) -> None: 83 | warnings.warn( 84 | 'Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.', 85 | FutureWarning, 86 | stacklevel=1, 87 | ) 88 | self.transformer.deparallelize() 89 | self.transformer = self.transformer.to('cpu') 90 | self.score_head = self.score_head.to('cpu') 91 | self.model_parallel = False 92 | torch.cuda.empty_cache() 93 | 94 | def get_output_embeddings(self) -> None: 95 | return None 96 | 97 | def set_decoder(self, decoder: PreTrainedModel) -> None: 98 | self.transformer = decoder 99 | 100 | def get_decoder(self) -> PreTrainedModel: 101 | return self.transformer 102 | 103 | @add_start_docstrings_to_model_forward( 104 | GPTJ_INPUTS_DOCSTRING.format('batch_size, sequence_length'), 105 | ) 106 | def forward( # pylint: disable=too-many-arguments 107 | self, 108 | input_ids: torch.LongTensor, 109 | attention_mask: torch.FloatTensor, 110 | past_key_values: tuple[tuple[torch.Tensor]] | None = None, 111 | token_type_ids: torch.LongTensor | None = None, 112 | position_ids: torch.LongTensor | None = None, 113 | head_mask: torch.FloatTensor | None = None, 114 | inputs_embeds: torch.FloatTensor | None = None, 115 | use_cache: bool | None = None, 116 | output_attentions: bool | None = None, 117 | output_hidden_states: bool | None = None, 118 | return_dict: bool | None = None, 119 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 120 | """ 121 | Args: 122 | 123 | Returns: 124 | 125 | Examples: 126 | 127 | ```python 128 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore 129 | >>> from transformers import LlamaTokenizer 130 | 131 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 132 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 133 | 134 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 135 | >>> inputs = tokenizer(prompt, return_tensors="pt") 136 | 137 | # got score 138 | >>> outputs = model(**inputs) 139 | >>> scores = outputs.scores 140 | >>> scores 141 | tensor([[[0.0000]]]) 142 | ``` 143 | """ 144 | assert attention_mask is not None 145 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 146 | 147 | transformer_outputs = self.transformer( 148 | input_ids, 149 | past_key_values=past_key_values, 150 | attention_mask=attention_mask, 151 | token_type_ids=token_type_ids, 152 | position_ids=position_ids, 153 | head_mask=head_mask, 154 | inputs_embeds=inputs_embeds, 155 | use_cache=use_cache, 156 | output_attentions=output_attentions, 157 | output_hidden_states=output_hidden_states, 158 | return_dict=return_dict, 159 | ) 160 | hidden_states = transformer_outputs[0] # size = (B, L, E) 161 | 162 | # Set device for model parallelism 163 | if self.model_parallel: 164 | torch.cuda.set_device(self.transformer.first_device) 165 | hidden_states = hidden_states.to(self.lm_head.weight.device) 166 | 167 | return self.get_score( 168 | hidden_states, 169 | attention_mask=attention_mask, 170 | return_dict=return_dict, 171 | ) 172 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.llama.modeling_llama import LlamaModelForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/llama/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, ClassVar 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import LlamaModel, LlamaPreTrainedModel, PreTrainedModel 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC, LLAMA_INPUTS_DOCSTRING 25 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings 26 | 27 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 28 | 29 | 30 | class LlamaModelForScore(ScoreModelMixin, LlamaPreTrainedModel): 31 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight'] 32 | 33 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 34 | super().__init__(config) 35 | self.model = LlamaModel(config) 36 | 37 | config.architectures = [self.__class__.__name__] 38 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 39 | 40 | # Initialize weights and apply final processing 41 | self.post_init() 42 | 43 | def get_input_embeddings(self) -> nn.Embedding: 44 | return self.model.embed_tokens 45 | 46 | def set_input_embeddings(self, value: nn.Embedding) -> None: 47 | self.model.embed_tokens = value 48 | 49 | def get_output_embeddings(self) -> None: 50 | return None 51 | 52 | def set_decoder(self, decoder: PreTrainedModel) -> None: 53 | self.model = decoder 54 | 55 | def get_decoder(self) -> PreTrainedModel: 56 | return self.model 57 | 58 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 59 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC) 60 | def forward( # pylint: disable=too-many-arguments 61 | self, 62 | input_ids: torch.LongTensor, 63 | attention_mask: torch.Tensor, 64 | position_ids: torch.LongTensor | None = None, 65 | past_key_values: list[torch.FloatTensor] | None = None, 66 | inputs_embeds: torch.FloatTensor | None = None, 67 | use_cache: bool | None = None, 68 | output_attentions: bool | None = None, 69 | output_hidden_states: bool | None = None, 70 | return_dict: bool | None = None, 71 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 72 | """ 73 | Args: 74 | 75 | Returns: 76 | 77 | Examples: 78 | 79 | ```python 80 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore 81 | >>> from transformers import LlamaTokenizer 82 | 83 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 84 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 85 | 86 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 87 | >>> inputs = tokenizer(prompt, return_tensors="pt") 88 | 89 | # got score 90 | >>> outputs = model(**inputs) 91 | >>> scores = outputs.scores 92 | >>> scores 93 | tensor([[[0.0000]]]) 94 | ``` 95 | """ 96 | assert attention_mask is not None 97 | output_attentions = ( 98 | output_attentions if output_attentions is not None else self.config.output_attentions 99 | ) 100 | output_hidden_states = ( 101 | output_hidden_states 102 | if output_hidden_states is not None 103 | else self.config.output_hidden_states 104 | ) 105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 106 | 107 | outputs = self.model( 108 | input_ids=input_ids, 109 | attention_mask=attention_mask, 110 | position_ids=position_ids, 111 | past_key_values=past_key_values, 112 | inputs_embeds=inputs_embeds, 113 | use_cache=use_cache, 114 | output_attentions=output_attentions, 115 | output_hidden_states=output_hidden_states, 116 | return_dict=return_dict, 117 | ) 118 | hidden_states = outputs[0] # size = (B, L, E) 119 | return self.get_score( 120 | hidden_states, 121 | attention_mask=attention_mask, 122 | return_dict=return_dict, 123 | ) 124 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/mistral/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.mistral.modeling_mistral import MistralModelForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/mistral/modeling_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, ClassVar 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import MistralModel, MistralPreTrainedModel, PreTrainedModel 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC, MISTRAL_INPUTS_DOCSTRING 25 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings 26 | 27 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 28 | 29 | 30 | class MistralModelForScore(ScoreModelMixin, MistralPreTrainedModel): 31 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight'] 32 | 33 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 34 | super().__init__(config) 35 | self.model = MistralModel(config) 36 | 37 | config.architectures = [self.__class__.__name__] 38 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 39 | 40 | # Initialize weights and apply final processing 41 | self.post_init() 42 | 43 | def get_input_embeddings(self) -> nn.Embedding: 44 | return self.model.embed_tokens 45 | 46 | def set_input_embeddings(self, value: nn.Embedding) -> None: 47 | self.model.embed_tokens = value 48 | 49 | def get_output_embeddings(self) -> None: 50 | return None 51 | 52 | def set_decoder(self, decoder: PreTrainedModel) -> None: 53 | self.model = decoder 54 | 55 | def get_decoder(self) -> PreTrainedModel: 56 | return self.model 57 | 58 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 59 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC) 60 | def forward( # pylint: disable=too-many-arguments 61 | self, 62 | input_ids: torch.LongTensor, 63 | attention_mask: torch.Tensor, 64 | position_ids: torch.LongTensor | None = None, 65 | past_key_values: list[torch.FloatTensor] | None = None, 66 | inputs_embeds: torch.FloatTensor | None = None, 67 | use_cache: bool | None = None, 68 | output_attentions: bool | None = None, 69 | output_hidden_states: bool | None = None, 70 | return_dict: bool | None = None, 71 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 72 | """ 73 | Args: 74 | 75 | Returns: 76 | 77 | Examples: 78 | 79 | ```python 80 | >>> from mcts_rl.models.mistral.modeling_mistral import MistralModelForScore 81 | >>> from transformers import LlamaTokenizer 82 | 83 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 84 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 85 | 86 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 87 | >>> inputs = tokenizer(prompt, return_tensors="pt") 88 | 89 | # got score 90 | >>> outputs = model(**inputs) 91 | >>> scores = outputs.scores 92 | >>> scores 93 | tensor([[[0.0000]]]) 94 | ``` 95 | """ 96 | assert attention_mask is not None 97 | output_attentions = ( 98 | output_attentions if output_attentions is not None else self.config.output_attentions 99 | ) 100 | output_hidden_states = ( 101 | output_hidden_states 102 | if output_hidden_states is not None 103 | else self.config.output_hidden_states 104 | ) 105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 106 | 107 | outputs = self.model( 108 | input_ids=input_ids, 109 | attention_mask=attention_mask, 110 | position_ids=position_ids, 111 | past_key_values=past_key_values, 112 | inputs_embeds=inputs_embeds, 113 | use_cache=use_cache, 114 | output_attentions=output_attentions, 115 | output_hidden_states=output_hidden_states, 116 | return_dict=return_dict, 117 | ) 118 | hidden_states = outputs[0] # size = (B, L, E) 119 | return self.get_score( 120 | hidden_states, 121 | attention_mask=attention_mask, 122 | return_dict=return_dict, 123 | ) 124 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/open_llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.open_llama.modeling_open_llama import OpenLlamaForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/open_llama/modeling_open_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, ClassVar 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import OpenLlamaModel, OpenLlamaPreTrainedModel, PreTrainedModel 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.models.open_llama.modeling_open_llama import ( 25 | _CONFIG_FOR_DOC, 26 | OPEN_LLAMA_INPUTS_DOCSTRING, 27 | ) 28 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings 29 | 30 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 31 | 32 | 33 | class OpenLlamaForScore(ScoreModelMixin, OpenLlamaPreTrainedModel): 34 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight'] 35 | 36 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 37 | super().__init__(config) 38 | self.model = OpenLlamaModel(config) 39 | 40 | config.shared_input_output_embedding = False 41 | config.architectures = [self.__class__.__name__] 42 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 43 | 44 | # Initialize weights and apply final processing 45 | self.post_init() 46 | 47 | def get_input_embeddings(self) -> nn.Embedding: 48 | return self.model.embed_tokens 49 | 50 | def set_input_embeddings(self, value: nn.Embedding) -> None: 51 | self.model.embed_tokens = value 52 | 53 | def get_output_embeddings(self) -> None: 54 | return None 55 | 56 | def set_decoder(self, decoder: PreTrainedModel) -> None: 57 | self.model = decoder 58 | 59 | def get_decoder(self) -> PreTrainedModel: 60 | return self.model 61 | 62 | @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING) 63 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC) 64 | def forward( # pylint: disable=too-many-arguments 65 | self, 66 | input_ids: torch.LongTensor, 67 | attention_mask: torch.Tensor, 68 | position_ids: torch.LongTensor | None = None, 69 | past_key_values: list[torch.FloatTensor] | None = None, 70 | inputs_embeds: torch.FloatTensor | None = None, 71 | use_cache: bool | None = None, 72 | output_attentions: bool | None = None, 73 | output_hidden_states: bool | None = None, 74 | return_dict: bool | None = None, 75 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 76 | """ 77 | Args: 78 | 79 | Returns: 80 | 81 | Examples: 82 | 83 | ```python 84 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore 85 | >>> from transformers import LlamaTokenizer 86 | 87 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 88 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 89 | 90 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 91 | >>> inputs = tokenizer(prompt, return_tensors="pt") 92 | 93 | # got score 94 | >>> outputs = model(**inputs) 95 | >>> scores = outputs.scores 96 | >>> scores 97 | tensor([[[0.0000]]]) 98 | ``` 99 | """ 100 | assert attention_mask is not None 101 | output_attentions = ( 102 | output_attentions if output_attentions is not None else self.config.output_attentions 103 | ) 104 | output_hidden_states = ( 105 | output_hidden_states 106 | if output_hidden_states is not None 107 | else self.config.output_hidden_states 108 | ) 109 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 110 | 111 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 112 | outputs = self.model( 113 | input_ids=input_ids, 114 | attention_mask=attention_mask, 115 | position_ids=position_ids, 116 | past_key_values=past_key_values, 117 | inputs_embeds=inputs_embeds, 118 | use_cache=use_cache, 119 | output_attentions=output_attentions, 120 | output_hidden_states=output_hidden_states, 121 | return_dict=return_dict, 122 | ) 123 | hidden_states = outputs[0] # size = (B, L, E) 124 | return self.get_score( 125 | hidden_states, 126 | attention_mask=attention_mask, 127 | return_dict=return_dict, 128 | ) 129 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/opt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from mcts_rl.models.score_model.opt.modeling_opt import OPTForScore 17 | -------------------------------------------------------------------------------- /mcts_rl/models/score_model/opt/modeling_opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, ClassVar 19 | 20 | import torch 21 | import torch.nn as nn 22 | from transformers import OPTModel, OPTPreTrainedModel, PreTrainedModel 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.models.opt.modeling_opt import _CONFIG_FOR_DOC 25 | from transformers.utils.doc import replace_return_docstrings 26 | 27 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput 28 | 29 | 30 | class OPTForScore(ScoreModelMixin, OPTPreTrainedModel): 31 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight'] 32 | 33 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 34 | super().__init__(config) 35 | self.model = OPTModel(config) 36 | 37 | config.architectures = [self.__class__.__name__] 38 | self.init_score_head(config, hidden_size=config.word_embed_proj_dim, **kwargs) 39 | 40 | # Initialize weights and apply final processing 41 | self.post_init() 42 | 43 | def get_input_embeddings(self) -> nn.Embedding: 44 | return self.model.decoder.embed_tokens 45 | 46 | def set_input_embeddings(self, value: nn.Embedding) -> None: 47 | self.model.decoder.embed_tokens = value 48 | 49 | def get_output_embeddings(self) -> None: 50 | return None 51 | 52 | def set_decoder(self, decoder: PreTrainedModel) -> None: 53 | self.model = decoder 54 | 55 | def get_decoder(self) -> PreTrainedModel: 56 | return self.model 57 | 58 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC) 59 | def forward( # pylint: disable=too-many-arguments 60 | self, 61 | input_ids: torch.LongTensor, 62 | attention_mask: torch.Tensor, 63 | head_mask: torch.Tensor | None = None, 64 | past_key_values: list[torch.FloatTensor] | None = None, 65 | inputs_embeds: torch.FloatTensor | None = None, 66 | use_cache: bool | None = None, 67 | output_attentions: bool | None = None, 68 | output_hidden_states: bool | None = None, 69 | return_dict: bool | None = None, 70 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 71 | """ 72 | Args: 73 | 74 | Returns: 75 | 76 | Examples: 77 | 78 | ```python 79 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore 80 | >>> from transformers import LlamaTokenizer 81 | 82 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 83 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 84 | 85 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 86 | >>> inputs = tokenizer(prompt, return_tensors="pt") 87 | 88 | # got score 89 | >>> outputs = model(**inputs) 90 | >>> scores = outputs.scores 91 | >>> scores 92 | tensor([[[0.0000]]]) 93 | ``` 94 | """ 95 | assert attention_mask is not None 96 | output_attentions = ( 97 | output_attentions if output_attentions is not None else self.config.output_attentions 98 | ) 99 | output_hidden_states = ( 100 | output_hidden_states 101 | if output_hidden_states is not None 102 | else self.config.output_hidden_states 103 | ) 104 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 105 | 106 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 107 | outputs = self.model.decoder( 108 | input_ids=input_ids, 109 | attention_mask=attention_mask, 110 | head_mask=head_mask, 111 | past_key_values=past_key_values, 112 | inputs_embeds=inputs_embeds, 113 | use_cache=use_cache, 114 | output_attentions=output_attentions, 115 | output_hidden_states=output_hidden_states, 116 | return_dict=return_dict, 117 | ) 118 | hidden_states = outputs[0] # size = (B, L, E) 119 | return self.get_score( 120 | hidden_states, 121 | attention_mask=attention_mask, 122 | return_dict=return_dict, 123 | ) 124 | -------------------------------------------------------------------------------- /mcts_rl/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Trainer base classes.""" 16 | 17 | from mcts_rl.trainers.base import TrainerBase 18 | from mcts_rl.trainers.rl_trainer import RLTrainer 19 | from mcts_rl.trainers.tsrl_trainer import TSRLTrainer 20 | from mcts_rl.trainers.supervised_trainer import SupervisedTrainer 21 | 22 | 23 | __all__ = ['TrainerBase', 'RLTrainer', 'TSRLTrainer', 'SupervisedTrainer'] 24 | -------------------------------------------------------------------------------- /mcts_rl/trainers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Trainer base class.""" 16 | 17 | from __future__ import annotations 18 | 19 | import abc 20 | import argparse 21 | import os 22 | import subprocess 23 | import sys 24 | from datetime import datetime 25 | from typing import Any, ClassVar 26 | 27 | import deepspeed 28 | import torch.distributed as dist 29 | from transformers import CONFIG_NAME, WEIGHTS_NAME, PreTrainedModel, PreTrainedTokenizerBase 30 | 31 | from mcts_rl.logger import Logger 32 | from mcts_rl.utils import is_main_process 33 | 34 | 35 | class TrainerBase(metaclass=abc.ABCMeta): 36 | """Trainer base class. 37 | 38 | Abstract methods: 39 | init_models: Initialize model and tokenizer. 40 | init_datasets: Initialize training and evaluation datasets. 41 | init_engines: Initialize DeepSpeed engines. 42 | train: Train model. 43 | set_train: Set training mode for all models. 44 | """ 45 | 46 | TRAINING_TYPE: ClassVar[str] 47 | 48 | tokenizer: PreTrainedTokenizerBase 49 | 50 | args: argparse.Namespace 51 | logger: Logger 52 | 53 | @abc.abstractmethod 54 | def init_models(self) -> None: 55 | """Initialize model and tokenizer.""" 56 | raise NotImplementedError 57 | 58 | @abc.abstractmethod 59 | def init_datasets(self) -> None: 60 | """Initialize training and evaluation datasets.""" 61 | raise NotImplementedError 62 | 63 | @abc.abstractmethod 64 | def init_engines(self) -> None: 65 | """Initialize DeepSpeed engines.""" 66 | raise NotImplementedError 67 | 68 | def init_logger(self) -> None: 69 | """Set logger.""" 70 | if self.args.log_type is None: 71 | self.logger = Logger(config=self.args) 72 | return 73 | 74 | time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 75 | 76 | self.args.log_dir = self.args.log_dir or self.args.output_dir 77 | self.args.log_project = self.args.log_project or 'safe-rlhf' 78 | self.args.log_run_name = self.args.log_run_name or f'{self.TRAINING_TYPE}-{time}' 79 | 80 | self.logger = Logger( 81 | log_type=self.args.log_type, 82 | log_dir=self.args.log_dir, 83 | log_project=self.args.log_project, 84 | log_run_name=self.args.log_run_name, 85 | config=self.args, 86 | ) 87 | 88 | @abc.abstractmethod 89 | def train(self) -> None: 90 | """Train model.""" 91 | raise NotImplementedError 92 | 93 | def eval(self) -> dict[str, Any]: 94 | """Evaluate model.""" 95 | return {} 96 | 97 | @abc.abstractmethod 98 | def set_train(self, mode: bool = True) -> None: 99 | """Set training mode for all models.""" 100 | raise NotImplementedError 101 | 102 | def set_eval(self) -> None: 103 | """Set model to evaluation mode.""" 104 | self.set_train(mode=False) 105 | 106 | def save( 107 | self, 108 | model: deepspeed.DeepSpeedEngine | None = None, 109 | ds_config: dict[str, Any] | None = None, 110 | global_steps: int = -1, 111 | ) -> None: 112 | """Save model and tokenizer in Hugging Face format.""" 113 | dist.barrier() 114 | 115 | if model is None: 116 | model = self.model # pylint: disable=no-member 117 | if ds_config is None: 118 | ds_config = self.ds_config # pylint: disable=no-member 119 | 120 | output_dir = self.args.output_dir 121 | if global_steps > 0: 122 | output_dir = os.path.join(output_dir, f'steps{global_steps}') 123 | os.makedirs(output_dir, exist_ok=True) 124 | 125 | self.logger.print(f'Saving model to "{self.args.output_dir}" ...') 126 | 127 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 128 | model_to_save: PreTrainedModel = getattr(model, 'module', model) 129 | if is_main_process(): 130 | model_to_save.config.to_json_file(output_config_file) 131 | self.tokenizer.save_pretrained(output_dir) 132 | 133 | if self.args.save_16bit: 134 | self.logger.print('Saving 16-bit model...') 135 | model.save_16bit_model(output_dir) 136 | else: 137 | # Save model checkpoint 138 | if ds_config['zero_optimization']['stage'] >= 2: 139 | self.logger.print('Saving DeepSpeed Checkpoints...') 140 | model.save_checkpoint(output_dir) 141 | self.logger.print('Converting DeepSpeed Checkpoints to Hugging Face format...') 142 | if is_main_process(): 143 | subprocess.check_call( 144 | [sys.executable, 'zero_to_fp32.py', '.', WEIGHTS_NAME], # noqa: S603 145 | cwd=output_dir, 146 | ) 147 | dist.barrier() 148 | else: 149 | self.logger.print('Saving Hugging Face Checkpoints...') 150 | if is_main_process(): 151 | model_to_save.save_pretrained(output_dir, is_main_process=True) 152 | 153 | self.logger.print('Model saved!') 154 | -------------------------------------------------------------------------------- /mcts_rl/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Safe-RLHF: Safe Reinforcement Learning with Human Feedback.""" 16 | 17 | __version__ = '0.0.1dev0' 18 | __license__ = 'Apache License, Version 2.0' 19 | __author__ = 'PKU-Alignment Team' 20 | __release__ = False 21 | 22 | if not __release__: 23 | import os 24 | import subprocess 25 | 26 | try: 27 | prefix, sep, suffix = ( 28 | subprocess.check_output( 29 | ['git', 'describe', '--abbrev=7'], # noqa: S603,S607 30 | cwd=os.path.dirname(os.path.abspath(__file__)), 31 | stderr=subprocess.DEVNULL, 32 | text=True, 33 | ) 34 | .strip() 35 | .lstrip('v') 36 | .replace('-', '.dev', 1) 37 | .replace('-', '+', 1) 38 | .partition('.dev') 39 | ) 40 | if sep: 41 | version_prefix, dot, version_tail = prefix.rpartition('.') 42 | prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' 43 | __version__ = sep.join((prefix, suffix)) 44 | del version_prefix, dot, version_tail 45 | else: 46 | __version__ = prefix 47 | del prefix, sep, suffix 48 | except (OSError, subprocess.CalledProcessError): 49 | pass 50 | 51 | del os, subprocess 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.13 2 | transformers >= 4.28 3 | datasets 4 | tokenizers >= 0.13.3 5 | accelerate 6 | pyproject-toml 7 | numpy 8 | scipy 9 | sentencepiece 10 | wandb 11 | tensorboard 12 | optree 13 | matplotlib 14 | tqdm 15 | rich 16 | nltk 17 | peft 18 | bert_score 19 | graphviz -------------------------------------------------------------------------------- /scripts/eval/mctseval_math.sh: -------------------------------------------------------------------------------- 1 | if [ -z "${BASH_VERSION}" ]; then 2 | echo "Please use bash to run this script." >&2 3 | exit 1 4 | fi 5 | 6 | set -x 7 | 8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" 9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" 10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" 11 | export LOGLEVEL="${LOGLEVEL:-WARNING}" 12 | 13 | MODEL_NAME="" 14 | STEP_NUM= 15 | ACTOR_MODEL_NAME_OR_PATH=MCTS-DPO/outputs/checkpoints/arithmetic/${MODEL_NAME}/steps${STEP_NUM} 16 | 17 | OUTPUT_DIR="MCTS-DPO/outputs/eval" 18 | unset HOSTFILE 19 | ZERO_STAGE=2 20 | OFFLOAD="all" 21 | 22 | if [[ -z "${REWARD_CRITIC_MODEL_NAME_OR_PATH+x}" ]]; then 23 | REWARD_CRITIC_MODEL_NAME_OR_PATH="${REWARD_MODEL_NAME_OR_PATH}" 24 | fi 25 | 26 | mkdir -p "${OUTPUT_DIR}" 27 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" 28 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then 29 | echo '*' >"${OUTPUT_DIR}/.gitignore" 30 | fi 31 | 32 | cp -f "$0" "${OUTPUT_DIR}/script.sh" 33 | 34 | if [[ -z "${WANDB_API_KEY}" ]]; then 35 | export WANDB_MODE="offline" 36 | fi 37 | 38 | MASTER_PORT_START=10000 39 | MASTER_PORT_END=65535 40 | MASTER_PORT="$( 41 | comm -23 \ 42 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ 43 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | 44 | shuf | head -n 1 45 | )" 46 | 47 | DEEPSPEED_ARGS=() 48 | if [[ -n "${HOSTFILE+x}" ]]; then 49 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") 50 | fi 51 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") 52 | 53 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) 54 | 55 | export WANDB_API_KEY="1396a7d2a29a8e8241dff6e0e6371f2ad61e11e2" 56 | export WANDB_MODE=dryrun 57 | 58 | export NCCL_DEBUG=INFO 59 | export NCCL_DEBUG_SUBSYS=INIT,P2P 60 | 61 | gpu_vis=$1 62 | 63 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \ 64 | --module mcts_rl.algorithms.mcts \ 65 | --train_datasets GSM8K/train \ 66 | --eval_datasets MATH/test \ 67 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 68 | --actor_ref_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 69 | --max_length 768 \ 70 | --repetition_penalty 1.0 \ 71 | --trust_remote_code True \ 72 | --epochs 1 \ 73 | --update_iters 1 \ 74 | --save_interval 128 \ 75 | --per_device_prompt_batch_size 1 \ 76 | --per_device_train_batch_size 1 \ 77 | --per_device_eval_batch_size 1 \ 78 | --gradient_accumulation_steps 8 \ 79 | --actor_lr 1e-6 \ 80 | --actor_weight_decay 0.01 \ 81 | --actor_lr_scheduler_type cosine \ 82 | --actor_lr_warmup_ratio 0.03 \ 83 | --actor_gradient_checkpointing \ 84 | --need_eval \ 85 | --seed 42 \ 86 | --kl_coeff 0.02 \ 87 | --clip_range_ratio 0.2 \ 88 | --clip_range_score 50.0 \ 89 | --clip_range_value 5.0 \ 90 | --output_dir "${OUTPUT_DIR}" \ 91 | --log_type wandb \ 92 | --log_project MCTS-DPO-EVAL \ 93 | --zero_stage "${ZERO_STAGE}" \ 94 | --offload "${OFFLOAD}" \ 95 | --bf16 True \ 96 | --tf32 True \ 97 | --force_terminating_on_depth_limit \ 98 | --max_new_tokens 64 \ 99 | --n_iters 5 \ 100 | --depth_limit 3 \ 101 | --n_init_actions 5 \ 102 | --n_actions 3 \ 103 | --mcts_temperature 0.0 \ 104 | --num_return_sequences 1 \ 105 | --temperature 1.0 \ 106 | --model_type mistral \ 107 | --prediction_file_path MCTS-DPO/outputs/checkpoints/arithmetic/predictions/mistral_${modelname}_math${stp}.jsonl 108 | -------------------------------------------------------------------------------- /scripts/eval/mctseval_sqa.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ -z "${BASH_VERSION}" ]; then 3 | echo "Please use bash to run this script." >&2 4 | exit 1 5 | fi 6 | 7 | set -x 8 | 9 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" 10 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" 11 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" 12 | export LOGLEVEL="${LOGLEVEL:-WARNING}" 13 | 14 | MODEL_NAME="" 15 | STEP_NUM= 16 | ACTOR_MODEL_NAME_OR_PATH=MCTS-DPO/outputs/checkpoints/sqa/${MODEL_NAME}/steps${STEP_NUM} 17 | 18 | OUTPUT_DIR="MCTS-DPO/outputs/eval" 19 | unset HOSTFILE 20 | ZERO_STAGE=2 21 | OFFLOAD="all" 22 | 23 | 24 | mkdir -p "${OUTPUT_DIR}" 25 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" 26 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then 27 | echo '*' >"${OUTPUT_DIR}/.gitignore" 28 | fi 29 | 30 | cp -f "$0" "${OUTPUT_DIR}/script.sh" 31 | 32 | if [[ -z "${WANDB_API_KEY}" ]]; then 33 | export WANDB_MODE="offline" 34 | fi 35 | 36 | MASTER_PORT_START=10000 37 | MASTER_PORT_END=65535 38 | MASTER_PORT="$( 39 | comm -23 \ 40 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ 41 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | 42 | shuf | head -n 1 43 | )" 44 | 45 | DEEPSPEED_ARGS=() 46 | if [[ -n "${HOSTFILE+x}" ]]; then 47 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") 48 | fi 49 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") 50 | 51 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) 52 | 53 | export NCCL_DEBUG=INFO 54 | export NCCL_DEBUG_SUBSYS=INIT,P2P 55 | 56 | 57 | gpu_vis=$1 58 | 59 | 60 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \ 61 | --module mcts_rl.algorithms.mcts \ 62 | --train_datasets SQA/train \ 63 | --eval_datasets MCQ/test \ 64 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 65 | --actor_ref_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 66 | --max_length 512 \ 67 | --repetition_penalty 1.0 \ 68 | --trust_remote_code True \ 69 | --epochs 1 \ 70 | --update_iters 1 \ 71 | --save_interval 128 \ 72 | --per_device_prompt_batch_size 1 \ 73 | --per_device_train_batch_size 1 \ 74 | --per_device_eval_batch_size 1 \ 75 | --gradient_accumulation_steps 8 \ 76 | --actor_lr 1e-6 \ 77 | --actor_weight_decay 0.01 \ 78 | --actor_lr_scheduler_type cosine \ 79 | --actor_lr_warmup_ratio 0.03 \ 80 | --actor_gradient_checkpointing \ 81 | --need_eval \ 82 | --seed 42 \ 83 | --kl_coeff 0.02 \ 84 | --clip_range_ratio 0.2 \ 85 | --clip_range_score 50.0 \ 86 | --clip_range_value 5.0 \ 87 | --output_dir "${OUTPUT_DIR}" \ 88 | --log_type wandb \ 89 | --log_project MCTS-DPO-EVAL \ 90 | --zero_stage "${ZERO_STAGE}" \ 91 | --offload "${OFFLOAD}" \ 92 | --bf16 True \ 93 | --tf32 True \ 94 | --max_new_tokens 32 \ 95 | --depth_limit 3 \ 96 | --n_init_actions 4 \ 97 | --n_actions 3 \ 98 | --n_iters 16 \ 99 | --mcts_temperature 0.0 \ 100 | --num_return_sequences 1 \ 101 | --temperature 1.0 \ 102 | --init_temperature 1.0 \ 103 | --model_type mistral \ 104 | --use_mcq \ 105 | --prediction_file_path MCTS-DPO/outputs/checkpoints/sqa/predictions/mistral_${MODEL_NAME}_${STEP_NUM}.jsonl 106 | -------------------------------------------------------------------------------- /scripts/mcts_csr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -z "${BASH_VERSION}" ]; then 4 | echo "Please use bash to run this script." >&2 5 | exit 1 6 | fi 7 | 8 | set -x 9 | 10 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" 11 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" 12 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" 13 | export LOGLEVEL="${LOGLEVEL:-WARNING}" 14 | 15 | 16 | ACTOR_MODEL_NAME_OR_PATH="upaya07/Arithmo2-Mistral-7B" 17 | ACTOR_REF_MODEL_NAME_OR_PATH="upaya07/Arithmo2-Mistral-7B" 18 | 19 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/sqa/cdpo-4x2" 20 | unset HOSTFILE 21 | ZERO_STAGE=3 22 | OFFLOAD="optimizer" 23 | 24 | 25 | mkdir -p "${OUTPUT_DIR}" 26 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" 27 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then 28 | echo '*' >"${OUTPUT_DIR}/.gitignore" 29 | fi 30 | 31 | cp -f "$0" "${OUTPUT_DIR}/script.sh" 32 | 33 | export WANDB_API_KEY="" 34 | export WANDB_MODE=online 35 | if [[ -z "${WANDB_API_KEY}" ]]; then 36 | export WANDB_MODE="offline" 37 | fi 38 | 39 | MASTER_PORT_START=10000 40 | MASTER_PORT_END=65535 41 | MASTER_PORT="$( 42 | comm -23 \ 43 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ 44 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | 45 | shuf | head -n 1 46 | )" 47 | 48 | DEEPSPEED_ARGS=() 49 | if [[ -n "${HOSTFILE+x}" ]]; then 50 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") 51 | fi 52 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") 53 | 54 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) 55 | 56 | gpu_vis=$1 57 | 58 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \ 59 | --module mcts_rl.algorithms.mcts \ 60 | --train_datasets SQA/train \ 61 | --model_type mistral \ 62 | --save_mcts_data \ 63 | --choose_worst \ 64 | --use_mcq \ 65 | --filter \ 66 | --conservative \ 67 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 68 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \ 69 | --scale_coeff 0.1 \ 70 | --max_length 512 \ 71 | --temperature 1.0 \ 72 | --init_temperature 1.0 \ 73 | --num_return_sequences 1 \ 74 | --repetition_penalty 1.0 \ 75 | --mcts_length_penalty 1.25 \ 76 | --trust_remote_code True \ 77 | --epochs 1 \ 78 | --update_iters 1 \ 79 | --save_interval 64 \ 80 | --per_device_ptx_batch_size 4 \ 81 | --per_device_prompt_batch_size 1 \ 82 | --per_device_train_batch_size 1 \ 83 | --gradient_accumulation_steps 64 \ 84 | --actor_lr 2e-6 \ 85 | --actor_weight_decay 0.05 \ 86 | --actor_lr_scheduler_type cosine \ 87 | --actor_lr_warmup_ratio 0.2 \ 88 | --actor_gradient_checkpointing \ 89 | --seed 42 \ 90 | --kl_coeff 0.02 \ 91 | --clip_range_ratio 0.2 \ 92 | --clip_range_score 50.0 \ 93 | --clip_range_value 5.0 \ 94 | --output_dir "${OUTPUT_DIR}" \ 95 | --log_type wandb \ 96 | --log_project MCTS-IPL-SQA \ 97 | --zero_stage "${ZERO_STAGE}" \ 98 | --offload "${OFFLOAD}" \ 99 | --bf16 True \ 100 | --tf32 True \ 101 | --max_new_tokens 32 \ 102 | --n_iters 64 \ 103 | --depth_limit 4 \ 104 | --n_init_actions 4 \ 105 | --n_actions 2 \ 106 | --mcts_temperature 0.0 107 | -------------------------------------------------------------------------------- /scripts/mcts_mathqa.sh: -------------------------------------------------------------------------------- 1 | if [ -z "${BASH_VERSION}" ]; then 2 | echo "Please use bash to run this script." >&2 3 | exit 1 4 | fi 5 | 6 | set -x 7 | 8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" 9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" 10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" 11 | export LOGLEVEL="${LOGLEVEL:-WARNING}" 12 | 13 | ACTOR_MODEL_NAME_OR_PATH="SFT-Arithmo" 14 | ACTOR_REF_MODEL_NAME_OR_PATH="SFT-Arithmo" 15 | 16 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/arithmetic/cdpo-2x2-gtsft" 17 | unset HOSTFILE 18 | ZERO_STAGE=3 19 | OFFLOAD="optimizer" 20 | 21 | mkdir -p "${OUTPUT_DIR}" 22 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" 23 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then 24 | echo '*' >"${OUTPUT_DIR}/.gitignore" 25 | fi 26 | 27 | cp -f "$0" "${OUTPUT_DIR}/script.sh" 28 | 29 | export WANDB_API_KEY="" 30 | export WANDB_MODE=online 31 | if [[ -z "${WANDB_API_KEY}" ]]; then 32 | export WANDB_MODE="offline" 33 | fi 34 | 35 | MASTER_PORT_START=10000 36 | MASTER_PORT_END=65535 37 | MASTER_PORT="$( 38 | comm -23 \ 39 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ 40 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | 41 | shuf | head -n 1 42 | )" 43 | 44 | DEEPSPEED_ARGS=() 45 | if [[ -n "${HOSTFILE+x}" ]]; then 46 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") 47 | fi 48 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") 49 | 50 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) 51 | 52 | gpu_vis=$1 53 | 54 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \ 55 | --module mcts_rl.algorithms.mcts \ 56 | --train_datasets MathQA/train \ 57 | --model_type mistral \ 58 | --choose_worst \ 59 | --save_mcts_data \ 60 | --filter \ 61 | --iteration_interval 64 \ 62 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 63 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \ 64 | --scale_coeff 0.1 \ 65 | --max_length 512 \ 66 | --temperature 1.0 \ 67 | --init_temperature 1.0 \ 68 | --mcts_length_penalty 1.25 \ 69 | --num_return_sequences 1 \ 70 | --repetition_penalty 1.0 \ 71 | --trust_remote_code True \ 72 | --epochs 1 \ 73 | --conservative \ 74 | --update_iters 1 \ 75 | --save_interval 64 \ 76 | --per_device_ptx_batch_size 4 \ 77 | --per_device_prompt_batch_size 1 \ 78 | --per_device_train_batch_size 1 \ 79 | --gradient_accumulation_steps 64 \ 80 | --actor_lr 1e-6 \ 81 | --actor_weight_decay 0.05 \ 82 | --actor_lr_scheduler_type cosine \ 83 | --actor_lr_warmup_ratio 0.03 \ 84 | --actor_gradient_checkpointing \ 85 | --seed 42 \ 86 | --kl_coeff 0.02 \ 87 | --clip_range_ratio 0.2 \ 88 | --clip_range_score 50.0 \ 89 | --clip_range_value 5.0 \ 90 | --ptx_coeff 0.0 \ 91 | --output_dir "${OUTPUT_DIR}" \ 92 | --log_type wandb \ 93 | --log_project MCTS-IPL-Math \ 94 | --zero_stage "${ZERO_STAGE}" \ 95 | --offload "${OFFLOAD}" \ 96 | --bf16 True \ 97 | --tf32 True \ 98 | --max_new_tokens 128 \ 99 | --n_iters 64 \ 100 | --depth_limit 3 \ 101 | --n_init_actions 2 \ 102 | --n_actions 2 \ 103 | --force_terminating_on_depth_limit \ 104 | --mcts_temperature 0.0 -------------------------------------------------------------------------------- /scripts/mcts_mathqa2.sh: -------------------------------------------------------------------------------- 1 | if [ -z "${BASH_VERSION}" ]; then 2 | echo "Please use bash to run this script." >&2 3 | exit 1 4 | fi 5 | 6 | set -x 7 | 8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" 9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" 10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" 11 | export LOGLEVEL="${LOGLEVEL:-WARNING}" 12 | 13 | ACTOR_MODEL_NAME_OR_PATH="SFT-Arithmo" 14 | ACTOR_REF_MODEL_NAME_OR_PATH="SFT-Arithmo" 15 | 16 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/arithmetic/cdpo-2x2-nogt" 17 | unset HOSTFILE 18 | ZERO_STAGE=3 19 | OFFLOAD="optimizer" 20 | 21 | mkdir -p "${OUTPUT_DIR}" 22 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" 23 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then 24 | echo '*' >"${OUTPUT_DIR}/.gitignore" 25 | fi 26 | 27 | cp -f "$0" "${OUTPUT_DIR}/script.sh" 28 | 29 | export WANDB_API_KEY="" 30 | export WANDB_MODE=online 31 | if [[ -z "${WANDB_API_KEY}" ]]; then 32 | export WANDB_MODE="offline" 33 | fi 34 | 35 | MASTER_PORT_START=10000 36 | MASTER_PORT_END=65535 37 | MASTER_PORT="$( 38 | comm -23 \ 39 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ 40 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | 41 | shuf | head -n 1 42 | )" 43 | 44 | DEEPSPEED_ARGS=() 45 | if [[ -n "${HOSTFILE+x}" ]]; then 46 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") 47 | fi 48 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") 49 | 50 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) 51 | 52 | gpu_vis=$1 53 | 54 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \ 55 | --module mcts_rl.algorithms.mcts \ 56 | --train_datasets MathQA/train \ 57 | --model_type mistral \ 58 | --choose_worst \ 59 | --save_mcts_data \ 60 | --filter \ 61 | --not_include_gt \ 62 | --iteration_interval 64 \ 63 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 64 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \ 65 | --scale_coeff 0.1 \ 66 | --max_length 512 \ 67 | --temperature 1.0 \ 68 | --init_temperature 1.0 \ 69 | --mcts_length_penalty 1.25 \ 70 | --num_return_sequences 1 \ 71 | --repetition_penalty 1.0 \ 72 | --trust_remote_code True \ 73 | --epochs 1 \ 74 | --conservative \ 75 | --update_iters 1 \ 76 | --save_interval 64 \ 77 | --per_device_ptx_batch_size 4 \ 78 | --per_device_prompt_batch_size 1 \ 79 | --per_device_train_batch_size 1 \ 80 | --gradient_accumulation_steps 64 \ 81 | --actor_lr 1e-6 \ 82 | --actor_weight_decay 0.05 \ 83 | --actor_lr_scheduler_type cosine \ 84 | --actor_lr_warmup_ratio 0.03 \ 85 | --actor_gradient_checkpointing \ 86 | --seed 42 \ 87 | --kl_coeff 0.02 \ 88 | --clip_range_ratio 0.2 \ 89 | --clip_range_score 50.0 \ 90 | --clip_range_value 5.0 \ 91 | --ptx_coeff 0.0 \ 92 | --output_dir "${OUTPUT_DIR}" \ 93 | --log_type wandb \ 94 | --log_project MCTS-IPL-Math \ 95 | --zero_stage "${ZERO_STAGE}" \ 96 | --offload "${OFFLOAD}" \ 97 | --bf16 True \ 98 | --tf32 True \ 99 | --max_new_tokens 128 \ 100 | --n_iters 64 \ 101 | --depth_limit 3 \ 102 | --n_init_actions 2 \ 103 | --n_actions 2 \ 104 | --force_terminating_on_depth_limit \ 105 | --mcts_temperature 0.0 -------------------------------------------------------------------------------- /scripts/mcts_mathqa_llama3.sh: -------------------------------------------------------------------------------- 1 | if [ -z "${BASH_VERSION}" ]; then 2 | echo "Please use bash to run this script." >&2 3 | exit 1 4 | fi 5 | 6 | set -x 7 | 8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" 9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" 10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" 11 | export LOGLEVEL="${LOGLEVEL:-WARNING}" 12 | 13 | ACTOR_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 14 | ACTOR_REF_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 15 | 16 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/arithmetic/llama3-cdpo-2x2-gtsft" 17 | unset HOSTFILE 18 | ZERO_STAGE=3 19 | OFFLOAD="optimizer" 20 | 21 | 22 | mkdir -p "${OUTPUT_DIR}" 23 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" 24 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then 25 | echo '*' >"${OUTPUT_DIR}/.gitignore" 26 | fi 27 | 28 | cp -f "$0" "${OUTPUT_DIR}/script.sh" 29 | 30 | export WANDB_API_KEY="" 31 | export WANDB_MODE=online 32 | if [[ -z "${WANDB_API_KEY}" ]]; then 33 | export WANDB_MODE="offline" 34 | fi 35 | 36 | MASTER_PORT_START=10000 37 | MASTER_PORT_END=65535 38 | MASTER_PORT="$( 39 | comm -23 \ 40 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ 41 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | 42 | shuf | head -n 1 43 | )" 44 | 45 | DEEPSPEED_ARGS=() 46 | if [[ -n "${HOSTFILE+x}" ]]; then 47 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") 48 | fi 49 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") 50 | 51 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) 52 | 53 | gpu_vis=$1 54 | 55 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \ 56 | --module mcts_rl.algorithms.mcts \ 57 | --train_datasets MathQA/train \ 58 | --model_type llama3 \ 59 | --choose_worst \ 60 | --save_mcts_data \ 61 | --filter \ 62 | --iteration_interval 64 \ 63 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ 64 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \ 65 | --scale_coeff 0.1 \ 66 | --max_length 512 \ 67 | --temperature 1.0 \ 68 | --init_temperature 1.0 \ 69 | --mcts_length_penalty 1.25 \ 70 | --num_return_sequences 1 \ 71 | --repetition_penalty 1.0 \ 72 | --trust_remote_code True \ 73 | --epochs 1 \ 74 | --conservative \ 75 | --update_iters 1 \ 76 | --save_interval 64 \ 77 | --per_device_ptx_batch_size 4 \ 78 | --per_device_prompt_batch_size 1 \ 79 | --per_device_train_batch_size 1 \ 80 | --gradient_accumulation_steps 64 \ 81 | --actor_lr 1e-6 \ 82 | --actor_weight_decay 0.05 \ 83 | --actor_lr_scheduler_type cosine \ 84 | --actor_lr_warmup_ratio 0.03 \ 85 | --actor_gradient_checkpointing \ 86 | --seed 42 \ 87 | --kl_coeff 0.02 \ 88 | --clip_range_ratio 0.2 \ 89 | --clip_range_score 50.0 \ 90 | --clip_range_value 5.0 \ 91 | --ptx_coeff 0.0 \ 92 | --output_dir "${OUTPUT_DIR}" \ 93 | --log_type wandb \ 94 | --log_project MCTS-IPL-Math \ 95 | --zero_stage "${ZERO_STAGE}" \ 96 | --offload "${OFFLOAD}" \ 97 | --bf16 True \ 98 | --tf32 True \ 99 | --max_new_tokens 128 \ 100 | --n_iters 64 \ 101 | --depth_limit 3 \ 102 | --n_init_actions 2 \ 103 | --n_actions 2 \ 104 | --force_terminating_on_depth_limit \ 105 | --mcts_temperature 0.0 --------------------------------------------------------------------------------