├── version.txt ├── requirements.txt ├── src ├── config │ ├── config.yaml │ └── trainer │ │ └── default.yaml ├── loss.py ├── utils.py ├── trainer.py ├── datacollator.py ├── model.py └── dataset.py ├── pyproject.toml ├── README.md ├── .pre-commit-config.yaml ├── .github └── workflows │ └── ci.yml ├── .gitignore └── LICENSE /version.txt: -------------------------------------------------------------------------------- 1 | 0.0.1 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets>=2.11.0 2 | hydra-core>=1.3.2 3 | tokenizers>=0.13.2 4 | torch>=2.0.0 5 | transformers>=4.27.4 6 | wandb>=0.14.0 7 | -------------------------------------------------------------------------------- /src/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - trainer: default 3 | model: "EleutherAI/pythia-160m" 4 | log_dir: "safetyfiles" 5 | log_wandb: true 6 | run_name: neox-160m-3epochs 7 | wandb_entity: "shahules786" 8 | max_length: 512 9 | per_digit_tokens: False 10 | special_tokens: 11 | eos_token: "" 12 | sep_token: "" 13 | pad_token: "" 14 | datasets: 15 | - hf_summary: 16 | split: ["validation","test"] 17 | 18 | - webgpt: 19 | split: "train" 20 | 21 | validation_size: 0.15 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "Reward-Model" 3 | description = "Reward Model training for LLM alignment" 4 | version = "0.0.1" 5 | authors = [ 6 | { name = "Exploding gradients", email = "shahules786@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | dependencies = [ 10 | "datasets>=2.11.0", 11 | "torch>=2.0.0", 12 | "transformers>=4.27.4", 13 | "hydra-core>=1.3.2", 14 | "tokenizers>=0.13.2", 15 | "wandb>=0.14.0", 16 | 17 | ] 18 | 19 | [build-system] 20 | build-backend = "flit_core.buildapi" 21 | requires = ["flit_core >=3.2,<4"] 22 | -------------------------------------------------------------------------------- /src/config/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.TrainingArguments 2 | output_dir: "." 3 | learning_rate: 1e-3 4 | gradient_checkpointing: false 5 | gradient_accumulation_steps: 1 6 | per_device_train_batch_size: 4 7 | per_device_eval_batch_size: 4 8 | adam_beta1: 0.9 9 | adam_beta2: 0.95 10 | adam_epsilon: 1e-12 11 | weight_decay: 0.001 12 | eval_steps: 10 13 | save_steps: 10 14 | num_train_epochs: 3 15 | logging_steps: 10 16 | max_grad_norm: 1.0 17 | save_total_limit: 4 18 | fp16: false 19 | lr_scheduler_type: cosine 20 | warmup_ratio: 0.15 21 | evaluation_strategy: steps 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reward-Model 2 | Reward Model training framework for LLM RLHF. For in-depth understanding of Reward modeling, checkout our [blog](https://explodinggradients.com/) 3 | The word nemesis originally meant the distributor of fortune, neither good nor bad, simply in due proportion to each according to what was deserved. 4 | ### Quick Start 5 | * Inference 6 | ```python 7 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 8 | MODEL = "shahules786/Reward-model-gptneox-410M" 9 | 10 | model = AutoModelForSequenceClassification.from_pretrained(MODEL) 11 | tokenizer = AutoTokenizer.from_pretrained(MODEL) 12 | 13 | ``` 14 | 15 | * Training 16 | ```bash 17 | python src/training.py --config-name 18 | ``` 19 | 20 | 21 | 22 | ## Contributions 23 | * All contributions are welcome. Checkout #issues 24 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 22.8.0 4 | hooks: 5 | - id: black 6 | 7 | # Sort imports 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.10.1 10 | hooks: 11 | - id: isort 12 | args: ["--profile", "black"] 13 | 14 | - repo: https://gitlab.com/pycqa/flake8 15 | rev: 5.0.4 16 | hooks: 17 | - id: flake8 18 | args: ['--ignore=E203,E501,F811,E712,W503'] 19 | exclude: __init__.py 20 | 21 | # Formatting, Whitespace, etc 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: v3.2.0 24 | hooks: 25 | - id: trailing-whitespace 26 | - id: check-added-large-files 27 | args: ['--maxkb=1000'] 28 | - id: check-ast 29 | - id: check-json 30 | - id: check-merge-conflict 31 | - id: check-xml 32 | - id: check-yaml 33 | - id: debug-statements 34 | - id: end-of-file-fixer 35 | - id: requirements-txt-fixer 36 | - id: mixed-line-ending 37 | args: ['--fix=no'] 38 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class RMLoss(nn.Module): 6 | """ """ 7 | 8 | def __init__( 9 | self, 10 | reduction=None, 11 | beta=0.001, 12 | ): 13 | super().__init__() 14 | self.reduction = reduction 15 | self.beta = beta 16 | 17 | def forward( 18 | self, 19 | logits, 20 | k_lens=None, 21 | ): 22 | total_loss = [] 23 | indices = list(zip(k_lens[:-1], k_lens[1:])) 24 | for start, end in indices: 25 | combinations = torch.combinations( 26 | torch.arange(start, end, device=logits.device), 2 27 | ) 28 | positive = logits[combinations[:, 0]] 29 | negative = logits[combinations[:, 1]] 30 | l2 = 0.5 * (positive**2 + negative**2) 31 | loss = ( 32 | -1 * nn.functional.logsigmoid(positive - negative) + self.beta * l2 33 | ).mean() 34 | total_loss.append(loss) 35 | 36 | total_loss = torch.stack(total_loss) 37 | if self.reduction == "mean": 38 | total_loss = total_loss.mean() 39 | return total_loss 40 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Reward-Model 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.10] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 20 | uses: actions/setup-python@v3 21 | env : 22 | ACTIONS_ALLOW_UNSECURE_COMMANDS : true 23 | - name: Cache pip 24 | uses: actions/cache@v1 25 | with: 26 | path: ~/.cache/pip # This path is specific to Ubuntu 27 | # Look to see if there is a cache hit for the corresponding requirements file 28 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 29 | restore-keys: | 30 | ${{ runner.os }}-pip- 31 | ${{ runner.os }}- 32 | # You can test your matrix by printing the current Python version 33 | - name: Display Python version 34 | run: python -c "import sys; print(sys.version)" 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | sudo apt-get install libsndfile1 39 | pip install -r requirements.txt 40 | pip install black 41 | # - name: Install reward-model 42 | # run: | 43 | # pip install -e . 44 | - name: Run black 45 | run: 46 | black --check . 47 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tokenizers import pre_tokenizers 3 | from torch.utils.data import ConcatDataset, random_split 4 | from transformers import AutoTokenizer 5 | 6 | from dataset import AnthropicRLFH, HFSummary, WebGPT 7 | 8 | SPECIAL_TOKENS = {"prompter": "|prompter|", "assistant": "|assistant|"} 9 | generator = torch.Generator().manual_seed(42) 10 | 11 | 12 | def get_tokenizer(config): 13 | tokenizer = AutoTokenizer.from_pretrained(config.model) 14 | 15 | if hasattr(config, "per_digit_tokens") and config.per_digit_tokens: 16 | tokenizer._tokenizer.pre_processor = pre_tokenizers.Digits(True) 17 | 18 | if config.special_tokens: 19 | special_tokens = { 20 | "pad_token": config.special_tokens.pad_token, 21 | "eos_token": config.special_tokens.eos_token, 22 | "sep_token": config.special_tokens.sep_token, 23 | } 24 | tokenizer.add_special_tokens(special_tokens) 25 | 26 | tokenizer.add_special_tokens( 27 | {"additional_special_tokens": list(SPECIAL_TOKENS.values())} 28 | ) 29 | 30 | return tokenizer 31 | 32 | 33 | def get_single_dataset(name, **kwargs): 34 | if name == "hf_summary": 35 | dataset = HFSummary(**kwargs) 36 | elif name == "webgpt": 37 | dataset = WebGPT(**kwargs) 38 | elif name == "AnthropicRLHF": 39 | dataset = AnthropicRLFH(**kwargs) 40 | else: 41 | raise ValueError(f"Invalid dataset name {name}") 42 | 43 | return dataset 44 | 45 | 46 | def prepare_datasets(config): 47 | dataset_list = [] 48 | for dataset in config.datasets: 49 | name = list(dataset.keys())[0] 50 | kwargs = dataset[name] 51 | dataset_list.append(get_single_dataset(name, **kwargs)) 52 | 53 | dataset = ConcatDataset(dataset_list) 54 | train_dataset, valid_dataset = random_split( 55 | dataset, 56 | [1 - config.validation_size, config.validation_size], 57 | generator=generator, 58 | ) 59 | return train_dataset, valid_dataset 60 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | import hydra 5 | import torch 6 | from hydra.utils import instantiate 7 | from omegaconf import DictConfig 8 | from torch import nn 9 | from transformers import Trainer 10 | 11 | from datacollator import RMDataCollator 12 | from loss import RMLoss 13 | from model import GPTNeoXRM 14 | from utils import get_tokenizer, prepare_datasets 15 | 16 | 17 | class RMTrainer(Trainer): 18 | def __init__(self, **kwargs): 19 | super().__init__(**kwargs) 20 | self.loss = RMLoss(reduction="mean") 21 | 22 | def compute_loss(self, model, inputs, return_outputs=False): 23 | k_lens = inputs.pop("k_lens") 24 | logits = model(**inputs).logits 25 | loss = self.loss(logits, k_lens) 26 | return (loss, logits) if return_outputs else loss 27 | 28 | def prediction_step( 29 | self, 30 | model: nn.Module, 31 | inputs: Dict[str, Union[torch.Tensor, Any]], 32 | prediction_loss_only: bool, 33 | ignore_keys: Optional[List[str]] = None, 34 | ): 35 | with torch.no_grad(): 36 | loss, logits = self.compute_loss(model, inputs, return_outputs=True) 37 | 38 | return (loss, logits, None) 39 | 40 | 41 | @hydra.main(version_base=None, config_path="config", config_name="config") 42 | def train(cfg: DictConfig) -> None: 43 | if not os.path.exists(cfg.log_dir): 44 | os.mkdir(cfg.log_dir) 45 | 46 | if not cfg.log_wandb: 47 | os.environ["WANDB_MODE"] = "offline" 48 | 49 | if cfg.log_wandb: 50 | import wandb 51 | 52 | wandb.init( 53 | project="Reward-model", 54 | entity=cfg.wandb_entity, 55 | name=f"{cfg.model}-{cfg.run_name}-rm", 56 | config=cfg, 57 | ) 58 | 59 | model = GPTNeoXRM.from_pretrained(cfg.model) 60 | tokenizer = get_tokenizer(cfg) 61 | 62 | training_args = instantiate( 63 | cfg.trainer, report_to="wandb" if cfg.log_wandb else None 64 | ) 65 | 66 | train_dataset, validation_dataset = prepare_datasets(config=cfg) 67 | collator_fn = RMDataCollator(tokenizer=tokenizer, max_length=cfg.max_length) 68 | # Initialize our Trainer 69 | trainer = RMTrainer( 70 | model=model, 71 | args=training_args, 72 | train_dataset=train_dataset, 73 | eval_dataset=validation_dataset, 74 | data_collator=collator_fn, 75 | ) 76 | 77 | # training 78 | trainer.train() 79 | 80 | trainer.save_model(os.path.join(cfg.log_dir, f"{cfg.model.split('/')[-1]}-model")) 81 | tokenizer.save_pretrained(cfg.log_dir) 82 | 83 | 84 | if __name__ == "__main__": 85 | train() 86 | -------------------------------------------------------------------------------- /src/datacollator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from transformers import PreTrainedTokenizer 5 | 6 | from utils import SPECIAL_TOKENS 7 | 8 | 9 | @dataclass 10 | class RMDataCollator: 11 | tokenizer: PreTrainedTokenizer 12 | max_length: int = 512 13 | 14 | def format_prefix(self, prompts, eos): 15 | prompts = [ 16 | "{}{}{}".format( 17 | SPECIAL_TOKENS["prompter"] 18 | if i % 2 == 0 19 | else SPECIAL_TOKENS["assistant"], 20 | prompt, 21 | eos, 22 | ) 23 | for i, prompt in enumerate(prompts) 24 | ] 25 | return "".join(prompts) 26 | 27 | def format_suffix(self, answer, eos): 28 | return "{}{}{}".format(SPECIAL_TOKENS["assistant"], answer, eos) 29 | 30 | def process_example(self, example): 31 | trunc_len = 0 32 | eos = self.tokenizer.eos_token 33 | prefix, outputs = example 34 | prefix = self.format_prefix(prefix, eos) 35 | outputs = [self.format_suffix(output, eos) for output in outputs] 36 | print(prefix, outputs) 37 | prefix_tokens = self.tokenizer.encode(prefix) 38 | input_ids, attention_masks = [], [] 39 | for output in outputs: 40 | out_tokens = self.tokenizer.encode( 41 | output, 42 | ) 43 | if len(prefix_tokens) + len(out_tokens) > self.max_length: 44 | trunc_len = max( 45 | 0, len(prefix_tokens) + len(out_tokens) - self.max_length 46 | ) 47 | prefix_tokens = prefix_tokens[trunc_len:] 48 | out_tokens = prefix_tokens + out_tokens 49 | out_tokens = out_tokens[: self.max_length] 50 | pad_len = self.max_length - len(out_tokens) 51 | attn_masks = [1] * len(out_tokens) + [0] * pad_len 52 | out_tokens += [self.tokenizer.pad_token_id] * pad_len 53 | input_ids.append(out_tokens) 54 | attention_masks.append(attn_masks) 55 | return input_ids, attention_masks 56 | 57 | def __call__(self, examples): 58 | batch_k_lens = [0] 59 | input_ids, attention_masks = [], [] 60 | for i, example in enumerate(examples): 61 | inp_ids, attn_masks = self.process_example(example) 62 | input_ids.extend(inp_ids) 63 | attention_masks.extend(attn_masks) 64 | batch_k_lens.append(batch_k_lens[i] + len(inp_ids)) 65 | 66 | return { 67 | "input_ids": torch.tensor(input_ids), 68 | "attention_mask": torch.tensor(attention_masks), 69 | "k_lens": batch_k_lens, 70 | } 71 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import nn 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModelForSequenceClassification, 8 | GPTNeoXConfig, 9 | GPTNeoXModel, 10 | GPTNeoXPreTrainedModel, 11 | ) 12 | from transformers.utils import ModelOutput 13 | 14 | 15 | @dataclass 16 | class GPTNeoxRMOuptput(ModelOutput): 17 | """ 18 | Reward Model Output 19 | """ 20 | 21 | logits: torch.FloatTensor = None 22 | 23 | 24 | class GPTNeoXConfigRM(GPTNeoXConfig): 25 | model_type = "rm_gptneox_config" 26 | 27 | def __init__( 28 | self, 29 | pooling="last", 30 | **kwargs, 31 | ): 32 | super().__init__(**kwargs) 33 | self.pooling = pooling 34 | 35 | 36 | class GPTNeoXRM(GPTNeoXPreTrainedModel): 37 | config_class = GPTNeoXConfigRM 38 | """ 39 | Reward Model 40 | """ 41 | 42 | def __init__( 43 | self, 44 | config, 45 | ): 46 | super().__init__(config) 47 | self.gpt_neox = GPTNeoXModel(config) 48 | self.pooling = config.pooling 49 | hidden_size = ( 50 | config.hidden_size if self.pooling != "mean-max" else config.hidden_size * 2 51 | ) 52 | self.out_layer = nn.Linear(hidden_size, 1) 53 | 54 | def forward( 55 | self, 56 | input_ids, 57 | attention_mask, 58 | **kwargs, 59 | ): 60 | return_dict = ( 61 | kwargs.get("return_dict") 62 | if kwargs.get("return_dict") is not None 63 | else self.config.use_return_dict 64 | ) 65 | outputs = self.gpt_neox( 66 | input_ids, 67 | attention_mask, 68 | return_dict=return_dict, 69 | **kwargs, 70 | ) 71 | hidden_states = outputs[0] 72 | if self.pooling == "mean": 73 | if attention_mask is None: 74 | hidden_states = hidden_states.mean(dim=1) 75 | else: 76 | hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum( 77 | dim=1 78 | ) / attention_mask.sum(dim=1).unsqueeze(-1) 79 | elif self.pooling == "last": 80 | if attention_mask is None: 81 | hidden_states = hidden_states[:, -1, :] 82 | else: 83 | last_idx = attention_mask.cumsum(1).argmax(1) 84 | last_idx = last_idx.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1)) 85 | hidden_states = torch.gather(hidden_states, 1, last_idx).squeeze(1) 86 | elif self.pooling == "mean-max": 87 | if attention_mask is None: 88 | mean, max = hidden_states.mean(dim=1), hidden_states.max(dim=1).values 89 | hidden_states = torch.cat([mean, max], 1) 90 | else: 91 | mean = (hidden_states * attention_mask.unsqueeze(-1)).sum( 92 | dim=1 93 | ) / attention_mask.sum(dim=1).unsqueeze(-1) 94 | max = (hidden_states * attention_mask.unsqueeze(-1)).max(dim=1).values 95 | hidden_states = torch.cat([mean, max], 1) 96 | else: 97 | raise ValueError(f"invalid pooling {self.pooling}") 98 | 99 | lm_logits = self.out_layer(hidden_states) 100 | 101 | if not return_dict: 102 | return (lm_logits,) + outputs[1:] 103 | 104 | return GPTNeoxRMOuptput(logits=lm_logits) 105 | 106 | 107 | AutoConfig.register("rm_gptneox_config", GPTNeoXConfigRM) 108 | AutoModelForSequenceClassification.register(GPTNeoXConfigRM, GPTNeoXRM) 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from typing import List, Union 4 | 5 | from datasets import load_dataset 6 | from omegaconf import OmegaConf 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class HFSummary(Dataset): 11 | name = "openai/summarize_from_feedback" 12 | 13 | def __init__(self, split: Union[List[str], str] = "train"): 14 | super().__init__() 15 | if isinstance(split, str): 16 | split = [split] 17 | if isinstance(split, OmegaConf): 18 | self.split = OmegaConf.to_object(split) 19 | else: 20 | self.split = split 21 | dataset = load_dataset(self.name, "axis", split=self.split) 22 | self.data_dict = self.prepare_axis(dataset) 23 | self.postids = list(self.data_dict.keys()) 24 | 25 | def prepare_axis(self, dataset): 26 | data_dict = defaultdict(dict) 27 | for data in dataset: 28 | for item in data: 29 | if item["summary"]["axes"].get("overall") is not None: 30 | postid = item["info"]["id"] 31 | summary = {k: item["summary"][k] for k in ["text", "axes"]} 32 | if postid not in data_dict.keys(): 33 | instruction = "summarize: " + ( 34 | item["info"]["post"] or item["info"]["article"] 35 | ) 36 | data_dict[postid].update( 37 | {"post": instruction, "summaries": [summary]} 38 | ) 39 | else: 40 | data_dict[postid]["summaries"].append(summary) 41 | 42 | return data_dict 43 | 44 | def __len__(self): 45 | return len(self.postids) 46 | 47 | def __getitem__(self, idx): 48 | post, summaries = self.data_dict[self.postids[idx]].values() 49 | summaries = sorted(summaries, key=lambda x: x["axes"]["overall"], reverse=True) 50 | dedup_dict = {item["axes"]["overall"]: item["text"] for item in summaries} 51 | summaries = {key: val for val, key in dedup_dict.items()} 52 | summaries = list(summaries.keys()) 53 | return [post], summaries 54 | 55 | 56 | class WebGPT: 57 | name = "openai/webgpt_comparisons" 58 | 59 | def __init__(self, split: str = "train"): 60 | super().__init__() 61 | self.split = split 62 | dataset = load_dataset(self.name, split=self.split) 63 | self.dataset_dict = defaultdict(dict) 64 | for item in dataset: 65 | post_id = item["question"]["id"] 66 | if post_id not in self.dataset_dict.keys(): 67 | self.dataset_dict[post_id] = { 68 | "full_text": item["question"]["full_text"], 69 | "answers": [], 70 | } 71 | if item["score_0"] > 0: 72 | answers = [item["answer_0"], item["answer_1"]] 73 | elif item["score_0"] < 0: 74 | answers = [item["answer_1"], item["answer_0"]] 75 | else: 76 | answers = [] 77 | answers = [re.sub(r"\[\d+\]", "", answer) for answer in answers] 78 | answers = [ 79 | ".".join([sent.strip() for sent in answer.split(".")]) 80 | for answer in answers 81 | ] 82 | if answers: 83 | self.dataset_dict[post_id]["answers"].extend(answers) 84 | else: 85 | _ = self.dataset_dict.pop(post_id) 86 | 87 | self.post_ids = list(self.dataset_dict.keys()) 88 | 89 | def __len__(self): 90 | return len(self.post_ids) 91 | 92 | def __getitem__(self, idx): 93 | question, answers = self.dataset_dict[self.post_ids[idx]].values() 94 | return [question], answers 95 | 96 | 97 | class AnthropicRLFH(Dataset): 98 | name = "Dahoas/full-hh-rlhf" 99 | 100 | def __init__(self, split: Union[List[str], str] = "train"): 101 | super().__init__() 102 | if isinstance(split, str): 103 | split = [split] 104 | if isinstance(split, OmegaConf): 105 | self.split = OmegaConf.to_object(split) 106 | else: 107 | self.split = split 108 | dataset = load_dataset(self.name, split=self.split) 109 | self.data_dict = defaultdict(dict) 110 | id = 0 111 | for data in dataset: 112 | for item in data: 113 | dialogs = [ 114 | text.replace("\n\n", "").strip() 115 | for text in re.split(r"Human:|Assistant:", item["prompt"]) 116 | ] 117 | dialogs = [text for text in dialogs if text != ""] 118 | self.data_dict[f"prompt{id}"].update( 119 | {"prompt": dialogs, "answers": [item["chosen"], item["rejected"]]} 120 | ) 121 | id += 1 122 | 123 | self.prompt_ids = list(self.data_dict.keys()) 124 | 125 | def __len__( 126 | self, 127 | ): 128 | return len(self.prompt_ids) 129 | 130 | def __getitem__(self, idx): 131 | prompt, answers = self.data_dict.get(self.prompt_ids[idx], {}).values() 132 | return prompt, answers 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------