├── .gitignore ├── LICENSE ├── README.md ├── load_data.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Hao Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Embedding-based LLM Alignment: 2 | ## A Minimalist, Efficient, and Effective Infrastructure of Reward Modeling Research. 3 | Codebase for report _[Reusing Embeddings: Reproducible Reward Model Research in Large Language Model Alignment without GPUs](https://arxiv.org/pdf/2502.04357)_ 4 | 5 | 6 | ----- 7 | ### 🚀 Example Usage 8 | 9 | ```python3 10 | # Specify Task, Embedding Model, Response Generation Model 11 | args.task = 'Harmless' 12 | args.res_gen_model = 'Gemma2b-sft' 13 | args.embed_model = 'Gemma2b' 14 | 15 | # Load Training Data 16 | train_embeddings, train_rewards = load_embd_data(task=args.task, res_gen_model=args.res_gen_model, embed_model=args.embed_model, split='train') 17 | ### train_embeddings.shape = (40000, 10, 2048), 40000 prompts, 10 responses for each prompt, Gemma2b has a 2048-dim embedding space 18 | ### train_rewards.shape = (40000, 10, 1), corresponding reward 19 | 20 | # Load Testing Data 21 | test_embeddings, test_rewards = load_embd_data(task=args.task, res_gen_model=args.res_gen_model, embed_model=args.embed_model, split='test') 22 | ### test_embeddings.shape = (2000, 500, 2048) 23 | ### test_rewards.shape = (2000, 500, 1) 24 | 25 | # Generation of Pairwise Comparisons 26 | train_comparisons, train_labels = pair_annotate(train_embeddings, train_rewards, annotation_quality = 0.1) 27 | # annotation noise can be adjusted through "annotation_quality" 28 | 29 | # Train Embedding-based Reward Model (e.g., use a Bradley-Terry MLP) 30 | reward_model = BT_MLP() 31 | reward_model.fit(train_comparisons, train_labels) 32 | 33 | # Make Predictions with the Reward Model on Testset 34 | rm_predictions = reward_model.predict(test_embeddings) 35 | print(rm_predictions.shape) 36 | ### (2000, 500, 1) 37 | 38 | # Calculate Evaluation Metrics on Testset 39 | bon_500 = calc_bon(rm_predictions, test_rewards, N=500) 40 | spearmanr = calc_spearmanr(rm_predictions, test_rewards) 41 | ``` 42 | 43 | ---- 44 | ### 🔨 Build (TBD) 45 | 46 | ```python3 47 | pip install 48 | ``` 49 | 50 | ---- 51 | ### 📊 Embedding Data Downloading 52 | Here is a Google Drive link for **a single experiment setup**, which is about 10GB. It can be used for a quick start/reproduction: 53 | 54 | [Google Drive Link (10GB)](https://drive.google.com/drive/folders/1Op0B1jc4Zr6t6DFWyLcCulpq67CJOYsU?usp=sharing) 55 | 56 | The full 300GB embedding files can be found at: 57 | 58 | [Google Drive Link (300GB)](https://drive.google.com/drive/folders/1cRiwvZDxlq_5DVHBIIVYjeunse42ALMO?usp=sharing) 59 | 60 | 61 | 62 | --- 63 | ### Demonstrative Use Cases (TBD) 64 | 65 | #### 1. A Quick Implementation of Reward Model Ensemble 66 | [This Repo.](https://github.com/holarissun/RewardModelingBeyondBradleyTerry/blob/c09e0971d360546e41c64879f7cd343ae0f845e6/step5_train_rms.py#L28C24-L28C39) 67 | 68 | #### 2. A Quick Implementation of Active Reward Modeling 69 | [This Repo.](https://github.com/YunyiShen/ARM-FI) 70 | 71 | #### 3. A Quick Implementation of Classification-based Reward Models 72 | [This Repo.](https://github.com/holarissun/RewardModelingBeyondBradleyTerry) 73 | 74 | #### 4. Exciting Future Works! 75 | - (Input) More RM data formats other than (pairwise) preferences? 76 | - (Input) Optimizing the embeddings for discriminative tasks? 77 | - (Objective) Beyond order consistency --- partial order consistency? 78 | 79 | 80 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def load_embd_data(task, res_gen_model, embed_model, split): 4 | raw_embd = np.load(f"../embd/embd_{res_gen_model}_{embed_model}_{task}_{split}.npy") 5 | raw_reward = np.load(f"../embd/reward_{res_gen_model}_{embed_model}_{task}_{split}.npy") 6 | return raw_embd, raw_reward 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | def pair_annotate(embeddings, rewards, annotation_quality=0.1, pair_strategy='random', pairs_per_prompt=10): 5 | """ 6 | Construct pairwise training data (winner vs loser) from reward-annotated embeddings. 7 | 8 | Args: 9 | embeddings: np.ndarray, shape (N, K, D) 10 | Embeddings of responses for each prompt. 11 | rewards: np.ndarray, shape (N, K, 1) 12 | Corresponding reward scores. 13 | annotation_quality: float 14 | Simulated annotation quality. 15 | > 0 : stochastic Bradley-Terry-style labeling. 16 | = -1: perfect annotation (based on true rewards). 17 | pair_strategy: str 18 | Strategy to generate pairs. Supported: 'random' 19 | pairs_per_prompt: int 20 | Number of pairs to sample per prompt (only used in 'random' strategy). 21 | 22 | Returns: 23 | pairwise_embeddings: np.ndarray, shape (M, 2, D) 24 | Embeddings of (winner, loser) pairs. 25 | labels: np.ndarray, shape (M,) 26 | """ 27 | N, K, D = embeddings.shape 28 | positive_sample = [] 29 | negative_sample = [] 30 | labels = [] 31 | 32 | rew_scale_factor = np.std(rewards) + 1e-8 # for BT-model prob normalization 33 | 34 | for i in range(N): # iterate over prompts 35 | if pair_strategy == 'random': 36 | sampled_pairs = set() 37 | num_attempts = 0 38 | while len(sampled_pairs) < pairs_per_prompt and num_attempts < pairs_per_prompt * 5: 39 | j, k = np.random.choice(K, 2, replace=False) 40 | pair_key = tuple(sorted((j, k))) 41 | if pair_key not in sampled_pairs: 42 | sampled_pairs.add(pair_key) 43 | num_attempts += 1 44 | else: 45 | raise NotImplementedError(f"Unknown pair_strategy: {pair_strategy}") 46 | 47 | for j, k in sampled_pairs: 48 | # flip order randomly to avoid position bias 49 | if np.random.rand() < 0.5: 50 | idx1, idx2 = j, k 51 | else: 52 | idx1, idx2 = k, j 53 | 54 | rew1 = rewards[i, idx1, 0] 55 | rew2 = rewards[i, idx2, 0] 56 | emb1 = embeddings[i, idx1] 57 | emb2 = embeddings[i, idx2] 58 | 59 | if annotation_quality < 0: 60 | # perfect annotator: choose the higher reward 61 | if rew1 > rew2: 62 | positive_sample.append(emb1) 63 | negative_sample.append(emb2) 64 | else: 65 | positive_sample.append(emb2) 66 | negative_sample.append(emb1) 67 | else: 68 | # noisy label via BT model 69 | delta_reward = (rew1 - rew2) / rew_scale_factor 70 | prob = 1 / (1 + np.exp(-delta_reward * annotation_quality)) 71 | if np.random.rand() < prob: 72 | positive_sample.append(emb1) 73 | negative_sample.append(emb2) 74 | else: 75 | positive_sample.append(emb2) 76 | negative_sample.append(emb1) 77 | 78 | # Output arrays 79 | positive_sample = np.array(positive_sample) # shape (M, D) 80 | negative_sample = np.array(negative_sample) # shape (M, D) 81 | pairwise_embeddings = np.stack([positive_sample, negative_sample], axis=1) # shape (M, 2, D) 82 | labels = np.ones(pairwise_embeddings.shape[0], dtype=np.int32) # dummy labels 83 | return pairwise_embeddings, labels 84 | --------------------------------------------------------------------------------