├── .gitignore ├── README.md ├── data ├── poem_generation │ └── test-00000-of-00001.parquet └── story_generation │ └── test-00000-of-00001.parquet ├── docs └── GEM-2025-03-23.pdf ├── evaluation ├── convert_response_for_if_eval.py ├── evaluation_diversity.py ├── evaluation_gsm8k.py ├── evaluation_gsm8k_voting.py ├── evaluation_reward.py ├── generate_response.py └── utils │ └── gsm8k.py ├── img ├── gem_vs_ce.png └── gem_with_remax.png ├── preprocess_data.py ├── requirements.txt ├── scripts ├── eval │ ├── creative_writing.sh │ ├── gsm8k_eval.sh │ ├── gsm8k_voting_eval.sh │ └── reward_eval.sh ├── llama3.1 │ ├── tokenize_data.sh │ ├── train_ce_ultrafeedback.sh │ └── train_gem_ultrafeedback.sh ├── qwen2.5 │ ├── tokenize_data.sh │ ├── train_ce_numina.sh │ └── train_gem_numina.sh ├── zero2.json └── zero3.json ├── sft_trainer.py ├── sft_trainer_v2.py ├── tests ├── test_gem_loss_triton.py └── test_gem_loss_triton_distributed.py ├── train.py └── utils ├── README.md ├── gem_triton_loss.py └── gem_triton_ops.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Result 10 | result/ 11 | log/ 12 | 13 | # Distribution / packaging 14 | .idea 15 | .Python 16 | .DS_Store 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | applications/DeepSpeed-Chat/data 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 PyTorch Implementation of GEM 🌟 2 | 3 | Welcome to the official PyTorch implementation of **GEM**! 🎉 4 | 5 | GEM was introduced in our [ICLR 2025 paper](https://openreview.net/forum?id=dulz3WVhMR) **"Preserving Diversity in Supervised Fine-tuning of Large Language Models"**. 6 | 7 | > This work was previously titled "Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity" and received the Best Paper Runner-up Award at the NeurIPS 2024 FITML Workshop. 8 | 9 | 10 | 11 | GEM can replace the CE loss during SFT (supervised fine-tuning) or RFT (reinforced fine-tuning) to preserve diversity and mitigate overfitting. 🌍✨ 12 | 13 | For an overview of GEM, please refer to our [presentation slides](docs/GEM-2025-03-23.pdf). 14 | 15 | For more insights on GEM's potential to enhance RL training through improved cold-start strategies, check out our blog post: ["Can Better Cold-Start Strategies Improve RL Training for LLMs?"](https://tangible-polo-203.notion.site/Can-Better-Cold-Start-Strategies-Improve-RL-Training-for-LLMs-17aa0742a51680828616c867ed53bc6b) 16 | 17 | 18 | 19 | ## Quickstart Guide 💻 20 | 21 | ### Setup 🔧 22 | 23 | First, create a new environment and install the required packages: 24 | 25 | ```bash 26 | conda create -n gem python=3.10 27 | conda activate gem 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | Note that the version of packages in `requirements.txt` is used in the paper. You may use a higher version of transformers (>= 4.46.0) that fixes the potential bug of gradient accumulation. 32 | 33 | We also provide a **Triton** implementation of GEM loss in the `utils` folder, which may be faster than the original implementation when training large-scale models. Please refer to the [README](utils/README.md) for more details. You may use this implementation with the following command: 34 | 35 | ```bash 36 | python train.py --loss gem_triton 37 | ``` 38 | 39 | 40 | ### Training 🏋️‍♂️ 41 | 42 | Kickstart your training process using the `UltraFeedback` dataset from HuggingFace. Here's how: 43 | 44 | **Tokenize Data** 45 | 46 | ```bash 47 | bash scripts/tokenize_data.sh 48 | ``` 49 | 50 | **Training** 51 | 52 | ```bash 53 | bash scripts/train_gem_ultrafeedback.sh 54 | ``` 55 | 56 | > **Note:** The `ce_loss` metric in training logs represents the cross-entropy loss calculated on a single machine without accounting for gradient accumulation. This value may differ from the reported `loss` metric. When monitoring training progress, you can use `ce_loss` as a diagnostic indicator to verify proper training behavior—the cross-entropy loss should decrease over time regardless of whether you're using CE or GEM as your primary loss function. 57 | 58 | ### Evaluation 🧪 59 | 60 | Run evaluations for different tasks: 61 | 62 | **GSM8K** 63 | 64 | ```bash 65 | bash scripts/eval/gsm8k_eval.sh 66 | ``` 67 | 68 | **GSM8K (Voting)** 69 | 70 | ```bash 71 | bash scripts/eval/gsm8k_voting_eval.sh 72 | ``` 73 | 74 | **Creative Writing** 75 | 76 | ```bash 77 | bash scripts/eval/creative_writing.sh 78 | ``` 79 | 80 | ## To Do 81 | 82 | - [ ] Add the adaptive mechanism for choosing the hyper-parameter $\beta$. 83 | 84 | ## 📜 Citation 85 | 86 | If you find this repository helpful in your research or projects, please consider citing the GEM paper in your academic work. Your support is much appreciated! 🙌 87 | 88 | 89 | ``` 90 | @inproceedings{li2025preserving, 91 | title={Preserving Diversity in Supervised Fine-Tuning of Large Language Models}, 92 | author={Ziniu Li and Congliang Chen and Tian Xu and Zeyu Qin and Jiancong Xiao and Zhi-Quan Luo and Ruoyu Sun}, 93 | booktitle={The Thirteenth International Conference on Learning Representations}, 94 | year={2025}, 95 | url={https://openreview.net/forum?id=NQEe7B7bSw} 96 | } 97 | ``` 98 | 99 | Our work was previously titled "Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity", available on arXiv. 100 | 101 | ```bibtex 102 | @article{li2024entropic, 103 | title={Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity}, 104 | author={Li, Ziniu and Chen, Congliang and Xu, Tian and Qin, Zeyu and Xiao, Jiancong and Sun, Ruoyu and Luo, Zhi-Quan}, 105 | journal={arXiv preprint arXiv:2408.16673}, 106 | year={2024} 107 | } 108 | ``` 109 | 110 | Ziniu Li would like to acknowledge Zhengyang Tang for his minimalistic and clean implementation of SFT. 111 | -------------------------------------------------------------------------------- /data/poem_generation/test-00000-of-00001.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/data/poem_generation/test-00000-of-00001.parquet -------------------------------------------------------------------------------- /data/story_generation/test-00000-of-00001.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/data/story_generation/test-00000-of-00001.parquet -------------------------------------------------------------------------------- /docs/GEM-2025-03-23.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/docs/GEM-2025-03-23.pdf -------------------------------------------------------------------------------- /evaluation/convert_response_for_if_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pprint import pprint 4 | from tqdm import tqdm 5 | import pandas as pd 6 | 7 | from dataclasses import dataclass, field 8 | from transformers import AutoTokenizer, HfArgumentParser 9 | from datasets import load_dataset, Dataset 10 | 11 | 12 | @dataclass 13 | class Arguments: 14 | response_path: str = field( 15 | default=None, 16 | metadata={"help": "Response path (json file) to convert."}, 17 | ) 18 | tokenizer_path: str = field( 19 | default="meta-llama/Meta-Llama-3-8B-Instruct", 20 | metadata={"help": "Tokenizer path to help clean str."}, 21 | ) 22 | save_path: str = field(default="alpaca_eval_response.json") 23 | 24 | 25 | def main(): 26 | parser = HfArgumentParser((Arguments,)) 27 | (args,) = parser.parse_args_into_dataclasses() 28 | 29 | pprint(args.__dict__) 30 | 31 | old_data = json.load(open(args.response_path, "r")) 32 | 33 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 34 | 35 | dataset = [] 36 | with open("./instruction_following_eval/data/input_data.jsonl") as f: 37 | for line in f.readlines(): 38 | dataset.append(json.loads(line)) 39 | if_eval_dataset = Dataset.from_pandas(pd.DataFrame(dataset)) 40 | 41 | new_data = [] 42 | 43 | for i in tqdm(range(len(old_data))): 44 | prompt = old_data[i]["prompt"] 45 | 46 | prompt_clean = ( 47 | tokenizer.decode( 48 | tokenizer(prompt.replace(tokenizer.bos_token, "")).input_ids, 49 | skip_special_tokens=True, 50 | ) 51 | .replace("user\n\n", "") 52 | .replace("assistant\n\n", "") 53 | ) 54 | prompt_ref = if_eval_dataset[i]["prompt"] 55 | 56 | if prompt_clean.strip()[:10] != prompt_ref.strip()[:10]: 57 | import ipdb 58 | 59 | ipdb.set_trace() 60 | 61 | new_data.append( 62 | { 63 | "id": i, 64 | "prompt": prompt_ref, 65 | "response": ( 66 | old_data[i]["answer"] 67 | if isinstance(old_data[i]["answer"], str) 68 | else old_data[i]["answer"][0] 69 | .replace("<|eot_id|>", "") 70 | .replace(tokenizer.eos_token, "") 71 | .strip() 72 | ), 73 | "generator": old_data[i]["model_name"], 74 | } 75 | ) 76 | os.makedirs( 77 | args.save_path.replace(args.save_path.split("/")[-1], ""), exist_ok=True 78 | ) 79 | 80 | with open(args.save_path, "w") as outfile: 81 | for entry in new_data: 82 | json.dump(entry, outfile) 83 | outfile.write("\n") 84 | print(f"Save response to {args.save_path}") 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /evaluation/evaluation_diversity.py: -------------------------------------------------------------------------------- 1 | ################# 2 | # This code is modified from https://github.com/facebookresearch/rlfh-gen-div 3 | ################# 4 | import os 5 | from dataclasses import dataclass, field 6 | import json 7 | from pprint import pprint 8 | 9 | import torch 10 | import numpy as np 11 | import sentence_transformers 12 | from tqdm import tqdm 13 | 14 | # from sklearn.metrics.pairwise import cosine_similarity 15 | from transformers import set_seed, HfArgumentParser, AutoTokenizer 16 | 17 | from nltk.util import ngrams 18 | from nltk import word_tokenize 19 | from collections import Counter 20 | 21 | import sacrebleu 22 | 23 | @dataclass 24 | class AllArguments: 25 | response_path: str = field( 26 | default="./results/responses", metadata={"help": "Response path (json file)."} 27 | ) 28 | 29 | tokenizer_path: str = field(default=None) 30 | detokenizer_path: str = field(default=None) 31 | 32 | 33 | class SentBertSimilarity: 34 | def __init__(self): 35 | 36 | self.model_name = "bert-large-nli-stsb-mean-tokens" # FIXME - hard coded 37 | self.model = sentence_transformers.SentenceTransformer(self.model_name) 38 | if torch.cuda.is_available(): 39 | self.model.to(torch.device("cuda")) 40 | 41 | # @functools.cache 42 | def embed(self, sentence): 43 | return self.model.encode(sentence) 44 | 45 | # @functools.cache 46 | def sent_bert_cosine_similarity(self, resps_1, resps_2): 47 | embeds_1 = self.model.encode( 48 | resps_1, batch_size=1024, convert_to_tensor=True, show_progress_bar=False 49 | ) 50 | embeds_2 = self.model.encode( 51 | resps_2, batch_size=1024, convert_to_tensor=True, show_progress_bar=False 52 | ) 53 | 54 | if torch.cuda.is_available(): 55 | embeds_1 = embeds_1.to(torch.device("cuda")) 56 | embeds_2 = embeds_2.to(torch.device("cuda")) 57 | 58 | dot_product = (embeds_1 * embeds_2).sum(dim=1) 59 | 60 | # Calculate cosine similarity 61 | cosine_similarity = dot_product / (embeds_1.norm(dim=1) * embeds_2.norm(dim=1)) 62 | 63 | return cosine_similarity.detach().cpu().numpy() 64 | 65 | def __call__(self, resp_a, resp_b): 66 | return self.sent_bert_cosine_similarity(resp_a, resp_b) 67 | 68 | 69 | class SentBertDiversity: 70 | """ 71 | Implements the diversity to similarity reduction specified on section 5 in the paper 72 | (https://arxiv.org/pdf/2004.02990.pdf) 73 | for any similarity metric. 74 | 75 | config: 76 | shared with the original similarity metric. 77 | 78 | usage: 79 | metric = Similarity2DiversityMetric(config, SimilarityMetricClassName) 80 | metric(response_set) 81 | 82 | inheritance guidelines: 83 | implement __init__ only 84 | 85 | inheritance example: 86 | see CosineSimilarity2Diversity 87 | """ 88 | 89 | def __init__(self): 90 | self.similarity_metric = SentBertSimilarity() 91 | 92 | def __call__(self, response_set): 93 | similarity_list = [] 94 | for i in tqdm(range(len(response_set))): 95 | for j in range(i): 96 | similarity_list.append( 97 | self.similarity_metric(response_set[i], response_set[j]) 98 | ) 99 | diversity_score = 1 - np.mean(similarity_list) 100 | return diversity_score 101 | 102 | 103 | class AveragedNgramDiversityMetric: 104 | """ 105 | Calculates the mean values of an n-gram based diversity metric in range n in [n_min, n_max]. 106 | 107 | config: 108 | shared with the original n-gram metric. 109 | n_min(int) > 0 - Specify the lowest n-gram value to be averaged 110 | n_max(int) > 0 - Specify the highest n-gram value to be averaged 111 | 112 | usage: 113 | metric = AveragedNgramDiversityMetric(config, NgramMetricClassName) 114 | metric(response_set) 115 | 116 | inheritance guidelines: 117 | implement __init__ only 118 | 119 | inheritance example: 120 | see AveragedDistinctNgrams 121 | """ 122 | 123 | def __init__(self, n_min, n_max): 124 | # add n field 125 | self.n_min = n_min 126 | self.n_max = n_max 127 | 128 | def __call__(self, response_set): 129 | ngrams_results = [] 130 | num_set = len(response_set) 131 | for i in range(len(response_set[0])): 132 | for n in range(self.n_min, self.n_max + 1): 133 | result = self.calculate_distinct_n( 134 | [response_set[j][i] for j in range(num_set)], n 135 | ) 136 | ngrams_results.append(result) 137 | return np.mean(ngrams_results) 138 | 139 | def calculate_distinct_n(self, responses, n): 140 | all_ngrams = [] 141 | for response in responses: 142 | tokens = word_tokenize(response) 143 | response_ngrams = list(ngrams(tokens, n)) 144 | all_ngrams.extend(response_ngrams) 145 | unique_ngrams = len(set(all_ngrams)) 146 | total_ngrams = len(all_ngrams) 147 | 148 | return unique_ngrams / total_ngrams if total_ngrams > 0 else 0 149 | 150 | 151 | class SelfBLEUMetric: 152 | def __call__(self, response_set): 153 | """Calculate the average Self-BLEU score for a list of texts.""" 154 | bleu_scores = [] 155 | k = len(response_set) 156 | for i in range(len(response_set[0])): 157 | texts = [response_set[j][i] for j in range(k)] 158 | bleu_scores.append(self.calculate_bleu_score(texts)) 159 | 160 | return np.mean(bleu_scores) 161 | 162 | def calculate_bleu_score(self, texts): 163 | bleu_scores = [] 164 | for i in range(len(texts)): 165 | # Treat the current text as the hypothesis 166 | hypothesis = texts[i] 167 | # Treat all other texts as references 168 | references = texts[:i] + texts[i + 1 :] 169 | 170 | if references: # Ensure there are references to compare against 171 | bleu_score = sacrebleu.corpus_bleu([hypothesis], [references]) 172 | bleu_scores.append(bleu_score.score) 173 | 174 | # Compute the average BLEU score 175 | average_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0 176 | return average_bleu 177 | 178 | 179 | def main(): 180 | parser = HfArgumentParser((AllArguments,)) 181 | (args,) = parser.parse_args_into_dataclasses() 182 | pprint(args.__dict__) 183 | 184 | if os.path.exists(args.response_path.replace(".json", "-cleaned.json")): 185 | args.response_path = args.response_path.replace(".json", "-cleaned.json") 186 | 187 | if args.response_path.endswith("-cleaned.json"): 188 | response_set = json.load(open(args.response_path, "r")) 189 | else: 190 | data = json.load(open(args.response_path, "r")) 191 | 192 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 193 | if args.detokenizer_path is not None: 194 | detokenizer = AutoTokenizer.from_pretrained(args.detokenizer_path) 195 | else: 196 | detokenizer = None 197 | 198 | response_set = [] 199 | for i in tqdm(range(len(data))): 200 | n = len(data[i]["answer"]) 201 | if len(response_set) == 0: 202 | response_set = [[] for _ in range(n)] 203 | else: 204 | assert len(response_set) == n 205 | for j in range(n): 206 | x = data[i] 207 | if detokenizer: 208 | prompt_str = ( 209 | detokenizer.decode( 210 | detokenizer.encode(x["prompt"]), skip_special_tokens=True 211 | ) 212 | .replace("user\n\n", "") 213 | .replace("assistant\n\n", "") 214 | ) 215 | else: 216 | prompt_str = x["prompt"] 217 | if detokenizer: 218 | # ans_str = detokenizer.decode( 219 | # detokenizer.encode(data[i]["answer"][j]), skip_special_tokens=True 220 | # ) 221 | ans_str = data[i]["answer"][j].replace("<|eot_id|>", "") 222 | else: 223 | ans_str = data[i]["answer"][j] 224 | chat = [ 225 | { 226 | "role": "user", 227 | "content": prompt_str, 228 | }, 229 | {"role": "assistant", "content": ans_str}, 230 | ] 231 | res = tokenizer.apply_chat_template(chat, tokenize=False) 232 | response_set[j].append(res) 233 | json.dump( 234 | response_set, 235 | open(args.response_path.replace(".json", "-cleaned.json"), "w"), 236 | indent=2, 237 | ) 238 | 239 | response_set = json.load( 240 | open(args.response_path.replace(".json", "-cleaned.json"), "r") 241 | ) 242 | print("Finished Data Preparation.") 243 | 244 | evaluation_results = { 245 | "sentbert_diversity_score": None, 246 | "bleu_diversity_score": None, 247 | "averaged_ngram_diversity_score": None, 248 | } 249 | 250 | print("Calculating N-gram diversity score...") 251 | metric = AveragedNgramDiversityMetric(n_min=1, n_max=3) 252 | diversity_score = metric(response_set) 253 | evaluation_results["averaged_ngram_diversity_score"] = np.round( 254 | diversity_score * 100, 2 255 | ) 256 | print("N-gram diversity score: {}".format(diversity_score)) 257 | 258 | print("Calculating BLEU similarity score...") 259 | metric = SelfBLEUMetric() 260 | similarity_score = metric(response_set) 261 | evaluation_results["bleu_diversity_score"] = np.round(100 - similarity_score, 2) 262 | print("BLEU similarity score: {}".format(100 - similarity_score)) 263 | 264 | print("Calculating Bert diversity score...") 265 | metric = SentBertDiversity() 266 | diversity_score = metric(response_set) 267 | evaluation_results["sentbert_diversity_score"] = np.round(diversity_score * 100, 2) 268 | print("Bert diversity score: {}".format(diversity_score)) 269 | 270 | pprint(evaluation_results) 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /evaluation/evaluation_gsm8k.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import json 5 | from dataclasses import dataclass, field 6 | from pprint import pprint 7 | from tqdm import tqdm 8 | from fraction import Fraction 9 | 10 | import torch 11 | import numpy as np 12 | 13 | from datasets import load_dataset 14 | # from evaluate import load 15 | 16 | sys.path.append( 17 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, "..")) 18 | ) 19 | 20 | from evaluation.utils.gsm8k import extract_answer_number 21 | 22 | from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer 23 | 24 | import vllm 25 | from vllm import SamplingParams 26 | 27 | TEMPLATE = """ 28 | Your task is to answer the question below. Give step by step reasoning before you answer, and when you’re ready to answer, please use the format "The answer is: ..."\nQuestion: {question} 29 | """ 30 | 31 | 32 | @dataclass 33 | class Arguments: 34 | dataset_name_or_path: str = field(default="gms8k") 35 | 36 | # model 37 | model_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat") 38 | tokenizer_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat") 39 | 40 | # generation 41 | do_sample: bool = field(default=False) 42 | 43 | use_vllm: bool = field( 44 | default=False, metadata={"help": "Whether use vLLM for generation."} 45 | ) 46 | vllm_gpu_memory_utilization: float = field( 47 | default=0.9, metadata={"help": "vLLM GPU consumption ratio."} 48 | ) 49 | 50 | seed: int = field( 51 | default=42, metadata={"help": "Random Seed for reproducing results."} 52 | ) 53 | 54 | batch_size: int = field(default=10) 55 | top_k: int = field(default=-1) 56 | top_p: float = field(default=1.0) 57 | temperature: float = field(default=0.0, metadata={"help": "Temperature."}) 58 | max_new_tokens: int = field(default=512, metadata={"help": "Max response length."}) 59 | 60 | # save 61 | remove_old: bool = field( 62 | default=False, metadata={"help": "Whether to remove old file."} 63 | ) 64 | save_path: str = field( 65 | default="evaluation_gsm8k.json", 66 | metadata={"help": "Evaluation results save path."}, 67 | ) 68 | 69 | 70 | def save_prompts_and_answers( 71 | model_name, prompts, labels, answers, evaluations, file_path 72 | ): 73 | assert len(prompts) == len(answers), "Mismatched lengths!" 74 | assert file_path.endswith(".json") 75 | data = [ 76 | { 77 | "id": i, 78 | "model_name": model_name, 79 | "prompt": prompts[i], 80 | "label": labels[i], 81 | "answer": answers[i], 82 | "evaluation": evaluations[i], 83 | } 84 | for i in range(len(prompts)) 85 | ] 86 | if not os.path.exists(file_path): 87 | with open(file_path, "w", encoding="utf-8") as file: 88 | json.dump(data, file, indent=2) 89 | else: 90 | with open(file_path, "r", encoding="utf-8") as file: 91 | data = json.load(file) 92 | 93 | # Determine the next id value 94 | next_id = data[-1]["id"] + 1 if data else 0 95 | 96 | # Create new entries and append them to the data list 97 | new_entries = [ 98 | { 99 | "id": i + next_id, 100 | "model_name": model_name, 101 | "prompt": prompts[i], 102 | "label": labels[i], 103 | "answer": answers[i], 104 | "evaluation": evaluations[i], 105 | } 106 | for i in range(len(prompts)) 107 | ] 108 | data.extend(new_entries) 109 | 110 | with open(file_path, "w", encoding="utf-8") as file: 111 | json.dump(data, file, indent=2) 112 | 113 | 114 | def main(): 115 | parser = HfArgumentParser((Arguments,)) 116 | (args,) = parser.parse_args_into_dataclasses() 117 | pprint(args.__dict__) 118 | 119 | if args.remove_old: 120 | if os.path.exists(args.save_path): 121 | os.remove(args.save_path) 122 | 123 | dataset = load_dataset(args.dataset_name_or_path, "main") 124 | dataset = dataset["test"] 125 | 126 | device = torch.device("cuda") 127 | 128 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) 129 | if "llama-3" in tokenizer.name_or_path.lower(): 130 | tokenizer.pad_token = tokenizer.decode(len(tokenizer) - 1) 131 | tokenizer.pad_token_id = len(tokenizer) - 1 132 | elif "llama-2" in tokenizer.name_or_path.lower(): 133 | tokenizer.pad_token = tokenizer.unk_token 134 | tokenizer.pad_token_id = tokenizer.unk_token_id 135 | # tokenizer.model_max_length = int(8096) 136 | 137 | if args.use_vllm: 138 | model = vllm.LLM( 139 | model=args.model_name_or_path, 140 | tokenizer=args.tokenizer_name_or_path, 141 | tensor_parallel_size=torch.cuda.device_count(), 142 | dtype=torch.bfloat16, 143 | gpu_memory_utilization=args.vllm_gpu_memory_utilization, 144 | seed=args.seed, 145 | ) 146 | args.batch_size = len(dataset) 147 | else: 148 | model = AutoModelForCausalLM.from_pretrained( 149 | args.model_name_or_path, 150 | torch_dtype=torch.bfloat16, 151 | attn_implementation="flash_attention_2", 152 | ) 153 | model.to(device) 154 | model.eval() 155 | 156 | prompt_to_save = [] 157 | ans_to_save = [] 158 | labels_to_save = [] 159 | evaluations_to_save = [] 160 | count = 0 161 | total_acc = 0 162 | total_num = 0 163 | for i in tqdm(range(0, len(dataset), args.batch_size)): 164 | prompt = dataset[i : i + args.batch_size]["question"] 165 | prompt_conv = [ 166 | [{"role": "user", "content": TEMPLATE.format(question=x)}] for x in prompt 167 | ] 168 | labels = [ 169 | x.replace("####", "The answer is:") 170 | for x in dataset[i : i + args.batch_size]["answer"] 171 | ] 172 | prompt_str = tokenizer.apply_chat_template( 173 | prompt_conv, tokenize=False, add_generation_prompt=True 174 | ) 175 | 176 | tokenizer.padding_side = "left" 177 | prompt_token = tokenizer.apply_chat_template( 178 | prompt_conv, 179 | padding="longest", 180 | add_generation_prompt=True, 181 | return_dict=True, 182 | return_tensors="pt", 183 | ).to(device) 184 | 185 | prompt_length = prompt_token.input_ids.size(-1) 186 | 187 | if args.use_vllm: 188 | prompt_token_ids = [ 189 | prompt_token.input_ids[ 190 | j, prompt_token.attention_mask[j].bool() 191 | ].tolist() 192 | for j in range(len(prompt_conv)) 193 | ] 194 | 195 | sampling_params = SamplingParams( 196 | top_k=args.top_k, 197 | top_p=args.top_p, 198 | temperature=args.temperature, 199 | max_tokens=args.max_new_tokens, 200 | ) 201 | with torch.no_grad(): 202 | output_results = model.generate( 203 | prompt_token_ids=prompt_token_ids, sampling_params=sampling_params 204 | ) 205 | ans_str = [] 206 | for j in range(len(output_results)): 207 | ans_str.append(output_results[j].outputs[0].text) 208 | 209 | evaluation_results = [] 210 | for j in range(len(output_results)): 211 | try: 212 | answer = extract_answer_number(ans_str[j]) 213 | except Exception as e: 214 | print("========Error=========") 215 | print(e) 216 | print(ans_str[i]) 217 | print() 218 | answer = None 219 | true_answer = extract_answer_number(labels[j]) 220 | if answer is not None: 221 | evaluation_results.append( 222 | (answer, true_answer, answer == true_answer) 223 | ) 224 | else: 225 | evaluation_results.append((answer, true_answer, None)) 226 | 227 | total_num += len([x for x in evaluation_results if x[-1] is not None]) 228 | total_acc += len( 229 | [x for x in evaluation_results if x is not None and x[-1] is True] 230 | ) 231 | 232 | prompt_to_save.extend(prompt_str) 233 | ans_to_save.extend(ans_str) 234 | labels_to_save.extend(labels) 235 | evaluations_to_save.extend(evaluation_results) 236 | count += 1 237 | 238 | print("===========Prompt=============") 239 | print(prompt_str[0]) 240 | print("===========Label=============") 241 | print(labels[0]) 242 | print("===========Response=============") 243 | print(ans_str[0]) 244 | print("===========Evaluation=============") 245 | print(evaluation_results[0]) 246 | else: 247 | with torch.no_grad(): 248 | outputs = model.generate( 249 | prompt_token.input_ids, 250 | attention_mask=prompt_token.attention_mask, 251 | max_new_tokens=args.max_new_tokens, 252 | pad_token_id=tokenizer.pad_token_id, 253 | do_sample=args.do_sample, 254 | ) 255 | ans_token = outputs[:, prompt_length:] 256 | ans_str = tokenizer.batch_decode(ans_token, skip_special_tokens=True) 257 | 258 | evaluation_results = [] 259 | for j in range(len(output_results)): 260 | try: 261 | answer = extract_answer_number(ans_str[j]) 262 | except Exception as e: 263 | print("========Error=========") 264 | print(e) 265 | print(ans_str[i]) 266 | print() 267 | answer = None 268 | true_answer = extract_answer_number(labels[j]) 269 | if answer is not None: 270 | evaluation_results.append( 271 | (answer, true_answer, answer == true_answer) 272 | ) 273 | else: 274 | evaluation_results.append((answer, true_answer, None)) 275 | 276 | total_num += len([x for x in evaluation_results if x[-1] is not None]) 277 | total_acc += len( 278 | [x for x in evaluation_results if x is not None and x[-1] is True] 279 | ) 280 | 281 | prompt_to_save.extend(prompt_str) 282 | ans_to_save.extend(ans_str) 283 | labels_to_save.extend(labels) 284 | evaluations_to_save.extend(evaluation_results) 285 | count += 1 286 | 287 | print("===========Prompt=============") 288 | print(prompt_str[0]) 289 | print("===========Label=============") 290 | print(labels[0]) 291 | print("===========Response=============") 292 | print(ans_str[0]) 293 | print("===========Evaluation=============") 294 | print(evaluation_results[0]) 295 | 296 | if count % 10 == 0: 297 | save_prompts_and_answers( 298 | args.model_name_or_path, 299 | prompt_to_save, 300 | labels_to_save, 301 | ans_to_save, 302 | evaluations_to_save, 303 | args.save_path, 304 | ) 305 | prompt_to_save.clear() 306 | ans_to_save.clear() 307 | labels_to_save.clear() 308 | evaluations_to_save.clear() 309 | 310 | if len(prompt_to_save) > 0: 311 | save_prompts_and_answers( 312 | args.model_name_or_path, 313 | prompt_to_save, 314 | labels_to_save, 315 | ans_to_save, 316 | evaluations_to_save, 317 | args.save_path, 318 | ) 319 | prompt_to_save.clear() 320 | ans_to_save.clear() 321 | labels_to_save.clear() 322 | evaluations_to_save.clear() 323 | 324 | if total_num > 0: 325 | pprint(args.__dict__) 326 | print( 327 | "Acc over {} valid answers is {:.4f}, over {} all answers is {:.4f}".format( 328 | total_num, total_acc / total_num, len(dataset), total_acc / len(dataset) 329 | ) 330 | ) 331 | 332 | 333 | if __name__ == "__main__": 334 | main() 335 | -------------------------------------------------------------------------------- /evaluation/evaluation_gsm8k_voting.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import json 5 | from dataclasses import dataclass, field 6 | from pprint import pprint 7 | from tqdm import tqdm 8 | from collections import Counter 9 | 10 | import torch 11 | import numpy as np 12 | 13 | from datasets import load_dataset 14 | from evaluate import load 15 | 16 | sys.path.append( 17 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, "..")) 18 | ) 19 | 20 | from evaluation.utils.gsm8k import extract_answer_number 21 | 22 | from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer 23 | 24 | import vllm 25 | from vllm import SamplingParams 26 | 27 | TEMPLATE = """ 28 | Your task is to answer the question below. Give step by step reasoning before you answer, and when you’re ready to answer, please use the format "The answer is: ..."\nQuestion: {question} 29 | """ 30 | 31 | 32 | @dataclass 33 | class Arguments: 34 | dataset_name_or_path: str = field(default="gms8k") 35 | 36 | # model 37 | model_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat") 38 | tokenizer_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat") 39 | dtype: str = field(default="bf16", metadata={"choices": ["fp16", "bf16"]}) 40 | 41 | # generation 42 | do_sample: bool = field(default=False) 43 | temperature: float = field( 44 | default=0.6, 45 | ) 46 | top_k: int = field(default=50) 47 | top_p: float = field(default=1.0) 48 | n: int = field(default=1) 49 | 50 | use_vllm: bool = field( 51 | default=False, metadata={"help": "Whether use vLLM for generation."} 52 | ) 53 | vllm_gpu_memory_utilization: float = field( 54 | default=0.9, metadata={"help": "vLLM GPU consumption ratio."} 55 | ) 56 | 57 | seed: int = field( 58 | default=42, metadata={"help": "Random Seed for reproducing results."} 59 | ) 60 | 61 | batch_size: int = field(default=16) 62 | max_new_tokens: int = field(default=512, metadata={"help": "Max response length."}) 63 | 64 | # save 65 | remove_old: bool = field( 66 | default=False, metadata={"help": "Whether to remove old file."} 67 | ) 68 | save_path: str = field( 69 | default="evaluation_gsm8k.json", 70 | metadata={"help": "Evaluation results save path."}, 71 | ) 72 | 73 | 74 | def save_prompts_and_answers( 75 | model_name, prompts, labels, answers, evaluations, file_path 76 | ): 77 | assert len(prompts) == len(answers), "Mismatched lengths!" 78 | assert file_path.endswith(".json") 79 | data = [ 80 | { 81 | "id": i, 82 | "model_name": model_name, 83 | "prompt": prompts[i], 84 | "label": labels[i], 85 | "answer": answers[i], 86 | "evaluation": evaluations[i], 87 | } 88 | for i in range(len(prompts)) 89 | ] 90 | if not os.path.exists(file_path): 91 | with open(file_path, "w", encoding="utf-8") as file: 92 | json.dump(data, file, indent=2) 93 | else: 94 | with open(file_path, "r", encoding="utf-8") as file: 95 | data = json.load(file) 96 | 97 | # Determine the next id value 98 | next_id = data[-1]["id"] + 1 if data else 0 99 | 100 | # Create new entries and append them to the data list 101 | new_entries = [ 102 | { 103 | "id": i + next_id, 104 | "model_name": model_name, 105 | "prompt": prompts[i], 106 | "label": labels[i], 107 | "answer": answers[i], 108 | "evaluation": evaluations[i], 109 | } 110 | for i in range(len(prompts)) 111 | ] 112 | data.extend(new_entries) 113 | 114 | with open(file_path, "w", encoding="utf-8") as file: 115 | json.dump(data, file, indent=2) 116 | 117 | 118 | def calculate_accuracy_voting(reference, candidates, depths=[1, 4, 8, 16, 32]): 119 | majority_voting_accuracies = [] 120 | best_of_n_accuracies = [] 121 | 122 | for depth in depths: 123 | if depth > len(candidates): 124 | break 125 | else: 126 | # Slice the candidates list to the current depth 127 | current_candidates = candidates[:depth] 128 | count = Counter(current_candidates) 129 | most_common = count.most_common(1)[0][0] # Get the most frequent answer 130 | 131 | # Majority voting accuracy 132 | is_correct_majority = most_common == reference 133 | majority_voting_accuracies.append(is_correct_majority) 134 | 135 | # Best of n accuracy 136 | is_correct_best_of_n = reference in current_candidates 137 | best_of_n_accuracies.append(is_correct_best_of_n) 138 | 139 | return majority_voting_accuracies, best_of_n_accuracies 140 | 141 | 142 | def main(): 143 | parser = HfArgumentParser((Arguments,)) 144 | (args,) = parser.parse_args_into_dataclasses() 145 | pprint(args.__dict__) 146 | 147 | if args.remove_old: 148 | if os.path.exists(args.save_path): 149 | os.remove(args.save_path) 150 | 151 | dataset = load_dataset(args.dataset_name_or_path, "main") 152 | dataset = dataset["test"] 153 | 154 | device = torch.device("cuda") 155 | 156 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) 157 | if "llama-3" in tokenizer.name_or_path.lower(): 158 | tokenizer.pad_token = tokenizer.decode(len(tokenizer) - 1) 159 | tokenizer.pad_token_id = len(tokenizer) - 1 160 | elif "llama-2" in tokenizer.name_or_path.lower(): 161 | tokenizer.pad_token = tokenizer.unk_token 162 | tokenizer.pad_token_id = tokenizer.unk_token_id 163 | # tokenizer.model_max_length = int(8096) 164 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype] 165 | 166 | if args.use_vllm: 167 | model = vllm.LLM( 168 | model=args.model_name_or_path, 169 | tokenizer=args.tokenizer_name_or_path, 170 | tensor_parallel_size=torch.cuda.device_count(), 171 | dtype=dtype, 172 | gpu_memory_utilization=args.vllm_gpu_memory_utilization, 173 | seed=args.seed, 174 | ) 175 | else: 176 | model = AutoModelForCausalLM.from_pretrained( 177 | args.model_name_or_path, 178 | torch_dtype=dtype, 179 | attn_implementation="flash_attention_2", 180 | ) 181 | model.to(device) 182 | model.eval() 183 | 184 | prompt_to_save = [] 185 | ans_to_save = [] 186 | labels_to_save = [] 187 | evaluations_to_save = [] 188 | 189 | count = 0 190 | max_depth = max(1, int(np.log2(args.n))) 191 | majority_voting_all = np.zeros([len(dataset), max_depth], dtype=int) 192 | best_of_n_all = np.zeros([len(dataset), max_depth], dtype=int) 193 | for i in tqdm(range(0, len(dataset), args.batch_size)): 194 | prompt = dataset[i : i + args.batch_size]["question"] 195 | prompt_conv = [ 196 | [{"role": "user", "content": TEMPLATE.format(question=x)}] for x in prompt 197 | ] 198 | labels = [ 199 | # x.replace("####", "Final answer:") 200 | x.replace("####", "The answer is:") 201 | for x in dataset[i : i + args.batch_size]["answer"] 202 | ] 203 | prompt_str = tokenizer.apply_chat_template( 204 | prompt_conv, tokenize=False, add_generation_prompt=True 205 | ) 206 | 207 | tokenizer.padding_side = "left" 208 | prompt_token = tokenizer.apply_chat_template( 209 | prompt_conv, 210 | padding="longest", 211 | add_generation_prompt=True, 212 | return_dict=True, 213 | return_tensors="pt", 214 | ).to(device) 215 | 216 | prompt_length = prompt_token.input_ids.size(-1) 217 | 218 | if args.use_vllm: 219 | prompt_token_ids = [ 220 | prompt_token.input_ids[ 221 | j, prompt_token.attention_mask[j].bool() 222 | ].tolist() 223 | for j in range(len(prompt_conv)) 224 | ] 225 | 226 | sampling_params = SamplingParams( 227 | n=args.n, 228 | temperature=args.temperature, 229 | top_k=args.top_k, 230 | top_p=args.top_p, 231 | max_tokens=args.max_new_tokens, 232 | ) 233 | with torch.no_grad(): 234 | output_results = model.generate( 235 | prompt_token_ids=prompt_token_ids, sampling_params=sampling_params 236 | ) 237 | ans_str = [] 238 | evaluation_results = [] 239 | for j in range(len(output_results)): 240 | final_answers = [] 241 | for k in range(args.n): 242 | try: 243 | answer = extract_answer_number( 244 | output_results[j].outputs[k].text 245 | ) 246 | except Exception as e: 247 | print("========Error=========") 248 | print(e) 249 | print(output_results[j].outputs[k].text) 250 | print() 251 | answer = None 252 | final_answers.append(answer) 253 | 254 | ans_str.append( 255 | [output_results[j].outputs[k].text for k in range(args.n)] 256 | ) 257 | true_answer = extract_answer_number(labels[j]) 258 | majority_evaluation, best_of_n_evaluation = calculate_accuracy_voting( 259 | true_answer, final_answers 260 | ) 261 | majority_voting_all[count] = np.array( 262 | majority_evaluation, dtype=np.int32 263 | ) 264 | best_of_n_all[count] = np.array(best_of_n_evaluation, dtype=np.int32) 265 | count += 1 266 | 267 | evaluation_results.append( 268 | { 269 | "true_answer": true_answer, 270 | "majority_evaluation": majority_evaluation, 271 | "best_of_n_evaluation": best_of_n_evaluation, 272 | } 273 | ) 274 | 275 | prompt_to_save.extend(prompt_str) 276 | ans_to_save.extend(ans_str) 277 | labels_to_save.extend(labels) 278 | evaluations_to_save.extend(evaluation_results) 279 | 280 | pprint("===========Prompt=============") 281 | pprint(prompt_str[0]) 282 | pprint("===========Label=============") 283 | pprint(labels[0]) 284 | pprint("===========Response=============") 285 | pprint(ans_str[0]) 286 | pprint("===========Evaluation=============") 287 | pprint(evaluation_results[0]) 288 | pprint( 289 | "Majority Acc so far: {} Best Acc so far: {}".format( 290 | np.round(np.mean(majority_voting_all[:count], axis=0) * 100, 2), 291 | np.round(np.mean(best_of_n_all[:count], axis=0) * 100, 2), 292 | ) 293 | ) 294 | else: 295 | raise NotImplementedError 296 | 297 | if count % 128 == 0: 298 | save_prompts_and_answers( 299 | args.model_name_or_path, 300 | prompt_to_save, 301 | labels_to_save, 302 | ans_to_save, 303 | evaluations_to_save, 304 | args.save_path, 305 | ) 306 | prompt_to_save.clear() 307 | ans_to_save.clear() 308 | labels_to_save.clear() 309 | evaluations_to_save.clear() 310 | 311 | if len(prompt_to_save) > 0: 312 | save_prompts_and_answers( 313 | args.model_name_or_path, 314 | prompt_to_save, 315 | labels_to_save, 316 | ans_to_save, 317 | evaluations_to_save, 318 | args.save_path, 319 | ) 320 | prompt_to_save.clear() 321 | ans_to_save.clear() 322 | labels_to_save.clear() 323 | evaluations_to_save.clear() 324 | 325 | pprint(args.__dict__) 326 | print( 327 | "==> Majority Acc over the dataset: {} Best Acc over the dataset: {}".format( 328 | np.round(np.mean(majority_voting_all, axis=0) * 100, 2), 329 | np.round(np.mean(best_of_n_all, axis=0) * 100, 2), 330 | ) 331 | ) 332 | 333 | 334 | if __name__ == "__main__": 335 | main() 336 | -------------------------------------------------------------------------------- /evaluation/evaluation_reward.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from pprint import pprint 4 | import json 5 | from types import MethodType 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | import torch 10 | from transformers import AutoTokenizer, pipeline 11 | from transformers import AutoModel, AutoModelForSequenceClassification, HfArgumentParser 12 | 13 | 14 | @dataclass 15 | class Arguments: 16 | model_name_or_path: str = field(default="sfairXC/FsfairX-LLaMA3-RM-v0.1") 17 | tokenizer_path: str = field(default=None) 18 | 19 | detokenizer_path: str = field(default=None) 20 | 21 | data_path: str = field(default=None) 22 | batch_size: int = field(default=1) 23 | max_size: int = field(default=None) 24 | 25 | save_path: str = field(default=None) 26 | 27 | 28 | def forward_value_fn( 29 | self, 30 | input_ids=None, 31 | attention_mask=None, 32 | past_key_values=None, 33 | position_ids=None, 34 | inputs_embeds=None, 35 | return_value_only=False, 36 | prompt_length=0, 37 | use_cache=False, 38 | **kwargs, 39 | ): 40 | transformer_outputs = self.model( 41 | input_ids, 42 | past_key_values=past_key_values, 43 | attention_mask=attention_mask, 44 | inputs_embeds=inputs_embeds, 45 | use_cache=use_cache, 46 | **kwargs, 47 | ) 48 | hidden_states = transformer_outputs[0] 49 | values = self.score(hidden_states).squeeze(-1) 50 | if return_value_only: 51 | return values 52 | else: 53 | if attention_mask is None: 54 | chosen_end_scores = values[:, -1] 55 | else: 56 | last_index = attention_mask.cumsum(dim=1).argmax(dim=1) 57 | chosen_end_scores = values.gather(1, last_index.unsqueeze(1)).squeeze(1) 58 | return { 59 | "values": values, 60 | "chosen_end_scores": chosen_end_scores, 61 | } 62 | 63 | 64 | def calculation_best_of_n(data): 65 | print("Calculating best of n reward ....") 66 | best_n = np.zeros([len(data), 6]) # 1, 2, 4, 8, 16 67 | mean_n = np.zeros([len(data), 6]) # 1, 2, 4, 8, 16 68 | for i in tqdm(range(len(data))): 69 | rewards = data[i]["reward"] 70 | best_n[i][0] = rewards[0] 71 | best_n[i][1] = max(rewards[:2]) 72 | best_n[i][2] = max(rewards[:4]) 73 | best_n[i][3] = max(rewards[:8]) 74 | best_n[i][4] = max(rewards[:16]) 75 | best_n[i][5] = max(rewards[:32]) 76 | 77 | mean_n[i][0] = rewards[0] 78 | mean_n[i][1] = np.mean(rewards[:2]) 79 | mean_n[i][2] = np.mean(rewards[:4]) 80 | mean_n[i][3] = np.mean(rewards[:8]) 81 | mean_n[i][4] = np.mean(rewards[:16]) 82 | mean_n[i][5] = np.mean(rewards[:32]) 83 | best_n = np.mean(best_n, axis=0) 84 | print("Best of n: {}".format(np.round(best_n, 2))) 85 | mean_n = np.mean(mean_n, axis=0) 86 | print("Mean of n: {}".format(np.round(mean_n, 2))) 87 | return best_n, mean_n 88 | 89 | 90 | def main(): 91 | parser = HfArgumentParser((Arguments,)) 92 | (args,) = parser.parse_args_into_dataclasses() 93 | pprint(args.__dict__) 94 | assert args.data_path is not None 95 | assert args.save_path is not None 96 | 97 | device = torch.device("cuda") 98 | 99 | model_class = AutoModelForSequenceClassification 100 | flash_attn = True 101 | model = model_class.from_pretrained( 102 | args.model_name_or_path, 103 | torch_dtype=torch.bfloat16, 104 | attn_implementation="flash_attention_2" if flash_attn else "eager", 105 | trust_remote_code=True, 106 | ) 107 | # model.forward_value = forward_value_fn 108 | model.forward_value = MethodType(forward_value_fn, model) 109 | model.eval() 110 | model.to(device) 111 | 112 | tokenizer = AutoTokenizer.from_pretrained( 113 | args.tokenizer_path or args.model_name_or_path 114 | ) 115 | tokenizer.padding_side = "right" 116 | if args.detokenizer_path is not None: 117 | detokenizer = AutoTokenizer.from_pretrained(args.detokenizer_path) 118 | else: 119 | detokenizer = None 120 | 121 | response_data = json.load(open(args.data_path, "r")) 122 | 123 | if args.max_size: 124 | response_data = response_data[: args.max_size] 125 | if os.path.exists(args.save_path): 126 | response_data = json.load(open(args.save_path, "r")) 127 | calculation_best_of_n(response_data) 128 | return 129 | for start in tqdm(range(0, len(response_data), args.batch_size)): 130 | end = start + args.batch_size 131 | prompts = [] 132 | answers = [] 133 | for x in response_data[start:end]: 134 | if detokenizer: 135 | prompt_str = ( 136 | detokenizer.decode( 137 | detokenizer.encode(x["prompt"]), skip_special_tokens=True 138 | ) 139 | .replace("user\n\n", "") 140 | .replace("assistant\n\n", "") 141 | ) 142 | else: 143 | if "prompt" in x: 144 | prompt_str = x["prompt"] 145 | elif "instruction" in x: 146 | prompt_str = x["instruction"] 147 | else: 148 | raise ValueError(x) 149 | if "answer" in x: 150 | for ans in x["answer"]: 151 | if detokenizer: 152 | ans_str = detokenizer.decode( 153 | detokenizer.encode(ans), skip_special_tokens=True 154 | ) 155 | else: 156 | ans_str = ans 157 | prompts.append(prompt_str) 158 | answers.append(ans_str) 159 | elif "output" in x: 160 | ans_str = x["output"] 161 | prompts.append(prompt_str) 162 | answers.append(ans_str) 163 | else: 164 | raise ValueError(x) 165 | 166 | chat = [] 167 | for i in range(len(prompts)): 168 | chat.append( 169 | [ 170 | {"role": "user", "content": prompts[i]}, 171 | {"role": "assistant", "content": answers[i]}, 172 | ] 173 | ) 174 | inputs = tokenizer.apply_chat_template( 175 | chat, 176 | padding="longest", 177 | add_generation_prompt=True, 178 | return_dict=True, 179 | return_tensors="pt", 180 | ).to(device) 181 | 182 | with torch.no_grad(): 183 | if "FsfairX-LLaMA3-RM-v0.1" in args.model_name_or_path: 184 | outputs = model.forward_value(**inputs)["chosen_end_scores"] 185 | else: 186 | outputs = model(**inputs, use_cahe=False) 187 | 188 | c_start = 0 189 | for x in response_data[start:end]: 190 | if "answer" in x: 191 | x["reward"] = outputs[c_start : c_start + len(x["answer"])].tolist() 192 | c_start += len(x["answer"]) 193 | elif "output" in x: 194 | x["reward"] = outputs[c_start].tolist() 195 | c_start += 1 196 | else: 197 | raise ValueError(x) 198 | 199 | print(chat[0]) 200 | print(outputs[0]) 201 | 202 | if "answer" in x: 203 | calculation_best_of_n(response_data) 204 | 205 | json.dump(response_data, open(args.save_path, "w"), indent=2) 206 | print("saving result to {}".format(args.save_path)) 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /evaluation/generate_response.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pprint import pprint 4 | from dataclasses import dataclass, field 5 | from tqdm import tqdm 6 | import pandas as pd 7 | 8 | import vllm 9 | from vllm import SamplingParams 10 | 11 | import torch 12 | from datasets import load_dataset, load_from_disk, Dataset 13 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, HfArgumentParser 14 | 15 | 16 | @dataclass 17 | class Arguments: 18 | model_name_or_path: str = field( 19 | default="meta-llama/Llama-2-7b-chat", metadata={"help": "Model name or path."} 20 | ) 21 | 22 | tokenizer_path: str = field( 23 | default="meta-llama/Llama-2-7b-chat", metadata={"help": "Tokenizer path."} 24 | ) 25 | 26 | # dataset 27 | dataset_path: str = field( 28 | default="RLHFlow/Orca-distibalel-standard", metadata={"help": "Dataset path."} 29 | ) 30 | split: str = field(default=None, metadata={"help": "Split of the dataset."}) 31 | column_name: str = field( 32 | default=None, metadata={"help": "Column name to extract prompt."} 33 | ) 34 | standard_format: bool = field( 35 | default=None, metadata={"help": "Dataset in the standard format."} 36 | ) 37 | load_from_disk: bool = field( 38 | default=False, metadata={"help": "Whether use the load_from_disk method."} 39 | ) 40 | max_size: int = field( 41 | default=None, metadata={"help": "Max data size for evaluation."} 42 | ) 43 | 44 | use_vllm: bool = field( 45 | default=False, metadata={"help": "Whether use vLLM for generation."} 46 | ) 47 | vllm_gpu_memory_utilization: float = field( 48 | default=0.9, metadata={"help": "vLLM GPU consumption ratio."} 49 | ) 50 | 51 | seed: int = field( 52 | default=42, metadata={"help": "Random Seed for reproducing results."} 53 | ) 54 | 55 | # generation 56 | batch_size: int = field(default=10) 57 | 58 | n: int = field(default=1, metadata={"help": "num of responses for each prompt."}) 59 | do_sample: bool = field( 60 | default=True, metadata={"help": "Do sample for generation."} 61 | ) 62 | top_k: int = field(default=50, metadata={"help": "Top k for generation."}) 63 | top_p: float = field(default=0.9, metadata={"help": "Top p for generation."}) 64 | temperature: float = field( 65 | default=0.6, metadata={"help": "Temperature for generation."} 66 | ) 67 | max_new_tokens: int = field(default=1024, metadata={"help": "Max response length."}) 68 | 69 | # save 70 | remove_old: bool = field( 71 | default=False, metadata={"help": "Whether to remove old file."} 72 | ) 73 | save_path: str = field( 74 | default="evaluation_log_probability.json", 75 | metadata={"help": "Evaluation results save path."}, 76 | ) 77 | 78 | def __post_init__(self): 79 | if self.column_name is None: 80 | if "tatsu-lab/alpaca_eval" in self.dataset_path: 81 | self.column_name = "instruction" 82 | if "HuggingFaceH4/ultrachat_200k" in self.dataset_path: 83 | self.column_name = "prompt" 84 | if "if_eval" in self.dataset_path: 85 | self.column_name = "prompt" 86 | if "poem_generation" in self.dataset_path: 87 | self.column_name = "instruction" 88 | if "story_generation" in self.dataset_path: 89 | self.column_name = "instruction" 90 | 91 | if self.split is None: 92 | if "tatsu-lab/alpaca_eval" in self.dataset_path: 93 | self.split = "eval" 94 | if "HuggingFaceH4/ultrachat_200k" in self.dataset_path: 95 | self.split = "test_sft" 96 | if "if_eval" in self.dataset_path: 97 | self.split = "train" 98 | if "poem_generation" in self.dataset_path: 99 | self.split = "test" 100 | if "story_generation" in self.dataset_path: 101 | self.split = "test" 102 | 103 | if self.standard_format is None: 104 | if "tatsu-lab/alpaca_eval" in self.dataset_path: 105 | self.standard_format = False 106 | if "HuggingFaceH4/ultrachat_200k" in self.dataset_path: 107 | self.standard_format = False 108 | if "if_eval" in self.dataset_path: 109 | self.standard_format = False 110 | if "poem_generation" in self.dataset_path: 111 | self.standard_format = False 112 | if "story_generation" in self.dataset_path: 113 | self.standard_format = False 114 | 115 | 116 | def get_dataset(dataset_name, split="test", from_disk=False): 117 | if from_disk: 118 | dataset = load_from_disk(dataset_name) 119 | else: 120 | if "tatsu-lab/alpaca_eval" in dataset_name: 121 | dataset = load_dataset(dataset_name, "alpaca_eval") 122 | if "if_eval" in dataset_name: 123 | dataset = [] 124 | with open("./data/if_eval_data.jsonl") as f: 125 | for line in f.readlines(): 126 | dataset.append(json.loads(line)) 127 | dataset = Dataset.from_pandas(pd.DataFrame(dataset)) 128 | return dataset 129 | else: 130 | dataset = load_dataset(dataset_name) 131 | if split in dataset: 132 | return dataset[split] 133 | else: 134 | assert "train" in dataset 135 | total_size = len(dataset["train"]) 136 | eval_size = min(1000, int(total_size * 0.1)) 137 | train_size = total_size - eval_size 138 | print( 139 | "There is no {} in the dataset. I set {} samples from the train split.".format( 140 | split, eval_size 141 | ) 142 | ) 143 | return dataset["train"].shuffle(seed=42).select(range(train_size, total_size)) 144 | 145 | 146 | def save_prompts_and_answers(model_name, prompts, answers, file_path): 147 | assert len(prompts) == len(answers), "Mismatched lengths!" 148 | assert file_path.endswith(".json") 149 | data = [ 150 | { 151 | "id": i, 152 | "model_name": model_name, 153 | "prompt": prompts[i], 154 | "answer": answers[i], 155 | } 156 | for i in range(len(prompts)) 157 | ] 158 | if not os.path.exists(file_path): 159 | with open(file_path, "w", encoding="utf-8") as file: 160 | json.dump(data, file, indent=2, ensure_ascii=False) 161 | else: 162 | with open(file_path, "r", encoding="utf-8") as file: 163 | data = json.load(file) 164 | 165 | # Determine the next id value 166 | next_id = data[-1]["id"] + 1 if data else 0 167 | 168 | # Create new entries and append them to the data list 169 | new_entries = [ 170 | { 171 | "id": next_id + i, 172 | "model_name": model_name, 173 | "prompt": prompts[i], 174 | "answer": answers[i], 175 | } 176 | for i in range(len(prompts)) 177 | ] 178 | data.extend(new_entries) 179 | 180 | with open(file_path, "w", encoding="utf-8") as file: 181 | json.dump(data, file, indent=2) 182 | 183 | 184 | def main(): 185 | parser = HfArgumentParser((Arguments,)) 186 | (args,) = parser.parse_args_into_dataclasses() 187 | pprint(args.__dict__) 188 | 189 | training_config = {} 190 | if os.path.exists(os.path.join(args.model_name_or_path, "args.json")) and False: 191 | # f = open(os.path.join(args.model_name_or_path, "args.json"), "r") 192 | # for line in f.readlines()[:-1]: 193 | # line_dict = json.loads(line) 194 | # for key, val in line_dict.items(): 195 | # training_config[key] = val 196 | training_config = json.load( 197 | open(os.path.join(args.model_name_or_path, "args.json"), "r") 198 | ) 199 | if training_config["algo"] == "sft": 200 | key_parameters = { 201 | "algo": "sft", 202 | "model_name_or_path": training_config["model_name_or_path"], 203 | "dataset": training_config["data_path"], 204 | "data_max_size": training_config["max_size"], 205 | "learning_rate": training_config["learning_rate"], 206 | "num_train_epochs": (training_config["num_train_epochs"]), 207 | "max_seq_len": training_config["max_seq_len"], 208 | } 209 | elif training_config["algo"] == "dpo": 210 | key_parameters = { 211 | "algo": "dpo", 212 | "actor_model_name_or_path": training_config["actor_model_name_or_path"], 213 | "max_entropy": training_config["max_entropy"], 214 | "beta": training_config["beta"], 215 | "tau": training_config["tau"] if "tau" in training_config else None, 216 | "gamma": training_config["gamma"], 217 | "alpha": training_config["alpha"], 218 | "dataset": training_config["data_path"], 219 | "learning_rate": training_config["actor_learning_rate"], 220 | "enable_ema": ( 221 | training_config["enable_ema"] 222 | if "enable_ema" in training_config 223 | else None 224 | ), 225 | "ema_coeff": ( 226 | training_config["ema_coeff"] 227 | if "ema_coeff" in training_config 228 | else None 229 | ), 230 | } 231 | print("===========Your Training Key Parameters===============") 232 | pprint(key_parameters) 233 | print("Your save path: {}".format(args.save_path)) 234 | print() 235 | 236 | # if input("Is the save path correct (yes or no)?:\n") != "yes": 237 | # assert 0 238 | set_seed(args.seed) 239 | if os.path.exists(args.save_path): 240 | if args.remove_old: 241 | # if ( 242 | # input( 243 | # "The given save path exists a file. Do you continue (yes ot no)?:\n" 244 | # ) 245 | # != "yes" 246 | # ): 247 | # assert 0 248 | os.remove(args.save_path) 249 | else: 250 | print("{} exists. Exit!".format(args.save_path)) 251 | return 252 | 253 | dataset = get_dataset(args.dataset_path, args.split, args.load_from_disk) 254 | if args.max_size: 255 | dataset = dataset.select(range(0, min(len(dataset), args.max_size))) 256 | 257 | device = torch.device("cuda") 258 | 259 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 260 | if "llama-3" in tokenizer.name_or_path.lower(): 261 | tokenizer.pad_token = tokenizer.decode(len(tokenizer) - 1) 262 | tokenizer.pad_token_id = len(tokenizer) - 1 263 | elif "llama-2" in tokenizer.name_or_path.lower(): 264 | tokenizer.pad_token = tokenizer.unk_token 265 | tokenizer.pad_token_id = tokenizer.unk_token_id 266 | 267 | if args.use_vllm: 268 | model = vllm.LLM( 269 | model=args.model_name_or_path, 270 | tokenizer=args.tokenizer_path, 271 | tensor_parallel_size=torch.cuda.device_count(), 272 | dtype=torch.bfloat16, 273 | gpu_memory_utilization=args.vllm_gpu_memory_utilization, 274 | seed=args.seed, 275 | swap_space=16, 276 | ) 277 | args.batch_size = len(dataset) 278 | else: 279 | model = AutoModelForCausalLM.from_pretrained( 280 | args.model_name_or_path, 281 | torch_dtype=torch.bfloat16, 282 | attn_implementation="flash_attention_2", 283 | ) 284 | model.to(device) 285 | model.eval() 286 | 287 | # eos_token_id = [tokenizer.eos_token_id] 288 | # if "llama-3" in tokenizer.name_or_path.lower(): 289 | # eos_token_id.append(tokenizer("<|eot_id|>").input_ids[-1]) 290 | 291 | prompt_to_save = [] 292 | ans_to_save = [] 293 | count = 0 294 | for i in tqdm(range(0, len(dataset), args.batch_size)): 295 | if args.standard_format: 296 | chosen_conv = dataset[i : i + args.batch_size][args.column_name] 297 | prompt_conv = [x[:-1] for x in chosen_conv] 298 | 299 | prompt_str = tokenizer.apply_chat_template( 300 | prompt_conv, tokenize=False, add_generation_prompt=True 301 | ) 302 | else: 303 | prompt = dataset[i : i + args.batch_size][args.column_name] 304 | prompt_conv = [[{"role": "user", "content": x}] for x in prompt] 305 | 306 | prompt_str = tokenizer.apply_chat_template( 307 | prompt_conv, tokenize=False, add_generation_prompt=True 308 | ) 309 | 310 | tokenizer.padding_side = "left" 311 | prompt_token = tokenizer.apply_chat_template( 312 | prompt_conv, 313 | padding="longest", 314 | add_generation_prompt=True, 315 | return_dict=True, 316 | return_tensors="pt", 317 | ).to(device) 318 | 319 | prompt_length = prompt_token.input_ids.size(-1) 320 | 321 | if args.use_vllm: 322 | prompt_token_ids = [ 323 | prompt_token.input_ids[ 324 | j, prompt_token.attention_mask[j].bool() 325 | ].tolist() 326 | for j in range(len(prompt_conv)) 327 | ] 328 | 329 | sampling_params = SamplingParams( 330 | n=args.n, 331 | temperature=args.temperature, 332 | top_p=args.top_p, 333 | top_k=args.top_k, 334 | # stop_token_ids=[ 335 | # tokenizer.eos_token_id, 336 | # tokenizer("<|eot_id|>").input_ids[-1], 337 | # ], 338 | max_tokens=args.max_new_tokens, 339 | ) 340 | with torch.no_grad(): 341 | output_results = model.generate( 342 | prompt_token_ids=prompt_token_ids, sampling_params=sampling_params 343 | ) 344 | ans_str = [] 345 | for j in range(len(output_results)): 346 | # if not output_results[j].outputs[0].finish_reason == "stop": 347 | # print(output_results[j].outputs[0].finish_reason) 348 | ans_str.append( 349 | [output_results[j].outputs[k].text for k in range(args.n)] 350 | ) 351 | # ans_str = [x.replace("<|eot_id|>", "") for x in ans_str] 352 | 353 | prompt_to_save.extend(prompt_str) 354 | ans_to_save.extend(ans_str) 355 | else: 356 | with torch.no_grad(): 357 | outputs = model.generate( 358 | prompt_token.input_ids, 359 | attention_mask=prompt_token.attention_mask, 360 | top_k=args.top_k, 361 | top_p=args.top_p, 362 | temperature=args.temperature, 363 | max_new_tokens=args.max_new_tokens, 364 | pad_token_id=tokenizer.pad_token_id, 365 | # eos_token_id=eos_token_id, 366 | do_sample=args.do_sample, 367 | ) 368 | ans_token = outputs[:, prompt_length:] 369 | ans_str = tokenizer.batch_decode(ans_token, skip_special_tokens=True) 370 | # ans_str = [x.replace("<|eot_id|>", "") for x in ans_str] 371 | 372 | prompt_to_save.extend(prompt_str) 373 | ans_to_save.extend(ans_str) 374 | count += 1 375 | 376 | print(prompt_str[0]) 377 | print(ans_str[0]) 378 | 379 | if count % 10 == 0: 380 | save_prompts_and_answers( 381 | args.model_name_or_path, 382 | prompt_to_save, 383 | ans_to_save, 384 | args.save_path, 385 | ) 386 | prompt_to_save.clear() 387 | ans_to_save.clear() 388 | 389 | if len(prompt_to_save) > 0: 390 | save_prompts_and_answers( 391 | args.model_name_or_path, 392 | prompt_to_save, 393 | ans_to_save, 394 | args.save_path, 395 | ) 396 | prompt_to_save.clear() 397 | ans_to_save.clear() 398 | 399 | 400 | if __name__ == "__main__": 401 | main() 402 | -------------------------------------------------------------------------------- /evaluation/utils/gsm8k.py: -------------------------------------------------------------------------------- 1 | import re 2 | from fraction import Fraction 3 | import numpy as np 4 | 5 | 6 | def is_number(s): 7 | try: 8 | float(s) 9 | return True 10 | except ValueError: 11 | pass 12 | try: 13 | import unicodedata 14 | 15 | unicodedata.numeric(s) 16 | return True 17 | except (TypeError, ValueError): 18 | pass 19 | return False 20 | 21 | 22 | def extract_answer_number(completion): 23 | completion_new = completion.replace("the answer is", "The answer is") 24 | text = completion_new.split("The answer is") 25 | if len(text) > 1: 26 | extract_ans = text[-1].strip() 27 | match = re.search(r"[\-+]?\d*[\.,/]?\d+", extract_ans) 28 | if match: 29 | if "/" in match.group(): 30 | denominator = match.group().split("/")[1] 31 | numerator = match.group().split("/")[0] 32 | if is_number(denominator) == True and is_number(numerator) == True: 33 | if denominator == "0": 34 | return round(float(numerator.replace(",", ""))) 35 | else: 36 | frac = Fraction(match.group().replace(",", "")) 37 | num_numerator = frac.numerator 38 | num_denominator = frac.denominator 39 | return round(float(num_numerator / num_denominator)) 40 | else: 41 | return None 42 | else: 43 | if float(match.group().replace(",", "")) == float("inf"): 44 | return None 45 | return round(float(match.group().replace(",", ""))) 46 | else: 47 | return None 48 | else: 49 | return None 50 | 51 | def evaluation_answers( 52 | true_answers, predicted_answers, eos_tokens=["<|eot_id|>"], print_results=False 53 | ): 54 | assert len(true_answers) == len( 55 | predicted_answers 56 | ), "true answers (size={}) != predicted answers (size={})".format( 57 | len(true_answers), len(predicted_answers) 58 | ) 59 | predictions = [] 60 | skipped_indices = [] 61 | for i in range(len(predicted_answers)): 62 | final_answer = extract_answer_number(predicted_answers[i]) 63 | if final_answer is not None: 64 | predictions.append(final_answer) 65 | else: 66 | skipped_indices.append(i) 67 | 68 | references = [] 69 | for i in range(len(true_answers)): 70 | if i not in skipped_indices: 71 | final_answer = extract_answer_number(true_answers[i]) 72 | references.append(final_answer) 73 | 74 | assert len(predictions) == len( 75 | references 76 | ), "predictions (size={}) != references (size={})".format( 77 | len(predictions), len(references) 78 | ) 79 | 80 | evaluations = [] 81 | c = 0 82 | for i in range(len(true_answers)): 83 | if i in skipped_indices: 84 | evaluations.append((None, None, None)) 85 | else: 86 | evaluations.append( 87 | (references[c], predictions[c], references[c] == predictions[c]) 88 | ) 89 | c += 1 90 | if print_results: 91 | print( 92 | "There are {}/{} matched results from the given strings. Acc: {:.4f}".format( 93 | len(predictions), 94 | len(true_answers), 95 | np.mean([x[-1] for x in evaluations if x[-1] is not None]), 96 | ) 97 | ) 98 | return evaluations 99 | -------------------------------------------------------------------------------- /img/gem_vs_ce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/img/gem_vs_ce.png -------------------------------------------------------------------------------- /img/gem_with_remax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/img/gem_with_remax.png -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 5 | import torch 6 | 7 | from datasets import load_dataset 8 | 9 | from argparse import ArgumentParser 10 | from transformers import AutoTokenizer 11 | from multiprocessing import Pool 12 | from tqdm import tqdm 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument( 16 | "--dataset_name_or_path", 17 | type=str, 18 | default="HuggingFaceH4/ultrafeedback_binarized", 19 | ) 20 | parser.add_argument( 21 | "--split", 22 | type=str, 23 | default="train", 24 | ) 25 | parser.add_argument( 26 | "--start", 27 | type=int, 28 | default=0, 29 | ) 30 | parser.add_argument( 31 | "--end", 32 | type=int, 33 | default=None, 34 | ) 35 | parser.add_argument( 36 | "--output_file", 37 | type=str, 38 | ) 39 | parser.add_argument( 40 | "--tokenizer_name_or_path", 41 | type=str, 42 | required=True 43 | ) 44 | parser.add_argument("--max_seq_length", type=int, default=4096) 45 | parser.add_argument("--preprocessing_num_workers", type=int, default=64) 46 | args = parser.parse_args() 47 | 48 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) 49 | print(f"load tokenizer from {args.tokenizer_name_or_path} done.") 50 | max_seq_length = args.max_seq_length 51 | 52 | input_data = load_dataset(args.dataset_name_or_path) 53 | if args.split: 54 | input_data = input_data[args.split] 55 | if args.end is None: 56 | args.end = len(input_data) 57 | input_data = input_data.select(range(args.start, args.end)) 58 | print( 59 | f"load input data from {args.dataset_name_or_path} done. len(input_data): {len(input_data)}" 60 | ) 61 | 62 | 63 | def encode_sft_example(example, verbose=False): 64 | """ 65 | This function encodes a single example into a format that can be used for sft training. 66 | Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields. 67 | We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors. 68 | """ 69 | messages = example["messages"] 70 | if len(messages) == 0: 71 | raise ValueError("messages field is empty.") 72 | if verbose: 73 | chat_messages = tokenizer.apply_chat_template( 74 | conversation=messages, 75 | tokenize=False, 76 | return_tensors="pt", 77 | padding=False, 78 | truncation=True, 79 | max_length=max_seq_length, 80 | add_generation_prompt=False, 81 | ) 82 | print(f"chat_messages:\n[{chat_messages}]") 83 | input_ids = tokenizer.apply_chat_template( 84 | conversation=messages, 85 | tokenize=True, 86 | return_tensors="pt", 87 | padding=False, 88 | truncation=True, 89 | max_length=max_seq_length, 90 | add_generation_prompt=False, 91 | ) 92 | labels = input_ids.clone() 93 | # mask the non-assistant part for avoiding loss 94 | for message_idx, message in enumerate(messages): 95 | if message["role"] != "assistant": 96 | # we calculate the start index of this non-assistant message 97 | if message_idx == 0: 98 | message_start_idx = 0 99 | else: 100 | message_start_idx = tokenizer.apply_chat_template( 101 | conversation=messages[ 102 | :message_idx 103 | ], # here marks the end of the previous messages 104 | tokenize=True, 105 | return_tensors="pt", 106 | padding=False, 107 | truncation=True, 108 | max_length=max_seq_length, 109 | add_generation_prompt=False, 110 | ).shape[1] 111 | # next, we calculate the end index of this non-assistant message 112 | if ( 113 | message_idx < len(messages) - 1 114 | and messages[message_idx + 1]["role"] == "assistant" 115 | ): 116 | # for intermediate messages that follow with an assistant message, we need to 117 | # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss 118 | # (e.g., `<|assistant|>`) 119 | message_end_idx = tokenizer.apply_chat_template( 120 | conversation=messages[: message_idx + 1], 121 | tokenize=True, 122 | return_tensors="pt", 123 | padding=False, 124 | truncation=True, 125 | max_length=max_seq_length, 126 | add_generation_prompt=True, 127 | ).shape[1] 128 | else: 129 | # for the last message or the message that doesn't follow with an assistant message, 130 | # we don't need to add the assistant generation prefix 131 | message_end_idx = tokenizer.apply_chat_template( 132 | conversation=messages[: message_idx + 1], 133 | tokenize=True, 134 | return_tensors="pt", 135 | padding=False, 136 | truncation=True, 137 | max_length=max_seq_length, 138 | add_generation_prompt=False, 139 | ).shape[1] 140 | # set the label to -100 for the non-assistant part 141 | labels[:, message_start_idx:message_end_idx] = -100 142 | if max_seq_length and message_end_idx >= max_seq_length: 143 | break 144 | attention_mask = torch.ones_like(input_ids) 145 | return { 146 | "input_ids": input_ids.flatten().tolist(), 147 | "labels": labels.flatten().tolist(), 148 | "attention_mask": attention_mask.flatten().tolist(), 149 | } 150 | 151 | 152 | print(encode_sft_example(input_data[0], verbose=True)) 153 | 154 | tokenized_data = [] 155 | with Pool(args.preprocessing_num_workers) as p: 156 | pbar = tqdm(input_data, desc=f"tokenizing") 157 | for tokenized_example in p.imap(encode_sft_example, pbar): 158 | dump = json.dumps(tokenized_example) 159 | tokenized_data.append(dump) 160 | 161 | with open(args.output_file, "w") as fw: 162 | for dump in tokenized_data: 163 | fw.write(dump + "\n") 164 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | transformers==4.45.2 3 | datasets==2.14.6 4 | deepspeed==0.15.0 5 | flash-attn==2.6.3 6 | accelerate 7 | vllm==2.6.1 8 | sentence-transformers==3.0.1 9 | nltk 10 | fraction 11 | sacrebleu -------------------------------------------------------------------------------- /scripts/eval/creative_writing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export CUDA_VISIBLE_DEVICES="0" 9 | 10 | MODEL_PATH="model-path" 11 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 12 | 13 | SEED=42 14 | N=16 15 | T=1.0 16 | K=50 17 | P=0.9 18 | 19 | 20 | ############################################ 21 | # poem writing 22 | ############################################ 23 | 24 | DATA_PATH="./data/poem_generation" 25 | 26 | python analysis/evaluation/generate_response.py \ 27 | --model_name_or_path $MODEL_PATH \ 28 | --tokenizer_path $TOKENIZER_PATH \ 29 | --dataset_path $DATA_PATH \ 30 | --max_size 1000 \ 31 | --seed $SEED \ 32 | --temperature $T \ 33 | --top_k $K \ 34 | --top_p $P \ 35 | --max_new_tokens 512 \ 36 | --n $N \ 37 | --use_vllm True \ 38 | --do_sample True \ 39 | --remove_old True \ 40 | --save_path "${MODEL_PATH}/poem-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json" 41 | 42 | 43 | python analysis/evaluation/evaluation_diversity.py \ 44 | --tokenizer_path $TOKENIZER_PATH \ 45 | --detokenizer_path $TOKENIZER_PATH \ 46 | --response_path "${MODEL_PATH}/poem-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json" \ 47 | 2>&1 | tee ${MODEL_PATH}/diversity_eval-poem-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.log 48 | 49 | ############################################ 50 | # story writing 51 | ############################################ 52 | 53 | DATA_PATH="./data/story_generation" 54 | 55 | python evaluation/generate_response.py \ 56 | --model_name_or_path $MODEL_PATH \ 57 | --tokenizer_path $TOKENIZER_PATH \ 58 | --dataset_path $DATA_PATH \ 59 | --max_size 500 \ 60 | --seed $SEED \ 61 | --temperature $T \ 62 | --top_k $K \ 63 | --top_p $P \ 64 | --max_new_tokens 512 \ 65 | --n $N \ 66 | --use_vllm True \ 67 | --do_sample True \ 68 | --remove_old True \ 69 | --save_path "${MODEL_PATH}/story-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json" 70 | 71 | python evaluation/evaluation_diversity.py \ 72 | --tokenizer_path $TOKENIZER_PATH \ 73 | --detokenizer_path $TOKENIZER_PATH \ 74 | --response_path "${MODEL_PATH}/story-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json" \ 75 | 2>&1 | tee ${MODEL_PATH}/diversity_eval-story-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.log -------------------------------------------------------------------------------- /scripts/eval/gsm8k_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export CUDA_VISIBLE_DEVICES="0" 9 | 10 | DATA_PATH="gsm8k" 11 | MODEL_PATH="model-path" 12 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 13 | 14 | T=0.0 15 | K=-1 16 | P=1.0 17 | 18 | python analysis/evaluation/evaluation_gsm8k.py \ 19 | --model_name_or_path $MODEL_PATH \ 20 | --tokenizer_name_or_path $TOKENIZER_PATH \ 21 | --dataset_name_or_path $DATA_PATH \ 22 | --batch_size 20 \ 23 | --max_new_tokens 512 \ 24 | --use_vllm True \ 25 | --remove_old True \ 26 | --temperature $T \ 27 | --top_p $P \ 28 | --top_k $K \ 29 | --save_path "${MODEL_PATH}/gsm8k_T_${T}_K_${K}_P_${P}.json" \ 30 | 2>&1 | tee ${MODEL_PATH}/gsm8k_T_${T}_K_${K}_P_${P}.log 31 | -------------------------------------------------------------------------------- /scripts/eval/gsm8k_voting_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export CUDA_VISIBLE_DEVICES="0" 9 | 10 | DATA_PATH="gsm8k" 11 | MODEL_PATH="model-path" 12 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 13 | 14 | SEED=42 15 | N=32 16 | T=0.6 17 | K=50 18 | P=0.9 19 | 20 | 21 | python analysis/evaluation/evaluation_gsm8k_voting.py \ 22 | --model_name_or_path $MODEL_PATH \ 23 | --tokenizer_name_or_path $TOKENIZER_PATH \ 24 | --dataset_name_or_path $DATA_PATH \ 25 | --dtype bf16 \ 26 | --batch_size 128 \ 27 | --max_new_tokens 512 \ 28 | --seed $SEED \ 29 | --n $N \ 30 | --temperature $T \ 31 | --top_k $K \ 32 | --top_p $P \ 33 | --use_vllm True \ 34 | --remove_old True \ 35 | --save_path "${MODEL_PATH}/gsm8k_voting-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.json" \ 36 | 2>&1 | tee "${MODEL_PATH}/gsm8k_evaluation-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.log" 37 | -------------------------------------------------------------------------------- /scripts/eval/reward_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export CUDA_VISIBLE_DEVICES="0" 9 | 10 | DATA_PATH="tatsu-lab/alpaca_eval" 11 | MODEL_PATH="model-path" 12 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" 13 | REWARD_MODEL="/sfairXC/FsfairX-LLaMA3-RM-v0.1" 14 | 15 | SEED=42 16 | T=0.6 17 | K=50 18 | P=0.9 19 | N=16 20 | 21 | python analysis/evaluation/generate_response.py \ 22 | --model_name_or_path $MODEL_PATH \ 23 | --tokenizer_path $TOKENIZER_PATH \ 24 | --dataset_path $DATA_PATH \ 25 | --max_size 1000 \ 26 | --seed $SEED \ 27 | --temperature $T \ 28 | --top_k $K \ 29 | --top_p $P \ 30 | --max_new_tokens 2048 \ 31 | --n $N \ 32 | --use_vllm True \ 33 | --save_path "${MODEL_PATH}/alpaca_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.json" 34 | 35 | python analysis/evaluation/evaluation_reward.py \ 36 | --model_name_or_path $REWARD_MODEL \ 37 | --batch_size 8 \ 38 | --detokenizer_path $TOKENIZER_PATH \ 39 | --data_path "${MODEL_PATH}/alpaca_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.json" \ 40 | --save_path "${MODEL_PATH}/alpaca_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}-reward.json" \ 41 | 2>&1 | tee ${MODEL_PATH}/reward_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.log 42 | -------------------------------------------------------------------------------- /scripts/llama3.1/tokenize_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export FLASH_ATTENTION_DETERMINISTIC="1" 9 | export CUDA_VISIBLE_DEVICES="0" 10 | 11 | # tokenize train data 12 | python preprocess_data.py \ 13 | --dataset_name_or_path "HuggingFaceH4/ultrafeedback_binarized" \ 14 | --split "train_sft" \ 15 | --tokenizer_name_or_path "meta-llama/Llama-3.1-8B-Instruct" \ 16 | --max_seq_length 2048 \ 17 | --output_file "./data/ultrafeedback_sft_train_llama3.1_tokenized.jsonl" 18 | 19 | # tokenize test data 20 | python preprocess_data.py \ 21 | --dataset_name_or_path "HuggingFaceH4/ultrafeedback_binarized" \ 22 | --split "test_sft" \ 23 | --tokenizer_name_or_path "meta-llama/Llama-3.1-8B-Instruct" \ 24 | --max_seq_length 2048 \ 25 | --output_file "./data/ultrafeedback_sft_test_llama3.1_tokenized.jsonl" -------------------------------------------------------------------------------- /scripts/llama3.1/train_ce_ultrafeedback.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export FLASH_ATTENTION_DETERMINISTIC="1" 9 | 10 | TRAIN_TOKENIZED_FILE="./data/ultrafeedback_sft_train_llama3.1_tokenized.jsonl" 11 | TEST_TOKENIZED_FILE="./data/ultrafeedback_sft_test_llama3.1_tokenized.jsonl" 12 | 13 | MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B" 14 | SEED=1234 15 | 16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"` 17 | OUTPUT_DIR="./log/sft_ce-llama3.1-8b-ultrafeedback-$TIME_STEP-$SEED" 18 | 19 | mkdir -p $OUTPUT_DIR 20 | 21 | deepspeed train.py \ 22 | --deepspeed scripts/zero2.json \ 23 | --seed $SEED \ 24 | --model_name_or_path $MODEL_NAME_OR_PATH \ 25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \ 26 | --test_tokenized_file $TEST_TOKENIZED_FILE \ 27 | --output_dir $OUTPUT_DIR \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --evaluation_strategy "epoch" \ 31 | --save_strategy "no" \ 32 | --loss "ce" \ 33 | --learning_rate 2e-5 \ 34 | --lr_scheduler_type cosine \ 35 | --warmup_ratio 0.03 \ 36 | --num_train_epochs 3 \ 37 | --logging_steps 10 \ 38 | --report_to "tensorboard" \ 39 | --gradient_checkpointing True \ 40 | --overwrite_output_dir \ 41 | --bf16 True \ 42 | 2>&1 | tee $OUTPUT_DIR/training.log -------------------------------------------------------------------------------- /scripts/llama3.1/train_gem_ultrafeedback.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export FLASH_ATTENTION_DETERMINISTIC="1" 9 | 10 | TRAIN_TOKENIZED_FILE="./data/ultrafeedback_sft_train_llama3.1_tokenized.jsonl" 11 | TEST_TOKENIZED_FILE="./data/ultrafeedback_sft_test_llama3.1_tokenized.jsonl" 12 | 13 | MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B" 14 | SEED=1234 15 | 16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"` 17 | OUTPUT_DIR="./log/sft_ce-llama3.1-8b-ultrafeedback-$TIME_STEP-$SEED" 18 | 19 | mkdir -p $OUTPUT_DIR 20 | 21 | deepspeed train.py \ 22 | --deepspeed scripts/zero2.json \ 23 | --seed $SEED \ 24 | --model_name_or_path $MODEL_NAME_OR_PATH \ 25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \ 26 | --test_tokenized_file $TEST_TOKENIZED_FILE \ 27 | --output_dir $OUTPUT_DIR \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --evaluation_strategy "epoch" \ 31 | --save_strategy "no" \ 32 | --loss "gem" \ 33 | --gem_beta 0.7 \ 34 | --gem_h "logsigmoid" \ 35 | --learning_rate 2e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --warmup_ratio 0.03 \ 38 | --num_train_epochs 3 \ 39 | --logging_steps 10 \ 40 | --report_to "tensorboard" \ 41 | --gradient_checkpointing True \ 42 | --overwrite_output_dir \ 43 | --bf16 True \ 44 | 2>&1 | tee $OUTPUT_DIR/training.log -------------------------------------------------------------------------------- /scripts/qwen2.5/tokenize_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export FLASH_ATTENTION_DETERMINISTIC="1" 9 | export CUDA_VISIBLE_DEVICES="0" 10 | 11 | # tokenize train data 12 | python preprocess_data.py \ 13 | --dataset_name_or_path "AIMO/NuminaMath-CoT" \ 14 | --split "train" \ 15 | --tokenizer_name_or_path "Qwen/Qwen2.5-Math-7B" \ 16 | --max_seq_length 2048 \ 17 | --output_file "./data/numina_sft_train_qwen2.5_tokenized.jsonl" \ 18 | --start 0 \ 19 | --end 20000 20 | 21 | # tokenize test data 22 | python preprocess_data.py \ 23 | --dataset_name_or_path "AIMO/NuminaMath-CoT" \ 24 | --split "train" \ 25 | --tokenizer_name_or_path "Qwen/Qwen2.5-Math-7B" \ 26 | --max_seq_length 2048 \ 27 | --output_file "./data/numina_sft_test_qwen2.5_tokenized.jsonl" \ 28 | --start 20000 \ 29 | --end 21000 -------------------------------------------------------------------------------- /scripts/qwen2.5/train_ce_numina.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export FLASH_ATTENTION_DETERMINISTIC="1" 9 | 10 | TRAIN_TOKENIZED_FILE="./data/numina_sft_train_qwen2.5_tokenized.jsonl" 11 | TEST_TOKENIZED_FILE="./data/numina_sft_test_qwen2.5_tokenized.jsonl" 12 | 13 | MODEL_NAME_OR_PATH="Qwen/Qwen2.5-Math-7B" 14 | SEED=1234 15 | 16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"` 17 | OUTPUT_DIR="./log/sft_ce-qwen2.5_7b-numina-$TIME_STEP-$SEED" 18 | 19 | mkdir -p $OUTPUT_DIR 20 | 21 | deepspeed train.py \ 22 | --deepspeed scripts/zero3.json \ 23 | --seed $SEED \ 24 | --model_name_or_path $MODEL_NAME_OR_PATH \ 25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \ 26 | --test_tokenized_file $TEST_TOKENIZED_FILE \ 27 | --output_dir $OUTPUT_DIR \ 28 | --per_device_train_batch_size 8 \ 29 | --gradient_accumulation_steps 2 \ 30 | --evaluation_strategy "epoch" \ 31 | --save_strategy "no" \ 32 | --loss "ce" \ 33 | --learning_rate 2e-5 \ 34 | --lr_scheduler_type cosine \ 35 | --warmup_ratio 0.03 \ 36 | --num_train_epochs 3 \ 37 | --logging_steps 10 \ 38 | --report_to "tensorboard" \ 39 | --gradient_checkpointing True \ 40 | --overwrite_output_dir \ 41 | --bf16 True \ 42 | 2>&1 | tee $OUTPUT_DIR/training.log -------------------------------------------------------------------------------- /scripts/qwen2.5/train_gem_numina.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export HF_DATASETS_OFFLINE=1 7 | export TRANSFORMERS_OFFLINE=1 8 | export FLASH_ATTENTION_DETERMINISTIC="1" 9 | 10 | TRAIN_TOKENIZED_FILE="./data/numina_sft_train_qwen2.5_tokenized.jsonl" 11 | TEST_TOKENIZED_FILE="./data/numina_sft_test_qwen2.5_tokenized.jsonl" 12 | 13 | MODEL_NAME_OR_PATH="Qwen/Qwen2.5-Math-7B" 14 | SEED=1234 15 | 16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"` 17 | OUTPUT_DIR="./log/sft_gem-qwen2.5_7b-numina-$TIME_STEP-$SEED" 18 | 19 | mkdir -p $OUTPUT_DIR 20 | 21 | deepspeed train.py \ 22 | --deepspeed scripts/zero3.json \ 23 | --seed $SEED \ 24 | --model_name_or_path $MODEL_NAME_OR_PATH \ 25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \ 26 | --test_tokenized_file $TEST_TOKENIZED_FILE \ 27 | --output_dir $OUTPUT_DIR \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --evaluation_strategy "epoch" \ 31 | --save_strategy "no" \ 32 | --loss "gem" \ 33 | --gem_beta 0.7 \ 34 | --gem_h "logsigmoid" \ 35 | --learning_rate 2e-5 \ 36 | --lr_scheduler_type cosine \ 37 | --warmup_ratio 0.03 \ 38 | --num_train_epochs 3 \ 39 | --logging_steps 10 \ 40 | --report_to "tensorboard" \ 41 | --gradient_checkpointing True \ 42 | --overwrite_output_dir \ 43 | --bf16 True \ 44 | 2>&1 | tee $OUTPUT_DIR/training.log -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": false, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": false, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_param_persistence_threshold": "auto", 23 | "stage3_max_live_parameters": 1e9, 24 | "stage3_max_reuse_distance": 1e9, 25 | "stage3_gather_16bit_weights_on_model_save": true 26 | } 27 | } -------------------------------------------------------------------------------- /sft_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from transformers import Trainer 5 | from transformers.trainer import ( 6 | ### 7 | _is_peft_model, 8 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, 9 | is_torch_xla_available, 10 | ) 11 | from typing import List, Optional, Dict 12 | from utils.gem_triton_loss import GEMLoss 13 | 14 | 15 | class SFTTrainer(Trainer): 16 | 17 | @torch.no_grad 18 | def compute_training_logs(self, logits, labels): 19 | shift_logits = logits[..., :-1, :] 20 | shift_labels = labels[..., 1:] 21 | 22 | mask = shift_labels != -100 23 | shift_logits = shift_logits[mask] 24 | shift_labels = shift_labels[mask] 25 | 26 | training_logs = {} 27 | if self.args.print_entropy: 28 | entropy = chunked_entropy_from_logits( 29 | shift_logits, 30 | batch_size=max(1, shift_logits.size(0) // 4), 31 | ).mean() 32 | training_logs["entropy"] = round(entropy.item(), 2) 33 | 34 | return training_logs 35 | 36 | def gem_loss(self, logits, labels, beta=0.7, ignore_index=-100, h="logsigmoid"): 37 | shift_logits = logits[..., :-1, :].contiguous() 38 | shift_labels = labels[..., 1:].contiguous() 39 | 40 | mask = shift_labels != -100 41 | shift_logits = shift_logits[mask] 42 | shift_labels = shift_labels[mask] 43 | 44 | with torch.no_grad(): 45 | logits_on_labels = torch.gather( 46 | shift_logits, dim=-1, index=shift_labels.unsqueeze(-1) 47 | ).squeeze(-1) 48 | 49 | logits_diff = shift_logits - logits_on_labels.unsqueeze(-1) 50 | if h == "linear": 51 | weights = torch.ones_like(logits_diff) 52 | elif h == "logsigmoid": 53 | weights = F.sigmoid(0.01 * logits_diff) 54 | else: 55 | raise ValueError(h) 56 | 57 | gene_log_probs = F.log_softmax(shift_logits, dim=-1) 58 | q_probs = torch.exp(F.log_softmax(shift_logits / beta, dim=-1)).detach() 59 | 60 | real_log_probs = torch.gather( 61 | gene_log_probs, dim=-1, index=shift_labels.unsqueeze(-1) 62 | ) 63 | 64 | loss = -torch.sum( 65 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1 66 | ).mean() 67 | 68 | return loss 69 | 70 | def gem_loss_triton(self, logits, labels, beta=0.7, ignore_index=-100, h="linear"): 71 | if h != "linear": 72 | print(f"[warning] only linear is supported for gem_loss_triton for now. Got {h}.") 73 | 74 | gem_loss_func = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="mean") 75 | 76 | shift_logits = logits[..., :-1, :].contiguous() 77 | shift_labels = labels[..., 1:].contiguous() 78 | 79 | mask = shift_labels != -100 80 | shift_logits = shift_logits[mask] 81 | shift_labels = shift_labels[mask] 82 | 83 | loss = gem_loss_func(shift_logits, shift_labels) 84 | 85 | return loss 86 | 87 | # copied from Transformer's trainer with 88 | def compute_loss(self, model, inputs, return_outputs=False): 89 | """ 90 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 91 | 92 | Subclass and override for custom behavior. 93 | """ 94 | if self.label_smoother is not None and "labels" in inputs: 95 | labels = inputs.pop("labels") 96 | else: 97 | labels = None 98 | outputs = model(**inputs) 99 | # Save past state if it exists 100 | # TODO: this needs to be fixed and made cleaner later. 101 | if self.args.past_index >= 0: 102 | self._past = outputs[self.args.past_index] 103 | 104 | if labels is not None: 105 | unwrapped_model = self.accelerator.unwrap_model(model) 106 | if _is_peft_model(unwrapped_model): 107 | model_name = unwrapped_model.base_model.model._get_name() 108 | else: 109 | model_name = unwrapped_model._get_name() 110 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 111 | loss = self.label_smoother(outputs, labels, shift_labels=True) 112 | else: 113 | loss = self.label_smoother(outputs, labels) 114 | else: 115 | if isinstance(outputs, dict) and "loss" not in outputs: 116 | raise ValueError( 117 | "The model did not return a loss from the inputs, only the following keys: " 118 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 119 | ) 120 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 121 | if self.args.loss == "ce" or self.control.should_evaluate: 122 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 123 | else: 124 | loss = self.gem_loss( 125 | outputs.logits, 126 | inputs["labels"], 127 | beta=self.args.gem_beta, 128 | h=self.args.gem_h, 129 | ) 130 | 131 | # ziniu add logs 132 | if not self.control.should_evaluate: 133 | self.training_logs = self.compute_training_logs( 134 | outputs.logits, inputs["labels"] 135 | ) 136 | self.training_logs["ce_loss"] = ( 137 | outputs["loss"] if isinstance(outputs, dict) else outputs[0] 138 | ) 139 | self.training_logs["ce_loss"] = round(self.training_logs["ce_loss"].item(), 4) 140 | 141 | return (loss, outputs) if return_outputs else loss 142 | 143 | def _maybe_log_save_evaluate( 144 | self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval 145 | ): 146 | if ( 147 | self.control.should_log 148 | and self.state.global_step > self._globalstep_last_logged 149 | ): 150 | if is_torch_xla_available(): 151 | xm.mark_step() 152 | 153 | logs: Dict[str, float] = {} 154 | 155 | # all_gather + mean() to get average loss over all processes 156 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item() 157 | 158 | # reset tr_loss to zero 159 | tr_loss -= tr_loss 160 | 161 | logs["loss"] = round( 162 | tr_loss_scalar 163 | / (self.state.global_step - self._globalstep_last_logged), 164 | 4, 165 | ) 166 | if grad_norm is not None: 167 | logs["grad_norm"] = round( 168 | ( 169 | grad_norm.detach().item() 170 | if isinstance(grad_norm, torch.Tensor) 171 | else grad_norm 172 | ), 173 | 4, 174 | ) 175 | logs["learning_rate"] = self._get_learning_rate() 176 | ### update logs 177 | if getattr(self, "training_logs", {}): 178 | logs.update(getattr(self, "training_logs", {})) 179 | 180 | self._total_loss_scalar += tr_loss_scalar 181 | self._globalstep_last_logged = self.state.global_step 182 | self.store_flos() 183 | 184 | self.log(logs) 185 | 186 | metrics = None 187 | if self.control.should_evaluate: 188 | metrics = self._evaluate(trial, ignore_keys_for_eval) 189 | 190 | if self.control.should_save: 191 | self._save_checkpoint(model, trial, metrics=metrics) 192 | self.control = self.callback_handler.on_save( 193 | self.args, self.state, self.control 194 | ) 195 | 196 | def chunked_entropy_from_logits(chunk_logits, batch_size=None): 197 | """ 198 | Compute entropy from logits in a memory-efficient manner by introducing a batch_size parameter. 199 | 200 | Args: 201 | chunk_logits (torch.Tensor): Logits tensor of shape (total_samples, num_classes). 202 | batch_size (int): Number of samples to process per batch. 203 | 204 | Returns: 205 | torch.Tensor: Entropy tensor of shape (total_samples,). 206 | """ 207 | total_samples, num_classes = chunk_logits.shape 208 | entropy_list = [] 209 | if batch_size is None: 210 | batch_size = total_samples 211 | 212 | # Process logits in batches 213 | for start_idx in range(0, total_samples, batch_size): 214 | end_idx = min(start_idx + batch_size, total_samples) 215 | logits_batch = chunk_logits[start_idx:end_idx] # Get a batch of logits 216 | 217 | # Compute logsumexp for the current batch 218 | logsumexp_batch = torch.logsumexp(logits_batch, dim=-1, keepdim=False) # Shape: (batch_size,) 219 | # Compute probabilities in log-space without computing softmax 220 | normalized_logits = logits_batch - logsumexp_batch.unsqueeze(-1) # Shape: (batch_size, num_classes) 221 | exp_normalized_logits = torch.exp(normalized_logits) # Shape: (batch_size, num_classes) 222 | # Compute entropy for the batch 223 | entropy_batch = logsumexp_batch - (logits_batch * exp_normalized_logits).sum(dim=-1) # Shape: (batch_size,) 224 | 225 | entropy_list.append(entropy_batch) # Store entropy for the current batch 226 | 227 | # Concatenate results from all batches 228 | if len(entropy_list) > 0: 229 | return torch.cat(entropy_list, dim=0) 230 | else: 231 | return torch.tensor(0.0) -------------------------------------------------------------------------------- /sft_trainer_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from transformers import Trainer 5 | from transformers.trainer import ( 6 | ### 7 | _is_peft_model, 8 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, 9 | is_torch_xla_available, 10 | SaveStrategy 11 | ) 12 | from typing import List, Optional, Dict 13 | from utils.gem_triton_loss import GEMLoss 14 | 15 | 16 | class SFTTrainer(Trainer): 17 | 18 | @torch.no_grad 19 | def compute_training_logs(self, logits, labels): 20 | shift_logits = logits[..., :-1, :] 21 | shift_labels = labels[..., 1:] 22 | 23 | mask = shift_labels != -100 24 | shift_logits = shift_logits[mask] 25 | shift_labels = shift_labels[mask] 26 | 27 | training_logs = {} 28 | if self.args.print_entropy: 29 | entropy = chunked_entropy_from_logits( 30 | shift_logits, 31 | batch_size=max(1, shift_logits.size(0) // 4), 32 | ).mean() 33 | training_logs["entropy"] = round(entropy.item(), 2) 34 | 35 | return training_logs 36 | 37 | def gem_loss(self, logits, labels, num_items_in_batch, beta=0.7, ignore_index=-100, h="logsigmoid"): 38 | shift_logits = logits[..., :-1, :].contiguous() 39 | shift_labels = labels[..., 1:].contiguous() 40 | 41 | mask = shift_labels != -100 42 | shift_logits = shift_logits[mask] 43 | shift_labels = shift_labels[mask] 44 | 45 | with torch.no_grad(): 46 | logits_on_labels = torch.gather( 47 | shift_logits, dim=-1, index=shift_labels.unsqueeze(-1) 48 | ).squeeze(-1) 49 | 50 | logits_diff = shift_logits - logits_on_labels.unsqueeze(-1) 51 | if h == "linear": 52 | weights = torch.ones_like(logits_diff) 53 | elif h == "logsigmoid": 54 | weights = F.sigmoid(0.01 * logits_diff) 55 | else: 56 | raise ValueError(h) 57 | 58 | gene_log_probs = F.log_softmax(shift_logits, dim=-1) 59 | q_probs = torch.exp(F.log_softmax(shift_logits / beta, dim=-1)).detach() 60 | 61 | real_log_probs = torch.gather( 62 | gene_log_probs, dim=-1, index=shift_labels.unsqueeze(-1) 63 | ) 64 | 65 | if num_items_in_batch is not None: 66 | loss = -torch.sum( 67 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1 68 | ).sum() / num_items_in_batch 69 | else: 70 | loss = -torch.sum( 71 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1 72 | ).mean() 73 | 74 | return loss 75 | 76 | def gem_loss_triton(self, logits, labels, num_items_in_batch, beta=0.7, ignore_index=-100, h="linear"): 77 | assert h == "linear", "Only linear is supported for gem_loss_triton for now." 78 | 79 | if num_items_in_batch is not None: 80 | gem_loss_func = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="none") 81 | else: 82 | gem_loss_func = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="mean") 83 | 84 | shift_logits = logits[..., :-1, :].contiguous() 85 | shift_labels = labels[..., 1:].contiguous() 86 | 87 | mask = shift_labels != -100 88 | shift_logits = shift_logits[mask] 89 | shift_labels = shift_labels[mask] 90 | 91 | loss = gem_loss_func(shift_logits, shift_labels) 92 | 93 | if num_items_in_batch is not None: 94 | loss = loss.sum() / num_items_in_batch 95 | else: 96 | loss = loss 97 | return loss 98 | 99 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 100 | """ 101 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 102 | 103 | Subclass and override for custom behavior. 104 | """ 105 | if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: 106 | labels = inputs.pop("labels") 107 | else: 108 | labels = None 109 | if self.model_accepts_loss_kwargs: 110 | loss_kwargs = {} 111 | if num_items_in_batch is not None: 112 | loss_kwargs["num_items_in_batch"] = num_items_in_batch 113 | inputs = {**inputs, **loss_kwargs} 114 | outputs = model(**inputs) 115 | # Save past state if it exists 116 | # TODO: this needs to be fixed and made cleaner later. 117 | if self.args.past_index >= 0: 118 | self._past = outputs[self.args.past_index] 119 | 120 | if labels is not None: 121 | unwrapped_model = self.accelerator.unwrap_model(model) 122 | if _is_peft_model(unwrapped_model): 123 | model_name = unwrapped_model.base_model.model._get_name() 124 | else: 125 | model_name = unwrapped_model._get_name() 126 | # User-defined compute_loss function 127 | if self.compute_loss_func is not None: 128 | loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) 129 | elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 130 | loss = self.label_smoother(outputs, labels, shift_labels=True) 131 | else: 132 | loss = self.label_smoother(outputs, labels) 133 | else: 134 | if isinstance(outputs, dict) and "loss" not in outputs: 135 | raise ValueError( 136 | "The model did not return a loss from the inputs, only the following keys: " 137 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 138 | ) 139 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 140 | if self.args.loss == "ce" or self.control.should_evaluate: 141 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 142 | elif self.args.loss == "gem": 143 | loss = self.gem_loss( 144 | outputs.logits, 145 | inputs["labels"], 146 | num_items_in_batch=num_items_in_batch, 147 | beta=self.args.gem_beta, 148 | h=self.args.gem_h 149 | ) 150 | elif self.args.loss == "gem_triton": 151 | loss = self.gem_loss_triton( 152 | outputs.logits, 153 | inputs["labels"], 154 | num_items_in_batch=num_items_in_batch, 155 | beta=self.args.gem_beta, 156 | h=self.args.gem_h 157 | ) 158 | 159 | if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: 160 | loss *= self.accelerator.num_processes 161 | 162 | # ziniu add logs 163 | if not self.control.should_evaluate: 164 | self.training_logs = self.compute_training_logs( 165 | outputs.logits, inputs["labels"] 166 | ) 167 | self.training_logs["ce_loss"] = ( 168 | outputs["loss"] if isinstance(outputs, dict) else outputs[0] 169 | ) 170 | self.training_logs["ce_loss"] = round(self.training_logs["ce_loss"].item(), 4) 171 | 172 | return (loss, outputs) if return_outputs else loss 173 | 174 | def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): 175 | if self.control.should_log and self.state.global_step > self._globalstep_last_logged: 176 | if is_torch_xla_available(): 177 | xm.mark_step() 178 | 179 | logs: Dict[str, float] = {} 180 | 181 | # all_gather + mean() to get average loss over all processes 182 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item() 183 | 184 | # reset tr_loss to zero 185 | tr_loss -= tr_loss 186 | 187 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 188 | if grad_norm is not None: 189 | logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm 190 | logs["learning_rate"] = self._get_learning_rate() 191 | if getattr(self, "training_logs", None): 192 | logs.update(self.training_logs) 193 | 194 | self._total_loss_scalar += tr_loss_scalar 195 | self._globalstep_last_logged = self.state.global_step 196 | self.store_flos() 197 | 198 | self.log(logs, start_time) 199 | 200 | metrics = None 201 | if self.control.should_evaluate: 202 | metrics = self._evaluate(trial, ignore_keys_for_eval) 203 | is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) 204 | 205 | if self.args.save_strategy == SaveStrategy.BEST: 206 | self.control.should_save = is_new_best_metric 207 | 208 | if self.control.should_save: 209 | self._save_checkpoint(model, trial) 210 | self.control = self.callback_handler.on_save(self.args, self.state, self.control) 211 | 212 | 213 | def chunked_entropy_from_logits(chunk_logits, batch_size=None): 214 | """ 215 | Compute entropy from logits in a memory-efficient manner by introducing a batch_size parameter. 216 | 217 | Args: 218 | chunk_logits (torch.Tensor): Logits tensor of shape (total_samples, num_classes). 219 | batch_size (int): Number of samples to process per batch. 220 | 221 | Returns: 222 | torch.Tensor: Entropy tensor of shape (total_samples,). 223 | """ 224 | total_samples, num_classes = chunk_logits.shape 225 | entropy_list = [] 226 | if batch_size is None: 227 | batch_size = total_samples 228 | 229 | # Process logits in batches 230 | for start_idx in range(0, total_samples, batch_size): 231 | end_idx = min(start_idx + batch_size, total_samples) 232 | logits_batch = chunk_logits[start_idx:end_idx] # Get a batch of logits 233 | 234 | # Compute logsumexp for the current batch 235 | logsumexp_batch = torch.logsumexp(logits_batch, dim=-1, keepdim=False) # Shape: (batch_size,) 236 | # Compute probabilities in log-space without computing softmax 237 | normalized_logits = logits_batch - logsumexp_batch.unsqueeze(-1) # Shape: (batch_size, num_classes) 238 | exp_normalized_logits = torch.exp(normalized_logits) # Shape: (batch_size, num_classes) 239 | # Compute entropy for the batch 240 | entropy_batch = logsumexp_batch - (logits_batch * exp_normalized_logits).sum(dim=-1) # Shape: (batch_size,) 241 | 242 | entropy_list.append(entropy_batch) # Store entropy for the current batch 243 | 244 | # Concatenate results from all batches 245 | if len(entropy_list) > 0: 246 | return torch.cat(entropy_list, dim=0) 247 | else: 248 | return torch.tensor(0.0) -------------------------------------------------------------------------------- /tests/test_gem_loss_triton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # Add project root to Python path 5 | import os 6 | import sys 7 | from pathlib import Path 8 | project_root = str(Path(__file__).parent.parent) 9 | if project_root not in sys.path: 10 | sys.path.insert(0, project_root) 11 | 12 | from utils.gem_triton_loss import GEMLoss 13 | 14 | class ReferenceGEMLoss(torch.nn.Module): 15 | def forward(self, logits, labels, beta=0.7, ignore_index=-100, h="linear", reduction="mean"): 16 | """Reference implementation of GEM loss""" 17 | mask = labels != ignore_index 18 | masked_logits = logits[mask] 19 | masked_labels = labels[mask] 20 | 21 | with torch.no_grad(): 22 | logits_on_labels = torch.gather( 23 | masked_logits, dim=-1, index=masked_labels.unsqueeze(-1) 24 | ).squeeze(-1) 25 | logits_diff = masked_logits - logits_on_labels.unsqueeze(-1) 26 | if h == "linear": 27 | weights = torch.ones_like(logits_diff) 28 | else: 29 | raise ValueError(f"Unsupported h function: {h}") 30 | 31 | gene_log_probs = F.log_softmax(masked_logits, dim=-1) 32 | with torch.no_grad(): 33 | q_probs = torch.exp(F.log_softmax(masked_logits / beta, dim=-1)).detach() 34 | 35 | real_log_probs = torch.gather( 36 | gene_log_probs, dim=-1, index=masked_labels.unsqueeze(-1) 37 | ) 38 | 39 | if reduction == "mean": 40 | loss = -torch.sum( 41 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1 42 | ).mean() 43 | elif reduction == "sum": 44 | loss = -torch.sum( 45 | q_probs * weights * (real_log_probs - gene_log_probs) 46 | ) 47 | else: 48 | raise ValueError(f"Unsupported reduction: {reduction}") 49 | 50 | return loss 51 | 52 | def test_gem_loss(): 53 | # Set random seed for reproducibility 54 | torch.manual_seed(42) 55 | 56 | # Test parameters 57 | batch_size = 10 58 | vocab_size = 120000 59 | beta = 0.7 60 | ignore_index = -100 61 | 62 | # Create random inputs 63 | logits = torch.randn(batch_size, vocab_size, device='cuda', requires_grad=True) 64 | labels = torch.randint(0, vocab_size, (batch_size,), device='cuda') 65 | # Add some ignored indices 66 | # labels[0] = ignore_index 67 | 68 | # Create loss functions 69 | triton_loss_fn = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="sum") 70 | ref_loss_fn = ReferenceGEMLoss() 71 | 72 | # Forward pass 73 | triton_loss = triton_loss_fn(logits, labels) 74 | ref_loss = ref_loss_fn(logits, labels, beta=beta, ignore_index=ignore_index, h="linear", reduction="sum") 75 | 76 | test_failed = False 77 | # Check forward pass results 78 | print("*" * 100) 79 | print("triton_loss:", triton_loss) 80 | print("ref_loss:", ref_loss) 81 | print("Forward pass difference:", torch.abs((triton_loss - ref_loss)).mean().item()) 82 | try: 83 | torch.testing.assert_close(triton_loss, ref_loss, rtol=1e-4, atol=1e-4) 84 | except Exception as e: 85 | print(e) 86 | test_failed = True 87 | 88 | # Backward pass 89 | triton_logits = logits.clone().detach().requires_grad_(True) 90 | ref_logits = logits.clone().detach().requires_grad_(True) 91 | 92 | triton_loss = triton_loss_fn(triton_logits, labels) 93 | ref_loss = ref_loss_fn(ref_logits, labels, beta=beta, ignore_index=ignore_index, h="linear", reduction="sum") 94 | 95 | triton_loss.mean().backward() 96 | ref_loss.mean().backward() 97 | 98 | # Check backward pass results 99 | print("*" * 100) 100 | print("Max gradient difference:", torch.max(torch.abs(triton_logits.grad - ref_logits.grad)).item()) 101 | try: 102 | torch.testing.assert_close(triton_logits.grad, ref_logits.grad, rtol=1e-4, atol=1e-4) 103 | except Exception as e: 104 | print(e) 105 | test_failed = True 106 | 107 | if test_failed: 108 | print("Test failed!") 109 | else: 110 | print("All tests passed!") 111 | 112 | if __name__ == "__main__": 113 | test_gem_loss() 114 | -------------------------------------------------------------------------------- /tests/test_gem_loss_triton_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | from torch.nn.parallel import DistributedDataParallel as DDP 8 | 9 | # Add project root to Python path 10 | import os 11 | import sys 12 | from pathlib import Path 13 | project_root = str(Path(__file__).parent.parent) 14 | if project_root not in sys.path: 15 | sys.path.insert(0, project_root) 16 | 17 | from utils.gem_triton_loss import GEMLoss as TritonGEMLoss 18 | 19 | class ReferenceGEMLoss(torch.nn.Module): 20 | def forward(self, logits, labels, beta=0.7, ignore_index=-100, h="linear"): 21 | """Reference implementation of GEM loss""" 22 | mask = labels != ignore_index 23 | masked_logits = logits[mask] 24 | masked_labels = labels[mask] 25 | 26 | with torch.no_grad(): 27 | logits_on_labels = torch.gather( 28 | masked_logits, dim=-1, index=masked_labels.unsqueeze(-1) 29 | ).squeeze(-1) 30 | logits_diff = masked_logits - logits_on_labels.unsqueeze(-1) 31 | if h == "linear": 32 | weights = torch.ones_like(logits_diff) 33 | else: 34 | raise ValueError(f"Unsupported h function: {h}") 35 | 36 | gene_log_probs = F.log_softmax(masked_logits, dim=-1) 37 | with torch.no_grad(): 38 | q_probs = torch.exp(F.log_softmax(masked_logits / beta, dim=-1)).detach() 39 | 40 | real_log_probs = torch.gather( 41 | gene_log_probs, dim=-1, index=masked_labels.unsqueeze(-1) 42 | ) 43 | 44 | loss = -torch.sum( 45 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1 46 | ).mean() 47 | 48 | return loss 49 | 50 | def setup(local_rank=None): 51 | """Initialize the distributed environment.""" 52 | # When using torch.distributed.launch, use the environment variables it sets 53 | if local_rank is None: 54 | if 'LOCAL_RANK' in os.environ: 55 | local_rank = int(os.environ['LOCAL_RANK']) 56 | world_size = int(os.environ['WORLD_SIZE']) 57 | rank = int(os.environ['RANK']) 58 | else: 59 | raise ValueError("LOCAL_RANK not found in environment variables") 60 | else: 61 | # For mp.spawn path 62 | rank = local_rank 63 | world_size = int(os.environ.get('WORLD_SIZE', '1')) 64 | os.environ['MASTER_ADDR'] = 'localhost' 65 | os.environ['MASTER_PORT'] = '12355' 66 | 67 | # Initialize the process group 68 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 69 | 70 | # Set device for this process 71 | torch.cuda.set_device(local_rank) 72 | 73 | print(f"Rank {rank}/{world_size} initialized") 74 | return rank, world_size 75 | 76 | def cleanup(): 77 | """Clean up the distributed environment.""" 78 | dist.destroy_process_group() 79 | 80 | def run_test(rank=None, world_size=None): 81 | """Run the distributed GEM loss test on a single process.""" 82 | # Initialize distributed environment 83 | if rank is None: # When called directly (not through mp.spawn) 84 | rank, world_size = setup() 85 | else: # When called through mp.spawn 86 | setup(rank) 87 | 88 | # Set random seed for reproducibility 89 | torch.manual_seed(42 + rank) # Different seed per rank 90 | 91 | # Test parameters 92 | batch_size = 100 93 | vocab_size = 102400 94 | beta = 0.7 95 | ignore_index = -100 96 | 97 | # Each rank handles a portion of the vocabulary 98 | local_vocab_size = vocab_size // world_size 99 | 100 | # Create random inputs for this rank 101 | logits = torch.randn(batch_size, local_vocab_size, device=rank, requires_grad=True) 102 | 103 | # All ranks have the same labels (in the range of the full vocabulary) 104 | # We use the same seed for labels to ensure consistency 105 | torch.manual_seed(42) # Same seed for labels across all ranks 106 | labels = torch.randint(0, vocab_size, (batch_size,), device=rank) 107 | 108 | # Create loss function with process group 109 | triton_loss_fn = TritonGEMLoss(beta=beta, ignore_index=ignore_index, process_group=dist.group.WORLD) 110 | 111 | # Forward pass 112 | triton_loss = triton_loss_fn(logits, labels) 113 | 114 | # Basic sanity checks 115 | assert not torch.isnan(triton_loss).any(), f"Rank {rank}: Loss contains NaN values" 116 | assert not torch.isinf(triton_loss).any(), f"Rank {rank}: Loss contains Inf values" 117 | 118 | # Backward pass 119 | triton_loss.mean().backward() 120 | 121 | # Check gradients 122 | assert not torch.isnan(logits.grad).any(), f"Rank {rank}: Gradients contain NaN values" 123 | assert not torch.isinf(logits.grad).any(), f"Rank {rank}: Gradients contain Inf values" 124 | 125 | # Gather all logits to rank 0 for verification (optional) 126 | if rank == 0: 127 | all_logits = [torch.zeros_like(logits) for _ in range(world_size)] 128 | else: 129 | all_logits = None 130 | 131 | dist.gather(logits, all_logits if rank == 0 else None, dst=0) 132 | 133 | # On rank 0, verify the loss against reference implementation 134 | if rank == 0: 135 | print(f"Distributed test on {world_size} GPUs:") 136 | print(f"Triton GEM loss: {triton_loss.mean().item()}") 137 | 138 | # Concatenate all logits to get the full vocabulary 139 | full_logits = torch.cat(all_logits, dim=1).detach().requires_grad_(True) 140 | 141 | # Compute reference loss 142 | ref_loss_fn = ReferenceGEMLoss() 143 | ref_loss = ref_loss_fn(full_logits, labels, beta=beta, ignore_index=ignore_index) 144 | 145 | print(f"Reference loss: {ref_loss.item()}") 146 | print(f"Difference: {abs(triton_loss.mean().item() - ref_loss.item())}") 147 | 148 | # Note: We don't expect exact match due to distributed computation differences 149 | print("Distributed test completed successfully!") 150 | 151 | # Wait for all processes 152 | dist.barrier() 153 | 154 | # Clean up 155 | cleanup() 156 | 157 | def main(): 158 | """Main function to launch the distributed test.""" 159 | # Check if CUDA is available 160 | if not torch.cuda.is_available(): 161 | print("CUDA not available. Skipping distributed test.") 162 | return 163 | 164 | # Check if we're being launched by torch.distributed.launch 165 | if 'LOCAL_RANK' in os.environ: 166 | # We're being launched by torch.distributed.launch, so just run the test 167 | run_test() 168 | else: 169 | # Manual launch with mp.spawn 170 | world_size = int(os.environ.get("WORLD_SIZE", 2)) 171 | 172 | # Ensure we have enough GPUs 173 | if torch.cuda.device_count() < world_size: 174 | print(f"Not enough GPUs available. Need {world_size}, found {torch.cuda.device_count()}") 175 | world_size = torch.cuda.device_count() 176 | print(f"Reducing world_size to {world_size}") 177 | 178 | if world_size < 2: 179 | print("Need at least 2 GPUs for a meaningful distributed test") 180 | return 181 | 182 | print(f"Running distributed test with world_size={world_size}") 183 | 184 | # Spawn processes 185 | mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True) 186 | 187 | if __name__ == "__main__": 188 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This file is modified from the huggingface example for finetuning language models 5 | [run_clm.py](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py) 6 | """ 7 | 8 | import logging 9 | import os 10 | 11 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 12 | import sys 13 | from typing import Optional 14 | from functools import partial 15 | import datasets 16 | import torch 17 | import torch.distributed as dist 18 | import deepspeed 19 | from datasets import load_dataset 20 | from torch.utils.data import Dataset 21 | from dataclasses import dataclass, field 22 | from typing import Optional, List, Union 23 | 24 | import transformers 25 | from transformers import ( 26 | AutoModelForCausalLM, 27 | AutoTokenizer, 28 | HfArgumentParser, 29 | DataCollatorForSeq2Seq, 30 | set_seed, 31 | ) 32 | from transformers.trainer_utils import get_last_checkpoint 33 | 34 | from packaging import version 35 | 36 | if version.parse(transformers.__version__) >= version.parse("4.46.0"): 37 | from sft_trainer_v2 import SFTTrainer 38 | else: 39 | from sft_trainer import SFTTrainer 40 | 41 | logging.basicConfig(level=logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | @dataclass 46 | class TrainingArguments(transformers.TrainingArguments): 47 | adam_beta2: float = field(default=0.95, metadata={"help": "Beta2 for AdamW"}) 48 | loss: str = field( 49 | default="gem", metadata={"help": "Loss name", "choices": ["gem", "ce", "gem_triton"]} 50 | ) 51 | gem_beta: float = field(default=0.7, metadata={"help": "Hyper-parameter in GEM. A value between 0 and 1. A value close to 1.0 makes GEM behave more like CE, while a value close to 0.0 preserves more diversity."}) 52 | gem_h: str = field( 53 | default="linear", metadata={"help": "Function $h$ in GEM. The 'logsigmoid' function is more adaptive, but the difference between 'logsigmoid' and 'linear' is usually negligible.", "choices": ["logsigmoid", "linear"]} 54 | ) 55 | print_entropy: bool = field( 56 | default=False, metadata={"help": "Print entropy during training"} 57 | ) 58 | 59 | 60 | @dataclass 61 | class ModelArguments: 62 | model_name_or_path: str = field( 63 | metadata={ 64 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 65 | } 66 | ) 67 | cache_dir: Optional[str] = field( 68 | default=None, 69 | metadata={ 70 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 71 | }, 72 | ) 73 | use_flash_attn: bool = field( 74 | default=True, 75 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 76 | ) 77 | 78 | 79 | @dataclass 80 | class DataArguments: 81 | train_tokenized_file: str = field( 82 | default=None, metadata={"help": "huggingface dataset name or local data path"} 83 | ) 84 | test_tokenized_file: str = field( 85 | default=None, metadata={"help": "huggingface dataset name or local data path"} 86 | ) 87 | max_train_samples: Optional[int] = field( 88 | default=None, 89 | metadata={ 90 | "help": ( 91 | "For debugging purposes or quicker training, truncate the number of training examples to this " 92 | "value if set." 93 | ) 94 | }, 95 | ) 96 | max_seq_length: Optional[int] = field( 97 | default=None, 98 | metadata={ 99 | "help": ( 100 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 101 | ) 102 | }, 103 | ) 104 | overwrite_cache: bool = field( 105 | default=False, 106 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 107 | ) 108 | 109 | 110 | class CustomDataset(Dataset): 111 | def __init__( 112 | self, 113 | training_args, 114 | data_args, 115 | model_args, 116 | train_tokenized_file, 117 | ): 118 | self.training_args = training_args 119 | self.data_args = data_args 120 | self.model_args = model_args 121 | 122 | raw_datasets = load_dataset( 123 | "json", 124 | data_files=[train_tokenized_file], 125 | cache_dir=self.model_args.cache_dir, 126 | ) 127 | self.data = raw_datasets["train"] 128 | 129 | if self.data_args.max_train_samples is not None: 130 | max_samples = min(len(self.data), self.data_args.max_train_samples) 131 | self.data = self.data.select(range(max_samples)) 132 | 133 | def __len__(self): 134 | return len(self.data) 135 | 136 | def __getitem__(self, item): 137 | example = self.data[item] 138 | assert "input_ids" in example 139 | assert "labels" in example 140 | example = {k: torch.tensor(v, dtype=torch.long) for k, v in example.items()} 141 | return example 142 | 143 | 144 | def main(): 145 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 146 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 147 | model_args, data_args, training_args = parser.parse_json_file( 148 | json_file=os.path.abspath(sys.argv[1]) 149 | ) 150 | else: 151 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 152 | 153 | # Setup logging 154 | logging.basicConfig( 155 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 156 | datefmt="%m/%d/%Y %H:%M:%S", 157 | handlers=[logging.StreamHandler(sys.stdout)], 158 | ) 159 | 160 | if training_args.should_log: 161 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 162 | transformers.utils.logging.set_verbosity_info() 163 | 164 | log_level = training_args.get_process_log_level() 165 | logger.setLevel(log_level) 166 | datasets.utils.logging.set_verbosity(log_level) 167 | transformers.utils.logging.set_verbosity(log_level) 168 | transformers.utils.logging.enable_default_handler() 169 | transformers.utils.logging.enable_explicit_format() 170 | 171 | # Log on each process the small summary: 172 | global_rank = dist.get_rank() 173 | logger.warning( 174 | f"Process rank: {global_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 175 | ) 176 | logger.info(f"Training parameters {training_args}") 177 | 178 | # Set seed before initializing model. 179 | set_seed(training_args.seed) 180 | 181 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 182 | if "llama-3" in tokenizer.name_or_path.lower() and tokenizer.pad_token is None: 183 | tokenizer.pad_token_id = len(tokenizer) - 1 184 | tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id) 185 | 186 | model = AutoModelForCausalLM.from_pretrained( 187 | model_args.model_name_or_path, 188 | torch_dtype="auto", 189 | attn_implementation=( 190 | "flash_attention_2" if model_args.use_flash_attn else "eager" 191 | ), 192 | ) 193 | 194 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 195 | # on a small vocab and want a smaller embedding size, remove this test. 196 | # gather deepspeed to get "real" embedding size 197 | embeddings = model.get_input_embeddings() 198 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): 199 | embedding_size = embeddings.weight.shape[0] 200 | # resize does its own gather 201 | if len(tokenizer) > embedding_size: 202 | # pad to multiple for tensor cores. 203 | logging.warning(f"len(tokenizer) > embedding_size!!! we are resizing...") 204 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) 205 | 206 | # set up datasets 207 | train_dataset = CustomDataset(training_args, data_args, model_args, data_args.train_tokenized_file) 208 | if data_args.test_tokenized_file: 209 | test_dataset = CustomDataset(training_args, data_args, model_args, data_args.test_tokenized_file) 210 | else: 211 | test_dataset = None 212 | 213 | # initalize a trainer 214 | # here we use a custom trainer that moves the model to CPU when saving the checkpoint in FSDP mode 215 | # we can switch to the default trainer after moving to deepspeed (let's don't change too much for now) 216 | 217 | trainer = SFTTrainer( 218 | model=model, 219 | args=training_args, 220 | train_dataset=train_dataset, 221 | eval_dataset=test_dataset, 222 | tokenizer=tokenizer, 223 | data_collator=DataCollatorForSeq2Seq( 224 | tokenizer=tokenizer, model=model, padding="longest" 225 | ), 226 | preprocess_logits_for_metrics=None, 227 | compute_metrics=None, 228 | ) 229 | 230 | # Training 231 | logger.info("*** Train ***") 232 | checkpoint = None 233 | if training_args.resume_from_checkpoint is not None: 234 | checkpoint = training_args.resume_from_checkpoint 235 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 236 | if "llama-3" in model.config.name_or_path.lower() and isinstance(model.generation_config.eos_token_id, int): 237 | model.generation_config.eos_token_id = [128001, 128009] 238 | trainer.save_model() # Saves the tokenizer too for easy upload 239 | 240 | metrics = train_result.metrics 241 | metrics["train_samples"] = len(train_dataset) 242 | trainer.log_metrics("train", metrics) 243 | trainer.save_metrics("train", metrics) 244 | 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Triton Implementation of GEM Loss 2 | 3 | This folder contains the Triton implementation of GEM loss. 4 | 5 | ## Test 6 | 7 | We have successfully tested the implementation. 8 | 9 | To run the tests, you can use the following command: 10 | 11 | ```bash 12 | python tests/test_gem_loss_triton.py 13 | python tests/test_gem_loss_triton_distributed.py 14 | ``` 15 | 16 | Please contact Ziniu Li (ziniuli@link.cuhk.edu.cn) if you find any issues. 17 | 18 | 19 | ## To Do 20 | 21 | - [ ] Add the implementation of GEM with $h = logsigmoid$ (currently only $h = linear$ is supported). 22 | - [ ] Add more tests. 23 | 24 | 25 | ## Acknowledgement 26 | 27 | We thank the authors of [flash-attention](https://github.com/Dao-AILab/flash-attention) for providing the Triton implementation of CE loss, for which the GEM loss is based on. 28 | -------------------------------------------------------------------------------- /utils/gem_triton_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Ziniu Li. 2 | # The code is modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .gem_triton_ops import gem_loss 9 | 10 | 11 | class GEMLoss(nn.Module): 12 | def __init__( 13 | self, 14 | ignore_index=-100, 15 | reduction="mean", 16 | beta=1.0, 17 | logit_scale=1.0, 18 | lse_square_scale=0.0, 19 | inplace_backward=False, 20 | process_group=None, 21 | return_z_loss=False, 22 | ): 23 | """ 24 | Arguments: 25 | ignore_index: int. If labels == ignore_index, the loss is set to 0.0. 26 | beta: float 27 | lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. 28 | This is also referred to as "z-loss". 29 | inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. 30 | This saves memory. 31 | process_group: if not None, we're doing Tensor Parallel: each process is responsible for 32 | one part of the vocab. The loss will be aggregated across processes. 33 | return_z_loss: bool. If True, we return the component of the loss contributed by 34 | the lse_square_scale value. This value is only for logging and does not support 35 | backprop. 36 | """ 37 | super().__init__() 38 | if reduction not in ["mean", "none", "sum"]: 39 | raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") 40 | self.ignore_index = ignore_index 41 | self.reduction = reduction 42 | self.beta = beta 43 | self.logit_scale = logit_scale 44 | self.lse_square_scale = lse_square_scale 45 | self.inplace_backward = inplace_backward 46 | self.process_group = process_group 47 | self.return_z_loss = return_z_loss 48 | 49 | def forward(self, input, target, precomputed_lse=None): 50 | """ 51 | Arguments: 52 | input: (batch, vocab_size) 53 | target: (batch,) 54 | Returns: 55 | losses: (batch,) if reduction is 'none', else (1,), dtype float 56 | z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) 57 | """ 58 | assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" 59 | loss, z_loss = gem_loss( 60 | input, 61 | target, 62 | precomputed_lse=precomputed_lse, 63 | beta=self.beta, 64 | logit_scale=self.logit_scale, 65 | lse_square_scale=self.lse_square_scale, 66 | ignore_index=self.ignore_index, 67 | inplace_backward=self.inplace_backward, 68 | process_group=self.process_group, 69 | ) 70 | if self.reduction == "mean": 71 | loss = loss.sum() / (target != self.ignore_index).sum() 72 | elif self.reduction == "sum": 73 | loss = loss.sum() 74 | else: 75 | loss = loss 76 | 77 | if not self.return_z_loss: 78 | return loss 79 | 80 | if self.reduction == "mean": 81 | z_loss = z_loss.sum() / (target != self.ignore_index).sum() 82 | elif self.reduction == "sum": 83 | z_loss = z_loss.sum() 84 | else: 85 | z_loss = z_loss 86 | 87 | return loss, z_loss 88 | -------------------------------------------------------------------------------- /utils/gem_triton_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Ziniu Li. 2 | # The code is modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py 3 | 4 | from typing import Tuple, Optional, Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | import triton 10 | import triton.language as tl 11 | 12 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for 13 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent 14 | # version of PyTorch. The following 2 lines are for backward compatibility with 15 | # older PyTorch. 16 | if "all_gather_into_tensor" not in dir(torch.distributed): 17 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 18 | 19 | @triton.jit 20 | def gem_fwd_kernel( 21 | loss_ptr, # data ptrs 22 | lse_ptr, 23 | z_loss_ptr, 24 | logits_ptr, 25 | labels_ptr, 26 | beta, 27 | logit_scale, 28 | lse_square_scale, 29 | ignore_index, 30 | total_classes, 31 | class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes 32 | n_cols, # shapes 33 | logits_row_stride, # strides 34 | BLOCK_SIZE: tl.constexpr, 35 | # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE 36 | SPLIT: tl.constexpr, 37 | PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) 38 | ): 39 | # GEM Loss (h = linear) 40 | # Loss = -log p(y|x) + sum_y q(y|x) * log p(y|x) 41 | # q (y|x) = softmax (1 / beta * log p(y|x)) 42 | # Note that the first term is the same as the CE loss. 43 | 44 | # Prepare for calculating lse. 45 | row_idx = tl.program_id(0) 46 | logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) 47 | 48 | if not PRECOMPUTED_LSE: 49 | # Statistics for online softmax 50 | m_i = -float("inf") 51 | l_i = 0.0 52 | for col_offset in range(0, n_cols, BLOCK_SIZE): 53 | cols = col_offset + tl.arange(0, BLOCK_SIZE) 54 | logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( 55 | tl.float32 56 | ) * logit_scale 57 | m_i_new = tl.maximum(m_i, tl.max(logits)) 58 | l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) 59 | m_i = m_i_new 60 | lse = tl.log(l_i) + m_i 61 | tl.store(lse_ptr + row_idx, lse) 62 | else: 63 | lse = tl.load(lse_ptr + row_idx) 64 | label_idx = tl.load(labels_ptr + row_idx) 65 | 66 | # Second term: q-regularized loss 67 | m_q = -float("inf") # running max 68 | s_q = 0.0 # running sum for denominator 69 | for col_offset in range(0, n_cols, BLOCK_SIZE): 70 | cols = col_offset + tl.arange(0, BLOCK_SIZE) 71 | logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( 72 | tl.float32 73 | ) * logit_scale 74 | logits_scaled = logits / beta 75 | 76 | # Update running max and rescale previous sum 77 | m_q_prev = m_q 78 | m_q = tl.maximum(m_q, tl.max(logits_scaled)) 79 | s_q = tl.exp(m_q_prev - m_q) * s_q 80 | 81 | # Add contribution from current block 82 | numerator = tl.exp(logits_scaled - m_q) 83 | s_q += tl.sum(tl.where(cols < n_cols, numerator, 0.0)) 84 | 85 | # Second pass: compute q_loss using final m_q and s_q 86 | q_loss = 0.0 87 | for col_offset in range(0, n_cols, BLOCK_SIZE): 88 | cols = col_offset + tl.arange(0, BLOCK_SIZE) 89 | logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( 90 | tl.float32 91 | ) * logit_scale 92 | logits_scaled = logits / beta 93 | 94 | # Compute q_probs using final normalization constants 95 | q_probs = tl.exp(logits_scaled - m_q) / s_q 96 | q_loss += tl.sum(tl.where(cols < n_cols, q_probs * (logits - lse), 0.0)) 97 | 98 | # Compute CE loss term (same as before) 99 | label_idx = tl.load(labels_ptr + row_idx) 100 | if label_idx == ignore_index: 101 | loss = 0.0 102 | z_loss = 0.0 103 | else: 104 | label_idx -= class_start_idx 105 | if label_idx >= 0 and label_idx < n_cols: 106 | logits_label = tl.load(logits_ptr + label_idx) * logit_scale 107 | # GEM loss = CE loss + q_loss 108 | loss = (lse if not SPLIT else 0.0) - logits_label + q_loss 109 | else: 110 | loss = q_loss 111 | if not SPLIT: 112 | z_loss = lse_square_scale * lse * lse 113 | loss += z_loss 114 | else: 115 | z_loss = 0.0 116 | tl.store(loss_ptr + row_idx, loss) 117 | if not SPLIT: 118 | tl.store(z_loss_ptr + row_idx, z_loss) 119 | 120 | 121 | @triton.jit 122 | def gem_bwd_kernel( 123 | dlogits_ptr, # data ptrs 124 | dloss_ptr, 125 | logits_ptr, 126 | lse_ptr, 127 | labels_ptr, 128 | beta, 129 | logit_scale, 130 | lse_square_scale, 131 | ignore_index, 132 | total_classes, 133 | class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes 134 | n_cols, # shapes 135 | logits_row_stride, # strides 136 | dlogits_row_stride, 137 | dloss_row_stride, 138 | BLOCK_SIZE: tl.constexpr, 139 | ): 140 | row_idx = tl.program_id(0) 141 | col_block_idx = tl.program_id(1) 142 | logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) 143 | dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) 144 | col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 145 | label_idx = tl.load(labels_ptr + row_idx) 146 | if label_idx != ignore_index: 147 | dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) 148 | else: 149 | dloss = 0.0 150 | logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( 151 | tl.float32 152 | ) * logit_scale 153 | lse = tl.load(lse_ptr + row_idx) 154 | 155 | # GEM: gradient of y w.r.t. logits = q(y|x) if y != label else q(y|x) - 1. 156 | # First pass: compute max and sum for proper softmax normalization 157 | m_q = -float("inf") 158 | s_q = 0.0 159 | for offset in range(0, n_cols, BLOCK_SIZE): 160 | cols = offset + tl.arange(0, BLOCK_SIZE) 161 | logits_block = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(tl.float32) * logit_scale 162 | logits_scaled = logits_block / beta 163 | m_q = tl.maximum(m_q, tl.max(logits_scaled)) 164 | 165 | # Second pass: compute sum with stable numerics 166 | for offset in range(0, n_cols, BLOCK_SIZE): 167 | cols = offset + tl.arange(0, BLOCK_SIZE) 168 | logits_block = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(tl.float32) * logit_scale 169 | logits_scaled = logits_block / beta 170 | numerator = tl.exp(logits_scaled - m_q) 171 | s_q += tl.sum(tl.where(cols < n_cols, numerator, 0.0)) 172 | 173 | # Final pass: compute gradients for current block 174 | col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 175 | logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32) * logit_scale 176 | logits_scaled = logits / beta 177 | probs = tl.exp(logits_scaled - m_q) / s_q 178 | 179 | probs += 2.0 * lse_square_scale * lse * probs 180 | label_idx -= class_start_idx 181 | probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) 182 | tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) 183 | 184 | 185 | class GEMLoss(torch.autograd.Function): 186 | 187 | @staticmethod 188 | def forward( 189 | ctx, 190 | logits, 191 | labels, 192 | precomputed_lse=None, 193 | beta=1.0, 194 | logit_scale=1.0, 195 | lse_square_scale=0.0, 196 | ignore_index=-100, 197 | inplace_backward=False, 198 | process_group=None, 199 | ): 200 | # For some reason Triton generates wrong code when labels has dtype long and its address 201 | # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. 202 | if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: 203 | labels = F.pad(labels, (0, 1))[..., :-1] 204 | assert labels.data_ptr() % 16 == 0 205 | assert logit_scale > 0.0 206 | n_rows, n_cols = logits.shape 207 | assert labels.shape == (n_rows,) 208 | world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) 209 | total_classes = world_size * n_cols 210 | rank = 0 if process_group is None else torch.distributed.get_rank(process_group) 211 | class_start_idx = rank * n_cols 212 | use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 213 | 214 | if logits.stride(-1) != 1: 215 | logits = logits.contiguous() 216 | MAX_BLOCK_SIZE = 16 * 1024 217 | BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) 218 | num_warps = ( 219 | 4 220 | if BLOCK_SIZE < 2048 221 | else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) 222 | ) 223 | losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) 224 | if use_precomputed_lse: 225 | assert precomputed_lse.shape == (n_rows,) 226 | lse = precomputed_lse.contiguous() 227 | else: 228 | lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) 229 | z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) 230 | # Need this, otherwise Triton tries to launch from cuda:0 and we get 231 | # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) 232 | with torch.cuda.device(logits.device.index): 233 | gem_fwd_kernel[(n_rows,)]( 234 | losses, # data ptrs 235 | lse, 236 | z_losses, 237 | logits, 238 | labels, 239 | beta, 240 | logit_scale, 241 | lse_square_scale, 242 | ignore_index, 243 | total_classes, 244 | class_start_idx, 245 | n_cols, # shapes 246 | logits.stride(0), # strides 247 | BLOCK_SIZE=BLOCK_SIZE, # constants 248 | SPLIT=world_size > 1, 249 | PRECOMPUTED_LSE=use_precomputed_lse, 250 | num_warps=num_warps, 251 | ) 252 | 253 | if world_size > 1: 254 | # For GEM loss, if labels are in the vocab of this partition, losses contains 255 | # - predicted logit + q_loss for this partition, and 0 otherwise. 256 | # For labels not in the vocab of this partition, losses contains 257 | # only q_loss for this partition 258 | lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) 259 | torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) 260 | handle_losses = torch.distributed.all_reduce( 261 | losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True 262 | ) 263 | lse = torch.logsumexp(lse_allgather, dim=0) 264 | handle_losses.wait() 265 | # After the allreduce: 266 | # 1. For labels in any partition, we have the sum of all q_losses from each partition 267 | # 2. For the partition containing the label, we also have -predicted_logit 268 | # We just need to add the global lse to complete the GEM loss 269 | losses += lse 270 | if lse_square_scale != 0.0: 271 | z_losses = lse_square_scale * lse.square() 272 | z_losses.masked_fill_(labels == ignore_index, 0.0) 273 | losses += z_losses 274 | else: 275 | z_losses = torch.zeros_like(losses) 276 | losses.masked_fill_(labels == ignore_index, 0.0) 277 | 278 | ctx.save_for_backward(logits, lse, labels) 279 | ctx.mark_non_differentiable(z_losses) 280 | ctx.beta = beta 281 | ctx.logit_scale = logit_scale 282 | ctx.lse_square_scale = lse_square_scale 283 | ctx.ignore_index = ignore_index 284 | ctx.total_classes = total_classes 285 | ctx.class_start_idx = class_start_idx 286 | ctx.inplace_backward = inplace_backward 287 | return losses, z_losses 288 | 289 | @staticmethod 290 | def backward(ctx, grad_losses, grad_z_losses): 291 | del grad_z_losses # z_losses are only for logging. 292 | 293 | logits, lse, labels = ctx.saved_tensors 294 | dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) 295 | n_rows, n_cols = logits.shape 296 | BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) 297 | num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) 298 | grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa 299 | # Need this, otherwise Triton tries to launch from cuda:0 and we get 300 | # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) 301 | with torch.cuda.device(logits.device.index): 302 | gem_bwd_kernel[grid]( 303 | dlogits, # data ptrs 304 | grad_losses, 305 | logits, 306 | lse, 307 | labels, 308 | ctx.beta, 309 | ctx.logit_scale, 310 | ctx.lse_square_scale, 311 | ctx.ignore_index, 312 | ctx.total_classes, 313 | ctx.class_start_idx, 314 | n_cols, # shapes 315 | logits.stride(0), # strides 316 | dlogits.stride(0), 317 | grad_losses.stride(0), 318 | BLOCK_SIZE=BLOCK_SIZE, # constants 319 | num_warps=num_warps, 320 | ) 321 | return dlogits, None, None, None, None, None, None, None, None, None 322 | 323 | 324 | def gem_loss( 325 | logits: torch.Tensor, 326 | labels: torch.Tensor, 327 | precomputed_lse: Optional[torch.Tensor] = None, 328 | beta: float = 1.0, 329 | logit_scale: float = 1.0, 330 | lse_square_scale: float = 0.0, 331 | ignore_index=-100, 332 | inplace_backward: bool = False, 333 | process_group=None, 334 | ) -> Tuple[torch.Tensor, torch.Tensor]: 335 | """ 336 | Arguments: 337 | logits: (batch, vocab_size) 338 | labels: (batch,) 339 | beta: float 340 | logit_scale: float. Multiply logits by this scale before calculating the loss. 341 | lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. 342 | This is also referred to as "z-loss". 343 | ignore_index: int. If labels == ignore_index, the loss is set to 0.0. 344 | inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. 345 | This saves memory. 346 | process_group: if not None, we're doing Tensor Parallel: each process is responsible for 347 | one part of the vocab. The loss will be aggregated across processes. 348 | Returns: 349 | losses: (batch,), float 350 | z_losses: (batch,), float 351 | """ 352 | return GEMLoss.apply( 353 | logits, 354 | labels, 355 | precomputed_lse, 356 | beta, 357 | logit_scale, 358 | lse_square_scale, 359 | ignore_index, 360 | inplace_backward, 361 | process_group, 362 | ) 363 | --------------------------------------------------------------------------------