├── .gitignore
├── LICENSE
├── README.md
├── conda-recipe.yaml
├── framework-colorblindfriendly.jpg
├── mcts_rl
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── dpo
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── main.py
│ │ └── trainer.py
│ ├── mcts
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── main.py
│ │ ├── mcts
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── embedding_retrieve.py
│ │ │ ├── mcts.py
│ │ │ ├── search_config.py
│ │ │ └── world_model.py
│ │ └── trainer.py
│ └── ppo
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── main.py
│ │ └── trainer.py
├── configs
│ ├── __init__.py
│ ├── constants.py
│ ├── deepspeed_config.py
│ ├── ds_eval_config_template.json
│ ├── ds_train_config_template.json
│ └── fsdp_config.json
├── datasets
│ ├── __init__.py
│ ├── base.py
│ ├── preference.py
│ ├── prompt_only.py
│ ├── raw
│ │ ├── __init__.py
│ │ ├── alpaca.py
│ │ ├── aqua.py
│ │ ├── arithmo.py
│ │ ├── exam.py
│ │ ├── firefly.py
│ │ ├── gsm8k.py
│ │ ├── hh_rlhf.py
│ │ ├── math.py
│ │ ├── math_qa.py
│ │ ├── mcq.py
│ │ ├── mcq_for_eval.py
│ │ ├── mcq_pairs.py
│ │ ├── moss.py
│ │ ├── prm800k.py
│ │ ├── qa_feedback.py
│ │ └── safe_rlhf.py
│ ├── safety_preference.py
│ ├── supervised.py
│ └── utils.py
├── finetune
│ ├── __init__.py
│ ├── __main__.py
│ ├── deepspeed.py
│ ├── huggingface.py
│ ├── main.py
│ └── trainer.py
├── logger.py
├── models
│ ├── __init__.py
│ ├── normalizer.py
│ ├── pretrained.py
│ └── score_model
│ │ ├── __init__.py
│ │ ├── bloom
│ │ ├── __init__.py
│ │ └── modeling_bloom.py
│ │ ├── gpt2
│ │ ├── __init__.py
│ │ └── modeling_gpt2.py
│ │ ├── gpt_neo
│ │ ├── __init__.py
│ │ └── modeling_gpt_neo.py
│ │ ├── gpt_neox
│ │ ├── __init__.py
│ │ └── modeling_gpt_neox.py
│ │ ├── gptj
│ │ ├── __init__.py
│ │ └── modeling_gptj.py
│ │ ├── llama
│ │ ├── __init__.py
│ │ └── modeling_llama.py
│ │ ├── mistral
│ │ ├── __init__.py
│ │ └── modeling_mistral.py
│ │ ├── open_llama
│ │ ├── __init__.py
│ │ └── modeling_open_llama.py
│ │ └── opt
│ │ ├── __init__.py
│ │ └── modeling_opt.py
├── trainers
│ ├── __init__.py
│ ├── base.py
│ ├── rl_trainer.py
│ ├── supervised_trainer.py
│ └── tsrl_trainer.py
├── utils.py
└── version.py
├── requirements.txt
├── scripts
├── eval
│ ├── mctseval_math.sh
│ └── mctseval_sqa.sh
├── mcts_csr.sh
├── mcts_mathqa.sh
├── mcts_mathqa2.sh
└── mcts_mathqa_llama3.sh
└── visualize.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning
2 |
3 | This repository contains code and analysis for the paper: [Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning](https://arxiv.org/abs/2405.00451).
4 | Below is the framework of our proposed method.
5 |
6 | 
7 |
8 | #### Environment Setup
9 |
10 | ```sh
11 | conda env create --file conda-recipe.yaml
12 | pip install -r requirements.txt
13 | ```
14 |
15 | #### Dataset Download
16 |
17 | * Arithmo: [akjindal53244/Arithmo-Data](https://huggingface.co/datasets/akjindal53244/Arithmo-Data)
18 |
19 | * GSM8K: [openai/grade-school-math](https://github.com/openai/grade-school-math/tree/master/grade_school_math/data)
20 |
21 | * MATH: [hendrycks/math](https://github.com/hendrycks/math/)
22 |
23 | * ARC: [AI2 Reasoning Challenge](https://paperswithcode.com/dataset/arc)
24 |
25 | * AI2S: [AI2 Science Questions](http://data.allenai.org/ai2-science-questions)
26 |
27 | * OBQA: [Openbook QA](https://allenai.org/data/open-book-qa)
28 |
29 | * CSQA: [tau/commonsense_qa](https://huggingface.co/datasets/tau/commonsense_qa)
30 |
31 | * SciQ: [SciQ Dataset](https://allenai.org/data/sciq)
32 |
33 |
34 | #### Run MCTS-DPO
35 |
36 | Our main code include `./mcts_rl/algorithms/mcts` and `./mcts_rl/trainers/tsrl_trainer.py`
37 |
38 | To run MCTS-DPO for MathQA on Mistral (SFT):
39 | ```sh
40 | bash scripts/mcts_mathqa.sh
41 | ```
42 |
43 | To run MCTS-DPO for CSR on Mistral (SFT):
44 | ```sh
45 | bash scripts/mcts_csr.sh
46 | ```
47 |
48 | ## Citation
49 |
50 | ```
51 | @article{xie2024monte,
52 | title={Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning},
53 | author={Xie, Yuxi and Goyal, Anirudh and Zheng, Wenyue and Kan, Min-Yen and Lillicrap, Timothy P and Kawaguchi, Kenji and Shieh, Michael},
54 | journal={arXiv preprint arXiv:2405.00451},
55 | year={2024}
56 | }
57 | ```
58 |
59 | ---
60 | This repository is adapted from the code of the works [Safe-RLHF](https://github.com/PKU-Alignment/safe-rlhf).
61 |
--------------------------------------------------------------------------------
/conda-recipe.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # Create virtual environment with command:
3 | #
4 | # $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml
5 | #
6 |
7 | name: mcts-dpo
8 | channels:
9 | - huggingface
10 | - pytorch
11 | - nvidia/label/cuda-11.7.1
12 | - defaults
13 | - conda-forge
14 | dependencies:
15 | - python = 3.10
16 | - pip
17 |
18 | - pytorch::pytorch >= 1.13
19 | - pytorch::pytorch-mutex =*=*cuda*
20 | - transformers >= 4.29.0
21 | - datasets
22 | - tokenizers >= 0.13.3
23 | - sentencepiece
24 | - tensorboard
25 | - wandb
26 | - pip:
27 | - accelerate
28 |
29 | - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7
30 |
31 | - optree
32 | - scipy
33 | - nvitop
34 | - matplotlib-base
35 | - rich
36 | - tqdm
37 | - typing-extensions
38 | - ipdb
39 | - jsonlines
40 | - func_timeout
41 |
--------------------------------------------------------------------------------
/framework-colorblindfriendly.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuxiXie/MCTS-DPO/fb23437f502189b21e0824e7f6ad9c822cc8e091/framework-colorblindfriendly.jpg
--------------------------------------------------------------------------------
/mcts_rl/__init__.py:
--------------------------------------------------------------------------------
1 | from mcts_rl import algorithms, configs, datasets, models, trainers, utils
2 | from mcts_rl.algorithms import * # noqa: F403
3 | from mcts_rl.configs import * # noqa: F403
4 | from mcts_rl.datasets import * # noqa: F403
5 | from mcts_rl.models import * # noqa: F403
6 | from mcts_rl.trainers import * # noqa: F403
7 | from mcts_rl.utils import * # noqa: F403
8 | from mcts_rl.version import __version__
9 |
10 |
11 | __all__ = [
12 | *algorithms.__all__,
13 | *configs.__all__,
14 | *datasets.__all__,
15 | *models.__all__,
16 | *trainers.__all__,
17 | *utils.__all__,
18 | ]
19 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """RL algorithms for RLHF."""
16 |
17 | from mcts_rl.algorithms.dpo import DPOTrainer
18 | from mcts_rl.algorithms.ppo import PPOTrainer
19 | from mcts_rl.algorithms.mcts import MCTSTrainer
20 |
21 |
22 | __all__ = ['PPOTrainer', 'DPOTrainer', 'MCTSTrainer']
23 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/dpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The Direct Preference Optimization (DPO) algorithm."""
16 |
17 | from mcts_rl.algorithms.dpo.trainer import DPOTrainer
18 |
19 |
20 | __all__ = ['DPOTrainer']
21 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/dpo/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The main training script to run the DPO algorithm."""
16 |
17 | import sys
18 |
19 | from mcts_rl.algorithms.dpo.main import main
20 |
21 |
22 | if __name__ == '__main__':
23 | sys.exit(main())
24 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/mcts/__init__.py:
--------------------------------------------------------------------------------
1 | """RL with MCTS algorithm."""
2 |
3 | from mcts_rl.algorithms.mcts.trainer import MCTSTrainer
4 |
5 |
6 | __all__ = ['MCTSTrainer']
7 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/mcts/__main__.py:
--------------------------------------------------------------------------------
1 | """The main training script to train RLHF using PPO algorithm."""
2 |
3 | import sys
4 |
5 | from mcts_rl.algorithms.mcts.main import main
6 |
7 |
8 | if __name__ == '__main__':
9 | sys.exit(main())
10 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/mcts/mcts/__init__.py:
--------------------------------------------------------------------------------
1 | from mcts_rl.algorithms.mcts.mcts.base import TreeConstructor
2 | from mcts_rl.algorithms.mcts.mcts.mcts import MCTS, MCTSNode, MCTSResult, MCTSConfig
3 | from mcts_rl.algorithms.mcts.mcts.world_model import StepLMWorldModel, LMExample
4 | from mcts_rl.algorithms.mcts.mcts.search_config import StepLMConfig, SearchArgs
--------------------------------------------------------------------------------
/mcts_rl/algorithms/mcts/mcts/base.py:
--------------------------------------------------------------------------------
1 | from typing import Generic, TypeVar, Protocol
2 | from abc import ABC, abstractmethod
3 |
4 | State = TypeVar("State")
5 | Action = TypeVar("Action")
6 | Example = TypeVar("Example")
7 | Trace = tuple[list[State], list[Action]]
8 | Args = TypeVar("Args")
9 |
10 |
11 | class WorldModel(ABC, Generic[State, Action, Example]):
12 | def __init__(self) -> None:
13 | self.example = None
14 |
15 | @abstractmethod
16 | def init_state(self) -> State: ...
17 |
18 | @abstractmethod
19 | def step(self, state: State, action: Action) -> State: ...
20 |
21 | @abstractmethod
22 | def is_terminal(self, state: State) -> bool: ...
23 |
24 | def update_example(self, example: Example) -> None:
25 | self.example = example
26 |
27 |
28 | class SearchConfig(ABC, Generic[State, Action, Example]):
29 | def __init__(self) -> None:
30 | self.example = None
31 |
32 | @abstractmethod
33 | def get_actions(self, state: State) -> list[Action]: ...
34 |
35 | @abstractmethod
36 | def reward(self, state, action, **kwargs) -> tuple[float, dict]: ...
37 |
38 | @abstractmethod
39 | def get_values(self, state: State, action: Action) -> list[tuple[float, bool]]: ...
40 |
41 | def update_example(self, example: Example) -> None:
42 | self.example = example
43 |
44 |
45 | class HasTerminalStateAndTrace(Protocol[State]):
46 | terminal_state: State
47 | trace: Trace
48 |
49 |
50 | class SearchAlgorithm(ABC):
51 | def __init__(self, **kwargs): ...
52 |
53 | @abstractmethod
54 | def __call__(self, world_model: WorldModel, search_config: SearchConfig, **kwargs) -> HasTerminalStateAndTrace: ...
55 |
56 |
57 | class TreeConstructor(ABC, Generic[State, Action, Example]):
58 | def __init__(self,
59 | world_model: WorldModel[State, Action, Example],
60 | search_config: SearchConfig[State, Action, Example],
61 | search_algo: SearchAlgorithm) -> None:
62 | self.world_model = world_model
63 | self.search_config = search_config
64 | self.search_algo = search_algo
65 |
66 | def __call__(self, example: Example, node=None, **kwargs) -> HasTerminalStateAndTrace[State]:
67 | self.world_model.update_example(example)
68 | self.search_config.update_example(example)
69 | return self.search_algo(self.world_model,
70 | self.search_config,
71 | root_node=node,
72 | **kwargs)
73 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/mcts/mcts/embedding_retrieve.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 | import torch
4 | import torch.distributed as dist
5 | from bert_score.utils import greedy_cos_idf
6 | from torch.nn.utils.rnn import pad_sequence
7 |
8 |
9 | def get_embs_masks(seq_ids: list[list[int]], seq_embs: list[torch.Tensor],
10 | special_ids: list = [1, 2, 32000]):
11 | seq_ids = [ids[:len(embs)] for ids, embs in zip(seq_ids, seq_embs)]
12 | seq_embs = pad_sequence(seq_embs, batch_first=True, padding_value=2.0)
13 | lengths = torch.tensor([len(seq) for seq in seq_ids], dtype=torch.long)
14 | seq_masks = torch.arange(max(lengths), dtype=torch.long).expand(len(lengths), max(lengths))
15 | seq_masks = seq_masks < lengths.unsqueeze(1)
16 | seq_idfs = pad_sequence([
17 | torch.tensor([float(_id not in special_ids) for _id in seq])
18 | for seq in seq_ids], batch_first=True, padding_value=0.0)
19 | return seq_embs, seq_masks.to(seq_embs.device), seq_idfs.to(seq_embs.device)
20 |
21 |
22 | def filter_with_similarity(candidates, tokenizer):
23 | sequences = [tokenizer.decode(x[0]) for x in candidates]
24 | special_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id]
25 |
26 | seq_ids, seq_embs = [c[0][:len(c[2])] for c in candidates], [c[2] for c in candidates]
27 | seq_embs, seq_masks, seq_idfs = get_embs_masks(seq_ids, seq_embs, special_ids=special_ids)
28 |
29 | scores = {'P': {}, 'R': {}, 'F': {}}
30 | for i, rst in enumerate(candidates):
31 | P, R, F = check_match(special_ids,
32 | seq_embs, seq_masks, seq_idfs,
33 | rst[0][:len(rst[2])], rst[2])
34 | scores['P'][i], scores['R'][i], scores['F'][i] = P, R, F
35 | max_f = max(x.item() for xid, x in enumerate(F) if xid != i)
36 | if max_f > .9:
37 | import ipdb; ipdb.set_trace()
38 | F_scores = list(dict(sorted(scores['F'].items(), key=lambda x: x[0])).values())
39 |
40 | return F_scores
41 |
42 |
43 | def check_match(keys_embs: torch.Tensor, keys_masks: torch.BoolTensor, keys_idfs: torch.Tensor,
44 | query: list[int], query_emb: torch.Tensor,
45 | special_ids: list = [1, 2, 32000]):
46 | query_embs = torch.stack([query_emb for _ in keys_embs], dim=0)
47 | query_masks = torch.ones(query_embs.size()[:-1]).bool().to(query_embs.device)
48 | query_idfs = torch.stack([
49 | torch.tensor([float(_id not in special_ids) for _id in query[:query_emb.size(0)]])
50 | for _ in keys_idfs], dim=0).to(query_embs.device)
51 |
52 | P, R, F = greedy_cos_idf(keys_embs.float(), keys_masks, keys_idfs,
53 | query_embs.float(), query_masks, query_idfs)
54 | return P, R, F
--------------------------------------------------------------------------------
/mcts_rl/algorithms/mcts/mcts/world_model.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/maitrix-org/llm-reasoners/blob/main/examples/RAP/gsm8k/world_model.py
2 |
3 | from typing import NamedTuple, TypedDict
4 |
5 | import torch
6 | from transformers import GenerationConfig, PreTrainedTokenizerBase
7 |
8 | from mcts_rl.algorithms.mcts.mcts.base import WorldModel
9 |
10 |
11 | class StepSubResult(NamedTuple):
12 | next_step_ids: torch.LongTensor
13 | log_probs: torch.Tensor
14 |
15 |
16 | StepLMState = list[StepSubResult]
17 | StepLMAction = torch.LongTensor
18 |
19 |
20 | class WorldModelArgs(NamedTuple):
21 | base_tokenizer: PreTrainedTokenizerBase
22 | generation_config: GenerationConfig
23 | stop_tokens: list[str] = []
24 |
25 |
26 | class LMExample(TypedDict):
27 | input_ids: torch.LongTensor # (L,)
28 | attention_mask: torch.BoolTensor # (L,)
29 |
30 |
31 | class StepLMWorldModel(WorldModel[StepLMState, StepLMAction, LMExample]):
32 | def __init__(self,
33 | max_length: int,
34 | base_tokenizer: PreTrainedTokenizerBase,
35 | generation_config: GenerationConfig,
36 | stop_tokens=[]) -> None:
37 | super().__init__()
38 | self.base_tokenizer = base_tokenizer
39 | self.generation_config = generation_config
40 | self.max_tokens_num = max_length
41 | self.stop_tokens = list(set(
42 | stop_tokens + [self.base_tokenizer.decode([self.generation_config.eos_token_id])]
43 | ))
44 |
45 | def init_state(self) -> list:
46 | return []
47 |
48 | def step(self, state: StepLMState, action: StepLMAction, log_probs: torch.Tensor) -> StepLMState:
49 | state = state.copy()
50 | state.append(StepSubResult(action, log_probs))
51 | return state
52 |
53 | def is_terminal(self, state: StepLMState) -> bool:
54 | input_length = self.example['attention_mask'].nonzero()[-1].item() + 1
55 | sum_tokens_num = sum(x.next_step_ids.size(0) for x in state) + input_length
56 |
57 | if sum_tokens_num >= self.max_tokens_num - 5:
58 | return True
59 | elif state[-1].next_step_ids.eq(self.base_tokenizer.eos_token_id).sum():
60 | return True
61 | elif state[-1].next_step_ids.eq(self.base_tokenizer.convert_tokens_to_ids("<|eot_id|>")).sum():
62 | return True
63 | elif self.base_tokenizer.decode(state[-1].next_step_ids).count('QUESTION: '):
64 | return True
65 | else:
66 | return False
67 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """RLHF with PPO algorithm."""
16 |
17 | from mcts_rl.algorithms.ppo.trainer import PPOTrainer
18 |
19 |
20 | __all__ = ['PPOTrainer']
21 |
--------------------------------------------------------------------------------
/mcts_rl/algorithms/ppo/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The main training script to train RLHF using PPO algorithm."""
16 |
17 | import sys
18 |
19 | from mcts_rl.algorithms.ppo.main import main
20 |
21 |
22 | if __name__ == '__main__':
23 | sys.exit(main())
24 |
--------------------------------------------------------------------------------
/mcts_rl/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Configurations and constants."""
16 |
17 | from mcts_rl.configs import constants
18 | from mcts_rl.configs.constants import * # noqa: F403
19 | from mcts_rl.configs.deepspeed_config import (
20 | TEMPLATE_DIR,
21 | get_deepspeed_eval_config,
22 | get_deepspeed_train_config,
23 | )
24 |
25 |
26 | __all__ = [
27 | *constants.__all__,
28 | 'TEMPLATE_DIR',
29 | 'get_deepspeed_eval_config',
30 | 'get_deepspeed_train_config',
31 | ]
32 |
--------------------------------------------------------------------------------
/mcts_rl/configs/deepspeed_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """DeepSpeed configuration for training and evaluation."""
16 |
17 | from __future__ import annotations
18 |
19 | import json
20 | import pathlib
21 | from typing import Any, Literal
22 |
23 | import torch.distributed as dist
24 |
25 |
26 | __all__ = ['TEMPLATE_DIR', 'get_deepspeed_train_config', 'get_deepspeed_eval_config']
27 |
28 |
29 | TEMPLATE_DIR = pathlib.Path(__file__).absolute().parent
30 | TRAIN_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_train_config_template.json'
31 | EVAL_TEMPLATE_FILE = TEMPLATE_DIR / 'ds_eval_config_template.json'
32 |
33 |
34 | def get_deepspeed_train_config(
35 | *,
36 | micro_batch_size_per_gpu: int = 16,
37 | gradient_accumulation_steps: int = 1,
38 | stage: int = 3,
39 | offload: Literal['none', 'parameter', 'optimizer', 'all'] = 'none',
40 | enable_hybrid_engine: bool = False,
41 | max_length: int = 512,
42 | fp16: bool = False,
43 | bf16: bool = False,
44 | ) -> dict[str, Any]:
45 | """Get the DeepSpeed config for training.
46 |
47 | Args:
48 | micro_batch_size_per_gpu (int, optional): The micro batch size per GPU. Defaults to 16.
49 | gradient_accumulation_steps (int, optional): The number of gradient accumulation steps.
50 | Defaults to 1.
51 | stage (int, optional): The stage of ZeRO. Defaults to 3.
52 | offload (Literal['none', 'parameter', 'optimizer', 'all'], optional): The offload mode.
53 | enable_hybrid_engine (bool, optional): Whether to enable the DeepSpeed hybrid engine.
54 | Defaults to False.
55 | max_length (int, optional): The maximum length of the input sequence. Defaults to 512.
56 | fp16 (bool, optional): Whether to use FP16 precision. Defaults to False.
57 | bf16 (bool, optional): Whether to use BF16 precision. Defaults to False.
58 |
59 | Returns:
60 | The DeepSpeed config for training.
61 | """
62 | assert offload in {'none', 'parameter', 'optimizer', 'all'}
63 |
64 | with TRAIN_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f:
65 | train_config = json.load(f)
66 |
67 | word_size = dist.get_world_size() if dist.is_initialized() else 1
68 | train_batch_size = micro_batch_size_per_gpu * word_size * gradient_accumulation_steps
69 |
70 | train_config['train_batch_size'] = train_batch_size
71 | train_config['train_micro_batch_size_per_gpu'] = micro_batch_size_per_gpu
72 | train_config['gradient_accumulation_steps'] = gradient_accumulation_steps
73 | train_config['zero_optimization']['stage'] = stage
74 | if offload in {'parameter', 'all'}:
75 | train_config['zero_optimization'].setdefault('offload_param', {})
76 | train_config['zero_optimization']['offload_param']['device'] = 'cpu'
77 | if offload in {'optimizer', 'all'}:
78 | train_config['zero_optimization'].setdefault('offload_optimizer', {})
79 | train_config['zero_optimization']['offload_optimizer']['device'] = 'cpu'
80 | train_config['hybrid_engine']['enabled'] = enable_hybrid_engine
81 | train_config['hybrid_engine']['max_out_tokens'] = max_length
82 | if fp16 or 'fp16' in train_config:
83 | train_config.setdefault('fp16', {})
84 | train_config['fp16']['enabled'] = fp16
85 | if bf16 or 'bf16' in train_config:
86 | train_config.setdefault('bf16', {})
87 | train_config['bf16']['enabled'] = bf16
88 | return train_config
89 |
90 |
91 | def get_deepspeed_eval_config(
92 | *,
93 | stage: int = 3,
94 | offload: Literal['none', 'parameter', 'optimizer', 'all'] = 'none',
95 | fp16: bool = False,
96 | bf16: bool = False,
97 | ) -> dict[str, Any]:
98 | """Get the DeepSpeed config for evaluation.
99 |
100 | Args:
101 | stage (int, optional): The stage of ZeRO. Defaults to 3.
102 | offload (Literal['none', 'parameter', 'optimizer', 'all'], optional): The offload mode.
103 | fp16 (bool, optional): Whether to use FP16 precision. Defaults to False.
104 | bf16 (bool, optional): Whether to use BF16 precision. Defaults to False.
105 |
106 | Returns:
107 | The DeepSpeed config for evaluation.
108 | """
109 | assert offload in {'none', 'parameter', 'optimizer', 'all'}
110 |
111 | with EVAL_TEMPLATE_FILE.open(mode='rt', encoding='utf-8') as f:
112 | eval_config = json.load(f)
113 |
114 | if stage in {1, 2}:
115 | # The evaluation config only works for ZeRO stage 0 and ZeRO stage 3
116 | stage = 0
117 |
118 | eval_config['train_batch_size'] = None
119 | eval_config['train_micro_batch_size_per_gpu'] = 1
120 | eval_config['gradient_accumulation_steps'] = 1
121 | eval_config['zero_optimization']['stage'] = stage
122 | if offload in {'parameter', 'all'}:
123 | eval_config['zero_optimization'].setdefault('offload_param', {})
124 | eval_config['zero_optimization']['offload_param']['device'] = 'cpu'
125 | if fp16 or 'fp16' in eval_config:
126 | eval_config.setdefault('fp16', {})
127 | eval_config['fp16']['enabled'] = fp16
128 | if bf16 or 'bf16' in eval_config:
129 | eval_config.setdefault('bf16', {})
130 | eval_config['bf16']['enabled'] = bf16
131 | return eval_config
132 |
--------------------------------------------------------------------------------
/mcts_rl/configs/ds_eval_config_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": null,
3 | "train_micro_batch_size_per_gpu": 1,
4 | "gradient_accumulation_steps": 1,
5 | "steps_per_print": 10,
6 | "zero_optimization": {
7 | "stage": 3,
8 | "offload_param": {
9 | "device": "none"
10 | },
11 | "param_persistence_threshold": 1e4,
12 | "memory_efficient_linear": false
13 | },
14 | "gradient_clipping": 1.0,
15 | "prescale_gradients": false,
16 | "wall_clock_breakdown": false
17 | }
18 |
--------------------------------------------------------------------------------
/mcts_rl/configs/ds_train_config_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": 128,
3 | "train_micro_batch_size_per_gpu": 16,
4 | "gradient_accumulation_steps": "auto",
5 | "steps_per_print": 10,
6 | "zero_optimization": {
7 | "stage": 2,
8 | "offload_param": {
9 | "device": "none"
10 | },
11 | "offload_optimizer": {
12 | "device": "none"
13 | },
14 | "param_persistence_threshold": 1e4,
15 | "max_live_parameters": 3e7,
16 | "prefetch_bucket_size": 3e7,
17 | "memory_efficient_linear": false,
18 | "gather_16bit_weights_on_model_save": true,
19 | "overlap_comm": true,
20 | "contiguous_gradients": true,
21 | "sub_group_size": 1e9,
22 | "reduce_bucket_size": "auto"
23 | },
24 | "gradient_clipping": 1.0,
25 | "prescale_gradients": false,
26 | "wall_clock_breakdown": false,
27 | "hybrid_engine": {
28 | "enabled": true,
29 | "max_out_tokens": 512,
30 | "inference_tp_size": 1,
31 | "release_inference_cache": false,
32 | "pin_parameters": true,
33 | "tp_gather_partition_size": 8
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/mcts_rl/configs/fsdp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer"
3 | }
4 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Dataset classes."""
16 |
17 | from __future__ import annotations
18 |
19 | from typing import Dict
20 |
21 | import torch
22 | from torch.utils.data import Dataset
23 |
24 | from mcts_rl.datasets import raw
25 | from mcts_rl.datasets.base import (
26 | CollatorBase,
27 | RawDataset,
28 | RawSample,
29 | TokenizedDataset,
30 | parse_dataset,
31 | )
32 | from mcts_rl.datasets.preference import (
33 | PreferenceBatch,
34 | PreferenceCollator,
35 | PreferenceDataset,
36 | PreferenceSample,
37 | )
38 | from mcts_rl.datasets.prompt_only import (
39 | PromptOnlyBatch,
40 | PromptOnlyCollator,
41 | PromptOnlyDataset,
42 | PromptOnlySample,
43 | PromptOnlyPostDataset,
44 | PromptOnlyPostCollator,
45 | PromptOnlyPostSample,
46 | PromptOnlyPostBatch,
47 | )
48 | from mcts_rl.datasets.raw import * # noqa: F403
49 | from mcts_rl.datasets.safety_preference import (
50 | SafetyPreferenceBatch,
51 | SafetyPreferenceCollator,
52 | SafetyPreferenceDataset,
53 | SafetyPreferenceSample,
54 | )
55 | from mcts_rl.datasets.supervised import (
56 | SupervisedBatch,
57 | SupervisedCollator,
58 | SupervisedDataset,
59 | SupervisedSample,
60 | )
61 |
62 |
63 | __all__ = [
64 | 'DummyDataset',
65 | 'parse_dataset',
66 | 'RawDataset',
67 | 'RawSample',
68 | 'TokenizedDataset',
69 | 'CollatorBase',
70 | 'PreferenceDataset',
71 | 'PreferenceSample',
72 | 'PreferenceBatch',
73 | 'PreferenceCollator',
74 | 'PromptOnlyDataset',
75 | 'PromptOnlyCollator',
76 | 'PromptOnlySample',
77 | 'PromptOnlyBatch',
78 | 'PromptOnlyPostDataset',
79 | 'PromptOnlyPostCollator',
80 | 'PromptOnlyPostSample',
81 | 'PromptOnlyPostBatch',
82 | 'SafetyPreferenceDataset',
83 | 'SafetyPreferenceCollator',
84 | 'SafetyPreferenceSample',
85 | 'SafetyPreferenceBatch',
86 | 'SupervisedDataset',
87 | 'SupervisedCollator',
88 | 'SupervisedSample',
89 | 'SupervisedBatch',
90 | *raw.__all__,
91 | ]
92 |
93 |
94 | class DummyDataset(Dataset[Dict[str, torch.Tensor]]):
95 | def __init__(self, length: int) -> None:
96 | self.length = length
97 |
98 | def __len__(self) -> int:
99 | return self.length
100 |
101 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
102 | return {}
103 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/preference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Dataset class for preference training."""
16 |
17 | from __future__ import annotations
18 |
19 | from typing import Callable
20 | from typing_extensions import TypedDict # Python 3.10+
21 |
22 | import torch
23 |
24 | from mcts_rl.datasets.base import CollatorBase, RawSample, TokenizedDataset
25 | from mcts_rl.datasets.utils import format_prompt, right_padding
26 |
27 |
28 | __all__ = [
29 | 'PreferenceDataset',
30 | 'PreferenceCollator',
31 | 'PreferenceSample',
32 | 'PreferenceBatch',
33 | ]
34 |
35 |
36 | class PreferenceSample(TypedDict, total=True):
37 | better_input_ids: torch.LongTensor # size = (L,)
38 | worse_input_ids: torch.LongTensor # size = (L,)
39 |
40 |
41 | class PreferenceBatch(TypedDict, total=True):
42 | better_input_ids: torch.LongTensor # size = (B, L)
43 | better_attention_mask: torch.BoolTensor # size = (B, L)
44 |
45 | worse_input_ids: torch.LongTensor # size = (B, L)
46 | worse_attention_mask: torch.BoolTensor # size = (B, L)
47 |
48 |
49 | class PreferenceDataset(TokenizedDataset):
50 | def preprocess(self, raw_sample: RawSample) -> PreferenceSample:
51 | prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token, use_mcq=self.use_mcq)
52 | better_answer = raw_sample['answer']
53 | worse_answer = raw_sample['other_answer']
54 | better = raw_sample['better']
55 | if not better:
56 | better_answer, worse_answer = worse_answer, better_answer
57 |
58 | better_input_ids = self.tokenize(prompt + better_answer + self.tokenizer.eos_token)
59 | worse_input_ids = self.tokenize(prompt + worse_answer + self.tokenizer.eos_token)
60 | if (
61 | better_input_ids.size() == worse_input_ids.size()
62 | and torch.all(torch.eq(better_input_ids, worse_input_ids)).item()
63 | ):
64 | raise ValueError(
65 | 'Two responses get the same `input_ids` after tokenization.\n\n'
66 | f'Prompt: {prompt}\n\n'
67 | f'Better answer: {better_answer}\n\n'
68 | f'Worse answer: {worse_answer}',
69 | )
70 | return {
71 | 'better_input_ids': better_input_ids, # size = (L,)
72 | 'worse_input_ids': worse_input_ids, # size = (L,)
73 | }
74 |
75 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
76 | return PreferenceCollator(self.tokenizer.pad_token_id)
77 |
78 |
79 | class PreferenceCollator(CollatorBase):
80 | def __call__(self, samples: list[PreferenceSample]) -> PreferenceBatch:
81 | input_ids = [sample['better_input_ids'] for sample in samples] + [
82 | sample['worse_input_ids'] for sample in samples
83 | ] # size = (2 * B, L)
84 | attention_mask = [
85 | input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids
86 | ] # size = (2 * B, L)
87 |
88 | input_ids = right_padding(input_ids, padding_value=self.pad_token_id) # size = (2 * B, L)
89 | attention_mask = right_padding(attention_mask, padding_value=0) # size = (2 * B, L)
90 |
91 | (
92 | better_input_ids, # size = (B, L)
93 | worse_input_ids, # size = (B, L)
94 | ) = input_ids.chunk(chunks=2, dim=0)
95 | (
96 | better_attention_mask, # size = (B, L)
97 | worse_attention_mask, # size = (B, L)
98 | ) = attention_mask.chunk(chunks=2, dim=0)
99 |
100 | return {
101 | 'better_input_ids': better_input_ids, # size = (B, L)
102 | 'better_attention_mask': better_attention_mask, # size = (B, L)
103 | 'worse_input_ids': worse_input_ids, # size = (B, L)
104 | 'worse_attention_mask': worse_attention_mask, # size = (B, L)
105 | }
106 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/prompt_only.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Callable, Hashable
19 | from typing_extensions import TypedDict # Python 3.10+
20 |
21 | import torch
22 | from torch.utils.data import Dataset, Subset
23 |
24 | from mcts_rl.datasets.base import CollatorBase, RawSample, RawSamplePost, TokenizedDataset
25 | from mcts_rl.datasets.utils import format_prompt, left_padding
26 |
27 |
28 | __all__ = [
29 | 'PromptOnlyDataset',
30 | 'PromptOnlyCollator',
31 | 'PromptOnlySample',
32 | 'PromptOnlyBatch',
33 | 'PromptOnlyPostDataset',
34 | 'PromptOnlyPostCollator',
35 | 'PromptOnlyPostSample',
36 | 'PromptOnlyPostBatch',
37 | ]
38 |
39 |
40 | class PromptOnlySample(TypedDict, total=True):
41 | input_ids: torch.LongTensor # size = (L,)
42 |
43 |
44 | class PromptOnlyBatch(TypedDict, total=True):
45 | input_ids: torch.LongTensor # size = (B, L)
46 | attention_mask: torch.BoolTensor # size = (B, L)
47 |
48 |
49 | class PromptOnlyDataset(TokenizedDataset):
50 | def preprocess(self, raw_sample: RawSample) -> PromptOnlySample:
51 | try:
52 | prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token,
53 | use_mcq=self.use_mcq, few_shot=self.few_shot, model_type=self.model_type)
54 | except:
55 | import ipdb; ipdb.set_trace()
56 | input_ids = self.tokenize(prompt)
57 | return {
58 | 'input_ids': input_ids, # size = (L,)
59 | 'answer': raw_sample.get('final_answer', ''), # str
60 | 'reasoning': raw_sample.get('answer', ''),
61 | 'answer_content': raw_sample.get('final_answer_content', raw_sample['final_answer'] if 'final_answer' in raw_sample else '')
62 | }
63 |
64 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
65 | return PromptOnlyCollator(self.tokenizer.pad_token_id)
66 |
67 | def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSample]:
68 | """Merge multiple raw datasets into one dataset and remove duplicates."""
69 |
70 | def to_hashable(raw_sample: RawSample) -> Hashable:
71 | input = raw_sample['input'] # pylint: disable=redefined-builtin
72 | return input if isinstance(input, str) else tuple(input)
73 |
74 | merged = super()._merge_raw_datasets(seed)
75 | inputs = {to_hashable(merged[i]): i for i in range(len(merged)) if isinstance(merged[i]['input'], str) or len(merged[i]['input']) == 1}
76 | return Subset(merged, sorted(inputs.values()))
77 |
78 |
79 | class PromptOnlyCollator(CollatorBase):
80 | def __call__(self, samples: list[PromptOnlySample]) -> PromptOnlyBatch:
81 | input_ids = [sample['input_ids'] for sample in samples]
82 | attention_mask = [
83 | input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids
84 | ]
85 |
86 | input_ids = left_padding(input_ids, padding_value=self.pad_token_id)
87 | attention_mask = left_padding(attention_mask, padding_value=0)
88 | return {
89 | 'input_ids': input_ids, # size = (B, L)
90 | 'attention_mask': attention_mask, # size = (B, L)
91 | 'answer': [sample['answer'] for sample in samples],
92 | 'reasoning': [sample['reasoning'] for sample in samples],
93 | 'answer_content': [sample['answer_content'] for sample in samples],
94 | }
95 |
96 |
97 | class PromptOnlyPostSample(TypedDict, total=True):
98 | prompts_list: list[torch.LongTensor]
99 | input_ids_list: list[torch.LongTensor]
100 | answer: str
101 | answer_content: str
102 |
103 |
104 | class PromptOnlyPostBatch(TypedDict, total=True):
105 | prompts_list: list[torch.LongTensor]
106 | input_ids_list: list[torch.LongTensor]
107 | answer: list[str]
108 | answer_content: list[str]
109 |
110 |
111 | class PromptOnlyPostDataset(TokenizedDataset):
112 | def preprocess(self, raw_sample: RawSamplePost) -> PromptOnlyPostSample:
113 | return {
114 | 'prompts_list': [raw_sample['prompt']],
115 | 'input_ids_list': raw_sample['input_ids_list'],
116 | 'answer': raw_sample['final_answer'], # str
117 | 'answer_content': raw_sample.get('final_answer_content', raw_sample['final_answer'])
118 | }
119 |
120 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
121 | return PromptOnlyPostCollator(self.tokenizer.pad_token_id)
122 |
123 | def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSamplePost]:
124 | """Merge multiple raw datasets into one dataset and remove duplicates."""
125 |
126 | def to_hashable(raw_sample: RawSamplePost) -> Hashable:
127 | input = raw_sample['prompt'] # pylint: disable=redefined-builtin
128 | return input if isinstance(input, str) else tuple(input.tolist())
129 |
130 | merged = super()._merge_raw_datasets(seed)
131 | inputs = {to_hashable(merged[i]): i for i in range(len(merged))}
132 | return Subset(merged, sorted(inputs.values()))
133 |
134 |
135 | class PromptOnlyPostCollator(CollatorBase):
136 | def __call__(self, samples: list[PromptOnlyPostSample]) -> PromptOnlyPostBatch:
137 | prompts_list = [sample['prompts_list'] for sample in samples]
138 | input_ids_list = [sample['input_ids_list'] for sample in samples]
139 | attention_mask_list = [[
140 | input_ids.not_equal(self.pad_token_id) for input_ids in sample['input_ids_list']
141 | ] for sample in samples]
142 |
143 | return {
144 | 'prompts_list': prompts_list,
145 | 'input_ids_list': input_ids_list,
146 | 'attention_mask_list': attention_mask_list,
147 | 'answer': [sample['answer'] for sample in samples],
148 | 'answer_content': [sample['answer_content'] for sample in samples],
149 | }
150 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/__init__.py:
--------------------------------------------------------------------------------
1 | """Raw datasets."""
2 |
3 | from mcts_rl.datasets.raw.alpaca import AlpacaDataset
4 | from mcts_rl.datasets.raw.firefly import FireflyDataset
5 | from mcts_rl.datasets.raw.hh_rlhf import (
6 | HhRLHFDialogueDataset,
7 | HhRLHFHarmlessDialogueDataset,
8 | HhRLHFHelpfulDialogueDataset,
9 | )
10 | from mcts_rl.datasets.raw.moss import MOSS002SFT, MOSS003SFT
11 | from mcts_rl.datasets.raw.safe_rlhf import (
12 | SafeRLHF10KTrainDataset,
13 | SafeRLHFDataset,
14 | SafeRLHFTestDataset,
15 | SafeRLHFTrainDataset,
16 | )
17 | from mcts_rl.datasets.raw.prm800k import (
18 | PRM800KDataset,
19 | PRM800KTestDataset,
20 | PRM800KTrainDataset,
21 | )
22 | from mcts_rl.datasets.raw.mcq import (
23 | MCQDataset,
24 | SQATestDataset,
25 | SQATrainDataset,
26 | CSRTestDataset,
27 | CSRTrainDataset,
28 | SciQTestDataset,
29 | NLITestDataset,
30 | MCQTestDataset,
31 | MCQTrainDataset,
32 | )
33 | from mcts_rl.datasets.raw.math import (
34 | MATHDataset,
35 | MATHTestDataset,
36 | MATHTrainDataset,
37 | MATHSFTTrainDataset,
38 | MATHSFTTestDataset,
39 | )
40 | from mcts_rl.datasets.raw.gsm8k import (
41 | GSM8KDataset,
42 | GSM8KTestDataset,
43 | GSM8KTrainDataset,
44 | GSM8KPoTTestDataset,
45 | GSM8KPoTTrainDataset,
46 | GSM8KSFTTrainDataset,
47 | GSM8KSFTTestDataset,
48 | )
49 | from mcts_rl.datasets.raw.arithmo import (
50 | ArithmoDataset,
51 | ArithmoTestDataset,
52 | ArithmoTrainDataset,
53 | ArithmoMATHTrainDataset,
54 | ArithmoMCQTrainDataset,
55 | ArithmoCodeTrainDataset,
56 | )
57 | from mcts_rl.datasets.raw.qa_feedback import (
58 | QAFBDataset,
59 | QAFBTestDataset,
60 | QAFBTrainDataset,
61 | )
62 | from mcts_rl.datasets.raw.mcq_pairs import (
63 | MCQPreferenceDataset,
64 | SQAPreferenceTestDataset,
65 | SQAPreferenceTrainDataset,
66 | CSRPreferenceTestDataset,
67 | CSRPreferenceTrainDataset,
68 | GSMPreferenceTrainDataset,
69 | GSMPreferenceTestDataset,
70 | )
71 | from mcts_rl.datasets.raw.mcq_for_eval import (
72 | MCQEvalDataset,
73 | SQAEvalTestDataset,
74 | SQAEvalTrainDataset,
75 | CSREvalTestDataset,
76 | CSREvalTrainDataset,
77 | GSMEvalTestDataset,
78 | GSMEvalTrainDataset,
79 | )
80 | from mcts_rl.datasets.raw.math_qa import (
81 | MathQADataset,
82 | MathQATestDataset,
83 | MathQATrainDataset,
84 | MathQACodeTestDataset,
85 | MathQACodeTrainDataset,
86 | MathQAAllTrainDataset,
87 | MathQAAllTestDataset,
88 | MathQASFTTrainDataset,
89 | )
90 | from mcts_rl.datasets.raw.aqua import (
91 | AQuADataset,
92 | AQuAPoTTestDataset,
93 | AQuATestDataset,
94 | )
95 | from mcts_rl.datasets.raw.exam import (
96 | ExamTestDataset,
97 | ExamDataset,
98 | )
99 |
100 |
101 | __all__ = [
102 | 'AlpacaDataset',
103 | 'FireflyDataset',
104 | 'HhRLHFDialogueDataset',
105 | 'HhRLHFHarmlessDialogueDataset',
106 | 'HhRLHFHelpfulDialogueDataset',
107 | 'MOSS002SFT',
108 | 'MOSS003SFT',
109 | 'SafeRLHFDataset',
110 | 'SafeRLHFTrainDataset',
111 | 'SafeRLHFTestDataset',
112 | 'SafeRLHF10KTrainDataset',
113 | 'PRM800KDataset',
114 | 'PRM800KTrainDataset',
115 | 'PRM800KTestDataset',
116 | 'MCQDataset',
117 | 'SQATestDataset',
118 | 'SQATrainDataset',
119 | 'CSRTestDataset',
120 | 'CSRTrainDataset',
121 | 'SciQTestDataset',
122 | 'NLITestDataset',
123 | 'MATHDataset',
124 | 'MATHTrainDataset',
125 | 'MATHTestDataset',
126 | 'MathQAAllTrainDataset',
127 | 'GSM8KDataset',
128 | 'GSM8KTestDataset',
129 | 'GSM8KTrainDataset',
130 | 'GSM8KPoTTestDataset',
131 | 'GSM8KPoTTrainDataset',
132 | 'ArithmoDataset',
133 | 'ArithmoTestDataset',
134 | 'ArithmoTrainDataset',
135 | 'ArithmoMATHTrainDataset',
136 | 'ArithmoMCQTrainDataset',
137 | 'ArithmoCodeTrainDataset',
138 | 'QAFBDataset',
139 | 'QAFBTestDataset',
140 | 'QAFBTrainDataset',
141 | 'MCQPreferenceDataset',
142 | 'SQAPreferenceTestDataset',
143 | 'SQAPreferenceTrainDataset',
144 | 'CSRPreferenceTestDataset',
145 | 'CSRPreferenceTrainDataset',
146 | 'MCQTestDataset',
147 | 'MCQTrainDataset',
148 | 'GSMPreferenceTrainDataset',
149 | 'GSMPreferenceTestDataset',
150 | 'MCQEvalDataset',
151 | 'SQAEvalTestDataset',
152 | 'SQAEvalTrainDataset',
153 | 'CSREvalTestDataset',
154 | 'CSREvalTrainDataset',
155 | 'GSMEvalTestDataset',
156 | 'GSMEvalTrainDataset',
157 | 'MathQADataset',
158 | 'MathQATestDataset',
159 | 'MathQATrainDataset',
160 | 'MathQAAllTestDataset',
161 | 'MathQACodeTestDataset',
162 | 'MathQACodeTrainDataset',
163 | 'AQuADataset',
164 | 'AQuAPoTTestDataset',
165 | 'AQuATestDataset',
166 | 'ExamTestDataset',
167 | 'ExamDataset',
168 | 'MathQASFTTrainDataset',
169 | 'MATHSFTTrainDataset',
170 | 'MATHSFTTestDataset',
171 | 'GSM8KSFTTrainDataset',
172 | 'GSM8KSFTTestDataset',
173 | ]
174 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/alpaca.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Stanford Alpaca dataset for supervised instruction fine-tuning."""
16 |
17 | from __future__ import annotations
18 |
19 | from datasets import load_dataset
20 | from mcts_rl.datasets.base import RawDataset, RawSample
21 |
22 |
23 | __all__ = ['AlpacaDataset']
24 |
25 |
26 | class AlpacaDataset(RawDataset):
27 | NAME: str = 'alpaca'
28 | ALIASES: tuple[str, ...] = ('stanford-alpaca',)
29 |
30 | def __init__(self, path: str | None = None) -> None:
31 | alpaca = load_dataset(path or 'tatsu-lab/alpaca', split='train')
32 | self.data = []
33 | # for data in load_dataset('McGill-NLP/feedbackQA', split='train'):
34 | # question = data['question']
35 | # answer = data['answer'].replace('\n', ' ')
36 | # comment = '\n'.join([f'{r}: {e}' for r, e in zip(data['feedback']['rating'], data['feedback']['explanation'])])
37 | # if ('Excellent' in comment or 'Acceptable' in comment) and 'Bad' not in comment:
38 | # self.data.append({'instruction': question, 'input': '', 'output': answer})
39 | # self.data += list(alpaca)[:len(self.data) * 7]
40 | safe_data = {}
41 | for data in load_dataset('PKU-Alignment/PKU-SafeRLHF', split='train'):
42 | question = data['prompt']
43 | idx = data['better_response_id']
44 | if data[f'is_response_{idx}_safe']:
45 | answer = data[f'response_{idx}']
46 | if question not in safe_data or data['safer_response_id'] == idx:
47 | safe_data[question] = {'instruction': question, 'input': '', 'output': answer}
48 | self.data = list(safe_data.values()) + list(alpaca)[:len(self.data) * 7]
49 |
50 | def __getitem__(self, index: int) -> RawSample:
51 | data = self.data[index]
52 | input = ( # pylint: disable=redefined-builtin
53 | ' '.join((data['instruction'], data['input'])) if data['input'] else data['instruction']
54 | )
55 | answer = data['output']
56 | return RawSample(input=input, answer=answer)
57 |
58 | def __len__(self) -> int:
59 | return len(self.data)
60 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/aqua.py:
--------------------------------------------------------------------------------
1 | """MATH datasets."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | from typing import ClassVar
7 |
8 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
9 |
10 |
11 | __all__ = [
12 | 'AQuADataset',
13 | 'AQuATestDataset',
14 | 'AQuAPoTTestDataset',
15 | ]
16 |
17 | DATA_DIR = "path_to_dataset_folder"
18 |
19 |
20 | class AQuADataset(RawDataset):
21 | SPLIT: ClassVar[str]
22 | PTYPE: ClassVar[str]
23 |
24 | def __init__(self) -> None:
25 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'auqa/aqua_{self.SPLIT}.jsonl'))
26 |
27 | def __getitem__(self, index: int) -> RawSample:
28 | data = self.data[index]
29 | question = data['question'] + '\nAnswer Choices: (' + ' ('.join(data['options'])
30 | prompt = question + '\nWrite a Python program to solve this.' if self.PTYPE == 'pot' else question
31 | return RawSample(
32 | input=prompt,
33 | answer=data['rationale'],
34 | final_answer=data.get('correct', None),
35 | final_answer_content=data.get('correct', None),
36 | )
37 |
38 | def __len__(self) -> int:
39 | return len(self.data)
40 |
41 |
42 | class AQuATestDataset(AQuADataset):
43 | NAME: str = 'AQuA/test'
44 | SPLIT: str = 'test'
45 | PTYPE: str = 'cot'
46 |
47 |
48 | class AQuAPoTTestDataset(AQuADataset):
49 | NAME: str = 'AQuACode/test'
50 | SPLIT: str = 'test'
51 | PTYPE: str = 'pot'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/arithmo.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | import os
5 | import regex
6 | from typing import ClassVar
7 |
8 | from datasets import load_dataset
9 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
10 |
11 |
12 | __all__ = [
13 | 'ArithmoDataset',
14 | 'ArithmoTrainDataset',
15 | 'ArithmoTestDataset',
16 | 'ArithmoMATHTrainDataset',
17 | 'ArithmoMCQTrainDataset',
18 | 'ArithmoCodeTrainDataset',
19 | ]
20 |
21 | DATA_DIR = "path_to_dataset_folder"
22 |
23 | class ArithmoDataset(RawDataset):
24 | SPLIT: ClassVar[str]
25 | PATH: ClassVar[str]
26 | TYPE: ClassVar[str]
27 |
28 | def __init__(self, path: str | None = None) -> None:
29 | try:
30 | self.data = load_dataset(path or self.PATH, split=self.SPLIT)
31 | except:
32 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'arithmo/{self.SPLIT}.jsonl'))
33 | if self.TYPE == 'math':
34 | self.data = [dt for dt in self.data if ' answer is' in dt['answer']]
35 | elif self.TYPE == 'mcq':
36 | self.data = [
37 | dt for dt in self.data if ' answer is' in dt['answer'] and not dt['answer'].startswith('The answer is') \
38 | and 'answer choices' in dt['question'].lower()
39 | ]
40 | elif self.TYPE == 'code':
41 | self.data = [dt for dt in self.data if regex.search(r'print\(.+\)', dt['answer'])]
42 |
43 | def __getitem__(self, index: int) -> RawSample:
44 | data = self.data[index]
45 | return RawSample(
46 | input=data['question'],
47 | answer=data['answer'],
48 | )
49 |
50 | def __len__(self) -> int:
51 | return len(self.data)
52 |
53 |
54 | class ArithmoTrainDataset(ArithmoDataset):
55 | NAME: str = 'Arithmo/train'
56 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train',)
57 | PATH: str = 'akjindal53244/Arithmo-Data'
58 | SPLIT: str = 'train'
59 | TYPE: str = 'all'
60 |
61 |
62 | class ArithmoTestDataset(ArithmoDataset):
63 | NAME: str = 'Arithmo/test'
64 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/test',)
65 | PATH: str = 'akjindal53244/Arithmo-Data'
66 | SPLIT: str = 'test'
67 | TYPE: str = 'all'
68 |
69 |
70 | class ArithmoMATHTrainDataset(ArithmoDataset):
71 | NAME: str = 'ArithmoMATH/train'
72 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train/mathqa',)
73 | PATH: str = 'akjindal53244/Arithmo-Data'
74 | SPLIT: str = 'train'
75 | TYPE: str = 'math'
76 |
77 |
78 | class ArithmoMCQTrainDataset(ArithmoDataset):
79 | NAME: str = 'ArithmoMCQ/train'
80 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train/mcq',)
81 | PATH: str = 'akjindal53244/Arithmo-Data'
82 | SPLIT: str = 'train'
83 | TYPE: str = 'mcq'
84 |
85 |
86 | class ArithmoCodeTrainDataset(ArithmoDataset):
87 | NAME: str = 'ArithmoCode/train'
88 | ALIASES: tuple[str, ...] = ('akjindal53244/Arithmo-Data/train/code',)
89 | PATH: str = 'akjindal53244/Arithmo-Data'
90 | SPLIT: str = 'train'
91 | TYPE: str = 'code'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/exam.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import ClassVar
5 |
6 | from datasets import load_dataset
7 | from mcts_rl.datasets.base import RawDataset, RawSample
8 |
9 |
10 | __all__ = [
11 | 'ExamDataset',
12 | ]
13 |
14 | class ExamDataset(RawDataset):
15 | SPLIT: ClassVar[str]
16 | PATH: ClassVar[str]
17 |
18 | def __init__(self, path: str | None = None) -> None:
19 | self.data = load_dataset(path or self.PATH, split=self.SPLIT)
20 |
21 | def __getitem__(self, index: int) -> RawSample:
22 | data = self.data[index]
23 | return RawSample(
24 | input=data['Question'],
25 | answer=None,
26 | )
27 |
28 | def __len__(self) -> int:
29 | return len(self.data)
30 |
31 |
32 | class ExamTestDataset(ExamDataset):
33 | NAME: str = 'Exam/test'
34 | PATH: str = 'keirp/hungarian_national_hs_finals_exam'
35 | SPLIT: str = 'test'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/firefly.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Firefly (流萤) dataset for supervised instruction fine-tuning."""
16 |
17 | from __future__ import annotations
18 |
19 | from datasets import load_dataset
20 | from mcts_rl.datasets.base import RawDataset, RawSample
21 |
22 |
23 | __all__ = ['FireflyDataset']
24 |
25 |
26 | class FireflyDataset(RawDataset):
27 | NAME: str = 'firefly'
28 |
29 | def __init__(self, path: str | None = None) -> None:
30 | self.data = load_dataset(path or 'YeungNLP/firefly-train-1.1M', split='train')
31 |
32 | def __getitem__(self, index: int) -> RawSample:
33 | data = self.data[index]
34 | return RawSample(input=data['input'], answer=data['target'])
35 |
36 | def __len__(self) -> int:
37 | return len(self.data)
38 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/gsm8k.py:
--------------------------------------------------------------------------------
1 | """MATH datasets."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | from typing import ClassVar
7 | from datasets import load_dataset
8 |
9 | from mcts_rl.datasets.base import RawDataset, RawSample
10 | from mcts_rl.utils import extract_answer, list_to_dict, get_math_data
11 |
12 |
13 | __all__ = [
14 | 'GSM8KDataset',
15 | 'GSM8KTrainDataset',
16 | 'GSM8KTestDataset',
17 | 'GSM8KPoTTrainDataset',
18 | 'GSM8KPoTTestDataset',
19 | 'GSM8KSFTTrainDataset',
20 | 'GSM8KSFTTestDataset',
21 | ]
22 |
23 |
24 | class GSM8KDataset(RawDataset):
25 | SPLIT: ClassVar[str]
26 | PTYPE: ClassVar[str]
27 | DTYPE: ClassVar[str]
28 |
29 | def __init__(self) -> None:
30 | if self.PTYPE != 'pot':
31 | self.data = load_dataset('openai/gsm8k', 'main', split=self.SPLIT, trust_remote_code=True)
32 | else:
33 | raise ValueError('Do not Support PoT for now.')
34 | if self.DTYPE == 'arithmo':
35 | gsm_dict = list_to_dict(self.data)
36 | arithmo_dict = list_to_dict(get_math_data(load_dataset('akjindal53244/Arithmo-Data', split=self.SPLIT)))
37 | arithmo = {k:v for k, v in arithmo_dict.items() if k in gsm_dict}
38 | self.data = [vv for v in arithmo.values() for vv in v]
39 | # self.data = get_arithmo_data(arithmo)
40 |
41 | def __getitem__(self, index: int) -> RawSample:
42 | data = self.data[index]
43 | prompt = data['problem'] if 'problem' in data else data['question']
44 | prompt = prompt + '\nWrite a Python program to solve this.' if self.PTYPE == 'pot' else prompt
45 | solution = data['solution'] if 'solution' in data else data['answer']
46 | answer = extract_answer(solution)
47 | if self.DTYPE == 'default':
48 | solution = f'{solution}\nThe answer is {answer}'
49 | return RawSample(
50 | input=prompt,
51 | answer=solution,
52 | final_answer=answer,
53 | final_answer_content=answer,
54 | )
55 |
56 | def __len__(self) -> int:
57 | return len(self.data)
58 |
59 |
60 | class GSM8KSFTTrainDataset(GSM8KDataset):
61 | NAME: str = 'GSM8KSFT/train'
62 | SPLIT: str = 'train'
63 | PTYPE: str = 'cot'
64 | DTYPE: str = 'arithmo'
65 |
66 |
67 | class GSM8KSFTTestDataset(GSM8KDataset):
68 | NAME: str = 'GSM8KSFT/test'
69 | SPLIT: str = 'test'
70 | PTYPE: str = 'cot'
71 | DTYPE: str = 'arithmo'
72 |
73 |
74 | class GSM8KTrainDataset(GSM8KDataset):
75 | NAME: str = 'GSM8K/train'
76 | SPLIT: str = 'train'
77 | PTYPE: str = 'cot'
78 | DTYPE: str = 'default'
79 |
80 |
81 | class GSM8KTestDataset(GSM8KDataset):
82 | NAME: str = 'GSM8K/test'
83 | SPLIT: str = 'test'
84 | PTYPE: str = 'cot'
85 | DTYPE: str = 'default'
86 |
87 |
88 | class GSM8KPoTTrainDataset(GSM8KDataset):
89 | NAME: str = 'GSM8KCode/train'
90 | SPLIT: str = 'train'
91 | PTYPE: str = 'pot'
92 | DTYPE: str = 'default'
93 |
94 |
95 | class GSM8KPoTTestDataset(GSM8KDataset):
96 | NAME: str = 'GSM8KCode/test'
97 | SPLIT: str = 'test'
98 | PTYPE: str = 'pot'
99 | DTYPE: str = 'default'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/hh_rlhf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Helpful and Harmless Dialogue Datasets from Anthropic."""
16 |
17 | from __future__ import annotations
18 |
19 | from typing import ClassVar
20 |
21 | from datasets import load_dataset
22 | from mcts_rl.datasets.base import RawDataset, RawSample
23 |
24 |
25 | __all__ = [
26 | 'HhRLHFDialogueDataset',
27 | 'HhRLHFHarmlessDialogueDataset',
28 | 'HhRLHFHelpfulDialogueDataset',
29 | 'HhRLHFPreferenceDataset',
30 | 'HhRLHFHarmlessPreferenceTrainDataset',
31 | 'HhRLHFHarmlessPreferenceTestDataset',
32 | 'HhRLHFHelpfulPreferenceTrainDataset',
33 | 'HhRLHFHelpfulPreferenceTestDataset',
34 | ]
35 |
36 |
37 | class HhRLHFDialogueDataset(RawDataset):
38 | NAME: ClassVar[str] = 'hh-rlhf-dialogue'
39 | ALIASES: tuple[str, ...] = ('hh-dialogue',)
40 | DATA_DIR: ClassVar[str | None] = None
41 |
42 | def __init__(self, path: str | None = None) -> None:
43 | self.data = load_dataset(
44 | path or 'PKU-Alignment/processed-hh-rlhf',
45 | data_dir=self.DATA_DIR,
46 | split='train',
47 | )
48 |
49 | def __getitem__(self, index: int) -> RawSample:
50 | data = self.data[index]
51 | dialogue = [content['text'] for content in data['context']]
52 | dialogue.append(data['chosen']['text'])
53 | return RawSample(dialogue=dialogue)
54 |
55 | def __len__(self) -> int:
56 | return len(self.data)
57 |
58 |
59 | class HhRLHFHarmlessDialogueDataset(HhRLHFDialogueDataset):
60 | NAME: str = 'hh-rlhf-harmless-dialogue'
61 | ALIASES: tuple[str, ...] = (
62 | 'hh-rlhf-dialogue/harmless-base',
63 | 'hh-harmless-dialogue',
64 | 'hh-dialogue/harmless-base',
65 | )
66 | DATA_DIR: str = 'harmless-base'
67 |
68 |
69 | class HhRLHFHelpfulDialogueDataset(HhRLHFDialogueDataset):
70 | NAME: str = 'hh-rlhf-helpful-dialogue'
71 | ALIASES: tuple[str, ...] = (
72 | 'hh-rlhf-dialogue/helpful-base',
73 | 'hh-helpful-dialogue',
74 | 'hh-dialogue/helpful-base',
75 | )
76 | DATA_DIR: str = 'helpful-base'
77 |
78 |
79 | class HhRLHFPreferenceDataset(RawDataset):
80 | NAME: ClassVar[str] = 'hh-rlhf-preference'
81 | ALIASES: tuple[str, ...] = ('hh-preference',)
82 | DATA_DIR: ClassVar[str | None] = None
83 | SPLIT: ClassVar[str]
84 |
85 | def __init__(self, path: str | None = None) -> None:
86 | self.data = load_dataset(
87 | path or 'PKU-Alignment/processed-hh-rlhf',
88 | data_dir=self.DATA_DIR,
89 | split=self.SPLIT,
90 | )
91 |
92 | def __getitem__(self, index: int) -> RawSample:
93 | data = self.data[index]
94 | dialogue = [content['text'] for content in data['context']]
95 | answer = data['chosen']['text']
96 | other_answer = data['rejected']['text']
97 |
98 | return RawSample(
99 | input=dialogue,
100 | answer=answer,
101 | other_answer=other_answer,
102 | better=True,
103 | )
104 |
105 | def __len__(self) -> int:
106 | return len(self.data)
107 |
108 |
109 | class HhRLHFHarmlessPreferenceTrainDataset(HhRLHFPreferenceDataset):
110 | NAME: str = 'hh-rlhf-harmless-preference/train'
111 | ALIASES: tuple[str, ...] = (
112 | 'hh-rlhf-preference/harmless-base/train',
113 | 'hh-harmless-preference/train',
114 | 'hh-preference/harmless-base/train',
115 | )
116 | DATA_DIR: str = 'harmless-base'
117 | SPLIT: str = 'train'
118 |
119 |
120 | class HhRLHFHarmlessPreferenceTestDataset(HhRLHFPreferenceDataset):
121 | NAME: str = 'hh-rlhf-harmless-preference/test'
122 | ALIASES: tuple[str, ...] = (
123 | 'hh-rlhf-preference/harmless-base/test',
124 | 'hh-harmless-preference/test',
125 | 'hh-preference/harmless-base/test',
126 | )
127 | DATA_DIR: str = 'harmless-base'
128 | SPLIT: str = 'test'
129 |
130 |
131 | class HhRLHFHelpfulPreferenceTrainDataset(HhRLHFPreferenceDataset):
132 | NAME: str = 'hh-rlhf-helpful-preference/train'
133 | ALIASES: tuple[str, ...] = (
134 | 'hh-rlhf-preference/helpful-base/train',
135 | 'hh-helpful-preference/train',
136 | 'hh-preference/helpful-base/train',
137 | )
138 | DATA_DIR: str = 'helpful-base'
139 | SPLIT: str = 'train'
140 |
141 |
142 | class HhRLHFHelpfulPreferenceTestDataset(HhRLHFPreferenceDataset):
143 | NAME: str = 'hh-rlhf-helpful-preference/test'
144 | ALIASES: tuple[str, ...] = (
145 | 'hh-rlhf-preference/helpful-base/test',
146 | 'hh-helpful-preference/test',
147 | 'hh-preference/helpful-base/test',
148 | )
149 | DATA_DIR: str = 'helpful-base'
150 | SPLIT: str = 'test'
151 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/math.py:
--------------------------------------------------------------------------------
1 | """MATH datasets."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | from typing import ClassVar
7 | from datasets import load_dataset
8 |
9 | from mcts_rl.datasets.base import RawDataset, RawSample
10 | from mcts_rl.utils import extract_answer, get_math_data, list_to_dict, get_arithmo_data
11 |
12 |
13 | __all__ = [
14 | 'MATHDataset',
15 | 'MATHTrainDataset',
16 | 'MATHTestDataset',
17 | 'MATHSFTTrainDataset',
18 | 'MATHSFTTestDataset',
19 | ]
20 |
21 |
22 | class MATHDataset(RawDataset):
23 | SPLIT: ClassVar[str]
24 | DTYPE: ClassVar[str]
25 |
26 | def __init__(self) -> None:
27 | self.data = load_dataset('hendrycks/competition_math', split=self.SPLIT, trust_remote_code=True)
28 | if self.DTYPE == 'arithmo':
29 | math_dict = list_to_dict(self.data)
30 | arithmo_dict = list_to_dict(get_math_data(load_dataset('akjindal53244/Arithmo-Data', split=self.SPLIT)))
31 | arithmo = {k:v for k, v in arithmo_dict.items() if k in math_dict}
32 | self.data = [vv for v in arithmo.values() for vv in v]
33 | # self.data = get_arithmo_data(arithmo)
34 |
35 |
36 | def __getitem__(self, index: int) -> RawSample:
37 | data = self.data[index]
38 | solution = data['solution']
39 | answer = extract_answer(solution)
40 | if self.DTYPE == 'default':
41 | solution = f'{solution}\nThe answer is {answer}'
42 | return RawSample(
43 | input=data['problem'] if 'problem' in data else data['question'],
44 | answer=solution,
45 | final_answer=answer,
46 | final_answer_content=answer,
47 | )
48 |
49 | def __len__(self) -> int:
50 | return len(self.data)
51 |
52 |
53 | class MATHTrainDataset(MATHDataset):
54 | NAME: str = 'MATH/train'
55 | SPLIT: str = 'train'
56 | DTYPE: str = 'default'
57 |
58 |
59 | class MATHTestDataset(MATHDataset):
60 | NAME: str = 'MATH/test'
61 | SPLIT: str = 'test'
62 | DTYPE: str = 'default'
63 |
64 |
65 | class MATHSFTTrainDataset(MATHDataset):
66 | NAME: str = 'MATHSFT/train'
67 | SPLIT: str = 'train'
68 | DTYPE: str = 'arithmo'
69 |
70 |
71 | class MATHSFTTestDataset(MATHDataset):
72 | NAME: str = 'MATHSFT/test'
73 | SPLIT: str = 'test'
74 | DTYPE: str = 'arithmo'
75 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/math_qa.py:
--------------------------------------------------------------------------------
1 | """MATH datasets."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | from typing import ClassVar
7 |
8 | from datasets import load_dataset
9 | from mcts_rl.utils import get_math_data, get_arithmo_data, list_to_dict, tqdm
10 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
11 |
12 |
13 | __all__ = [
14 | 'MathQADataset',
15 | 'MathQATrainDataset',
16 | 'MathQATestDataset',
17 | 'MathQACodeTrainDataset',
18 | 'MathQACodeTestDataset',
19 | 'MathQASFTTrainDataset',
20 | ]
21 |
22 | DATA_DIR = "path_to_dataset_folder"
23 |
24 |
25 | class MathQADataset(RawDataset):
26 | SPLIT: ClassVar[str]
27 | TYPE: ClassVar[str]
28 |
29 | def __init__(self) -> None:
30 | if self.TYPE == 'pot':
31 | raise ValueError('Do not Support PoT for now.')
32 | ## PoT data
33 | gsm8k = jsonlines_load(os.path.join(DATA_DIR, f'gsm8k/gsm8k_{self.SPLIT}.jsonl'))
34 | raw_arithmo = jsonlines_load(os.path.join(DATA_DIR, f'arithmo/arithmo_code_{self.SPLIT}.jsonl'))
35 | arithmo = []
36 | for dt in raw_arithmo:
37 | prompt = dt['question'] if 'question' in dt else dt['problem']
38 | if all(not prompt.strip().startswith(x['question'].strip()) for x in gsm8k):
39 | arithmo.append(dt)
40 | for i, dt in enumerate(gsm8k):
41 | gsm8k[i]['question'] = dt['question'] + ' Write a Python program to solve this.'
42 | self.data = gsm8k
43 | elif self.TYPE == 'all':
44 | raise ValueError('Do not Support PoT for now.')
45 | ## CoT + PoT data
46 | gsm8k = jsonlines_load(os.path.join(DATA_DIR, f'gsm8k/gsm8k_{self.SPLIT}.jsonl'))
47 | math = jsonlines_load(os.path.join(DATA_DIR, f'math/math_{self.SPLIT}.jsonl'))
48 | self.data = gsm8k + math
49 | if self.SPLIT == 'train':
50 | raw_arithmo = jsonlines_load(os.path.join(DATA_DIR, f'arithmo/arithmo_code_{self.SPLIT}.jsonl'))
51 | for dt in raw_arithmo[::-1]:
52 | prompt = dt['question'] if 'question' in dt else dt['problem']
53 | if any(prompt.strip().startswith(x['question'].strip()) for x in gsm8k):
54 | self.data.append(dt)
55 | else:
56 | for i, dt in enumerate(gsm8k[:]):
57 | dt['question'] = dt['question'] + ' Write a Python program to solve this.'
58 | self.data.append(dt)
59 | else:
60 | gsm8k = load_dataset('openai/gsm8k', 'main', split=self.SPLIT, trust_remote_code=True)
61 | math = load_dataset('hendrycks/competition_math', split=self.SPLIT, trust_remote_code=True)
62 | try:
63 | arithmo = get_math_data(load_dataset('akjindal53244/Arithmo-Data', split=self.SPLIT))
64 | except:
65 | arithmo = get_math_data(jsonlines_load(os.path.join(DATA_DIR, 'arithmo/train.jsonl')))
66 | if self.TYPE == 'sft':
67 | arithmo, gsm8k, math = list_to_dict(arithmo), list_to_dict(gsm8k), list_to_dict(math)
68 | ## use the corresponding training data seen in SFT
69 | mathqa_dict = {k:v for k,v in arithmo.items() if k in math or k in gsm8k}
70 | self.data = [vv for v in mathqa_dict.values() for vv in v]
71 | # self.data = get_arithmo_data(mathqa_dict)
72 | else:
73 | self.data = gsm8k + math
74 |
75 | def __getitem__(self, index: int) -> RawSample:
76 | data = self.data[index]
77 | prompt = data['question'] if 'question' in data else data['problem']
78 | return RawSample(
79 | input=prompt,
80 | answer=data['solution'],
81 | final_answer=data.get('answer', None),
82 | final_answer_content=data.get('answer_content', data.get('answer', None)),
83 | )
84 |
85 | def __len__(self) -> int:
86 | return len(self.data)
87 |
88 |
89 | class MathQASFTTrainDataset(MathQADataset):
90 | NAME: str = 'MathQASFT/train'
91 | SPLIT: str = 'train'
92 | TYPE: str = 'sft'
93 |
94 |
95 | class MathQATrainDataset(MathQADataset):
96 | NAME: str = 'MathQA/train'
97 | SPLIT: str = 'train'
98 | TYPE: str = 'cot'
99 |
100 |
101 | class MathQAAllTrainDataset(MathQADataset):
102 | NAME: str = 'MathQAAll/train'
103 | SPLIT: str = 'train'
104 | TYPE: str = 'all'
105 |
106 |
107 | class MathQAAllTestDataset(MathQADataset):
108 | NAME: str = 'MathQAAll/test'
109 | SPLIT: str = 'test'
110 | TYPE: str = 'all'
111 |
112 |
113 | class MathQATestDataset(MathQADataset):
114 | NAME: str = 'MathQA/test'
115 | SPLIT: str = 'test'
116 | TYPE: str = 'cot'
117 |
118 |
119 | class MathQACodeTrainDataset(MathQADataset):
120 | NAME: str = 'MathQACode/train'
121 | SPLIT: str = 'train'
122 | TYPE: str = 'pot'
123 |
124 |
125 | class MathQACodeTestDataset(MathQADataset):
126 | NAME: str = 'MathQACode/test'
127 | SPLIT: str = 'test'
128 | TYPE: str = 'pot'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/mcq.py:
--------------------------------------------------------------------------------
1 | """CSR Datasets"""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | from typing import ClassVar
7 |
8 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
9 |
10 |
11 | __all__ = ['MCQDataset']
12 |
13 | DATA_DIR = "path_to_dataset_folder"
14 |
15 |
16 | class MCQDataset(RawDataset):
17 | SPLIT: ClassVar[str]
18 | DTYPE: ClassVar[str]
19 |
20 | def __init__(self, path: str | None = None) -> None:
21 | if self.DTYPE == 'all':
22 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'csr/mcq_{self.SPLIT}.jsonl'))
23 | else:
24 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'csr/mcq_{self.DTYPE}_{self.SPLIT}.jsonl'))
25 |
26 | def __getitem__(self, index: int) -> RawSample:
27 | data = self.data[index]
28 | question = data['question']
29 | return RawSample(input=question, final_answer=data['answer'],
30 | final_answer_content=data.get('answer_content', data['answer']))
31 |
32 | def __len__(self) -> int:
33 | return len(self.data)
34 |
35 |
36 | class MCQTrainDataset(MCQDataset):
37 | NAME: str = 'MCQ/train'
38 | DTYPE: str = 'all'
39 | SPLIT: str = 'train'
40 |
41 |
42 | class MCQTestDataset(MCQDataset):
43 | NAME: str = 'MCQ/test'
44 | DTYPE: str = 'all'
45 | SPLIT: str = 'test'
46 |
47 |
48 | class SQATrainDataset(MCQDataset):
49 | NAME: str = 'SQA/train'
50 | DTYPE: str = 'sqa'
51 | SPLIT: str = 'train'
52 |
53 |
54 | class CSRTrainDataset(MCQDataset):
55 | NAME: str = 'CSR/train'
56 | DTYPE: str = 'csqa'
57 | SPLIT: str = 'train'
58 |
59 |
60 | class SQATestDataset(MCQDataset):
61 | NAME: str = 'SQA/test'
62 | DTYPE: str = 'sqa'
63 | SPLIT: str = 'fulltest'
64 |
65 |
66 | class CSRTestDataset(MCQDataset):
67 | NAME: str = 'CSR/test'
68 | DTYPE: str = 'csqa'
69 | SPLIT: str = 'test'
70 |
71 |
72 | class SciQTestDataset(MCQDataset):
73 | NAME: str = 'SciQ/test'
74 | DTYPE: str = 'sciq'
75 | SPLIT: str = 'test'
76 |
77 |
78 | class NLITestDataset(MCQDataset):
79 | NAME: str = 'NLI/test'
80 | DTYPE: str = 'nli'
81 | SPLIT: str = 'test'
82 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/mcq_for_eval.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | import os
5 | from typing import ClassVar
6 |
7 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
8 | from mcts_rl.configs.constants import HINTED_EVAL_PROMPT
9 |
10 |
11 | __all__ = [
12 | 'MCQEvalDataset',
13 | ]
14 |
15 | DATA_DIR = "path_to_dataset_folder"
16 |
17 |
18 | class MCQEvalDataset(RawDataset):
19 | SPLIT: ClassVar[str]
20 | DTYPE: ClassVar[str]
21 |
22 | def __init__(self, path: str | None = None) -> None:
23 | data = jsonlines_load(os.path.join(DATA_DIR, 'arithmo/stepwise_generations.jsonl'))
24 | self.data = []
25 | for dt in data:
26 | prompt = dt['question']
27 | solution = dt['solution']
28 | eval_prompt = HINTED_EVAL_PROMPT.format(
29 | input=f'{prompt}\n\n',
30 | solution=dt['answer'],
31 | prompt=solution,
32 | ).replace('\n\nANSWER: The answer is', '').strip()
33 | self.data.append({
34 | 'question': eval_prompt,
35 | 'answer': '',
36 | })
37 |
38 |
39 | def __getitem__(self, index: int) -> RawSample:
40 | data = self.data[index]
41 | return RawSample(
42 | input=data['question'],
43 | final_answer=data['answer'],
44 | )
45 |
46 | def __len__(self) -> int:
47 | return len(self.data)
48 |
49 |
50 | class SQAEvalTrainDataset(MCQEvalDataset):
51 | NAME: str = 'SQAEval/train'
52 | DTYPE: str = 'sqa_all'
53 | SPLIT: str = 'train'
54 |
55 |
56 | class SQAEvalTestDataset(MCQEvalDataset):
57 | NAME: str = 'SQAEval/test'
58 | DTYPE: str = 'sqa'
59 | SPLIT: str = 'train'
60 |
61 |
62 | class CSREvalTrainDataset(MCQEvalDataset):
63 | NAME: str = 'CSREval/train'
64 | DTYPE: str = 'csr_all'
65 | SPLIT: str = 'train'
66 |
67 |
68 | class CSREvalTestDataset(MCQEvalDataset):
69 | NAME: str = 'CSREval/test'
70 | DTYPE: str = 'csr'
71 | SPLIT: str = 'test'
72 |
73 |
74 | class GSMEvalTrainDataset(MCQEvalDataset):
75 | NAME: str = 'GSMEval/train'
76 | DTYPE: str = 'gsm_all'
77 | SPLIT: str = 'train'
78 |
79 |
80 | class GSMEvalTestDataset(MCQEvalDataset):
81 | NAME: str = 'GSMEval/test'
82 | DTYPE: str = 'gsm'
83 | SPLIT: str = 'train'
84 |
85 |
86 | class GSMEvalTestDataset(MCQEvalDataset):
87 | NAME: str = 'arithmo/test'
88 | DTYPE: str = 'arithmo'
89 | SPLIT: str = 'train'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/mcq_pairs.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | import os
5 | from typing import ClassVar
6 |
7 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
8 |
9 |
10 | __all__ = [
11 | 'MCQPreferenceDataset',
12 | ]
13 |
14 | DATA_DIR = "path_to_dataset_folder"
15 |
16 |
17 | class MCQPreferenceDataset(RawDataset):
18 | SPLIT: ClassVar[str]
19 | DTYPE: ClassVar[str]
20 |
21 | def __init__(self, path: str | None = None) -> None:
22 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'{self.DTYPE}_pairs_{self.SPLIT}.jsonl'))
23 |
24 | def __getitem__(self, index: int) -> RawSample:
25 | data = self.data[index]
26 | return RawSample(
27 | input=data['prompt'].replace('QUESTION: ', ''),
28 | answer=f"\n{data['response_0']}",
29 | other_answer=f"\n{data['response_1']}",
30 | better=True,
31 | is_safe=bool(data['is_response_0_correct']),
32 | is_other_safe=bool(data['is_response_1_correct']),
33 | )
34 |
35 | def __len__(self) -> int:
36 | return len(self.data)
37 |
38 |
39 | class SQAPreferenceTrainDataset(MCQPreferenceDataset):
40 | NAME: str = 'SQAPreference/train'
41 | DTYPE: str = 'sqa_all'
42 | SPLIT: str = 'train'
43 |
44 |
45 | class SQAPreferenceTestDataset(MCQPreferenceDataset):
46 | NAME: str = 'SQAPreference/test'
47 | DTYPE: str = 'sqa'
48 | SPLIT: str = 'train'
49 |
50 |
51 | class CSRPreferenceTrainDataset(MCQPreferenceDataset):
52 | NAME: str = 'CSRPreference/train'
53 | DTYPE: str = 'csr_all'
54 | SPLIT: str = 'train'
55 |
56 |
57 | class CSRPreferenceTestDataset(MCQPreferenceDataset):
58 | NAME: str = 'CSRPreference/test'
59 | DTYPE: str = 'csr'
60 | SPLIT: str = 'test'
61 |
62 |
63 | class GSMPreferenceTrainDataset(MCQPreferenceDataset):
64 | NAME: str = 'GSMPreference/train'
65 | DTYPE: str = 'gsm'
66 | SPLIT: str = 'train'
67 |
68 |
69 | class GSMPreferenceTestDataset(MCQPreferenceDataset):
70 | NAME: str = 'GSMPreference/test'
71 | DTYPE: str = 'gsm'
72 | SPLIT: str = 'test'
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/moss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """MOSS datasets for supervised instruction fine-tuning."""
16 |
17 | from __future__ import annotations
18 |
19 | import json
20 | import pathlib
21 | import re
22 | import zipfile
23 |
24 | from mcts_rl.datasets.base import RawDataset, RawSample
25 |
26 |
27 | __all__ = ['MOSS002SFT', 'MOSS003SFT']
28 |
29 |
30 | class MOSS002SFT(RawDataset):
31 | NAME: str = 'moss-002-sft'
32 | PATTERN = re.compile(
33 | r"""
34 | ^
35 | \[(?P(?PHuman)|(?PMOSS))\]:
36 | \s*
37 | (?P.*?)
38 | \s*
39 | (?(human)|)
40 | \s*
41 | """,
42 | flags=re.DOTALL | re.VERBOSE,
43 | )
44 |
45 | def __init__(self, path: str | None = None) -> None:
46 | if path is None: # fnlp/moss-002-sft-data cannot load with `load_dataset`
47 | raise ValueError('moss-002-sft dataset requires a local path to the dataset.')
48 |
49 | path = pathlib.Path(path).expanduser().absolute()
50 | if not path.exists():
51 | raise ValueError('moss-002-sft dataset path does not exist.')
52 | if not path.is_dir():
53 | raise ValueError('moss-002-sft dataset path is not a directory.')
54 |
55 | data_files = sorted(path.glob('*.json'))
56 | self.data = []
57 | for file in data_files:
58 | with file.open(mode='rt', encoding='utf-8') as f:
59 | self.data.extend(json.load(f))
60 |
61 | def __getitem__(self, index: int) -> RawSample:
62 | data = self.data[index]
63 | plain_text = data['plain_text'].strip()
64 | if not plain_text.startswith(('[Human]:', '[MOSS]:')):
65 | raise ValueError(f'Invalid plain text: {plain_text}')
66 |
67 | dialogue = []
68 | text = plain_text
69 | while len(text) > 0:
70 | match = self.PATTERN.match(text)
71 | if match is None:
72 | raise ValueError(f'Invalid plain text: {plain_text}')
73 | if (match.group('human') is not None and len(dialogue) % 2 != 0) or (
74 | match.group('assistant') is not None and len(dialogue) % 2 != 1
75 | ):
76 | raise ValueError(f'Invalid plain text: {plain_text}')
77 | dialogue.append(match.group('value'))
78 | text = text[match.end() :]
79 |
80 | return RawSample(dialogue=dialogue)
81 |
82 | def __len__(self) -> int:
83 | return len(self.data)
84 |
85 |
86 | class MOSS003SFT(RawDataset):
87 | NAME: str = 'moss-003-sft'
88 |
89 | def __init__(self, path: str | None = None) -> None:
90 | if path is None: # fnlp/moss-003-sft-data cannot load with `load_dataset`
91 | raise ValueError('moss-003-sft dataset requires a local path to the dataset.')
92 |
93 | path = pathlib.Path(path).expanduser().absolute()
94 | if not path.exists():
95 | raise ValueError('moss-003-sft dataset path does not exist.')
96 | if not path.is_dir():
97 | raise ValueError('moss-003-sft dataset path is not a directory.')
98 |
99 | data_file = path / 'moss-003-sft-no-tools.jsonl'
100 | archive_file = path / 'moss-003-sft-no-tools.jsonl.zip'
101 |
102 | if not data_file.exists():
103 | if not archive_file.exists():
104 | raise ValueError('moss-003-sft dataset requires a local path to the dataset.')
105 | with zipfile.ZipFile(archive_file, mode='r') as archive:
106 | archive.extractall(path)
107 |
108 | self.data = []
109 | with data_file.open(mode='rt', encoding='utf-8') as f:
110 | for line in f:
111 | self.data.append(json.loads(line))
112 |
113 | def __getitem__(self, index: int) -> RawSample:
114 | data = self.data[index]
115 | num_turns = data['num_turns']
116 | chat = data['chat']
117 | dialogue = []
118 | for i in range(1, num_turns + 1):
119 | turn = chat[f'turn_{i}']
120 | dialogue.append(turn['Human'].replace('<|Human|>:', '').replace('', '').strip())
121 | dialogue.append(turn['MOSS'].replace('<|MOSS|>', '').replace('', '').strip())
122 | return RawSample(dialogue=dialogue)
123 |
124 | def __len__(self) -> int:
125 | return len(self.data)
126 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/prm800k.py:
--------------------------------------------------------------------------------
1 | """PRM800K preference datasets."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | from typing import ClassVar
7 |
8 | from mcts_rl.datasets.base import RawDataset, RawSample, jsonlines_load
9 |
10 |
11 | __all__ = [
12 | 'PRM800KDataset',
13 | 'PRM800KTrainDataset',
14 | 'PRM800KTestDataset',
15 | ]
16 |
17 | DATA_DIR = "path_to_dataset_folder"
18 |
19 |
20 | class PRM800KDataset(RawDataset):
21 | SPLIT: ClassVar[str]
22 | PATH: ClassVar[str]
23 |
24 | def __init__(self, path: str | None = None) -> None:
25 | self.data = jsonlines_load(os.path.join(DATA_DIR, f'prm800k/preference_prm_{self.SPLIT}.jsonl'))
26 |
27 | def __getitem__(self, index: int) -> RawSample:
28 | data = self.data[index]
29 | prompt = data['prompt']
30 | # prompt = f'{prompt}\n\nANSWER: {data["solution"]["solution"]}\nThe answer is {data["solution"]["answer"]}\n\nQUESTION: {prompt}'
31 | # from mcts_rl.utils import extract_answer, math_equal
32 | # if not math_equal(extract_answer(data.get('generation-base', 'None')), data.get('answer', None)):
33 | # prompt = '{}\n\nANSWER: {}\n\nREVISION REQUEST: Please revise the above answer to get the correct answer {}.'.format(
34 | # prompt, data.get('generation-base', 'None'), data.get('answer', 'None'),
35 | # )
36 | # import ipdb; ipdb.set_trace()
37 | return RawSample(
38 | input=prompt,
39 | # answer=data['response_0'],
40 | answer=data['solution']['solution'],
41 | other_answer=data['response_1'],
42 | better=int(data['better_response_id']) == 0,
43 | safer=int(data['better_response_id']) == 0,
44 | is_safe=bool(data['is_response_0_correct_answer']),
45 | is_other_safe=bool(data['is_response_1_correct_answer']),
46 | final_answer=data.get('answer', None),
47 | )
48 |
49 | def __len__(self) -> int:
50 | return len(self.data)
51 |
52 |
53 | class PRM800KTrainDataset(PRM800KDataset):
54 | NAME: str = 'PRM800K/train'
55 | ALIASES: tuple[str, ...] = ('OpenAI/PRM800K/train',)
56 | PATH: str = 'OpenAI/PRM800K'
57 | SPLIT: str = 'train'
58 |
59 |
60 | class PRM800KTestDataset(PRM800KDataset):
61 | NAME: str = 'PRM800K/test'
62 | ALIASES: tuple[str, ...] = ('OpenAI/PRM800K/test',)
63 | PATH: str = 'OpenAI/PRM800K'
64 | SPLIT: str = 'test'
65 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/qa_feedback.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import ClassVar
5 |
6 | from datasets import load_dataset
7 | from mcts_rl.datasets.base import RawDataset, RawSample
8 | from mcts_rl.configs.constants import QA_EVAL_PROMPT, PROMPT_ASSISTANT
9 |
10 |
11 | __all__ = [
12 | 'QAFBDataset',
13 | 'QAFBTrainDataset',
14 | 'QAFBTestDataset',
15 | ]
16 |
17 |
18 | class QAFBDataset(RawDataset):
19 | SPLIT: ClassVar[str]
20 | PATH: ClassVar[str]
21 |
22 | def __init__(self, path: str | None = None) -> None:
23 | self.data = load_dataset(path or self.PATH, split=self.SPLIT)
24 |
25 | def __getitem__(self, index: int) -> RawSample:
26 | data = self.data[index]
27 | question = data['question']
28 | # question = QA_EVAL_PROMPT.format(input=question, prompt=PROMPT_ASSISTANT + f' {data["answer"]}').rstrip('ASSISTANT: ').rstrip()
29 | return RawSample(
30 | input=question,
31 | final_answer=data['answer'],
32 | final_answer_content='\n'.join([f'{r}: {e}' for r, e in zip(data['feedback']['rating'], data['feedback']['explanation'])])
33 | )
34 |
35 | def __len__(self) -> int:
36 | return len(self.data)
37 |
38 |
39 | class QAFBTrainDataset(QAFBDataset):
40 | NAME: str = 'FeedbackQA/train'
41 | PATH: str = 'McGill-NLP/feedbackQA'
42 | SPLIT: str = 'train'
43 |
44 |
45 | class QAFBTestDataset(QAFBDataset):
46 | NAME: str = 'FeedbackQA/test'
47 | PATH: str = 'McGill-NLP/feedbackQA'
48 | SPLIT: str = 'test'
49 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/raw/safe_rlhf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Safe-RLHF preference datasets."""
16 |
17 | from __future__ import annotations
18 |
19 | from typing import ClassVar
20 |
21 | from datasets import load_dataset
22 | from mcts_rl.datasets.base import RawDataset, RawSample
23 |
24 |
25 | __all__ = [
26 | 'SafeRLHFDataset',
27 | 'SafeRLHFTrainDataset',
28 | 'SafeRLHFTestDataset',
29 | 'SafeRLHF10KTrainDataset',
30 | ]
31 |
32 |
33 | class SafeRLHFDataset(RawDataset):
34 | SPLIT: ClassVar[str]
35 | PATH: ClassVar[str]
36 |
37 | def __init__(self, path: str | None = None) -> None:
38 | self.data = load_dataset(path or self.PATH, split=self.SPLIT)
39 |
40 | def __getitem__(self, index: int) -> RawSample:
41 | data = self.data[index]
42 | return RawSample(
43 | input=data['prompt'],
44 | answer=data['response_0'],
45 | other_answer=data['response_1'],
46 | better=int(data['better_response_id']) == 0,
47 | safer=int(data['safer_response_id']) == 0,
48 | is_safe=bool(data['is_response_0_safe']),
49 | is_other_safe=bool(data['is_response_1_safe']),
50 | )
51 |
52 | def __len__(self) -> int:
53 | return len(self.data)
54 |
55 |
56 | class SafeRLHFTrainDataset(SafeRLHFDataset):
57 | NAME: str = 'PKU-SafeRLHF/train'
58 | ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF/train',)
59 | PATH: str = 'PKU-Alignment/PKU-SafeRLHF'
60 | SPLIT: str = 'train'
61 |
62 |
63 | class SafeRLHFTestDataset(SafeRLHFDataset):
64 | NAME: str = 'PKU-SafeRLHF/test'
65 | ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF/test',)
66 | PATH: str = 'PKU-Alignment/PKU-SafeRLHF'
67 | SPLIT: str = 'test'
68 |
69 |
70 | class SafeRLHF10KTrainDataset(SafeRLHFDataset):
71 | NAME: str = 'PKU-SafeRLHF-10K/train'
72 | ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-10K/train',)
73 | PATH: str = 'PKU-Alignment/PKU-SafeRLHF-10K'
74 | SPLIT: str = 'train'
75 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/safety_preference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Callable
19 | from typing_extensions import TypedDict # Python 3.10+
20 |
21 | import torch
22 |
23 | from mcts_rl.datasets.base import CollatorBase, RawSample, TokenizedDataset
24 | from mcts_rl.datasets.utils import format_prompt, right_padding
25 |
26 |
27 | __all__ = [
28 | 'SafetyPreferenceDataset',
29 | 'SafetyPreferenceCollator',
30 | 'SafetyPreferenceSample',
31 | 'SafetyPreferenceBatch',
32 | ]
33 |
34 |
35 | class SafetyPreferenceSample(TypedDict, total=True):
36 | safer_input_ids: torch.LongTensor # size = (L,)
37 | # +1 for safe / -1 for unsafe
38 | safer_sign: torch.LongTensor # size = (L,)
39 |
40 | unsafer_input_ids: torch.LongTensor # size = (L,)
41 | # +1 for safe / -1 for unsafe
42 | unsafer_sign: torch.LongTensor # size = (L,)
43 |
44 |
45 | class SafetyPreferenceBatch(TypedDict, total=True):
46 | safer_input_ids: torch.LongTensor # size = (B, L)
47 | safer_attention_mask: torch.BoolTensor # size = (B, L)
48 | # +1 for safe / -1 for unsafe
49 | safer_safety_sign: torch.LongTensor # size = (B,)
50 |
51 | unsafer_input_ids: torch.LongTensor # size = (B, L)
52 | unsafer_attention_mask: torch.BoolTensor # size = (B, L)
53 | # +1 for safe / -1 for unsafe
54 | unsafer_safety_sign: torch.LongTensor # size = (B,)
55 |
56 |
57 | class SafetyPreferenceDataset(TokenizedDataset):
58 | def preprocess(self, raw_sample: RawSample) -> SafetyPreferenceSample:
59 | prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token)
60 | answer = raw_sample['answer']
61 | other_answer = raw_sample['other_answer']
62 | safer = raw_sample['safer']
63 | is_safe = raw_sample['is_safe']
64 | is_other_safe = raw_sample['is_other_safe']
65 |
66 | safer_answer, unsafer_answer = answer, other_answer
67 | safer_sign, unsafer_sign = ( # +1 for safe / -1 for unsafe
68 | 2 * int(is_safe) - 1,
69 | 2 * int(is_other_safe) - 1,
70 | )
71 | if not safer:
72 | safer_answer, unsafer_answer = unsafer_answer, safer_answer
73 | safer_sign, unsafer_sign = unsafer_sign, safer_sign
74 |
75 | if safer_sign < unsafer_sign:
76 | raise ValueError(
77 | 'The safer answer is not safer than the unsafer answer.\n\n'
78 | f'Prompt: {prompt}\n\n'
79 | f'Safer answer (labeled as unsafe): {safer_answer}\n\n'
80 | f'Unsafer answer (labeled as safe): {unsafer_answer}',
81 | )
82 |
83 | # size = (L,)
84 | safer_input_ids = self.tokenize(prompt + safer_answer + self.tokenizer.eos_token)
85 | unsafer_input_ids = self.tokenize(prompt + unsafer_answer + self.tokenizer.eos_token)
86 | if (
87 | safer_input_ids.size() == unsafer_input_ids.size()
88 | and torch.all(torch.eq(safer_input_ids, unsafer_input_ids)).item()
89 | ):
90 | raise ValueError(
91 | 'Two responses get the same `input_ids` after tokenization.\n\n'
92 | f'Prompt: {prompt}\n\n'
93 | f'Safer answer: {safer_answer}\n\n'
94 | f'Unsafer answer: {unsafer_answer}',
95 | )
96 | return {
97 | 'safer_input_ids': safer_input_ids, # size = (L,)
98 | 'safer_sign': torch.tensor(safer_sign), # size = ()
99 | 'unsafer_input_ids': unsafer_input_ids, # size = (L,)
100 | 'unsafer_sign': torch.tensor(unsafer_sign), # size = ()
101 | }
102 |
103 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
104 | return SafetyPreferenceCollator(self.tokenizer.pad_token_id)
105 |
106 |
107 | class SafetyPreferenceCollator(CollatorBase):
108 | def __call__(self, samples: list[SafetyPreferenceSample]) -> SafetyPreferenceBatch:
109 | input_ids = [sample['safer_input_ids'] for sample in samples] + [
110 | sample['unsafer_input_ids'] for sample in samples
111 | ]
112 | attention_mask = [
113 | input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids
114 | ]
115 | safety_sign = [sample['safer_sign'] for sample in samples] + [
116 | sample['unsafer_sign'] for sample in samples
117 | ]
118 |
119 | # size = (2 * B, L)
120 | input_ids = right_padding(input_ids, padding_value=self.pad_token_id)
121 | attention_mask = right_padding(attention_mask, padding_value=0)
122 | # size = (2 * B,)
123 | safety_sign = torch.tensor(safety_sign, dtype=torch.long)
124 |
125 | # size = (B, L)
126 | safer_input_ids, unsafer_input_ids = input_ids.chunk(chunks=2, dim=0)
127 | safer_attention_mask, unsafer_attention_mask = attention_mask.chunk(chunks=2, dim=0)
128 | # size = (B,)
129 | safer_safety_sign, unsafer_safety_sign = safety_sign.chunk(chunks=2, dim=0)
130 | return {
131 | 'safer_input_ids': safer_input_ids, # size = (B, L)
132 | 'safer_attention_mask': safer_attention_mask, # size = (B, L)
133 | 'safer_safety_sign': safer_safety_sign, # size = (B,)
134 | 'unsafer_input_ids': unsafer_input_ids, # size = (B, L)
135 | 'unsafer_attention_mask': unsafer_attention_mask, # size = (B, L)
136 | 'unsafer_safety_sign': unsafer_safety_sign, # size = (B,)
137 | }
138 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/supervised.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Callable
19 | from typing_extensions import TypedDict # Python 3.10+
20 |
21 | import torch
22 |
23 | from mcts_rl.configs import IGNORE_INDEX, PROMPT_ASSISTANT, PROMPT_BEGIN, PROMPT_USER
24 | from mcts_rl.datasets.base import CollatorBase, RawSample, TokenizedDataset
25 | from mcts_rl.datasets.utils import format_prompt, right_padding
26 |
27 |
28 | __all__ = [
29 | 'SupervisedDataset',
30 | 'SupervisedCollator',
31 | 'SupervisedSample',
32 | 'SupervisedBatch',
33 | ]
34 |
35 |
36 | class SupervisedSample(TypedDict, total=True):
37 | input_ids: torch.LongTensor # size = (L,)
38 | labels: torch.LongTensor # size = (L,)
39 |
40 |
41 | class SupervisedBatch(TypedDict, total=True):
42 | input_ids: torch.LongTensor # size = (B, L)
43 | labels: torch.LongTensor # size = (B, L)
44 | attention_mask: torch.BoolTensor # size = (B, L)
45 |
46 |
47 | class SupervisedDataset(TokenizedDataset):
48 | def preprocess(self, raw_sample: RawSample) -> SupervisedSample:
49 | if raw_sample.get('input') is None and raw_sample.get('dialogue') is None:
50 | raise ValueError('Either `input` or `dialogue` must be provided.')
51 | if raw_sample.get('input') is not None and raw_sample.get('dialogue') is not None:
52 | raise ValueError('At most one of `input` and `dialogue` can be provided.')
53 |
54 | if raw_sample.get('input') is not None:
55 | input = raw_sample['input'] # pylint: disable=redefined-builtin
56 | if not isinstance(input, str):
57 | raise ValueError(f'Unsupported type of `input`: {type(input)}. Expected: str.')
58 | prompt = format_prompt(input=input, eos_token=self.tokenizer.eos_token)
59 | answer = raw_sample['answer']
60 | if len(raw_sample) == 2 and PROMPT_ASSISTANT.endswith('think step by step.'):
61 | prompt = prompt.split(PROMPT_ASSISTANT)[0] + 'ANSWER: '
62 | if raw_sample.get('final_answer', None) is not None:
63 | text = f'{prompt}{answer}\nThe answer is: {raw_sample["final_answer"]}' + self.tokenizer.eos_token
64 | else:
65 | text = prompt + answer + self.tokenizer.eos_token
66 | input_ids = self.tokenize(text)
67 | labels = input_ids.clone()
68 | # Mask non-assistant input
69 | labels[: len(self.tokenize(prompt))] = IGNORE_INDEX
70 | return {'input_ids': input_ids, 'labels': labels}
71 |
72 | dialogue = raw_sample['dialogue'] # is not None
73 | text = PROMPT_BEGIN
74 | offsets = [0]
75 | input_ids = torch.empty(0, dtype=torch.long)
76 | for i, line in enumerate(dialogue):
77 | if i % 2 == 0:
78 | # User input
79 | text += PROMPT_USER.format(input=line) + PROMPT_ASSISTANT
80 | else:
81 | # Assistant input
82 | text += line + self.tokenizer.eos_token
83 | input_ids = self.tokenize(text)
84 | offsets.append(len(input_ids))
85 |
86 | labels = input_ids.clone()
87 | # Mask non-assistant input
88 | for begin, end in zip(offsets[::2], offsets[1::2]):
89 | labels[begin:end] = IGNORE_INDEX
90 |
91 | return {
92 | 'input_ids': input_ids, # size = (L,)
93 | 'labels': labels, # size = (L,)
94 | }
95 |
96 | def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
97 | return SupervisedCollator(self.tokenizer.pad_token_id)
98 |
99 |
100 | class SupervisedCollator(CollatorBase):
101 | def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch:
102 | input_ids = right_padding(
103 | [sample['input_ids'] for sample in samples],
104 | padding_value=self.pad_token_id,
105 | )
106 | labels = right_padding(
107 | [sample['labels'] for sample in samples],
108 | padding_value=IGNORE_INDEX,
109 | )
110 | attention_mask = input_ids.ne(self.pad_token_id)
111 | return {
112 | 'input_ids': input_ids, # size = (B, L)
113 | 'labels': labels, # size = (B, L)
114 | 'attention_mask': attention_mask, # size = (B, L)
115 | }
116 |
--------------------------------------------------------------------------------
/mcts_rl/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | import random
19 | random.seed(0)
20 |
21 | import torch
22 | from torch.nn.utils.rnn import pad_sequence
23 | from torch.types import Number
24 |
25 | from mcts_rl.configs import (
26 | PROMPT_BEGIN, PROMPT_USER,
27 | PROMPT_ASSISTANT, PROMPT_ASSISTANT_MCQ,
28 | SQA_PROMPT,
29 | LLAMA3_PROMPT_USER,
30 | LLAMA3_PROMPT_ASSISTANT,
31 | LLAMA3_PROMPT_ASSISTANT_MCQ,
32 | )
33 |
34 |
35 | def format_prompt(
36 | input: str | list[str], # pylint: disable=redefined-builtin
37 | eos_token: str,
38 | use_mcq: bool = False,
39 | few_shot: bool = False,
40 | model_type: str = 'mistral',
41 | ) -> str:
42 | if isinstance(input, str):
43 | input = [input]
44 | elif not isinstance(input, list):
45 | raise ValueError(f'Unsupported type of `input`: {type(input)}. Expected: str or list[str].')
46 |
47 | if len(input) % 2 != 1:
48 | raise ValueError(
49 | 'The length of `input` must be odd, while `input` must end at the user question.',
50 | )
51 |
52 | if 'USER:' in PROMPT_USER:
53 | buffer = [PROMPT_BEGIN]
54 | elif few_shot:
55 | # buffer = [GSM8K_PROMPT]
56 | buffer = [SQA_PROMPT]
57 | else:
58 | # exp = random.choice(GSM8K_EXP)
59 | # buffer = [PROMPT_USER.format(input=exp['Q']) + PROMPT_ASSISTANT + ' ' + exp['A'] + DEFAULT_EOS_TOKEN + '\n\n']
60 | # buffer = ['At the end of your answer output #### {final answer}.\n\n']
61 | buffer = []
62 |
63 | for i, line in enumerate(input):
64 | if i % 2 == 0:
65 | # User input
66 | if model_type == 'llama3':
67 | buffer.extend((LLAMA3_PROMPT_USER.format(input=line),
68 | LLAMA3_PROMPT_ASSISTANT_MCQ if use_mcq else LLAMA3_PROMPT_ASSISTANT))
69 | else:
70 | buffer.extend((PROMPT_USER.format(input=line),
71 | PROMPT_ASSISTANT_MCQ if use_mcq else PROMPT_ASSISTANT))
72 | else:
73 | # Assistant response
74 | buffer.extend((line, eos_token))
75 | return ''.join(buffer)
76 |
77 |
78 | def right_padding(sequences: list[torch.Tensor], padding_value: Number) -> torch.Tensor:
79 | return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
80 |
81 |
82 | def left_padding(sequences: list[torch.Tensor], padding_value: Number) -> torch.Tensor:
83 | return right_padding(
84 | [seq.flip(0) for seq in sequences],
85 | padding_value=padding_value,
86 | ).flip(1)
87 |
--------------------------------------------------------------------------------
/mcts_rl/finetune/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Supervised Fine-Tuning (SFT)."""
16 |
17 | from mcts_rl.finetune.trainer import SupervisedFinetuneTrainer
18 |
19 |
20 | __all__ = ['SupervisedFinetuneTrainer']
21 |
--------------------------------------------------------------------------------
/mcts_rl/finetune/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The main training script to supervised finetune a model."""
16 |
17 | import sys
18 |
19 | from mcts_rl.finetune.main import main
20 |
21 |
22 | if __name__ == '__main__':
23 | sys.exit(main())
24 |
--------------------------------------------------------------------------------
/mcts_rl/finetune/huggingface.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The main training script to supervised finetune a model using Hugging Face Transformers Trainer."""
16 |
17 | import argparse
18 | from dataclasses import dataclass, field
19 | from typing import List, Optional, Tuple, Union
20 |
21 | import transformers
22 | from transformers.training_args import OptimizerNames
23 |
24 | from mcts_rl.datasets import SupervisedDataset, parse_dataset
25 | from mcts_rl.models import load_pretrained_models
26 |
27 |
28 | @dataclass
29 | class ModelArguments:
30 | """Arguments for models."""
31 |
32 | model_name_or_path: str
33 |
34 |
35 | @dataclass
36 | class DataArguments:
37 | """Arguments for datasets."""
38 |
39 | datasets: List[parse_dataset] = field(
40 | default=None,
41 | metadata={'help': 'Path to the training data.'},
42 | )
43 |
44 |
45 | @dataclass
46 | class TrainingArguments(transformers.TrainingArguments):
47 | """Arguments for the training loop."""
48 |
49 | cache_dir: Optional[str] = field(default=None)
50 | optim: Union[OptimizerNames, str] = field(
51 | default=OptimizerNames.ADAMW_TORCH,
52 | metadata={'help': 'The optimizer to use.'},
53 | )
54 | model_max_length: int = field(
55 | default=512,
56 | metadata={
57 | 'help': 'Maximum sequence length. Sequences will be right padded (and possibly truncated).',
58 | },
59 | )
60 |
61 |
62 | def parse_arguments() -> Tuple[argparse.Namespace, argparse.Namespace, argparse.Namespace]:
63 | """Parse the command-line arguments."""
64 | parser = transformers.HfArgumentParser([TrainingArguments, ModelArguments, DataArguments])
65 | # pylint: disable-next=unbalanced-tuple-unpacking
66 | training_args, model_args, data_args = parser.parse_args_into_dataclasses()
67 | return training_args, model_args, data_args
68 |
69 |
70 | def main() -> None:
71 | """Main training routine."""
72 | # pylint: disable=no-member
73 | training_args, model_args, data_args = parse_arguments()
74 |
75 | model, tokenizer = load_pretrained_models(
76 | model_args.model_name_or_path,
77 | model_max_length=training_args.model_max_length,
78 | padding_side='right',
79 | cache_dir=training_args.cache_dir,
80 | trust_remote_code=True,
81 | )
82 |
83 | train_dataset = SupervisedDataset(
84 | data_args.datasets,
85 | tokenizer=tokenizer,
86 | seed=training_args.seed,
87 | )
88 | data_collator = train_dataset.get_collator()
89 |
90 | trainer = transformers.Trainer(
91 | model=model,
92 | tokenizer=tokenizer,
93 | args=training_args,
94 | train_dataset=train_dataset,
95 | data_collator=data_collator,
96 | )
97 | trainer.train()
98 | trainer.save_state()
99 | trainer.save_model()
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/mcts_rl/finetune/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The main training script to supervised finetune a model."""
16 |
17 | from mcts_rl.finetune.deepspeed import main
18 |
19 |
20 | if __name__ == '__main__':
21 | main()
22 |
--------------------------------------------------------------------------------
/mcts_rl/finetune/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Trainer class for supervised finetuning."""
16 |
17 | from __future__ import annotations
18 |
19 | from typing import Any
20 | from tqdm import tqdm
21 |
22 | import torch
23 | import torch.distributed as dist
24 | from transformers import AutoModelForCausalLM
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 |
27 | from mcts_rl.datasets import SupervisedDataset
28 | from mcts_rl.trainers import SupervisedTrainer
29 | from mcts_rl.utils import (
30 | get_all_reduce_mean,
31 | is_main_process,
32 | to_device,
33 | )
34 |
35 |
36 | class SupervisedFinetuneTrainer(SupervisedTrainer):
37 | """Trainer class for supervised finetuning."""
38 |
39 | TRAINING_TYPE = 'sft'
40 | DATASET_TYPE = SupervisedDataset
41 | MODEL_TYPE = AutoModelForCausalLM
42 |
43 | def loss(
44 | self,
45 | input_ids: torch.LongTensor, # size = (B, L)
46 | labels: torch.LongTensor, # size = (B, L)
47 | attention_mask: torch.BoolTensor, # size = (B, L)
48 | ) -> dict[str, torch.Tensor]:
49 | """Loss function for supervised finetuning."""
50 | outputs: CausalLMOutputWithPast = self.model(
51 | input_ids=input_ids,
52 | attention_mask=attention_mask,
53 | labels=labels,
54 | )
55 | return {
56 | 'loss': outputs.loss,
57 | }
58 |
59 | def eval(
60 | self,
61 | ) -> dict[str, Any]:
62 | if self.eval_dataloader is None:
63 | return {}
64 |
65 | self.set_eval()
66 | eval_dataloader = tqdm(
67 | self.eval_dataloader,
68 | desc='Evaluating',
69 | disable=not is_main_process(),
70 | position=1,
71 | leave=False,
72 | )
73 |
74 | losses, batch = [], None
75 | for batch in eval_dataloader:
76 | batch = to_device(batch, self.args.device)
77 | with torch.no_grad():
78 | loss = self.loss(
79 | input_ids=batch['input_ids'],
80 | labels=batch['labels'],
81 | attention_mask=batch['attention_mask'],
82 | )['loss']
83 | losses.extend([loss])
84 |
85 | if batch is None:
86 | self.logger.print('WARNING: `eval_dataloader` is empty.')
87 | return {}
88 |
89 | losses = torch.stack(losses, dim=0)
90 | if is_main_process():
91 | gathered_losses = [torch.empty_like(losses) for _ in range(dist.get_world_size())]
92 | else:
93 | gathered_losses = []
94 | dist.gather(losses, gathered_losses, dst=0)
95 | if is_main_process():
96 | losses = torch.cat(gathered_losses, dim=0)
97 |
98 | self.set_train()
99 |
100 | return {
101 | 'eval/loss': losses.mean().item(),
102 | }
103 |
104 | def train_step(
105 | self,
106 | input_ids: torch.LongTensor, # size = (B, L)
107 | labels: torch.LongTensor, # size = (B, L)
108 | attention_mask: torch.BoolTensor, # size = (B, L)
109 | ) -> dict[str, Any]:
110 | """Performs a single training step.
111 |
112 | Args:
113 | input_ids (torch.LongTensor): input ids for causal inputs to complete with.
114 | labels (torch.LongTensor): labels for the full sequence.
115 | attention_mask (torch.BoolTensor): attention mask for the labels.
116 |
117 | Returns:
118 | dict[str, Any]: training loss, learning rate
119 | """
120 | loss = self.loss(
121 | input_ids=input_ids,
122 | labels=labels,
123 | attention_mask=attention_mask,
124 | )['loss']
125 | self.model.backward(loss)
126 | self.model.step()
127 |
128 | loss = get_all_reduce_mean(loss)
129 |
130 | return {
131 | 'train/loss': loss.item(),
132 | 'train/lr': self.model.optimizer.param_groups[0]['lr'],
133 | }
134 |
--------------------------------------------------------------------------------
/mcts_rl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utility functions for Hugging Face auto-models."""
16 |
17 | from mcts_rl.models.pretrained import load_pretrained_models
18 | from mcts_rl.models.score_model import AutoModelForScore, ScoreModelOutput
19 |
20 |
21 | __all__ = ['load_pretrained_models', 'AutoModelForScore', 'ScoreModelOutput']
22 |
--------------------------------------------------------------------------------
/mcts_rl/models/normalizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Normalizer for score models."""
16 |
17 | from __future__ import annotations
18 |
19 | from abc import abstractmethod
20 | from typing import Any, Literal
21 |
22 | import torch
23 | from torch import nn
24 | from torch.types import Number
25 |
26 |
27 | NormalizeFunction = Literal['affine', 'scale', 'translate', 'identity']
28 | NormalizerType = Literal['RunningMeanStd', 'ExponentialMovingAverage']
29 |
30 |
31 | class Normalizer(nn.Module):
32 | """Normalize input to have zero mean and unit variance."""
33 |
34 | mean: torch.Tensor
35 | var: torch.Tensor
36 | count: torch.LongTensor
37 | normalize_function: NormalizeFunction
38 |
39 | def __init__(
40 | self,
41 | normalize_function: NormalizeFunction,
42 | shape: tuple[int, ...],
43 | device: torch.device | str | None = None,
44 | ) -> None:
45 | """Initialize."""
46 | super().__init__()
47 | if normalize_function not in {'affine', 'scale', 'translate', 'identity'}:
48 | raise ValueError(
49 | f'Invalid normalization function type: {normalize_function}. ',
50 | 'Expected one of "affine", "scale", "translate", "identity".',
51 | )
52 | self.normalize_function = normalize_function
53 | self.register_buffer('mean', torch.zeros(shape, device=device))
54 | self.register_buffer('var', torch.ones(shape, device=device))
55 | self.register_buffer('count', torch.zeros(1, dtype=torch.long, device=device))
56 |
57 | @abstractmethod
58 | def update(self, data: torch.Tensor) -> None:
59 | """Update mean and variance."""
60 | raise NotImplementedError
61 |
62 | @property
63 | def std(self) -> torch.Tensor:
64 | """Return standard deviation."""
65 | return self.var.sqrt()
66 |
67 | def set_mean_var(
68 | self,
69 | mean: torch.Tensor | list[float] | tuple[float, ...] | None,
70 | var: torch.Tensor | list[float] | tuple[float, ...] | None,
71 | ) -> None:
72 | """Set mean and variance."""
73 | mean = (
74 | torch.as_tensor(mean, dtype=self.mean.dtype, device=self.mean.device)
75 | if mean is not None
76 | else self.mean
77 | )
78 | var = (
79 | torch.as_tensor(var, dtype=self.var.dtype, device=self.var.device)
80 | if var is not None
81 | else self.var
82 | )
83 |
84 | assert mean.shape == self.mean.shape
85 | assert var.shape == self.var.shape
86 |
87 | self.mean = mean
88 | self.var = var
89 |
90 | def forward(
91 | self,
92 | data: torch.Tensor,
93 | epsilon: Number = 1e-8,
94 | ) -> torch.Tensor:
95 | """Update and normalize input."""
96 | if self.training:
97 | self.update(data)
98 | return self.normalize(data, epsilon=epsilon)
99 |
100 | def normalize(
101 | self,
102 | data: torch.Tensor,
103 | epsilon: Number = 1e-8,
104 | ) -> torch.Tensor:
105 | """Normalize input."""
106 | if self.normalize_function == 'affine':
107 | return (data - self.mean.detach()) / (self.std.detach() + epsilon)
108 | if self.normalize_function == 'scale':
109 | return data / (self.std.detach() + epsilon)
110 | if self.normalize_function == 'translate':
111 | return data - self.mean.detach()
112 | if self.normalize_function == 'identity':
113 | return data
114 | raise ValueError(
115 | f'Invalid normalization function type: {self.normalize_function}. ',
116 | 'Expected one of "affine", "scale", "translate", "identity".',
117 | )
118 |
119 | @classmethod
120 | def instantiate(
121 | cls,
122 | normalizer_type: NormalizerType | None,
123 | normalize_function: NormalizeFunction,
124 | shape: tuple[int, ...],
125 | device: torch.device | str | None = None,
126 | **kwargs: Any,
127 | ) -> Normalizer:
128 | """Get a normalizer."""
129 | if normalizer_type == 'RunningMeanStd':
130 | return RunningMeanStd(
131 | normalize_function,
132 | shape=shape,
133 | device=device,
134 | )
135 | if normalizer_type == 'ExponentialMovingAverage':
136 | return ExponentialMovingAverage(
137 | normalize_function,
138 | shape=shape,
139 | device=device,
140 | **kwargs,
141 | )
142 | if normalizer_type is None:
143 | return IdentityNormalizer(
144 | normalize_function,
145 | shape=shape,
146 | device=device,
147 | )
148 | raise ValueError(
149 | f'Invalid normalization function type: {normalizer_type}. '
150 | 'Expected one of "RunningMeanStd", "ExponentialMovingAverage".',
151 | )
152 |
153 |
154 | class RunningMeanStd(Normalizer):
155 | """Running mean and standard deviation."""
156 |
157 | def update(self, data: torch.Tensor) -> None:
158 | """Update mean and variance."""
159 | batch_mean = data.mean(dim=0)
160 | batch_var = data.var(dim=0)
161 | batch_count = data.size(0)
162 |
163 | delta = batch_mean - self.mean
164 | total_count = self.count + batch_count
165 |
166 | new_mean = self.mean + delta * batch_count / total_count
167 | m_a = self.var * self.count
168 | m_b = batch_var * batch_count
169 | m2 = ( # pylint: disable=invalid-name
170 | m_a + m_b + torch.square(delta) * (self.count * batch_count / total_count)
171 | )
172 | new_var = m2 / total_count
173 |
174 | self.mean = new_mean
175 | self.var = new_var
176 | self.count = total_count
177 |
178 |
179 | class ExponentialMovingAverage(Normalizer):
180 | """Exponential moving average."""
181 |
182 | def __init__(
183 | self,
184 | normalize_function: NormalizeFunction,
185 | shape: tuple[int, ...],
186 | device: torch.device | str | None = None,
187 | momentum: float = 0.9,
188 | ) -> None:
189 | super().__init__(normalize_function, shape=shape, device=device)
190 | self.momentum = momentum
191 |
192 | def update(self, data: torch.Tensor) -> None:
193 | """Update mean and variance."""
194 | batch_mean = data.mean(dim=0)
195 | batch_var = data.var(dim=0)
196 | batch_count = data.size(0)
197 |
198 | self.mean = self.momentum * self.mean + (1.0 - self.momentum) * batch_mean
199 | self.var = self.momentum * self.var + (1.0 - self.momentum) * batch_var
200 | self.count += batch_count # pylint: disable=no-member
201 |
202 |
203 | class IdentityNormalizer(Normalizer):
204 | """Identity normalizer."""
205 |
206 | def update(self, data: torch.Tensor) -> None:
207 | """Update mean and variance."""
208 | self.count += data.size(0) # pylint: disable=no-member
209 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Auto-models for score models."""
16 |
17 | from __future__ import annotations
18 |
19 | import functools
20 | import importlib
21 | from collections import OrderedDict
22 | from dataclasses import dataclass
23 | from typing import Any
24 |
25 | import torch
26 | import torch.nn as nn
27 | import transformers.models.auto as auto_module
28 | from torch import distributed as dist
29 | from transformers import PretrainedConfig
30 | from transformers.models.auto.auto_factory import (
31 | _BaseAutoModelClass,
32 | _LazyAutoMapping,
33 | auto_class_update,
34 | getattribute_from_module,
35 | )
36 | from transformers.models.auto.configuration_auto import (
37 | CONFIG_MAPPING_NAMES,
38 | model_type_to_module_name,
39 | )
40 | from transformers.utils.generic import ModelOutput
41 |
42 | from mcts_rl.models.normalizer import NormalizeFunction, Normalizer
43 |
44 |
45 | class _LazyAutoMappingInSafeRLHF(_LazyAutoMapping):
46 | def _load_attr_from_module(self, model_type: str, attr: str) -> Any:
47 | module_name = model_type_to_module_name(model_type)
48 | if module_name not in self._modules:
49 | self._modules[module_name] = importlib.import_module(
50 | f'.{module_name}',
51 | 'mcts_rl.models.score_model',
52 | )
53 | return getattribute_from_module(self._modules[module_name], attr)
54 |
55 |
56 | MODEL_FOR_SCORE_MAPPING_NAMES: OrderedDict[str, str] = OrderedDict(
57 | [
58 | # Score model mapping
59 | ('llama', 'LlamaModelForScore'),
60 | ('bloom', 'BloomModelForScore'),
61 | ('open_llama', 'OpenLlamaForScore'),
62 | ('opt', 'OPTForScore'),
63 | ('gpt_neo', 'GPTNeoForScore'),
64 | ('gptj', 'GPTJForScore'),
65 | ('gpt2', 'GPT2ForScore'),
66 | ('gpt_neox', 'GPTNeoXForScore'),
67 | ('mistral', 'MistralModelForScore'),
68 | ],
69 | )
70 | MODEL_FOR_SCORE_MAPPING: OrderedDict[str, Any] = _LazyAutoMappingInSafeRLHF(
71 | CONFIG_MAPPING_NAMES,
72 | MODEL_FOR_SCORE_MAPPING_NAMES,
73 | )
74 |
75 |
76 | @functools.partial(auto_class_update, head_doc='score model')
77 | class AutoModelForScore(_BaseAutoModelClass):
78 | _model_mapping: OrderedDict[str, Any] = MODEL_FOR_SCORE_MAPPING
79 |
80 |
81 | setattr(auto_module, 'MODEL_FOR_SCORE_MAPPING', MODEL_FOR_SCORE_MAPPING) # noqa: B010
82 | setattr(auto_module, AutoModelForScore.__name__, AutoModelForScore)
83 |
84 |
85 | @dataclass
86 | class ScoreModelOutput(ModelOutput):
87 | """
88 | Output of the score model.
89 |
90 | Args:
91 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim, sequence_length)`):
92 | Prediction scores of the score model.
93 | end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim)`):
94 | Prediction scores of the end of the sequence.
95 | """
96 |
97 | scores: torch.Tensor | None = None # size = (B, L, D)
98 | end_scores: torch.Tensor | None = None # size = (B, D)
99 |
100 |
101 | class ScoreModelMixin:
102 | """Base class for score models."""
103 |
104 | score_head: nn.Linear
105 | normalizer: Normalizer
106 | do_normalize: bool = False
107 | normalize_function: NormalizeFunction = 'affine'
108 | _initialized: bool = False
109 |
110 | def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: Any) -> None:
111 | """Initialize the score head."""
112 | if self._initialized:
113 | return
114 |
115 | config.score_dim = kwargs.pop('score_dim', getattr(config, 'score_dim', 1))
116 | config.bias = kwargs.pop('bias', getattr(config, 'bias', False))
117 |
118 | config.score_type = kwargs.pop('score_type', getattr(config, 'score_type', 'reward'))
119 | if config.score_type == 'reward':
120 | self.normalize_function = 'affine'
121 | elif config.score_type == 'cost':
122 | self.normalize_function = 'scale'
123 | elif config.score_type == 'critic':
124 | self.normalize_function = 'identity'
125 | else:
126 | raise ValueError(
127 | f"Invalid score type: {config.score_type}. Expected one of 'reward', 'cost', or 'critic'.",
128 | )
129 |
130 | config.do_normalize = kwargs.pop(
131 | 'do_normalize',
132 | getattr(config, 'do_normalize', False),
133 | )
134 | self.do_normalize = config.do_normalize
135 |
136 | config.normalizer_type = kwargs.pop(
137 | 'normalizer_type',
138 | getattr(config, 'normalizer_type', None),
139 | )
140 | if config.normalizer_type not in {'RunningMeanStd', 'ExponentialMovingAverage', None}:
141 | raise ValueError(
142 | f'Invalid norm type: {config.normalizer_type}.'
143 | "Expected one of 'RunningMeadStd', 'ExponentialMovingAverage', or None.",
144 | )
145 | if config.normalizer_type == 'ExponentialMovingAverage':
146 | config.momentum = kwargs.pop('momentum', getattr(config, 'momentum', None))
147 | momentum = getattr(config, 'momentum', None)
148 |
149 | self.score_head = nn.Linear(hidden_size, config.score_dim, bias=config.bias)
150 | self.normalizer = Normalizer.instantiate(
151 | normalizer_type=config.normalizer_type,
152 | normalize_function=self.normalize_function,
153 | shape=(config.score_dim,),
154 | momentum=momentum,
155 | )
156 |
157 | mean = getattr(config, 'mean', None)
158 | var = getattr(config, 'var', None)
159 | self.normalizer.set_mean_var(mean, var)
160 |
161 | self._initialized = True
162 |
163 | def get_score(
164 | self,
165 | hidden_state: torch.Tensor, # size = (B, L, E)
166 | attention_mask: torch.BoolTensor, # size = (B, L)
167 | return_dict: bool | None = None,
168 | ) -> ScoreModelOutput:
169 | """Forward pass of the score model."""
170 | scores = self.score_head(hidden_state) # size = (B, L, D)
171 |
172 | end_score = []
173 | for i in range(hidden_state.size(0)):
174 | end_index = attention_mask[i].nonzero()[-1].item()
175 | end_score.append(scores[i, end_index]) # size = (D,)
176 | end_score = torch.stack(end_score, dim=0) # size = (B, D)
177 |
178 | if self.training:
179 | if dist.is_initialized():
180 | gathered_end_score_list = [
181 | torch.zeros_like(end_score) for _ in range(dist.get_world_size())
182 | ]
183 | dist.all_gather(gathered_end_score_list, end_score)
184 | gathered_end_score = torch.cat(gathered_end_score_list, dim=0)
185 | self.normalizer.update(gathered_end_score)
186 | else:
187 | self.normalizer.update(end_score)
188 | self.config.mean = self.normalizer.mean.tolist()
189 | self.config.var = self.normalizer.var.tolist()
190 |
191 | if self.do_normalize:
192 | scores = self.normalizer.normalize(scores)
193 | end_score = self.normalizer.normalize(end_score)
194 |
195 | if not return_dict:
196 | return scores, end_score
197 |
198 | return ScoreModelOutput(
199 | scores=scores, # size = (B, L, D)
200 | end_scores=end_score, # size = (B, D)
201 | )
202 |
203 | def set_normalize(self, mode: bool = True) -> None:
204 | if self.do_normalize == mode:
205 | return
206 |
207 | self.do_normalize = self.config.do_normalize = mode
208 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/bloom/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.bloom.modeling_bloom import BloomModelForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/bloom/modeling_bloom.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | import warnings
19 | from typing import Any, ClassVar
20 |
21 | import torch
22 | from transformers import BloomConfig, BloomModel, BloomPreTrainedModel, PreTrainedModel
23 | from transformers.models.bloom.modeling_bloom import (
24 | _CHECKPOINT_FOR_DOC,
25 | _CONFIG_FOR_DOC,
26 | BLOOM_INPUTS_DOCSTRING,
27 | )
28 | from transformers.utils.doc import add_code_sample_docstrings, add_start_docstrings_to_model_forward
29 |
30 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
31 |
32 |
33 | class BloomModelForScore(ScoreModelMixin, BloomPreTrainedModel):
34 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [
35 | 'h.*.self_attention.scale_mask_softmax.causal_mask',
36 | 'lm_head.weight',
37 | ]
38 |
39 | def __init__(self, config: BloomConfig, **kwargs: Any) -> None:
40 | super().__init__(config)
41 | self.transformer = BloomModel(config)
42 |
43 | config.architectures = [self.__class__.__name__]
44 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
45 |
46 | # Initialize weights and apply final processing
47 | self.post_init()
48 |
49 | def get_output_embeddings(self) -> None:
50 | return None
51 |
52 | def set_decoder(self, decoder: PreTrainedModel) -> None:
53 | self.transformer = decoder
54 |
55 | def get_decoder(self) -> PreTrainedModel:
56 | return self.transformer
57 |
58 | @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
59 | @add_code_sample_docstrings(
60 | checkpoint=_CHECKPOINT_FOR_DOC,
61 | output_type=ScoreModelOutput,
62 | config_class=_CONFIG_FOR_DOC,
63 | )
64 | def forward( # pylint: disable=too-many-arguments
65 | self,
66 | input_ids: torch.LongTensor | None = None,
67 | past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...] | None = None,
68 | attention_mask: torch.Tensor | None = None,
69 | head_mask: torch.Tensor | None = None,
70 | inputs_embeds: torch.Tensor | None = None,
71 | use_cache: bool | None = None,
72 | output_attentions: bool | None = None,
73 | output_hidden_states: bool | None = None,
74 | return_dict: bool | None = None,
75 | **deprecated_arguments: dict[str, Any],
76 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
77 | """
78 | Args:
79 |
80 | Returns:
81 |
82 | Examples:
83 |
84 | ```python
85 | >>> from safe_rlhf.models.llama.modeling_llama import LlamaModelForScore
86 | >>> from transformers import LlamaTokenizer
87 |
88 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
89 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
90 |
91 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
92 | >>> inputs = tokenizer(prompt, return_tensors="pt")
93 |
94 | # got score
95 | >>> outputs = model(**inputs)
96 | >>> scores = outputs.scores
97 | >>> scores
98 | tensor([[[0.0000]]])
99 | ```
100 | """
101 | assert attention_mask is not None
102 | if deprecated_arguments.pop('position_ids', False) is not False:
103 | # `position_ids` could have been `torch.Tensor` or `None`
104 | # so defaulting pop to `False` allows to detect if users were passing explicitly `None`
105 | warnings.warn(
106 | '`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. '
107 | 'You can safely ignore passing `position_ids`.',
108 | FutureWarning,
109 | stacklevel=1,
110 | )
111 | if len(deprecated_arguments) > 0:
112 | raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
113 |
114 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
115 |
116 | transformer_outputs = self.transformer(
117 | input_ids,
118 | past_key_values=past_key_values,
119 | attention_mask=attention_mask,
120 | head_mask=head_mask,
121 | inputs_embeds=inputs_embeds,
122 | use_cache=use_cache,
123 | output_attentions=output_attentions,
124 | output_hidden_states=output_hidden_states,
125 | return_dict=return_dict,
126 | )
127 | hidden_states = transformer_outputs[0] # size = (B, L, E)
128 | return self.get_score(
129 | hidden_states,
130 | attention_mask=attention_mask,
131 | return_dict=return_dict,
132 | )
133 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gpt2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.gpt2.modeling_gpt2 import GPT2ForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gpt2/modeling_gpt2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | import warnings
19 | from typing import Any, ClassVar
20 |
21 | import torch
22 | from transformers import GPT2Model, GPT2PreTrainedModel, PreTrainedModel
23 | from transformers.configuration_utils import PretrainedConfig
24 | from transformers.models.gpt2.modeling_gpt2 import (
25 | DEPARALLELIZE_DOCSTRING,
26 | GPT2_INPUTS_DOCSTRING,
27 | GPT2_START_DOCSTRING,
28 | PARALLELIZE_DOCSTRING,
29 | )
30 | from transformers.utils.doc import add_start_docstrings, add_start_docstrings_to_model_forward
31 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
32 |
33 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
34 |
35 |
36 | @add_start_docstrings(
37 | """
38 | The GPT2 Model transformer with a score head on top (linear layer with weights tied to the input
39 | embeddings).
40 | """,
41 | GPT2_START_DOCSTRING,
42 | )
43 | class GPT2ForScore(ScoreModelMixin, GPT2PreTrainedModel):
44 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [
45 | 'attn.masked_bias',
46 | 'attn.bias',
47 | 'lm_head.weight',
48 | ]
49 |
50 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
51 | super().__init__(config)
52 | self.transformer = GPT2Model(config)
53 |
54 | config.architectures = [self.__class__.__name__]
55 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
56 |
57 | # Model parallel
58 | self.model_parallel = False
59 | self.device_map = None
60 |
61 | # Initialize weights and apply final processing
62 | self.post_init()
63 |
64 | @add_start_docstrings(PARALLELIZE_DOCSTRING)
65 | def parallelize(self, device_map: str | None = None) -> None:
66 | warnings.warn(
67 | '`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load'
68 | " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
69 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
70 | " 0, 'transformer.h.1': 1, ...}",
71 | FutureWarning,
72 | stacklevel=1,
73 | )
74 | self.device_map = (
75 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
76 | if device_map is None
77 | else device_map
78 | )
79 | assert_device_map(self.device_map, len(self.transformer.h))
80 | self.transformer.parallelize(self.device_map)
81 | self.score_head = self.score_head.to(self.transformer.first_device)
82 | self.model_parallel = True
83 |
84 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
85 | def deparallelize(self) -> None:
86 | warnings.warn(
87 | 'Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.',
88 | FutureWarning,
89 | stacklevel=1,
90 | )
91 | self.transformer.deparallelize()
92 | self.transformer = self.transformer.to('cpu')
93 | self.score_head = self.score_head.to('cpu')
94 | self.model_parallel = False
95 | torch.cuda.empty_cache()
96 |
97 | def get_output_embeddings(self) -> None:
98 | return None
99 |
100 | def set_decoder(self, decoder: PreTrainedModel) -> None:
101 | self.transformer = decoder
102 |
103 | def get_decoder(self) -> PreTrainedModel:
104 | return self.transformer
105 |
106 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
107 | def forward( # pylint: disable=too-many-arguments
108 | self,
109 | input_ids: torch.LongTensor | None = None,
110 | attention_mask: torch.FloatTensor | None = None,
111 | past_key_values: tuple[tuple[torch.Tensor]] | None = None,
112 | token_type_ids: torch.LongTensor | None = None,
113 | position_ids: torch.LongTensor | None = None,
114 | head_mask: torch.FloatTensor | None = None,
115 | inputs_embeds: torch.FloatTensor | None = None,
116 | encoder_hidden_states: torch.Tensor | None = None,
117 | encoder_attention_mask: torch.FloatTensor | None = None,
118 | use_cache: bool | None = None,
119 | output_attentions: bool | None = None,
120 | output_hidden_states: bool | None = None,
121 | return_dict: bool | None = None,
122 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
123 | """
124 | Args:
125 |
126 | Returns:
127 |
128 | Examples:
129 |
130 | ```python
131 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore
132 | >>> from transformers import LlamaTokenizer
133 |
134 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
135 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
136 |
137 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
138 | >>> inputs = tokenizer(prompt, return_tensors="pt")
139 |
140 | # got score
141 | >>> outputs = model(**inputs)
142 | >>> scores = outputs.scores
143 | >>> scores
144 | tensor([[[0.0000]]])
145 | ```
146 | """
147 | assert attention_mask is not None
148 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
149 |
150 | transformer_outputs = self.transformer(
151 | input_ids,
152 | past_key_values=past_key_values,
153 | attention_mask=attention_mask,
154 | token_type_ids=token_type_ids,
155 | position_ids=position_ids,
156 | head_mask=head_mask,
157 | inputs_embeds=inputs_embeds,
158 | encoder_hidden_states=encoder_hidden_states,
159 | encoder_attention_mask=encoder_attention_mask,
160 | use_cache=use_cache,
161 | output_attentions=output_attentions,
162 | output_hidden_states=output_hidden_states,
163 | return_dict=return_dict,
164 | )
165 | hidden_states = transformer_outputs[0] # size = (B, L, E)
166 |
167 | # Set device for model parallelism
168 | if self.model_parallel:
169 | torch.cuda.set_device(self.transformer.first_device)
170 | hidden_states = hidden_states.to(self.lm_head.weight.device)
171 |
172 | return self.get_score(
173 | hidden_states,
174 | attention_mask=attention_mask,
175 | return_dict=return_dict,
176 | )
177 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gpt_neo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.gpt_neo.modeling_gpt_neo import GPTNeoForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gpt_neo/modeling_gpt_neo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Any, ClassVar
19 |
20 | import torch
21 | from transformers import GPTNeoModel, GPTNeoPreTrainedModel, PretrainedConfig, PreTrainedModel
22 | from transformers.models.gpt_neo.modeling_gpt_neo import (
23 | GPT_NEO_INPUTS_DOCSTRING,
24 | GPT_NEO_START_DOCSTRING,
25 | )
26 | from transformers.utils.doc import add_start_docstrings, add_start_docstrings_to_model_forward
27 |
28 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
29 |
30 |
31 | @add_start_docstrings(
32 | """
33 | The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
34 | embeddings).
35 | """,
36 | GPT_NEO_START_DOCSTRING,
37 | )
38 | class GPTNeoForScore(ScoreModelMixin, GPTNeoPreTrainedModel):
39 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [
40 | r'h\.\d+\.attn\.masked_bias',
41 | r'lm_head.weight',
42 | r'h\.\d+\.attn\.attention\.bias',
43 | ]
44 | _keys_to_ignore_on_save: ClassVar[list[str]] = [r'lm_head.weight']
45 |
46 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
47 | super().__init__(config)
48 | self.transformer = GPTNeoModel(config)
49 |
50 | config.architectures = [self.__class__.__name__]
51 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
52 |
53 | # Initialize weights and apply final processing
54 | self.post_init()
55 |
56 | def get_output_embeddings(self) -> None:
57 | return None
58 |
59 | def set_decoder(self, decoder: PreTrainedModel) -> None:
60 | self.transformer = decoder
61 |
62 | def get_decoder(self) -> PreTrainedModel:
63 | return self.transformer
64 |
65 | @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
66 | def forward( # pylint: disable=too-many-arguments
67 | self,
68 | input_ids: torch.Tensor | None = None,
69 | past_key_values: tuple[torch.FloatTensor] | None = None,
70 | attention_mask: torch.Tensor | None = None,
71 | token_type_ids: torch.Tensor | None = None,
72 | position_ids: torch.Tensor | None = None,
73 | head_mask: torch.Tensor | None = None,
74 | inputs_embeds: torch.Tensor | None = None,
75 | use_cache: bool | None = None,
76 | output_attentions: bool | None = None,
77 | output_hidden_states: bool | None = None,
78 | return_dict: bool | None = None,
79 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
80 | r"""
81 | Args:
82 |
83 | Returns:
84 |
85 | Examples:
86 |
87 | ```python
88 | >>> from safe_rlhf.models.llama.modeling_llama import LlamaModelForScore
89 | >>> from transformers import LlamaTokenizer
90 |
91 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
92 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
93 |
94 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
95 | >>> inputs = tokenizer(prompt, return_tensors="pt")
96 |
97 | # got score
98 | >>> outputs = model(**inputs)
99 | >>> scores = outputs.scores
100 | >>> scores
101 | tensor([[[0.0000]]])
102 | ```
103 | """
104 | assert attention_mask is not None
105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
106 |
107 | outputs = self.transformer(
108 | input_ids,
109 | past_key_values=past_key_values,
110 | attention_mask=attention_mask,
111 | token_type_ids=token_type_ids,
112 | position_ids=position_ids,
113 | head_mask=head_mask,
114 | inputs_embeds=inputs_embeds,
115 | use_cache=use_cache,
116 | output_attentions=output_attentions,
117 | output_hidden_states=output_hidden_states,
118 | return_dict=return_dict,
119 | )
120 | hidden_states = outputs[0] # size = (B, L, E)
121 | return self.get_score(
122 | hidden_states,
123 | attention_mask=attention_mask,
124 | return_dict=return_dict,
125 | )
126 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gpt_neox/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.gpt_neox.modeling_gpt_neox import GPTNeoXForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gpt_neox/modeling_gpt_neox.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Any, ClassVar
19 |
20 | import torch
21 | from transformers import GPTNeoXModel, LlamaPreTrainedModel, PreTrainedModel
22 | from transformers.configuration_utils import PretrainedConfig
23 | from transformers.models.gpt_neox.modeling_gpt_neox import (
24 | _CONFIG_FOR_DOC,
25 | GPT_NEOX_INPUTS_DOCSTRING,
26 | )
27 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings
28 |
29 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
30 |
31 |
32 | class GPTNeoXForScore(ScoreModelMixin, LlamaPreTrainedModel):
33 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [
34 | r'position_ids',
35 | r'predictions.decoder.bias',
36 | ]
37 |
38 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
39 | super().__init__(config)
40 | self.gpt_neox = GPTNeoXModel(config)
41 |
42 | config.architectures = [self.__class__.__name__]
43 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
44 |
45 | # Initialize weights and apply final processing
46 | self.post_init()
47 |
48 | def get_output_embeddings(self) -> None:
49 | return None
50 |
51 | def set_decoder(self, decoder: PreTrainedModel) -> None:
52 | self.gpt_neox = decoder
53 |
54 | def get_decoder(self) -> PreTrainedModel:
55 | return self.gpt_neox
56 |
57 | @add_start_docstrings_to_model_forward(
58 | GPT_NEOX_INPUTS_DOCSTRING.format('batch_size, sequence_length'),
59 | )
60 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC)
61 | def forward( # pylint: disable=too-many-arguments
62 | self,
63 | input_ids: torch.LongTensor,
64 | attention_mask: torch.Tensor,
65 | position_ids: torch.LongTensor | None = None,
66 | inputs_embeds: torch.FloatTensor | None = None,
67 | head_mask: torch.FloatTensor | None = None,
68 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
69 | use_cache: bool | None = None,
70 | output_attentions: bool | None = None,
71 | output_hidden_states: bool | None = None,
72 | return_dict: bool | None = None,
73 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
74 | """
75 | Args:
76 |
77 | Returns:
78 |
79 | Examples:
80 |
81 | ```python
82 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore
83 | >>> from transformers import LlamaTokenizer
84 |
85 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
86 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
87 |
88 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
89 | >>> inputs = tokenizer(prompt, return_tensors="pt")
90 |
91 | # got score
92 | >>> outputs = model(**inputs)
93 | >>> scores = outputs.scores
94 | >>> scores
95 | tensor([[[0.0000]]])
96 | ```
97 | """
98 | assert attention_mask is not None
99 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
100 |
101 | outputs = self.gpt_neox(
102 | input_ids,
103 | attention_mask=attention_mask,
104 | position_ids=position_ids,
105 | head_mask=head_mask,
106 | inputs_embeds=inputs_embeds,
107 | past_key_values=past_key_values,
108 | use_cache=use_cache,
109 | output_attentions=output_attentions,
110 | output_hidden_states=output_hidden_states,
111 | return_dict=return_dict,
112 | )
113 | hidden_states = outputs[0] # size = (B, L, E)
114 | return self.get_score(
115 | hidden_states,
116 | attention_mask=attention_mask,
117 | return_dict=return_dict,
118 | )
119 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gptj/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.gptj.modeling_gptj import GPTJForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/gptj/modeling_gptj.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | import warnings
19 | from typing import Any, ClassVar
20 |
21 | import torch
22 | from transformers import GPTJModel, GPTJPreTrainedModel, PreTrainedModel
23 | from transformers.configuration_utils import PretrainedConfig
24 | from transformers.models.gptj.modeling_gptj import (
25 | DEPARALLELIZE_DOCSTRING,
26 | GPTJ_INPUTS_DOCSTRING,
27 | GPTJ_START_DOCSTRING,
28 | PARALLELIZE_DOCSTRING,
29 | )
30 | from transformers.utils.doc import add_start_docstrings, add_start_docstrings_to_model_forward
31 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
32 |
33 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
34 |
35 |
36 | @add_start_docstrings(
37 | """
38 | The GPT-J Model transformer with a score head on top.
39 | """,
40 | GPTJ_START_DOCSTRING,
41 | )
42 | class GPTJForScore(ScoreModelMixin, GPTJPreTrainedModel):
43 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [
44 | r'h\.\d+\.attn\.masked_bias',
45 | r'h\.\d+\.attn\.bias',
46 | ]
47 |
48 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
49 | super().__init__(config)
50 | self.transformer = GPTJModel(config)
51 |
52 | config.architectures = [self.__class__.__name__]
53 | self.init_score_head(config, hidden_size=config.n_embd, **kwargs)
54 |
55 | # Initialize weights and apply final processing
56 | self.post_init()
57 |
58 | self.device_map: dict[Any, Any]
59 | self.model_parallel: bool
60 |
61 | @add_start_docstrings(PARALLELIZE_DOCSTRING)
62 | def parallelize(self, device_map: str | None = None) -> None:
63 | warnings.warn(
64 | '`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load'
65 | " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
66 | " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
67 | " 0, 'transformer.h.1': 1, ...}",
68 | FutureWarning,
69 | stacklevel=1,
70 | )
71 | self.device_map = (
72 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
73 | if device_map is None
74 | else device_map
75 | )
76 | assert_device_map(self.device_map, len(self.transformer.h))
77 | self.transformer.parallelize(self.device_map)
78 | self.score_head = self.score_head.to(self.transformer.first_device)
79 | self.model_parallel = True
80 |
81 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
82 | def deparallelize(self) -> None:
83 | warnings.warn(
84 | 'Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.',
85 | FutureWarning,
86 | stacklevel=1,
87 | )
88 | self.transformer.deparallelize()
89 | self.transformer = self.transformer.to('cpu')
90 | self.score_head = self.score_head.to('cpu')
91 | self.model_parallel = False
92 | torch.cuda.empty_cache()
93 |
94 | def get_output_embeddings(self) -> None:
95 | return None
96 |
97 | def set_decoder(self, decoder: PreTrainedModel) -> None:
98 | self.transformer = decoder
99 |
100 | def get_decoder(self) -> PreTrainedModel:
101 | return self.transformer
102 |
103 | @add_start_docstrings_to_model_forward(
104 | GPTJ_INPUTS_DOCSTRING.format('batch_size, sequence_length'),
105 | )
106 | def forward( # pylint: disable=too-many-arguments
107 | self,
108 | input_ids: torch.LongTensor,
109 | attention_mask: torch.FloatTensor,
110 | past_key_values: tuple[tuple[torch.Tensor]] | None = None,
111 | token_type_ids: torch.LongTensor | None = None,
112 | position_ids: torch.LongTensor | None = None,
113 | head_mask: torch.FloatTensor | None = None,
114 | inputs_embeds: torch.FloatTensor | None = None,
115 | use_cache: bool | None = None,
116 | output_attentions: bool | None = None,
117 | output_hidden_states: bool | None = None,
118 | return_dict: bool | None = None,
119 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
120 | """
121 | Args:
122 |
123 | Returns:
124 |
125 | Examples:
126 |
127 | ```python
128 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore
129 | >>> from transformers import LlamaTokenizer
130 |
131 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
132 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
133 |
134 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
135 | >>> inputs = tokenizer(prompt, return_tensors="pt")
136 |
137 | # got score
138 | >>> outputs = model(**inputs)
139 | >>> scores = outputs.scores
140 | >>> scores
141 | tensor([[[0.0000]]])
142 | ```
143 | """
144 | assert attention_mask is not None
145 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
146 |
147 | transformer_outputs = self.transformer(
148 | input_ids,
149 | past_key_values=past_key_values,
150 | attention_mask=attention_mask,
151 | token_type_ids=token_type_ids,
152 | position_ids=position_ids,
153 | head_mask=head_mask,
154 | inputs_embeds=inputs_embeds,
155 | use_cache=use_cache,
156 | output_attentions=output_attentions,
157 | output_hidden_states=output_hidden_states,
158 | return_dict=return_dict,
159 | )
160 | hidden_states = transformer_outputs[0] # size = (B, L, E)
161 |
162 | # Set device for model parallelism
163 | if self.model_parallel:
164 | torch.cuda.set_device(self.transformer.first_device)
165 | hidden_states = hidden_states.to(self.lm_head.weight.device)
166 |
167 | return self.get_score(
168 | hidden_states,
169 | attention_mask=attention_mask,
170 | return_dict=return_dict,
171 | )
172 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.llama.modeling_llama import LlamaModelForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/llama/modeling_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Any, ClassVar
19 |
20 | import torch
21 | import torch.nn as nn
22 | from transformers import LlamaModel, LlamaPreTrainedModel, PreTrainedModel
23 | from transformers.configuration_utils import PretrainedConfig
24 | from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC, LLAMA_INPUTS_DOCSTRING
25 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings
26 |
27 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
28 |
29 |
30 | class LlamaModelForScore(ScoreModelMixin, LlamaPreTrainedModel):
31 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight']
32 |
33 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
34 | super().__init__(config)
35 | self.model = LlamaModel(config)
36 |
37 | config.architectures = [self.__class__.__name__]
38 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
39 |
40 | # Initialize weights and apply final processing
41 | self.post_init()
42 |
43 | def get_input_embeddings(self) -> nn.Embedding:
44 | return self.model.embed_tokens
45 |
46 | def set_input_embeddings(self, value: nn.Embedding) -> None:
47 | self.model.embed_tokens = value
48 |
49 | def get_output_embeddings(self) -> None:
50 | return None
51 |
52 | def set_decoder(self, decoder: PreTrainedModel) -> None:
53 | self.model = decoder
54 |
55 | def get_decoder(self) -> PreTrainedModel:
56 | return self.model
57 |
58 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
59 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC)
60 | def forward( # pylint: disable=too-many-arguments
61 | self,
62 | input_ids: torch.LongTensor,
63 | attention_mask: torch.Tensor,
64 | position_ids: torch.LongTensor | None = None,
65 | past_key_values: list[torch.FloatTensor] | None = None,
66 | inputs_embeds: torch.FloatTensor | None = None,
67 | use_cache: bool | None = None,
68 | output_attentions: bool | None = None,
69 | output_hidden_states: bool | None = None,
70 | return_dict: bool | None = None,
71 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
72 | """
73 | Args:
74 |
75 | Returns:
76 |
77 | Examples:
78 |
79 | ```python
80 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore
81 | >>> from transformers import LlamaTokenizer
82 |
83 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
84 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
85 |
86 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
87 | >>> inputs = tokenizer(prompt, return_tensors="pt")
88 |
89 | # got score
90 | >>> outputs = model(**inputs)
91 | >>> scores = outputs.scores
92 | >>> scores
93 | tensor([[[0.0000]]])
94 | ```
95 | """
96 | assert attention_mask is not None
97 | output_attentions = (
98 | output_attentions if output_attentions is not None else self.config.output_attentions
99 | )
100 | output_hidden_states = (
101 | output_hidden_states
102 | if output_hidden_states is not None
103 | else self.config.output_hidden_states
104 | )
105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
106 |
107 | outputs = self.model(
108 | input_ids=input_ids,
109 | attention_mask=attention_mask,
110 | position_ids=position_ids,
111 | past_key_values=past_key_values,
112 | inputs_embeds=inputs_embeds,
113 | use_cache=use_cache,
114 | output_attentions=output_attentions,
115 | output_hidden_states=output_hidden_states,
116 | return_dict=return_dict,
117 | )
118 | hidden_states = outputs[0] # size = (B, L, E)
119 | return self.get_score(
120 | hidden_states,
121 | attention_mask=attention_mask,
122 | return_dict=return_dict,
123 | )
124 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/mistral/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.mistral.modeling_mistral import MistralModelForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/mistral/modeling_mistral.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Any, ClassVar
19 |
20 | import torch
21 | import torch.nn as nn
22 | from transformers import MistralModel, MistralPreTrainedModel, PreTrainedModel
23 | from transformers.configuration_utils import PretrainedConfig
24 | from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC, MISTRAL_INPUTS_DOCSTRING
25 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings
26 |
27 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
28 |
29 |
30 | class MistralModelForScore(ScoreModelMixin, MistralPreTrainedModel):
31 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight']
32 |
33 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
34 | super().__init__(config)
35 | self.model = MistralModel(config)
36 |
37 | config.architectures = [self.__class__.__name__]
38 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
39 |
40 | # Initialize weights and apply final processing
41 | self.post_init()
42 |
43 | def get_input_embeddings(self) -> nn.Embedding:
44 | return self.model.embed_tokens
45 |
46 | def set_input_embeddings(self, value: nn.Embedding) -> None:
47 | self.model.embed_tokens = value
48 |
49 | def get_output_embeddings(self) -> None:
50 | return None
51 |
52 | def set_decoder(self, decoder: PreTrainedModel) -> None:
53 | self.model = decoder
54 |
55 | def get_decoder(self) -> PreTrainedModel:
56 | return self.model
57 |
58 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
59 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC)
60 | def forward( # pylint: disable=too-many-arguments
61 | self,
62 | input_ids: torch.LongTensor,
63 | attention_mask: torch.Tensor,
64 | position_ids: torch.LongTensor | None = None,
65 | past_key_values: list[torch.FloatTensor] | None = None,
66 | inputs_embeds: torch.FloatTensor | None = None,
67 | use_cache: bool | None = None,
68 | output_attentions: bool | None = None,
69 | output_hidden_states: bool | None = None,
70 | return_dict: bool | None = None,
71 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
72 | """
73 | Args:
74 |
75 | Returns:
76 |
77 | Examples:
78 |
79 | ```python
80 | >>> from mcts_rl.models.mistral.modeling_mistral import MistralModelForScore
81 | >>> from transformers import LlamaTokenizer
82 |
83 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
84 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
85 |
86 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
87 | >>> inputs = tokenizer(prompt, return_tensors="pt")
88 |
89 | # got score
90 | >>> outputs = model(**inputs)
91 | >>> scores = outputs.scores
92 | >>> scores
93 | tensor([[[0.0000]]])
94 | ```
95 | """
96 | assert attention_mask is not None
97 | output_attentions = (
98 | output_attentions if output_attentions is not None else self.config.output_attentions
99 | )
100 | output_hidden_states = (
101 | output_hidden_states
102 | if output_hidden_states is not None
103 | else self.config.output_hidden_states
104 | )
105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
106 |
107 | outputs = self.model(
108 | input_ids=input_ids,
109 | attention_mask=attention_mask,
110 | position_ids=position_ids,
111 | past_key_values=past_key_values,
112 | inputs_embeds=inputs_embeds,
113 | use_cache=use_cache,
114 | output_attentions=output_attentions,
115 | output_hidden_states=output_hidden_states,
116 | return_dict=return_dict,
117 | )
118 | hidden_states = outputs[0] # size = (B, L, E)
119 | return self.get_score(
120 | hidden_states,
121 | attention_mask=attention_mask,
122 | return_dict=return_dict,
123 | )
124 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/open_llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.open_llama.modeling_open_llama import OpenLlamaForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/open_llama/modeling_open_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Any, ClassVar
19 |
20 | import torch
21 | import torch.nn as nn
22 | from transformers import OpenLlamaModel, OpenLlamaPreTrainedModel, PreTrainedModel
23 | from transformers.configuration_utils import PretrainedConfig
24 | from transformers.models.open_llama.modeling_open_llama import (
25 | _CONFIG_FOR_DOC,
26 | OPEN_LLAMA_INPUTS_DOCSTRING,
27 | )
28 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings
29 |
30 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
31 |
32 |
33 | class OpenLlamaForScore(ScoreModelMixin, OpenLlamaPreTrainedModel):
34 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight']
35 |
36 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
37 | super().__init__(config)
38 | self.model = OpenLlamaModel(config)
39 |
40 | config.shared_input_output_embedding = False
41 | config.architectures = [self.__class__.__name__]
42 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs)
43 |
44 | # Initialize weights and apply final processing
45 | self.post_init()
46 |
47 | def get_input_embeddings(self) -> nn.Embedding:
48 | return self.model.embed_tokens
49 |
50 | def set_input_embeddings(self, value: nn.Embedding) -> None:
51 | self.model.embed_tokens = value
52 |
53 | def get_output_embeddings(self) -> None:
54 | return None
55 |
56 | def set_decoder(self, decoder: PreTrainedModel) -> None:
57 | self.model = decoder
58 |
59 | def get_decoder(self) -> PreTrainedModel:
60 | return self.model
61 |
62 | @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
63 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC)
64 | def forward( # pylint: disable=too-many-arguments
65 | self,
66 | input_ids: torch.LongTensor,
67 | attention_mask: torch.Tensor,
68 | position_ids: torch.LongTensor | None = None,
69 | past_key_values: list[torch.FloatTensor] | None = None,
70 | inputs_embeds: torch.FloatTensor | None = None,
71 | use_cache: bool | None = None,
72 | output_attentions: bool | None = None,
73 | output_hidden_states: bool | None = None,
74 | return_dict: bool | None = None,
75 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
76 | """
77 | Args:
78 |
79 | Returns:
80 |
81 | Examples:
82 |
83 | ```python
84 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore
85 | >>> from transformers import LlamaTokenizer
86 |
87 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
88 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
89 |
90 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
91 | >>> inputs = tokenizer(prompt, return_tensors="pt")
92 |
93 | # got score
94 | >>> outputs = model(**inputs)
95 | >>> scores = outputs.scores
96 | >>> scores
97 | tensor([[[0.0000]]])
98 | ```
99 | """
100 | assert attention_mask is not None
101 | output_attentions = (
102 | output_attentions if output_attentions is not None else self.config.output_attentions
103 | )
104 | output_hidden_states = (
105 | output_hidden_states
106 | if output_hidden_states is not None
107 | else self.config.output_hidden_states
108 | )
109 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
110 |
111 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
112 | outputs = self.model(
113 | input_ids=input_ids,
114 | attention_mask=attention_mask,
115 | position_ids=position_ids,
116 | past_key_values=past_key_values,
117 | inputs_embeds=inputs_embeds,
118 | use_cache=use_cache,
119 | output_attentions=output_attentions,
120 | output_hidden_states=output_hidden_states,
121 | return_dict=return_dict,
122 | )
123 | hidden_states = outputs[0] # size = (B, L, E)
124 | return self.get_score(
125 | hidden_states,
126 | attention_mask=attention_mask,
127 | return_dict=return_dict,
128 | )
129 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/opt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from mcts_rl.models.score_model.opt.modeling_opt import OPTForScore
17 |
--------------------------------------------------------------------------------
/mcts_rl/models/score_model/opt/modeling_opt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from __future__ import annotations
17 |
18 | from typing import Any, ClassVar
19 |
20 | import torch
21 | import torch.nn as nn
22 | from transformers import OPTModel, OPTPreTrainedModel, PreTrainedModel
23 | from transformers.configuration_utils import PretrainedConfig
24 | from transformers.models.opt.modeling_opt import _CONFIG_FOR_DOC
25 | from transformers.utils.doc import replace_return_docstrings
26 |
27 | from mcts_rl.models.score_model import ScoreModelMixin, ScoreModelOutput
28 |
29 |
30 | class OPTForScore(ScoreModelMixin, OPTPreTrainedModel):
31 | _keys_to_ignore_on_load_missing: ClassVar[list[str]] = ['lm_head.weight']
32 |
33 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
34 | super().__init__(config)
35 | self.model = OPTModel(config)
36 |
37 | config.architectures = [self.__class__.__name__]
38 | self.init_score_head(config, hidden_size=config.word_embed_proj_dim, **kwargs)
39 |
40 | # Initialize weights and apply final processing
41 | self.post_init()
42 |
43 | def get_input_embeddings(self) -> nn.Embedding:
44 | return self.model.decoder.embed_tokens
45 |
46 | def set_input_embeddings(self, value: nn.Embedding) -> None:
47 | self.model.decoder.embed_tokens = value
48 |
49 | def get_output_embeddings(self) -> None:
50 | return None
51 |
52 | def set_decoder(self, decoder: PreTrainedModel) -> None:
53 | self.model = decoder
54 |
55 | def get_decoder(self) -> PreTrainedModel:
56 | return self.model
57 |
58 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC)
59 | def forward( # pylint: disable=too-many-arguments
60 | self,
61 | input_ids: torch.LongTensor,
62 | attention_mask: torch.Tensor,
63 | head_mask: torch.Tensor | None = None,
64 | past_key_values: list[torch.FloatTensor] | None = None,
65 | inputs_embeds: torch.FloatTensor | None = None,
66 | use_cache: bool | None = None,
67 | output_attentions: bool | None = None,
68 | output_hidden_states: bool | None = None,
69 | return_dict: bool | None = None,
70 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput:
71 | """
72 | Args:
73 |
74 | Returns:
75 |
76 | Examples:
77 |
78 | ```python
79 | >>> from mcts_rl.models.llama.modeling_llama import LlamaModelForScore
80 | >>> from transformers import LlamaTokenizer
81 |
82 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
83 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
84 |
85 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
86 | >>> inputs = tokenizer(prompt, return_tensors="pt")
87 |
88 | # got score
89 | >>> outputs = model(**inputs)
90 | >>> scores = outputs.scores
91 | >>> scores
92 | tensor([[[0.0000]]])
93 | ```
94 | """
95 | assert attention_mask is not None
96 | output_attentions = (
97 | output_attentions if output_attentions is not None else self.config.output_attentions
98 | )
99 | output_hidden_states = (
100 | output_hidden_states
101 | if output_hidden_states is not None
102 | else self.config.output_hidden_states
103 | )
104 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
105 |
106 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
107 | outputs = self.model.decoder(
108 | input_ids=input_ids,
109 | attention_mask=attention_mask,
110 | head_mask=head_mask,
111 | past_key_values=past_key_values,
112 | inputs_embeds=inputs_embeds,
113 | use_cache=use_cache,
114 | output_attentions=output_attentions,
115 | output_hidden_states=output_hidden_states,
116 | return_dict=return_dict,
117 | )
118 | hidden_states = outputs[0] # size = (B, L, E)
119 | return self.get_score(
120 | hidden_states,
121 | attention_mask=attention_mask,
122 | return_dict=return_dict,
123 | )
124 |
--------------------------------------------------------------------------------
/mcts_rl/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Trainer base classes."""
16 |
17 | from mcts_rl.trainers.base import TrainerBase
18 | from mcts_rl.trainers.rl_trainer import RLTrainer
19 | from mcts_rl.trainers.tsrl_trainer import TSRLTrainer
20 | from mcts_rl.trainers.supervised_trainer import SupervisedTrainer
21 |
22 |
23 | __all__ = ['TrainerBase', 'RLTrainer', 'TSRLTrainer', 'SupervisedTrainer']
24 |
--------------------------------------------------------------------------------
/mcts_rl/trainers/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Trainer base class."""
16 |
17 | from __future__ import annotations
18 |
19 | import abc
20 | import argparse
21 | import os
22 | import subprocess
23 | import sys
24 | from datetime import datetime
25 | from typing import Any, ClassVar
26 |
27 | import deepspeed
28 | import torch.distributed as dist
29 | from transformers import CONFIG_NAME, WEIGHTS_NAME, PreTrainedModel, PreTrainedTokenizerBase
30 |
31 | from mcts_rl.logger import Logger
32 | from mcts_rl.utils import is_main_process
33 |
34 |
35 | class TrainerBase(metaclass=abc.ABCMeta):
36 | """Trainer base class.
37 |
38 | Abstract methods:
39 | init_models: Initialize model and tokenizer.
40 | init_datasets: Initialize training and evaluation datasets.
41 | init_engines: Initialize DeepSpeed engines.
42 | train: Train model.
43 | set_train: Set training mode for all models.
44 | """
45 |
46 | TRAINING_TYPE: ClassVar[str]
47 |
48 | tokenizer: PreTrainedTokenizerBase
49 |
50 | args: argparse.Namespace
51 | logger: Logger
52 |
53 | @abc.abstractmethod
54 | def init_models(self) -> None:
55 | """Initialize model and tokenizer."""
56 | raise NotImplementedError
57 |
58 | @abc.abstractmethod
59 | def init_datasets(self) -> None:
60 | """Initialize training and evaluation datasets."""
61 | raise NotImplementedError
62 |
63 | @abc.abstractmethod
64 | def init_engines(self) -> None:
65 | """Initialize DeepSpeed engines."""
66 | raise NotImplementedError
67 |
68 | def init_logger(self) -> None:
69 | """Set logger."""
70 | if self.args.log_type is None:
71 | self.logger = Logger(config=self.args)
72 | return
73 |
74 | time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
75 |
76 | self.args.log_dir = self.args.log_dir or self.args.output_dir
77 | self.args.log_project = self.args.log_project or 'safe-rlhf'
78 | self.args.log_run_name = self.args.log_run_name or f'{self.TRAINING_TYPE}-{time}'
79 |
80 | self.logger = Logger(
81 | log_type=self.args.log_type,
82 | log_dir=self.args.log_dir,
83 | log_project=self.args.log_project,
84 | log_run_name=self.args.log_run_name,
85 | config=self.args,
86 | )
87 |
88 | @abc.abstractmethod
89 | def train(self) -> None:
90 | """Train model."""
91 | raise NotImplementedError
92 |
93 | def eval(self) -> dict[str, Any]:
94 | """Evaluate model."""
95 | return {}
96 |
97 | @abc.abstractmethod
98 | def set_train(self, mode: bool = True) -> None:
99 | """Set training mode for all models."""
100 | raise NotImplementedError
101 |
102 | def set_eval(self) -> None:
103 | """Set model to evaluation mode."""
104 | self.set_train(mode=False)
105 |
106 | def save(
107 | self,
108 | model: deepspeed.DeepSpeedEngine | None = None,
109 | ds_config: dict[str, Any] | None = None,
110 | global_steps: int = -1,
111 | ) -> None:
112 | """Save model and tokenizer in Hugging Face format."""
113 | dist.barrier()
114 |
115 | if model is None:
116 | model = self.model # pylint: disable=no-member
117 | if ds_config is None:
118 | ds_config = self.ds_config # pylint: disable=no-member
119 |
120 | output_dir = self.args.output_dir
121 | if global_steps > 0:
122 | output_dir = os.path.join(output_dir, f'steps{global_steps}')
123 | os.makedirs(output_dir, exist_ok=True)
124 |
125 | self.logger.print(f'Saving model to "{self.args.output_dir}" ...')
126 |
127 | output_config_file = os.path.join(output_dir, CONFIG_NAME)
128 | model_to_save: PreTrainedModel = getattr(model, 'module', model)
129 | if is_main_process():
130 | model_to_save.config.to_json_file(output_config_file)
131 | self.tokenizer.save_pretrained(output_dir)
132 |
133 | if self.args.save_16bit:
134 | self.logger.print('Saving 16-bit model...')
135 | model.save_16bit_model(output_dir)
136 | else:
137 | # Save model checkpoint
138 | if ds_config['zero_optimization']['stage'] >= 2:
139 | self.logger.print('Saving DeepSpeed Checkpoints...')
140 | model.save_checkpoint(output_dir)
141 | self.logger.print('Converting DeepSpeed Checkpoints to Hugging Face format...')
142 | if is_main_process():
143 | subprocess.check_call(
144 | [sys.executable, 'zero_to_fp32.py', '.', WEIGHTS_NAME], # noqa: S603
145 | cwd=output_dir,
146 | )
147 | dist.barrier()
148 | else:
149 | self.logger.print('Saving Hugging Face Checkpoints...')
150 | if is_main_process():
151 | model_to_save.save_pretrained(output_dir, is_main_process=True)
152 |
153 | self.logger.print('Model saved!')
154 |
--------------------------------------------------------------------------------
/mcts_rl/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 PKU-Alignment Team. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Safe-RLHF: Safe Reinforcement Learning with Human Feedback."""
16 |
17 | __version__ = '0.0.1dev0'
18 | __license__ = 'Apache License, Version 2.0'
19 | __author__ = 'PKU-Alignment Team'
20 | __release__ = False
21 |
22 | if not __release__:
23 | import os
24 | import subprocess
25 |
26 | try:
27 | prefix, sep, suffix = (
28 | subprocess.check_output(
29 | ['git', 'describe', '--abbrev=7'], # noqa: S603,S607
30 | cwd=os.path.dirname(os.path.abspath(__file__)),
31 | stderr=subprocess.DEVNULL,
32 | text=True,
33 | )
34 | .strip()
35 | .lstrip('v')
36 | .replace('-', '.dev', 1)
37 | .replace('-', '+', 1)
38 | .partition('.dev')
39 | )
40 | if sep:
41 | version_prefix, dot, version_tail = prefix.rpartition('.')
42 | prefix = f'{version_prefix}{dot}{int(version_tail) + 1}'
43 | __version__ = sep.join((prefix, suffix))
44 | del version_prefix, dot, version_tail
45 | else:
46 | __version__ = prefix
47 | del prefix, sep, suffix
48 | except (OSError, subprocess.CalledProcessError):
49 | pass
50 |
51 | del os, subprocess
52 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch >= 1.13
2 | transformers >= 4.28
3 | datasets
4 | tokenizers >= 0.13.3
5 | accelerate
6 | pyproject-toml
7 | numpy
8 | scipy
9 | sentencepiece
10 | wandb
11 | tensorboard
12 | optree
13 | matplotlib
14 | tqdm
15 | rich
16 | nltk
17 | peft
18 | bert_score
19 | graphviz
--------------------------------------------------------------------------------
/scripts/eval/mctseval_math.sh:
--------------------------------------------------------------------------------
1 | if [ -z "${BASH_VERSION}" ]; then
2 | echo "Please use bash to run this script." >&2
3 | exit 1
4 | fi
5 |
6 | set -x
7 |
8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
11 | export LOGLEVEL="${LOGLEVEL:-WARNING}"
12 |
13 | MODEL_NAME=""
14 | STEP_NUM=
15 | ACTOR_MODEL_NAME_OR_PATH=MCTS-DPO/outputs/checkpoints/arithmetic/${MODEL_NAME}/steps${STEP_NUM}
16 |
17 | OUTPUT_DIR="MCTS-DPO/outputs/eval"
18 | unset HOSTFILE
19 | ZERO_STAGE=2
20 | OFFLOAD="all"
21 |
22 | if [[ -z "${REWARD_CRITIC_MODEL_NAME_OR_PATH+x}" ]]; then
23 | REWARD_CRITIC_MODEL_NAME_OR_PATH="${REWARD_MODEL_NAME_OR_PATH}"
24 | fi
25 |
26 | mkdir -p "${OUTPUT_DIR}"
27 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
28 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
29 | echo '*' >"${OUTPUT_DIR}/.gitignore"
30 | fi
31 |
32 | cp -f "$0" "${OUTPUT_DIR}/script.sh"
33 |
34 | if [[ -z "${WANDB_API_KEY}" ]]; then
35 | export WANDB_MODE="offline"
36 | fi
37 |
38 | MASTER_PORT_START=10000
39 | MASTER_PORT_END=65535
40 | MASTER_PORT="$(
41 | comm -23 \
42 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
43 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
44 | shuf | head -n 1
45 | )"
46 |
47 | DEEPSPEED_ARGS=()
48 | if [[ -n "${HOSTFILE+x}" ]]; then
49 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
50 | fi
51 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")
52 |
53 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
54 |
55 | export WANDB_API_KEY="1396a7d2a29a8e8241dff6e0e6371f2ad61e11e2"
56 | export WANDB_MODE=dryrun
57 |
58 | export NCCL_DEBUG=INFO
59 | export NCCL_DEBUG_SUBSYS=INIT,P2P
60 |
61 | gpu_vis=$1
62 |
63 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \
64 | --module mcts_rl.algorithms.mcts \
65 | --train_datasets GSM8K/train \
66 | --eval_datasets MATH/test \
67 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
68 | --actor_ref_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
69 | --max_length 768 \
70 | --repetition_penalty 1.0 \
71 | --trust_remote_code True \
72 | --epochs 1 \
73 | --update_iters 1 \
74 | --save_interval 128 \
75 | --per_device_prompt_batch_size 1 \
76 | --per_device_train_batch_size 1 \
77 | --per_device_eval_batch_size 1 \
78 | --gradient_accumulation_steps 8 \
79 | --actor_lr 1e-6 \
80 | --actor_weight_decay 0.01 \
81 | --actor_lr_scheduler_type cosine \
82 | --actor_lr_warmup_ratio 0.03 \
83 | --actor_gradient_checkpointing \
84 | --need_eval \
85 | --seed 42 \
86 | --kl_coeff 0.02 \
87 | --clip_range_ratio 0.2 \
88 | --clip_range_score 50.0 \
89 | --clip_range_value 5.0 \
90 | --output_dir "${OUTPUT_DIR}" \
91 | --log_type wandb \
92 | --log_project MCTS-DPO-EVAL \
93 | --zero_stage "${ZERO_STAGE}" \
94 | --offload "${OFFLOAD}" \
95 | --bf16 True \
96 | --tf32 True \
97 | --force_terminating_on_depth_limit \
98 | --max_new_tokens 64 \
99 | --n_iters 5 \
100 | --depth_limit 3 \
101 | --n_init_actions 5 \
102 | --n_actions 3 \
103 | --mcts_temperature 0.0 \
104 | --num_return_sequences 1 \
105 | --temperature 1.0 \
106 | --model_type mistral \
107 | --prediction_file_path MCTS-DPO/outputs/checkpoints/arithmetic/predictions/mistral_${modelname}_math${stp}.jsonl
108 |
--------------------------------------------------------------------------------
/scripts/eval/mctseval_sqa.sh:
--------------------------------------------------------------------------------
1 |
2 | if [ -z "${BASH_VERSION}" ]; then
3 | echo "Please use bash to run this script." >&2
4 | exit 1
5 | fi
6 |
7 | set -x
8 |
9 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
10 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
11 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
12 | export LOGLEVEL="${LOGLEVEL:-WARNING}"
13 |
14 | MODEL_NAME=""
15 | STEP_NUM=
16 | ACTOR_MODEL_NAME_OR_PATH=MCTS-DPO/outputs/checkpoints/sqa/${MODEL_NAME}/steps${STEP_NUM}
17 |
18 | OUTPUT_DIR="MCTS-DPO/outputs/eval"
19 | unset HOSTFILE
20 | ZERO_STAGE=2
21 | OFFLOAD="all"
22 |
23 |
24 | mkdir -p "${OUTPUT_DIR}"
25 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
26 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
27 | echo '*' >"${OUTPUT_DIR}/.gitignore"
28 | fi
29 |
30 | cp -f "$0" "${OUTPUT_DIR}/script.sh"
31 |
32 | if [[ -z "${WANDB_API_KEY}" ]]; then
33 | export WANDB_MODE="offline"
34 | fi
35 |
36 | MASTER_PORT_START=10000
37 | MASTER_PORT_END=65535
38 | MASTER_PORT="$(
39 | comm -23 \
40 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
41 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
42 | shuf | head -n 1
43 | )"
44 |
45 | DEEPSPEED_ARGS=()
46 | if [[ -n "${HOSTFILE+x}" ]]; then
47 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
48 | fi
49 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")
50 |
51 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
52 |
53 | export NCCL_DEBUG=INFO
54 | export NCCL_DEBUG_SUBSYS=INIT,P2P
55 |
56 |
57 | gpu_vis=$1
58 |
59 |
60 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \
61 | --module mcts_rl.algorithms.mcts \
62 | --train_datasets SQA/train \
63 | --eval_datasets MCQ/test \
64 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
65 | --actor_ref_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
66 | --max_length 512 \
67 | --repetition_penalty 1.0 \
68 | --trust_remote_code True \
69 | --epochs 1 \
70 | --update_iters 1 \
71 | --save_interval 128 \
72 | --per_device_prompt_batch_size 1 \
73 | --per_device_train_batch_size 1 \
74 | --per_device_eval_batch_size 1 \
75 | --gradient_accumulation_steps 8 \
76 | --actor_lr 1e-6 \
77 | --actor_weight_decay 0.01 \
78 | --actor_lr_scheduler_type cosine \
79 | --actor_lr_warmup_ratio 0.03 \
80 | --actor_gradient_checkpointing \
81 | --need_eval \
82 | --seed 42 \
83 | --kl_coeff 0.02 \
84 | --clip_range_ratio 0.2 \
85 | --clip_range_score 50.0 \
86 | --clip_range_value 5.0 \
87 | --output_dir "${OUTPUT_DIR}" \
88 | --log_type wandb \
89 | --log_project MCTS-DPO-EVAL \
90 | --zero_stage "${ZERO_STAGE}" \
91 | --offload "${OFFLOAD}" \
92 | --bf16 True \
93 | --tf32 True \
94 | --max_new_tokens 32 \
95 | --depth_limit 3 \
96 | --n_init_actions 4 \
97 | --n_actions 3 \
98 | --n_iters 16 \
99 | --mcts_temperature 0.0 \
100 | --num_return_sequences 1 \
101 | --temperature 1.0 \
102 | --init_temperature 1.0 \
103 | --model_type mistral \
104 | --use_mcq \
105 | --prediction_file_path MCTS-DPO/outputs/checkpoints/sqa/predictions/mistral_${MODEL_NAME}_${STEP_NUM}.jsonl
106 |
--------------------------------------------------------------------------------
/scripts/mcts_csr.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | if [ -z "${BASH_VERSION}" ]; then
4 | echo "Please use bash to run this script." >&2
5 | exit 1
6 | fi
7 |
8 | set -x
9 |
10 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
11 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
12 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
13 | export LOGLEVEL="${LOGLEVEL:-WARNING}"
14 |
15 |
16 | ACTOR_MODEL_NAME_OR_PATH="upaya07/Arithmo2-Mistral-7B"
17 | ACTOR_REF_MODEL_NAME_OR_PATH="upaya07/Arithmo2-Mistral-7B"
18 |
19 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/sqa/cdpo-4x2"
20 | unset HOSTFILE
21 | ZERO_STAGE=3
22 | OFFLOAD="optimizer"
23 |
24 |
25 | mkdir -p "${OUTPUT_DIR}"
26 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
27 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
28 | echo '*' >"${OUTPUT_DIR}/.gitignore"
29 | fi
30 |
31 | cp -f "$0" "${OUTPUT_DIR}/script.sh"
32 |
33 | export WANDB_API_KEY=""
34 | export WANDB_MODE=online
35 | if [[ -z "${WANDB_API_KEY}" ]]; then
36 | export WANDB_MODE="offline"
37 | fi
38 |
39 | MASTER_PORT_START=10000
40 | MASTER_PORT_END=65535
41 | MASTER_PORT="$(
42 | comm -23 \
43 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
44 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
45 | shuf | head -n 1
46 | )"
47 |
48 | DEEPSPEED_ARGS=()
49 | if [[ -n "${HOSTFILE+x}" ]]; then
50 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
51 | fi
52 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")
53 |
54 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
55 |
56 | gpu_vis=$1
57 |
58 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \
59 | --module mcts_rl.algorithms.mcts \
60 | --train_datasets SQA/train \
61 | --model_type mistral \
62 | --save_mcts_data \
63 | --choose_worst \
64 | --use_mcq \
65 | --filter \
66 | --conservative \
67 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
68 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \
69 | --scale_coeff 0.1 \
70 | --max_length 512 \
71 | --temperature 1.0 \
72 | --init_temperature 1.0 \
73 | --num_return_sequences 1 \
74 | --repetition_penalty 1.0 \
75 | --mcts_length_penalty 1.25 \
76 | --trust_remote_code True \
77 | --epochs 1 \
78 | --update_iters 1 \
79 | --save_interval 64 \
80 | --per_device_ptx_batch_size 4 \
81 | --per_device_prompt_batch_size 1 \
82 | --per_device_train_batch_size 1 \
83 | --gradient_accumulation_steps 64 \
84 | --actor_lr 2e-6 \
85 | --actor_weight_decay 0.05 \
86 | --actor_lr_scheduler_type cosine \
87 | --actor_lr_warmup_ratio 0.2 \
88 | --actor_gradient_checkpointing \
89 | --seed 42 \
90 | --kl_coeff 0.02 \
91 | --clip_range_ratio 0.2 \
92 | --clip_range_score 50.0 \
93 | --clip_range_value 5.0 \
94 | --output_dir "${OUTPUT_DIR}" \
95 | --log_type wandb \
96 | --log_project MCTS-IPL-SQA \
97 | --zero_stage "${ZERO_STAGE}" \
98 | --offload "${OFFLOAD}" \
99 | --bf16 True \
100 | --tf32 True \
101 | --max_new_tokens 32 \
102 | --n_iters 64 \
103 | --depth_limit 4 \
104 | --n_init_actions 4 \
105 | --n_actions 2 \
106 | --mcts_temperature 0.0
107 |
--------------------------------------------------------------------------------
/scripts/mcts_mathqa.sh:
--------------------------------------------------------------------------------
1 | if [ -z "${BASH_VERSION}" ]; then
2 | echo "Please use bash to run this script." >&2
3 | exit 1
4 | fi
5 |
6 | set -x
7 |
8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
11 | export LOGLEVEL="${LOGLEVEL:-WARNING}"
12 |
13 | ACTOR_MODEL_NAME_OR_PATH="SFT-Arithmo"
14 | ACTOR_REF_MODEL_NAME_OR_PATH="SFT-Arithmo"
15 |
16 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/arithmetic/cdpo-2x2-gtsft"
17 | unset HOSTFILE
18 | ZERO_STAGE=3
19 | OFFLOAD="optimizer"
20 |
21 | mkdir -p "${OUTPUT_DIR}"
22 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
23 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
24 | echo '*' >"${OUTPUT_DIR}/.gitignore"
25 | fi
26 |
27 | cp -f "$0" "${OUTPUT_DIR}/script.sh"
28 |
29 | export WANDB_API_KEY=""
30 | export WANDB_MODE=online
31 | if [[ -z "${WANDB_API_KEY}" ]]; then
32 | export WANDB_MODE="offline"
33 | fi
34 |
35 | MASTER_PORT_START=10000
36 | MASTER_PORT_END=65535
37 | MASTER_PORT="$(
38 | comm -23 \
39 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
40 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
41 | shuf | head -n 1
42 | )"
43 |
44 | DEEPSPEED_ARGS=()
45 | if [[ -n "${HOSTFILE+x}" ]]; then
46 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
47 | fi
48 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")
49 |
50 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
51 |
52 | gpu_vis=$1
53 |
54 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \
55 | --module mcts_rl.algorithms.mcts \
56 | --train_datasets MathQA/train \
57 | --model_type mistral \
58 | --choose_worst \
59 | --save_mcts_data \
60 | --filter \
61 | --iteration_interval 64 \
62 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
63 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \
64 | --scale_coeff 0.1 \
65 | --max_length 512 \
66 | --temperature 1.0 \
67 | --init_temperature 1.0 \
68 | --mcts_length_penalty 1.25 \
69 | --num_return_sequences 1 \
70 | --repetition_penalty 1.0 \
71 | --trust_remote_code True \
72 | --epochs 1 \
73 | --conservative \
74 | --update_iters 1 \
75 | --save_interval 64 \
76 | --per_device_ptx_batch_size 4 \
77 | --per_device_prompt_batch_size 1 \
78 | --per_device_train_batch_size 1 \
79 | --gradient_accumulation_steps 64 \
80 | --actor_lr 1e-6 \
81 | --actor_weight_decay 0.05 \
82 | --actor_lr_scheduler_type cosine \
83 | --actor_lr_warmup_ratio 0.03 \
84 | --actor_gradient_checkpointing \
85 | --seed 42 \
86 | --kl_coeff 0.02 \
87 | --clip_range_ratio 0.2 \
88 | --clip_range_score 50.0 \
89 | --clip_range_value 5.0 \
90 | --ptx_coeff 0.0 \
91 | --output_dir "${OUTPUT_DIR}" \
92 | --log_type wandb \
93 | --log_project MCTS-IPL-Math \
94 | --zero_stage "${ZERO_STAGE}" \
95 | --offload "${OFFLOAD}" \
96 | --bf16 True \
97 | --tf32 True \
98 | --max_new_tokens 128 \
99 | --n_iters 64 \
100 | --depth_limit 3 \
101 | --n_init_actions 2 \
102 | --n_actions 2 \
103 | --force_terminating_on_depth_limit \
104 | --mcts_temperature 0.0
--------------------------------------------------------------------------------
/scripts/mcts_mathqa2.sh:
--------------------------------------------------------------------------------
1 | if [ -z "${BASH_VERSION}" ]; then
2 | echo "Please use bash to run this script." >&2
3 | exit 1
4 | fi
5 |
6 | set -x
7 |
8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
11 | export LOGLEVEL="${LOGLEVEL:-WARNING}"
12 |
13 | ACTOR_MODEL_NAME_OR_PATH="SFT-Arithmo"
14 | ACTOR_REF_MODEL_NAME_OR_PATH="SFT-Arithmo"
15 |
16 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/arithmetic/cdpo-2x2-nogt"
17 | unset HOSTFILE
18 | ZERO_STAGE=3
19 | OFFLOAD="optimizer"
20 |
21 | mkdir -p "${OUTPUT_DIR}"
22 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
23 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
24 | echo '*' >"${OUTPUT_DIR}/.gitignore"
25 | fi
26 |
27 | cp -f "$0" "${OUTPUT_DIR}/script.sh"
28 |
29 | export WANDB_API_KEY=""
30 | export WANDB_MODE=online
31 | if [[ -z "${WANDB_API_KEY}" ]]; then
32 | export WANDB_MODE="offline"
33 | fi
34 |
35 | MASTER_PORT_START=10000
36 | MASTER_PORT_END=65535
37 | MASTER_PORT="$(
38 | comm -23 \
39 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
40 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
41 | shuf | head -n 1
42 | )"
43 |
44 | DEEPSPEED_ARGS=()
45 | if [[ -n "${HOSTFILE+x}" ]]; then
46 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
47 | fi
48 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")
49 |
50 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
51 |
52 | gpu_vis=$1
53 |
54 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \
55 | --module mcts_rl.algorithms.mcts \
56 | --train_datasets MathQA/train \
57 | --model_type mistral \
58 | --choose_worst \
59 | --save_mcts_data \
60 | --filter \
61 | --not_include_gt \
62 | --iteration_interval 64 \
63 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
64 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \
65 | --scale_coeff 0.1 \
66 | --max_length 512 \
67 | --temperature 1.0 \
68 | --init_temperature 1.0 \
69 | --mcts_length_penalty 1.25 \
70 | --num_return_sequences 1 \
71 | --repetition_penalty 1.0 \
72 | --trust_remote_code True \
73 | --epochs 1 \
74 | --conservative \
75 | --update_iters 1 \
76 | --save_interval 64 \
77 | --per_device_ptx_batch_size 4 \
78 | --per_device_prompt_batch_size 1 \
79 | --per_device_train_batch_size 1 \
80 | --gradient_accumulation_steps 64 \
81 | --actor_lr 1e-6 \
82 | --actor_weight_decay 0.05 \
83 | --actor_lr_scheduler_type cosine \
84 | --actor_lr_warmup_ratio 0.03 \
85 | --actor_gradient_checkpointing \
86 | --seed 42 \
87 | --kl_coeff 0.02 \
88 | --clip_range_ratio 0.2 \
89 | --clip_range_score 50.0 \
90 | --clip_range_value 5.0 \
91 | --ptx_coeff 0.0 \
92 | --output_dir "${OUTPUT_DIR}" \
93 | --log_type wandb \
94 | --log_project MCTS-IPL-Math \
95 | --zero_stage "${ZERO_STAGE}" \
96 | --offload "${OFFLOAD}" \
97 | --bf16 True \
98 | --tf32 True \
99 | --max_new_tokens 128 \
100 | --n_iters 64 \
101 | --depth_limit 3 \
102 | --n_init_actions 2 \
103 | --n_actions 2 \
104 | --force_terminating_on_depth_limit \
105 | --mcts_temperature 0.0
--------------------------------------------------------------------------------
/scripts/mcts_mathqa_llama3.sh:
--------------------------------------------------------------------------------
1 | if [ -z "${BASH_VERSION}" ]; then
2 | echo "Please use bash to run this script." >&2
3 | exit 1
4 | fi
5 |
6 | set -x
7 |
8 | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
9 | ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
10 | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
11 | export LOGLEVEL="${LOGLEVEL:-WARNING}"
12 |
13 | ACTOR_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
14 | ACTOR_REF_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
15 |
16 | OUTPUT_DIR="MCTS-DPO/outputs/checkpoints/arithmetic/llama3-cdpo-2x2-gtsft"
17 | unset HOSTFILE
18 | ZERO_STAGE=3
19 | OFFLOAD="optimizer"
20 |
21 |
22 | mkdir -p "${OUTPUT_DIR}"
23 | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
24 | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
25 | echo '*' >"${OUTPUT_DIR}/.gitignore"
26 | fi
27 |
28 | cp -f "$0" "${OUTPUT_DIR}/script.sh"
29 |
30 | export WANDB_API_KEY=""
31 | export WANDB_MODE=online
32 | if [[ -z "${WANDB_API_KEY}" ]]; then
33 | export WANDB_MODE="offline"
34 | fi
35 |
36 | MASTER_PORT_START=10000
37 | MASTER_PORT_END=65535
38 | MASTER_PORT="$(
39 | comm -23 \
40 | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
41 | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
42 | shuf | head -n 1
43 | )"
44 |
45 | DEEPSPEED_ARGS=()
46 | if [[ -n "${HOSTFILE+x}" ]]; then
47 | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
48 | fi
49 | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")
50 |
51 | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)
52 |
53 | gpu_vis=$1
54 |
55 | deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT \
56 | --module mcts_rl.algorithms.mcts \
57 | --train_datasets MathQA/train \
58 | --model_type llama3 \
59 | --choose_worst \
60 | --save_mcts_data \
61 | --filter \
62 | --iteration_interval 64 \
63 | --actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \
64 | --actor_ref_model_name_or_path "${ACTOR_REF_MODEL_NAME_OR_PATH}" \
65 | --scale_coeff 0.1 \
66 | --max_length 512 \
67 | --temperature 1.0 \
68 | --init_temperature 1.0 \
69 | --mcts_length_penalty 1.25 \
70 | --num_return_sequences 1 \
71 | --repetition_penalty 1.0 \
72 | --trust_remote_code True \
73 | --epochs 1 \
74 | --conservative \
75 | --update_iters 1 \
76 | --save_interval 64 \
77 | --per_device_ptx_batch_size 4 \
78 | --per_device_prompt_batch_size 1 \
79 | --per_device_train_batch_size 1 \
80 | --gradient_accumulation_steps 64 \
81 | --actor_lr 1e-6 \
82 | --actor_weight_decay 0.05 \
83 | --actor_lr_scheduler_type cosine \
84 | --actor_lr_warmup_ratio 0.03 \
85 | --actor_gradient_checkpointing \
86 | --seed 42 \
87 | --kl_coeff 0.02 \
88 | --clip_range_ratio 0.2 \
89 | --clip_range_score 50.0 \
90 | --clip_range_value 5.0 \
91 | --ptx_coeff 0.0 \
92 | --output_dir "${OUTPUT_DIR}" \
93 | --log_type wandb \
94 | --log_project MCTS-IPL-Math \
95 | --zero_stage "${ZERO_STAGE}" \
96 | --offload "${OFFLOAD}" \
97 | --bf16 True \
98 | --tf32 True \
99 | --max_new_tokens 128 \
100 | --n_iters 64 \
101 | --depth_limit 3 \
102 | --n_init_actions 2 \
103 | --n_actions 2 \
104 | --force_terminating_on_depth_limit \
105 | --mcts_temperature 0.0
--------------------------------------------------------------------------------