├── .gitignore
├── README.md
├── data
├── poem_generation
│ └── test-00000-of-00001.parquet
└── story_generation
│ └── test-00000-of-00001.parquet
├── docs
└── GEM-2025-03-23.pdf
├── evaluation
├── convert_response_for_if_eval.py
├── evaluation_diversity.py
├── evaluation_gsm8k.py
├── evaluation_gsm8k_voting.py
├── evaluation_reward.py
├── generate_response.py
└── utils
│ └── gsm8k.py
├── img
├── gem_vs_ce.png
└── gem_with_remax.png
├── preprocess_data.py
├── requirements.txt
├── scripts
├── eval
│ ├── creative_writing.sh
│ ├── gsm8k_eval.sh
│ ├── gsm8k_voting_eval.sh
│ └── reward_eval.sh
├── llama3.1
│ ├── tokenize_data.sh
│ ├── train_ce_ultrafeedback.sh
│ └── train_gem_ultrafeedback.sh
├── qwen2.5
│ ├── tokenize_data.sh
│ ├── train_ce_numina.sh
│ └── train_gem_numina.sh
├── zero2.json
└── zero3.json
├── sft_trainer.py
├── sft_trainer_v2.py
├── tests
├── test_gem_loss_triton.py
└── test_gem_loss_triton_distributed.py
├── train.py
└── utils
├── README.md
├── gem_triton_loss.py
└── gem_triton_ops.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Result
10 | result/
11 | log/
12 |
13 | # Distribution / packaging
14 | .idea
15 | .Python
16 | .DS_Store
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | applications/DeepSpeed-Chat/data
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🚀 PyTorch Implementation of GEM 🌟
2 |
3 | Welcome to the official PyTorch implementation of **GEM**! 🎉
4 |
5 | GEM was introduced in our [ICLR 2025 paper](https://openreview.net/forum?id=dulz3WVhMR) **"Preserving Diversity in Supervised Fine-tuning of Large Language Models"**.
6 |
7 | > This work was previously titled "Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity" and received the Best Paper Runner-up Award at the NeurIPS 2024 FITML Workshop.
8 |
9 |
10 |
11 | GEM can replace the CE loss during SFT (supervised fine-tuning) or RFT (reinforced fine-tuning) to preserve diversity and mitigate overfitting. 🌍✨
12 |
13 | For an overview of GEM, please refer to our [presentation slides](docs/GEM-2025-03-23.pdf).
14 |
15 | For more insights on GEM's potential to enhance RL training through improved cold-start strategies, check out our blog post: ["Can Better Cold-Start Strategies Improve RL Training for LLMs?"](https://tangible-polo-203.notion.site/Can-Better-Cold-Start-Strategies-Improve-RL-Training-for-LLMs-17aa0742a51680828616c867ed53bc6b)
16 |
17 |
18 |
19 | ## Quickstart Guide 💻
20 |
21 | ### Setup 🔧
22 |
23 | First, create a new environment and install the required packages:
24 |
25 | ```bash
26 | conda create -n gem python=3.10
27 | conda activate gem
28 | pip install -r requirements.txt
29 | ```
30 |
31 | Note that the version of packages in `requirements.txt` is used in the paper. You may use a higher version of transformers (>= 4.46.0) that fixes the potential bug of gradient accumulation.
32 |
33 | We also provide a **Triton** implementation of GEM loss in the `utils` folder, which may be faster than the original implementation when training large-scale models. Please refer to the [README](utils/README.md) for more details. You may use this implementation with the following command:
34 |
35 | ```bash
36 | python train.py --loss gem_triton
37 | ```
38 |
39 |
40 | ### Training 🏋️♂️
41 |
42 | Kickstart your training process using the `UltraFeedback` dataset from HuggingFace. Here's how:
43 |
44 | **Tokenize Data**
45 |
46 | ```bash
47 | bash scripts/tokenize_data.sh
48 | ```
49 |
50 | **Training**
51 |
52 | ```bash
53 | bash scripts/train_gem_ultrafeedback.sh
54 | ```
55 |
56 | > **Note:** The `ce_loss` metric in training logs represents the cross-entropy loss calculated on a single machine without accounting for gradient accumulation. This value may differ from the reported `loss` metric. When monitoring training progress, you can use `ce_loss` as a diagnostic indicator to verify proper training behavior—the cross-entropy loss should decrease over time regardless of whether you're using CE or GEM as your primary loss function.
57 |
58 | ### Evaluation 🧪
59 |
60 | Run evaluations for different tasks:
61 |
62 | **GSM8K**
63 |
64 | ```bash
65 | bash scripts/eval/gsm8k_eval.sh
66 | ```
67 |
68 | **GSM8K (Voting)**
69 |
70 | ```bash
71 | bash scripts/eval/gsm8k_voting_eval.sh
72 | ```
73 |
74 | **Creative Writing**
75 |
76 | ```bash
77 | bash scripts/eval/creative_writing.sh
78 | ```
79 |
80 | ## To Do
81 |
82 | - [ ] Add the adaptive mechanism for choosing the hyper-parameter $\beta$.
83 |
84 | ## 📜 Citation
85 |
86 | If you find this repository helpful in your research or projects, please consider citing the GEM paper in your academic work. Your support is much appreciated! 🙌
87 |
88 |
89 | ```
90 | @inproceedings{li2025preserving,
91 | title={Preserving Diversity in Supervised Fine-Tuning of Large Language Models},
92 | author={Ziniu Li and Congliang Chen and Tian Xu and Zeyu Qin and Jiancong Xiao and Zhi-Quan Luo and Ruoyu Sun},
93 | booktitle={The Thirteenth International Conference on Learning Representations},
94 | year={2025},
95 | url={https://openreview.net/forum?id=NQEe7B7bSw}
96 | }
97 | ```
98 |
99 | Our work was previously titled "Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity", available on arXiv.
100 |
101 | ```bibtex
102 | @article{li2024entropic,
103 | title={Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity},
104 | author={Li, Ziniu and Chen, Congliang and Xu, Tian and Qin, Zeyu and Xiao, Jiancong and Sun, Ruoyu and Luo, Zhi-Quan},
105 | journal={arXiv preprint arXiv:2408.16673},
106 | year={2024}
107 | }
108 | ```
109 |
110 | Ziniu Li would like to acknowledge Zhengyang Tang for his minimalistic and clean implementation of SFT.
111 |
--------------------------------------------------------------------------------
/data/poem_generation/test-00000-of-00001.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/data/poem_generation/test-00000-of-00001.parquet
--------------------------------------------------------------------------------
/data/story_generation/test-00000-of-00001.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/data/story_generation/test-00000-of-00001.parquet
--------------------------------------------------------------------------------
/docs/GEM-2025-03-23.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/docs/GEM-2025-03-23.pdf
--------------------------------------------------------------------------------
/evaluation/convert_response_for_if_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pprint import pprint
4 | from tqdm import tqdm
5 | import pandas as pd
6 |
7 | from dataclasses import dataclass, field
8 | from transformers import AutoTokenizer, HfArgumentParser
9 | from datasets import load_dataset, Dataset
10 |
11 |
12 | @dataclass
13 | class Arguments:
14 | response_path: str = field(
15 | default=None,
16 | metadata={"help": "Response path (json file) to convert."},
17 | )
18 | tokenizer_path: str = field(
19 | default="meta-llama/Meta-Llama-3-8B-Instruct",
20 | metadata={"help": "Tokenizer path to help clean str."},
21 | )
22 | save_path: str = field(default="alpaca_eval_response.json")
23 |
24 |
25 | def main():
26 | parser = HfArgumentParser((Arguments,))
27 | (args,) = parser.parse_args_into_dataclasses()
28 |
29 | pprint(args.__dict__)
30 |
31 | old_data = json.load(open(args.response_path, "r"))
32 |
33 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
34 |
35 | dataset = []
36 | with open("./instruction_following_eval/data/input_data.jsonl") as f:
37 | for line in f.readlines():
38 | dataset.append(json.loads(line))
39 | if_eval_dataset = Dataset.from_pandas(pd.DataFrame(dataset))
40 |
41 | new_data = []
42 |
43 | for i in tqdm(range(len(old_data))):
44 | prompt = old_data[i]["prompt"]
45 |
46 | prompt_clean = (
47 | tokenizer.decode(
48 | tokenizer(prompt.replace(tokenizer.bos_token, "")).input_ids,
49 | skip_special_tokens=True,
50 | )
51 | .replace("user\n\n", "")
52 | .replace("assistant\n\n", "")
53 | )
54 | prompt_ref = if_eval_dataset[i]["prompt"]
55 |
56 | if prompt_clean.strip()[:10] != prompt_ref.strip()[:10]:
57 | import ipdb
58 |
59 | ipdb.set_trace()
60 |
61 | new_data.append(
62 | {
63 | "id": i,
64 | "prompt": prompt_ref,
65 | "response": (
66 | old_data[i]["answer"]
67 | if isinstance(old_data[i]["answer"], str)
68 | else old_data[i]["answer"][0]
69 | .replace("<|eot_id|>", "")
70 | .replace(tokenizer.eos_token, "")
71 | .strip()
72 | ),
73 | "generator": old_data[i]["model_name"],
74 | }
75 | )
76 | os.makedirs(
77 | args.save_path.replace(args.save_path.split("/")[-1], ""), exist_ok=True
78 | )
79 |
80 | with open(args.save_path, "w") as outfile:
81 | for entry in new_data:
82 | json.dump(entry, outfile)
83 | outfile.write("\n")
84 | print(f"Save response to {args.save_path}")
85 |
86 |
87 | if __name__ == "__main__":
88 | main()
89 |
--------------------------------------------------------------------------------
/evaluation/evaluation_diversity.py:
--------------------------------------------------------------------------------
1 | #################
2 | # This code is modified from https://github.com/facebookresearch/rlfh-gen-div
3 | #################
4 | import os
5 | from dataclasses import dataclass, field
6 | import json
7 | from pprint import pprint
8 |
9 | import torch
10 | import numpy as np
11 | import sentence_transformers
12 | from tqdm import tqdm
13 |
14 | # from sklearn.metrics.pairwise import cosine_similarity
15 | from transformers import set_seed, HfArgumentParser, AutoTokenizer
16 |
17 | from nltk.util import ngrams
18 | from nltk import word_tokenize
19 | from collections import Counter
20 |
21 | import sacrebleu
22 |
23 | @dataclass
24 | class AllArguments:
25 | response_path: str = field(
26 | default="./results/responses", metadata={"help": "Response path (json file)."}
27 | )
28 |
29 | tokenizer_path: str = field(default=None)
30 | detokenizer_path: str = field(default=None)
31 |
32 |
33 | class SentBertSimilarity:
34 | def __init__(self):
35 |
36 | self.model_name = "bert-large-nli-stsb-mean-tokens" # FIXME - hard coded
37 | self.model = sentence_transformers.SentenceTransformer(self.model_name)
38 | if torch.cuda.is_available():
39 | self.model.to(torch.device("cuda"))
40 |
41 | # @functools.cache
42 | def embed(self, sentence):
43 | return self.model.encode(sentence)
44 |
45 | # @functools.cache
46 | def sent_bert_cosine_similarity(self, resps_1, resps_2):
47 | embeds_1 = self.model.encode(
48 | resps_1, batch_size=1024, convert_to_tensor=True, show_progress_bar=False
49 | )
50 | embeds_2 = self.model.encode(
51 | resps_2, batch_size=1024, convert_to_tensor=True, show_progress_bar=False
52 | )
53 |
54 | if torch.cuda.is_available():
55 | embeds_1 = embeds_1.to(torch.device("cuda"))
56 | embeds_2 = embeds_2.to(torch.device("cuda"))
57 |
58 | dot_product = (embeds_1 * embeds_2).sum(dim=1)
59 |
60 | # Calculate cosine similarity
61 | cosine_similarity = dot_product / (embeds_1.norm(dim=1) * embeds_2.norm(dim=1))
62 |
63 | return cosine_similarity.detach().cpu().numpy()
64 |
65 | def __call__(self, resp_a, resp_b):
66 | return self.sent_bert_cosine_similarity(resp_a, resp_b)
67 |
68 |
69 | class SentBertDiversity:
70 | """
71 | Implements the diversity to similarity reduction specified on section 5 in the paper
72 | (https://arxiv.org/pdf/2004.02990.pdf)
73 | for any similarity metric.
74 |
75 | config:
76 | shared with the original similarity metric.
77 |
78 | usage:
79 | metric = Similarity2DiversityMetric(config, SimilarityMetricClassName)
80 | metric(response_set)
81 |
82 | inheritance guidelines:
83 | implement __init__ only
84 |
85 | inheritance example:
86 | see CosineSimilarity2Diversity
87 | """
88 |
89 | def __init__(self):
90 | self.similarity_metric = SentBertSimilarity()
91 |
92 | def __call__(self, response_set):
93 | similarity_list = []
94 | for i in tqdm(range(len(response_set))):
95 | for j in range(i):
96 | similarity_list.append(
97 | self.similarity_metric(response_set[i], response_set[j])
98 | )
99 | diversity_score = 1 - np.mean(similarity_list)
100 | return diversity_score
101 |
102 |
103 | class AveragedNgramDiversityMetric:
104 | """
105 | Calculates the mean values of an n-gram based diversity metric in range n in [n_min, n_max].
106 |
107 | config:
108 | shared with the original n-gram metric.
109 | n_min(int) > 0 - Specify the lowest n-gram value to be averaged
110 | n_max(int) > 0 - Specify the highest n-gram value to be averaged
111 |
112 | usage:
113 | metric = AveragedNgramDiversityMetric(config, NgramMetricClassName)
114 | metric(response_set)
115 |
116 | inheritance guidelines:
117 | implement __init__ only
118 |
119 | inheritance example:
120 | see AveragedDistinctNgrams
121 | """
122 |
123 | def __init__(self, n_min, n_max):
124 | # add n field
125 | self.n_min = n_min
126 | self.n_max = n_max
127 |
128 | def __call__(self, response_set):
129 | ngrams_results = []
130 | num_set = len(response_set)
131 | for i in range(len(response_set[0])):
132 | for n in range(self.n_min, self.n_max + 1):
133 | result = self.calculate_distinct_n(
134 | [response_set[j][i] for j in range(num_set)], n
135 | )
136 | ngrams_results.append(result)
137 | return np.mean(ngrams_results)
138 |
139 | def calculate_distinct_n(self, responses, n):
140 | all_ngrams = []
141 | for response in responses:
142 | tokens = word_tokenize(response)
143 | response_ngrams = list(ngrams(tokens, n))
144 | all_ngrams.extend(response_ngrams)
145 | unique_ngrams = len(set(all_ngrams))
146 | total_ngrams = len(all_ngrams)
147 |
148 | return unique_ngrams / total_ngrams if total_ngrams > 0 else 0
149 |
150 |
151 | class SelfBLEUMetric:
152 | def __call__(self, response_set):
153 | """Calculate the average Self-BLEU score for a list of texts."""
154 | bleu_scores = []
155 | k = len(response_set)
156 | for i in range(len(response_set[0])):
157 | texts = [response_set[j][i] for j in range(k)]
158 | bleu_scores.append(self.calculate_bleu_score(texts))
159 |
160 | return np.mean(bleu_scores)
161 |
162 | def calculate_bleu_score(self, texts):
163 | bleu_scores = []
164 | for i in range(len(texts)):
165 | # Treat the current text as the hypothesis
166 | hypothesis = texts[i]
167 | # Treat all other texts as references
168 | references = texts[:i] + texts[i + 1 :]
169 |
170 | if references: # Ensure there are references to compare against
171 | bleu_score = sacrebleu.corpus_bleu([hypothesis], [references])
172 | bleu_scores.append(bleu_score.score)
173 |
174 | # Compute the average BLEU score
175 | average_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
176 | return average_bleu
177 |
178 |
179 | def main():
180 | parser = HfArgumentParser((AllArguments,))
181 | (args,) = parser.parse_args_into_dataclasses()
182 | pprint(args.__dict__)
183 |
184 | if os.path.exists(args.response_path.replace(".json", "-cleaned.json")):
185 | args.response_path = args.response_path.replace(".json", "-cleaned.json")
186 |
187 | if args.response_path.endswith("-cleaned.json"):
188 | response_set = json.load(open(args.response_path, "r"))
189 | else:
190 | data = json.load(open(args.response_path, "r"))
191 |
192 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
193 | if args.detokenizer_path is not None:
194 | detokenizer = AutoTokenizer.from_pretrained(args.detokenizer_path)
195 | else:
196 | detokenizer = None
197 |
198 | response_set = []
199 | for i in tqdm(range(len(data))):
200 | n = len(data[i]["answer"])
201 | if len(response_set) == 0:
202 | response_set = [[] for _ in range(n)]
203 | else:
204 | assert len(response_set) == n
205 | for j in range(n):
206 | x = data[i]
207 | if detokenizer:
208 | prompt_str = (
209 | detokenizer.decode(
210 | detokenizer.encode(x["prompt"]), skip_special_tokens=True
211 | )
212 | .replace("user\n\n", "")
213 | .replace("assistant\n\n", "")
214 | )
215 | else:
216 | prompt_str = x["prompt"]
217 | if detokenizer:
218 | # ans_str = detokenizer.decode(
219 | # detokenizer.encode(data[i]["answer"][j]), skip_special_tokens=True
220 | # )
221 | ans_str = data[i]["answer"][j].replace("<|eot_id|>", "")
222 | else:
223 | ans_str = data[i]["answer"][j]
224 | chat = [
225 | {
226 | "role": "user",
227 | "content": prompt_str,
228 | },
229 | {"role": "assistant", "content": ans_str},
230 | ]
231 | res = tokenizer.apply_chat_template(chat, tokenize=False)
232 | response_set[j].append(res)
233 | json.dump(
234 | response_set,
235 | open(args.response_path.replace(".json", "-cleaned.json"), "w"),
236 | indent=2,
237 | )
238 |
239 | response_set = json.load(
240 | open(args.response_path.replace(".json", "-cleaned.json"), "r")
241 | )
242 | print("Finished Data Preparation.")
243 |
244 | evaluation_results = {
245 | "sentbert_diversity_score": None,
246 | "bleu_diversity_score": None,
247 | "averaged_ngram_diversity_score": None,
248 | }
249 |
250 | print("Calculating N-gram diversity score...")
251 | metric = AveragedNgramDiversityMetric(n_min=1, n_max=3)
252 | diversity_score = metric(response_set)
253 | evaluation_results["averaged_ngram_diversity_score"] = np.round(
254 | diversity_score * 100, 2
255 | )
256 | print("N-gram diversity score: {}".format(diversity_score))
257 |
258 | print("Calculating BLEU similarity score...")
259 | metric = SelfBLEUMetric()
260 | similarity_score = metric(response_set)
261 | evaluation_results["bleu_diversity_score"] = np.round(100 - similarity_score, 2)
262 | print("BLEU similarity score: {}".format(100 - similarity_score))
263 |
264 | print("Calculating Bert diversity score...")
265 | metric = SentBertDiversity()
266 | diversity_score = metric(response_set)
267 | evaluation_results["sentbert_diversity_score"] = np.round(diversity_score * 100, 2)
268 | print("Bert diversity score: {}".format(diversity_score))
269 |
270 | pprint(evaluation_results)
271 |
272 |
273 | if __name__ == "__main__":
274 | main()
275 |
--------------------------------------------------------------------------------
/evaluation/evaluation_gsm8k.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import sys
4 | import json
5 | from dataclasses import dataclass, field
6 | from pprint import pprint
7 | from tqdm import tqdm
8 | from fraction import Fraction
9 |
10 | import torch
11 | import numpy as np
12 |
13 | from datasets import load_dataset
14 | # from evaluate import load
15 |
16 | sys.path.append(
17 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, ".."))
18 | )
19 |
20 | from evaluation.utils.gsm8k import extract_answer_number
21 |
22 | from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer
23 |
24 | import vllm
25 | from vllm import SamplingParams
26 |
27 | TEMPLATE = """
28 | Your task is to answer the question below. Give step by step reasoning before you answer, and when you’re ready to answer, please use the format "The answer is: ..."\nQuestion: {question}
29 | """
30 |
31 |
32 | @dataclass
33 | class Arguments:
34 | dataset_name_or_path: str = field(default="gms8k")
35 |
36 | # model
37 | model_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat")
38 | tokenizer_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat")
39 |
40 | # generation
41 | do_sample: bool = field(default=False)
42 |
43 | use_vllm: bool = field(
44 | default=False, metadata={"help": "Whether use vLLM for generation."}
45 | )
46 | vllm_gpu_memory_utilization: float = field(
47 | default=0.9, metadata={"help": "vLLM GPU consumption ratio."}
48 | )
49 |
50 | seed: int = field(
51 | default=42, metadata={"help": "Random Seed for reproducing results."}
52 | )
53 |
54 | batch_size: int = field(default=10)
55 | top_k: int = field(default=-1)
56 | top_p: float = field(default=1.0)
57 | temperature: float = field(default=0.0, metadata={"help": "Temperature."})
58 | max_new_tokens: int = field(default=512, metadata={"help": "Max response length."})
59 |
60 | # save
61 | remove_old: bool = field(
62 | default=False, metadata={"help": "Whether to remove old file."}
63 | )
64 | save_path: str = field(
65 | default="evaluation_gsm8k.json",
66 | metadata={"help": "Evaluation results save path."},
67 | )
68 |
69 |
70 | def save_prompts_and_answers(
71 | model_name, prompts, labels, answers, evaluations, file_path
72 | ):
73 | assert len(prompts) == len(answers), "Mismatched lengths!"
74 | assert file_path.endswith(".json")
75 | data = [
76 | {
77 | "id": i,
78 | "model_name": model_name,
79 | "prompt": prompts[i],
80 | "label": labels[i],
81 | "answer": answers[i],
82 | "evaluation": evaluations[i],
83 | }
84 | for i in range(len(prompts))
85 | ]
86 | if not os.path.exists(file_path):
87 | with open(file_path, "w", encoding="utf-8") as file:
88 | json.dump(data, file, indent=2)
89 | else:
90 | with open(file_path, "r", encoding="utf-8") as file:
91 | data = json.load(file)
92 |
93 | # Determine the next id value
94 | next_id = data[-1]["id"] + 1 if data else 0
95 |
96 | # Create new entries and append them to the data list
97 | new_entries = [
98 | {
99 | "id": i + next_id,
100 | "model_name": model_name,
101 | "prompt": prompts[i],
102 | "label": labels[i],
103 | "answer": answers[i],
104 | "evaluation": evaluations[i],
105 | }
106 | for i in range(len(prompts))
107 | ]
108 | data.extend(new_entries)
109 |
110 | with open(file_path, "w", encoding="utf-8") as file:
111 | json.dump(data, file, indent=2)
112 |
113 |
114 | def main():
115 | parser = HfArgumentParser((Arguments,))
116 | (args,) = parser.parse_args_into_dataclasses()
117 | pprint(args.__dict__)
118 |
119 | if args.remove_old:
120 | if os.path.exists(args.save_path):
121 | os.remove(args.save_path)
122 |
123 | dataset = load_dataset(args.dataset_name_or_path, "main")
124 | dataset = dataset["test"]
125 |
126 | device = torch.device("cuda")
127 |
128 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
129 | if "llama-3" in tokenizer.name_or_path.lower():
130 | tokenizer.pad_token = tokenizer.decode(len(tokenizer) - 1)
131 | tokenizer.pad_token_id = len(tokenizer) - 1
132 | elif "llama-2" in tokenizer.name_or_path.lower():
133 | tokenizer.pad_token = tokenizer.unk_token
134 | tokenizer.pad_token_id = tokenizer.unk_token_id
135 | # tokenizer.model_max_length = int(8096)
136 |
137 | if args.use_vllm:
138 | model = vllm.LLM(
139 | model=args.model_name_or_path,
140 | tokenizer=args.tokenizer_name_or_path,
141 | tensor_parallel_size=torch.cuda.device_count(),
142 | dtype=torch.bfloat16,
143 | gpu_memory_utilization=args.vllm_gpu_memory_utilization,
144 | seed=args.seed,
145 | )
146 | args.batch_size = len(dataset)
147 | else:
148 | model = AutoModelForCausalLM.from_pretrained(
149 | args.model_name_or_path,
150 | torch_dtype=torch.bfloat16,
151 | attn_implementation="flash_attention_2",
152 | )
153 | model.to(device)
154 | model.eval()
155 |
156 | prompt_to_save = []
157 | ans_to_save = []
158 | labels_to_save = []
159 | evaluations_to_save = []
160 | count = 0
161 | total_acc = 0
162 | total_num = 0
163 | for i in tqdm(range(0, len(dataset), args.batch_size)):
164 | prompt = dataset[i : i + args.batch_size]["question"]
165 | prompt_conv = [
166 | [{"role": "user", "content": TEMPLATE.format(question=x)}] for x in prompt
167 | ]
168 | labels = [
169 | x.replace("####", "The answer is:")
170 | for x in dataset[i : i + args.batch_size]["answer"]
171 | ]
172 | prompt_str = tokenizer.apply_chat_template(
173 | prompt_conv, tokenize=False, add_generation_prompt=True
174 | )
175 |
176 | tokenizer.padding_side = "left"
177 | prompt_token = tokenizer.apply_chat_template(
178 | prompt_conv,
179 | padding="longest",
180 | add_generation_prompt=True,
181 | return_dict=True,
182 | return_tensors="pt",
183 | ).to(device)
184 |
185 | prompt_length = prompt_token.input_ids.size(-1)
186 |
187 | if args.use_vllm:
188 | prompt_token_ids = [
189 | prompt_token.input_ids[
190 | j, prompt_token.attention_mask[j].bool()
191 | ].tolist()
192 | for j in range(len(prompt_conv))
193 | ]
194 |
195 | sampling_params = SamplingParams(
196 | top_k=args.top_k,
197 | top_p=args.top_p,
198 | temperature=args.temperature,
199 | max_tokens=args.max_new_tokens,
200 | )
201 | with torch.no_grad():
202 | output_results = model.generate(
203 | prompt_token_ids=prompt_token_ids, sampling_params=sampling_params
204 | )
205 | ans_str = []
206 | for j in range(len(output_results)):
207 | ans_str.append(output_results[j].outputs[0].text)
208 |
209 | evaluation_results = []
210 | for j in range(len(output_results)):
211 | try:
212 | answer = extract_answer_number(ans_str[j])
213 | except Exception as e:
214 | print("========Error=========")
215 | print(e)
216 | print(ans_str[i])
217 | print()
218 | answer = None
219 | true_answer = extract_answer_number(labels[j])
220 | if answer is not None:
221 | evaluation_results.append(
222 | (answer, true_answer, answer == true_answer)
223 | )
224 | else:
225 | evaluation_results.append((answer, true_answer, None))
226 |
227 | total_num += len([x for x in evaluation_results if x[-1] is not None])
228 | total_acc += len(
229 | [x for x in evaluation_results if x is not None and x[-1] is True]
230 | )
231 |
232 | prompt_to_save.extend(prompt_str)
233 | ans_to_save.extend(ans_str)
234 | labels_to_save.extend(labels)
235 | evaluations_to_save.extend(evaluation_results)
236 | count += 1
237 |
238 | print("===========Prompt=============")
239 | print(prompt_str[0])
240 | print("===========Label=============")
241 | print(labels[0])
242 | print("===========Response=============")
243 | print(ans_str[0])
244 | print("===========Evaluation=============")
245 | print(evaluation_results[0])
246 | else:
247 | with torch.no_grad():
248 | outputs = model.generate(
249 | prompt_token.input_ids,
250 | attention_mask=prompt_token.attention_mask,
251 | max_new_tokens=args.max_new_tokens,
252 | pad_token_id=tokenizer.pad_token_id,
253 | do_sample=args.do_sample,
254 | )
255 | ans_token = outputs[:, prompt_length:]
256 | ans_str = tokenizer.batch_decode(ans_token, skip_special_tokens=True)
257 |
258 | evaluation_results = []
259 | for j in range(len(output_results)):
260 | try:
261 | answer = extract_answer_number(ans_str[j])
262 | except Exception as e:
263 | print("========Error=========")
264 | print(e)
265 | print(ans_str[i])
266 | print()
267 | answer = None
268 | true_answer = extract_answer_number(labels[j])
269 | if answer is not None:
270 | evaluation_results.append(
271 | (answer, true_answer, answer == true_answer)
272 | )
273 | else:
274 | evaluation_results.append((answer, true_answer, None))
275 |
276 | total_num += len([x for x in evaluation_results if x[-1] is not None])
277 | total_acc += len(
278 | [x for x in evaluation_results if x is not None and x[-1] is True]
279 | )
280 |
281 | prompt_to_save.extend(prompt_str)
282 | ans_to_save.extend(ans_str)
283 | labels_to_save.extend(labels)
284 | evaluations_to_save.extend(evaluation_results)
285 | count += 1
286 |
287 | print("===========Prompt=============")
288 | print(prompt_str[0])
289 | print("===========Label=============")
290 | print(labels[0])
291 | print("===========Response=============")
292 | print(ans_str[0])
293 | print("===========Evaluation=============")
294 | print(evaluation_results[0])
295 |
296 | if count % 10 == 0:
297 | save_prompts_and_answers(
298 | args.model_name_or_path,
299 | prompt_to_save,
300 | labels_to_save,
301 | ans_to_save,
302 | evaluations_to_save,
303 | args.save_path,
304 | )
305 | prompt_to_save.clear()
306 | ans_to_save.clear()
307 | labels_to_save.clear()
308 | evaluations_to_save.clear()
309 |
310 | if len(prompt_to_save) > 0:
311 | save_prompts_and_answers(
312 | args.model_name_or_path,
313 | prompt_to_save,
314 | labels_to_save,
315 | ans_to_save,
316 | evaluations_to_save,
317 | args.save_path,
318 | )
319 | prompt_to_save.clear()
320 | ans_to_save.clear()
321 | labels_to_save.clear()
322 | evaluations_to_save.clear()
323 |
324 | if total_num > 0:
325 | pprint(args.__dict__)
326 | print(
327 | "Acc over {} valid answers is {:.4f}, over {} all answers is {:.4f}".format(
328 | total_num, total_acc / total_num, len(dataset), total_acc / len(dataset)
329 | )
330 | )
331 |
332 |
333 | if __name__ == "__main__":
334 | main()
335 |
--------------------------------------------------------------------------------
/evaluation/evaluation_gsm8k_voting.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import sys
4 | import json
5 | from dataclasses import dataclass, field
6 | from pprint import pprint
7 | from tqdm import tqdm
8 | from collections import Counter
9 |
10 | import torch
11 | import numpy as np
12 |
13 | from datasets import load_dataset
14 | from evaluate import load
15 |
16 | sys.path.append(
17 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, ".."))
18 | )
19 |
20 | from evaluation.utils.gsm8k import extract_answer_number
21 |
22 | from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer
23 |
24 | import vllm
25 | from vllm import SamplingParams
26 |
27 | TEMPLATE = """
28 | Your task is to answer the question below. Give step by step reasoning before you answer, and when you’re ready to answer, please use the format "The answer is: ..."\nQuestion: {question}
29 | """
30 |
31 |
32 | @dataclass
33 | class Arguments:
34 | dataset_name_or_path: str = field(default="gms8k")
35 |
36 | # model
37 | model_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat")
38 | tokenizer_name_or_path: str = field(default="meta-llama/Llama-2-7b-chat")
39 | dtype: str = field(default="bf16", metadata={"choices": ["fp16", "bf16"]})
40 |
41 | # generation
42 | do_sample: bool = field(default=False)
43 | temperature: float = field(
44 | default=0.6,
45 | )
46 | top_k: int = field(default=50)
47 | top_p: float = field(default=1.0)
48 | n: int = field(default=1)
49 |
50 | use_vllm: bool = field(
51 | default=False, metadata={"help": "Whether use vLLM for generation."}
52 | )
53 | vllm_gpu_memory_utilization: float = field(
54 | default=0.9, metadata={"help": "vLLM GPU consumption ratio."}
55 | )
56 |
57 | seed: int = field(
58 | default=42, metadata={"help": "Random Seed for reproducing results."}
59 | )
60 |
61 | batch_size: int = field(default=16)
62 | max_new_tokens: int = field(default=512, metadata={"help": "Max response length."})
63 |
64 | # save
65 | remove_old: bool = field(
66 | default=False, metadata={"help": "Whether to remove old file."}
67 | )
68 | save_path: str = field(
69 | default="evaluation_gsm8k.json",
70 | metadata={"help": "Evaluation results save path."},
71 | )
72 |
73 |
74 | def save_prompts_and_answers(
75 | model_name, prompts, labels, answers, evaluations, file_path
76 | ):
77 | assert len(prompts) == len(answers), "Mismatched lengths!"
78 | assert file_path.endswith(".json")
79 | data = [
80 | {
81 | "id": i,
82 | "model_name": model_name,
83 | "prompt": prompts[i],
84 | "label": labels[i],
85 | "answer": answers[i],
86 | "evaluation": evaluations[i],
87 | }
88 | for i in range(len(prompts))
89 | ]
90 | if not os.path.exists(file_path):
91 | with open(file_path, "w", encoding="utf-8") as file:
92 | json.dump(data, file, indent=2)
93 | else:
94 | with open(file_path, "r", encoding="utf-8") as file:
95 | data = json.load(file)
96 |
97 | # Determine the next id value
98 | next_id = data[-1]["id"] + 1 if data else 0
99 |
100 | # Create new entries and append them to the data list
101 | new_entries = [
102 | {
103 | "id": i + next_id,
104 | "model_name": model_name,
105 | "prompt": prompts[i],
106 | "label": labels[i],
107 | "answer": answers[i],
108 | "evaluation": evaluations[i],
109 | }
110 | for i in range(len(prompts))
111 | ]
112 | data.extend(new_entries)
113 |
114 | with open(file_path, "w", encoding="utf-8") as file:
115 | json.dump(data, file, indent=2)
116 |
117 |
118 | def calculate_accuracy_voting(reference, candidates, depths=[1, 4, 8, 16, 32]):
119 | majority_voting_accuracies = []
120 | best_of_n_accuracies = []
121 |
122 | for depth in depths:
123 | if depth > len(candidates):
124 | break
125 | else:
126 | # Slice the candidates list to the current depth
127 | current_candidates = candidates[:depth]
128 | count = Counter(current_candidates)
129 | most_common = count.most_common(1)[0][0] # Get the most frequent answer
130 |
131 | # Majority voting accuracy
132 | is_correct_majority = most_common == reference
133 | majority_voting_accuracies.append(is_correct_majority)
134 |
135 | # Best of n accuracy
136 | is_correct_best_of_n = reference in current_candidates
137 | best_of_n_accuracies.append(is_correct_best_of_n)
138 |
139 | return majority_voting_accuracies, best_of_n_accuracies
140 |
141 |
142 | def main():
143 | parser = HfArgumentParser((Arguments,))
144 | (args,) = parser.parse_args_into_dataclasses()
145 | pprint(args.__dict__)
146 |
147 | if args.remove_old:
148 | if os.path.exists(args.save_path):
149 | os.remove(args.save_path)
150 |
151 | dataset = load_dataset(args.dataset_name_or_path, "main")
152 | dataset = dataset["test"]
153 |
154 | device = torch.device("cuda")
155 |
156 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
157 | if "llama-3" in tokenizer.name_or_path.lower():
158 | tokenizer.pad_token = tokenizer.decode(len(tokenizer) - 1)
159 | tokenizer.pad_token_id = len(tokenizer) - 1
160 | elif "llama-2" in tokenizer.name_or_path.lower():
161 | tokenizer.pad_token = tokenizer.unk_token
162 | tokenizer.pad_token_id = tokenizer.unk_token_id
163 | # tokenizer.model_max_length = int(8096)
164 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype]
165 |
166 | if args.use_vllm:
167 | model = vllm.LLM(
168 | model=args.model_name_or_path,
169 | tokenizer=args.tokenizer_name_or_path,
170 | tensor_parallel_size=torch.cuda.device_count(),
171 | dtype=dtype,
172 | gpu_memory_utilization=args.vllm_gpu_memory_utilization,
173 | seed=args.seed,
174 | )
175 | else:
176 | model = AutoModelForCausalLM.from_pretrained(
177 | args.model_name_or_path,
178 | torch_dtype=dtype,
179 | attn_implementation="flash_attention_2",
180 | )
181 | model.to(device)
182 | model.eval()
183 |
184 | prompt_to_save = []
185 | ans_to_save = []
186 | labels_to_save = []
187 | evaluations_to_save = []
188 |
189 | count = 0
190 | max_depth = max(1, int(np.log2(args.n)))
191 | majority_voting_all = np.zeros([len(dataset), max_depth], dtype=int)
192 | best_of_n_all = np.zeros([len(dataset), max_depth], dtype=int)
193 | for i in tqdm(range(0, len(dataset), args.batch_size)):
194 | prompt = dataset[i : i + args.batch_size]["question"]
195 | prompt_conv = [
196 | [{"role": "user", "content": TEMPLATE.format(question=x)}] for x in prompt
197 | ]
198 | labels = [
199 | # x.replace("####", "Final answer:")
200 | x.replace("####", "The answer is:")
201 | for x in dataset[i : i + args.batch_size]["answer"]
202 | ]
203 | prompt_str = tokenizer.apply_chat_template(
204 | prompt_conv, tokenize=False, add_generation_prompt=True
205 | )
206 |
207 | tokenizer.padding_side = "left"
208 | prompt_token = tokenizer.apply_chat_template(
209 | prompt_conv,
210 | padding="longest",
211 | add_generation_prompt=True,
212 | return_dict=True,
213 | return_tensors="pt",
214 | ).to(device)
215 |
216 | prompt_length = prompt_token.input_ids.size(-1)
217 |
218 | if args.use_vllm:
219 | prompt_token_ids = [
220 | prompt_token.input_ids[
221 | j, prompt_token.attention_mask[j].bool()
222 | ].tolist()
223 | for j in range(len(prompt_conv))
224 | ]
225 |
226 | sampling_params = SamplingParams(
227 | n=args.n,
228 | temperature=args.temperature,
229 | top_k=args.top_k,
230 | top_p=args.top_p,
231 | max_tokens=args.max_new_tokens,
232 | )
233 | with torch.no_grad():
234 | output_results = model.generate(
235 | prompt_token_ids=prompt_token_ids, sampling_params=sampling_params
236 | )
237 | ans_str = []
238 | evaluation_results = []
239 | for j in range(len(output_results)):
240 | final_answers = []
241 | for k in range(args.n):
242 | try:
243 | answer = extract_answer_number(
244 | output_results[j].outputs[k].text
245 | )
246 | except Exception as e:
247 | print("========Error=========")
248 | print(e)
249 | print(output_results[j].outputs[k].text)
250 | print()
251 | answer = None
252 | final_answers.append(answer)
253 |
254 | ans_str.append(
255 | [output_results[j].outputs[k].text for k in range(args.n)]
256 | )
257 | true_answer = extract_answer_number(labels[j])
258 | majority_evaluation, best_of_n_evaluation = calculate_accuracy_voting(
259 | true_answer, final_answers
260 | )
261 | majority_voting_all[count] = np.array(
262 | majority_evaluation, dtype=np.int32
263 | )
264 | best_of_n_all[count] = np.array(best_of_n_evaluation, dtype=np.int32)
265 | count += 1
266 |
267 | evaluation_results.append(
268 | {
269 | "true_answer": true_answer,
270 | "majority_evaluation": majority_evaluation,
271 | "best_of_n_evaluation": best_of_n_evaluation,
272 | }
273 | )
274 |
275 | prompt_to_save.extend(prompt_str)
276 | ans_to_save.extend(ans_str)
277 | labels_to_save.extend(labels)
278 | evaluations_to_save.extend(evaluation_results)
279 |
280 | pprint("===========Prompt=============")
281 | pprint(prompt_str[0])
282 | pprint("===========Label=============")
283 | pprint(labels[0])
284 | pprint("===========Response=============")
285 | pprint(ans_str[0])
286 | pprint("===========Evaluation=============")
287 | pprint(evaluation_results[0])
288 | pprint(
289 | "Majority Acc so far: {} Best Acc so far: {}".format(
290 | np.round(np.mean(majority_voting_all[:count], axis=0) * 100, 2),
291 | np.round(np.mean(best_of_n_all[:count], axis=0) * 100, 2),
292 | )
293 | )
294 | else:
295 | raise NotImplementedError
296 |
297 | if count % 128 == 0:
298 | save_prompts_and_answers(
299 | args.model_name_or_path,
300 | prompt_to_save,
301 | labels_to_save,
302 | ans_to_save,
303 | evaluations_to_save,
304 | args.save_path,
305 | )
306 | prompt_to_save.clear()
307 | ans_to_save.clear()
308 | labels_to_save.clear()
309 | evaluations_to_save.clear()
310 |
311 | if len(prompt_to_save) > 0:
312 | save_prompts_and_answers(
313 | args.model_name_or_path,
314 | prompt_to_save,
315 | labels_to_save,
316 | ans_to_save,
317 | evaluations_to_save,
318 | args.save_path,
319 | )
320 | prompt_to_save.clear()
321 | ans_to_save.clear()
322 | labels_to_save.clear()
323 | evaluations_to_save.clear()
324 |
325 | pprint(args.__dict__)
326 | print(
327 | "==> Majority Acc over the dataset: {} Best Acc over the dataset: {}".format(
328 | np.round(np.mean(majority_voting_all, axis=0) * 100, 2),
329 | np.round(np.mean(best_of_n_all, axis=0) * 100, 2),
330 | )
331 | )
332 |
333 |
334 | if __name__ == "__main__":
335 | main()
336 |
--------------------------------------------------------------------------------
/evaluation/evaluation_reward.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from pprint import pprint
4 | import json
5 | from types import MethodType
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | import torch
10 | from transformers import AutoTokenizer, pipeline
11 | from transformers import AutoModel, AutoModelForSequenceClassification, HfArgumentParser
12 |
13 |
14 | @dataclass
15 | class Arguments:
16 | model_name_or_path: str = field(default="sfairXC/FsfairX-LLaMA3-RM-v0.1")
17 | tokenizer_path: str = field(default=None)
18 |
19 | detokenizer_path: str = field(default=None)
20 |
21 | data_path: str = field(default=None)
22 | batch_size: int = field(default=1)
23 | max_size: int = field(default=None)
24 |
25 | save_path: str = field(default=None)
26 |
27 |
28 | def forward_value_fn(
29 | self,
30 | input_ids=None,
31 | attention_mask=None,
32 | past_key_values=None,
33 | position_ids=None,
34 | inputs_embeds=None,
35 | return_value_only=False,
36 | prompt_length=0,
37 | use_cache=False,
38 | **kwargs,
39 | ):
40 | transformer_outputs = self.model(
41 | input_ids,
42 | past_key_values=past_key_values,
43 | attention_mask=attention_mask,
44 | inputs_embeds=inputs_embeds,
45 | use_cache=use_cache,
46 | **kwargs,
47 | )
48 | hidden_states = transformer_outputs[0]
49 | values = self.score(hidden_states).squeeze(-1)
50 | if return_value_only:
51 | return values
52 | else:
53 | if attention_mask is None:
54 | chosen_end_scores = values[:, -1]
55 | else:
56 | last_index = attention_mask.cumsum(dim=1).argmax(dim=1)
57 | chosen_end_scores = values.gather(1, last_index.unsqueeze(1)).squeeze(1)
58 | return {
59 | "values": values,
60 | "chosen_end_scores": chosen_end_scores,
61 | }
62 |
63 |
64 | def calculation_best_of_n(data):
65 | print("Calculating best of n reward ....")
66 | best_n = np.zeros([len(data), 6]) # 1, 2, 4, 8, 16
67 | mean_n = np.zeros([len(data), 6]) # 1, 2, 4, 8, 16
68 | for i in tqdm(range(len(data))):
69 | rewards = data[i]["reward"]
70 | best_n[i][0] = rewards[0]
71 | best_n[i][1] = max(rewards[:2])
72 | best_n[i][2] = max(rewards[:4])
73 | best_n[i][3] = max(rewards[:8])
74 | best_n[i][4] = max(rewards[:16])
75 | best_n[i][5] = max(rewards[:32])
76 |
77 | mean_n[i][0] = rewards[0]
78 | mean_n[i][1] = np.mean(rewards[:2])
79 | mean_n[i][2] = np.mean(rewards[:4])
80 | mean_n[i][3] = np.mean(rewards[:8])
81 | mean_n[i][4] = np.mean(rewards[:16])
82 | mean_n[i][5] = np.mean(rewards[:32])
83 | best_n = np.mean(best_n, axis=0)
84 | print("Best of n: {}".format(np.round(best_n, 2)))
85 | mean_n = np.mean(mean_n, axis=0)
86 | print("Mean of n: {}".format(np.round(mean_n, 2)))
87 | return best_n, mean_n
88 |
89 |
90 | def main():
91 | parser = HfArgumentParser((Arguments,))
92 | (args,) = parser.parse_args_into_dataclasses()
93 | pprint(args.__dict__)
94 | assert args.data_path is not None
95 | assert args.save_path is not None
96 |
97 | device = torch.device("cuda")
98 |
99 | model_class = AutoModelForSequenceClassification
100 | flash_attn = True
101 | model = model_class.from_pretrained(
102 | args.model_name_or_path,
103 | torch_dtype=torch.bfloat16,
104 | attn_implementation="flash_attention_2" if flash_attn else "eager",
105 | trust_remote_code=True,
106 | )
107 | # model.forward_value = forward_value_fn
108 | model.forward_value = MethodType(forward_value_fn, model)
109 | model.eval()
110 | model.to(device)
111 |
112 | tokenizer = AutoTokenizer.from_pretrained(
113 | args.tokenizer_path or args.model_name_or_path
114 | )
115 | tokenizer.padding_side = "right"
116 | if args.detokenizer_path is not None:
117 | detokenizer = AutoTokenizer.from_pretrained(args.detokenizer_path)
118 | else:
119 | detokenizer = None
120 |
121 | response_data = json.load(open(args.data_path, "r"))
122 |
123 | if args.max_size:
124 | response_data = response_data[: args.max_size]
125 | if os.path.exists(args.save_path):
126 | response_data = json.load(open(args.save_path, "r"))
127 | calculation_best_of_n(response_data)
128 | return
129 | for start in tqdm(range(0, len(response_data), args.batch_size)):
130 | end = start + args.batch_size
131 | prompts = []
132 | answers = []
133 | for x in response_data[start:end]:
134 | if detokenizer:
135 | prompt_str = (
136 | detokenizer.decode(
137 | detokenizer.encode(x["prompt"]), skip_special_tokens=True
138 | )
139 | .replace("user\n\n", "")
140 | .replace("assistant\n\n", "")
141 | )
142 | else:
143 | if "prompt" in x:
144 | prompt_str = x["prompt"]
145 | elif "instruction" in x:
146 | prompt_str = x["instruction"]
147 | else:
148 | raise ValueError(x)
149 | if "answer" in x:
150 | for ans in x["answer"]:
151 | if detokenizer:
152 | ans_str = detokenizer.decode(
153 | detokenizer.encode(ans), skip_special_tokens=True
154 | )
155 | else:
156 | ans_str = ans
157 | prompts.append(prompt_str)
158 | answers.append(ans_str)
159 | elif "output" in x:
160 | ans_str = x["output"]
161 | prompts.append(prompt_str)
162 | answers.append(ans_str)
163 | else:
164 | raise ValueError(x)
165 |
166 | chat = []
167 | for i in range(len(prompts)):
168 | chat.append(
169 | [
170 | {"role": "user", "content": prompts[i]},
171 | {"role": "assistant", "content": answers[i]},
172 | ]
173 | )
174 | inputs = tokenizer.apply_chat_template(
175 | chat,
176 | padding="longest",
177 | add_generation_prompt=True,
178 | return_dict=True,
179 | return_tensors="pt",
180 | ).to(device)
181 |
182 | with torch.no_grad():
183 | if "FsfairX-LLaMA3-RM-v0.1" in args.model_name_or_path:
184 | outputs = model.forward_value(**inputs)["chosen_end_scores"]
185 | else:
186 | outputs = model(**inputs, use_cahe=False)
187 |
188 | c_start = 0
189 | for x in response_data[start:end]:
190 | if "answer" in x:
191 | x["reward"] = outputs[c_start : c_start + len(x["answer"])].tolist()
192 | c_start += len(x["answer"])
193 | elif "output" in x:
194 | x["reward"] = outputs[c_start].tolist()
195 | c_start += 1
196 | else:
197 | raise ValueError(x)
198 |
199 | print(chat[0])
200 | print(outputs[0])
201 |
202 | if "answer" in x:
203 | calculation_best_of_n(response_data)
204 |
205 | json.dump(response_data, open(args.save_path, "w"), indent=2)
206 | print("saving result to {}".format(args.save_path))
207 |
208 |
209 | if __name__ == "__main__":
210 | main()
211 |
--------------------------------------------------------------------------------
/evaluation/generate_response.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pprint import pprint
4 | from dataclasses import dataclass, field
5 | from tqdm import tqdm
6 | import pandas as pd
7 |
8 | import vllm
9 | from vllm import SamplingParams
10 |
11 | import torch
12 | from datasets import load_dataset, load_from_disk, Dataset
13 | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, HfArgumentParser
14 |
15 |
16 | @dataclass
17 | class Arguments:
18 | model_name_or_path: str = field(
19 | default="meta-llama/Llama-2-7b-chat", metadata={"help": "Model name or path."}
20 | )
21 |
22 | tokenizer_path: str = field(
23 | default="meta-llama/Llama-2-7b-chat", metadata={"help": "Tokenizer path."}
24 | )
25 |
26 | # dataset
27 | dataset_path: str = field(
28 | default="RLHFlow/Orca-distibalel-standard", metadata={"help": "Dataset path."}
29 | )
30 | split: str = field(default=None, metadata={"help": "Split of the dataset."})
31 | column_name: str = field(
32 | default=None, metadata={"help": "Column name to extract prompt."}
33 | )
34 | standard_format: bool = field(
35 | default=None, metadata={"help": "Dataset in the standard format."}
36 | )
37 | load_from_disk: bool = field(
38 | default=False, metadata={"help": "Whether use the load_from_disk method."}
39 | )
40 | max_size: int = field(
41 | default=None, metadata={"help": "Max data size for evaluation."}
42 | )
43 |
44 | use_vllm: bool = field(
45 | default=False, metadata={"help": "Whether use vLLM for generation."}
46 | )
47 | vllm_gpu_memory_utilization: float = field(
48 | default=0.9, metadata={"help": "vLLM GPU consumption ratio."}
49 | )
50 |
51 | seed: int = field(
52 | default=42, metadata={"help": "Random Seed for reproducing results."}
53 | )
54 |
55 | # generation
56 | batch_size: int = field(default=10)
57 |
58 | n: int = field(default=1, metadata={"help": "num of responses for each prompt."})
59 | do_sample: bool = field(
60 | default=True, metadata={"help": "Do sample for generation."}
61 | )
62 | top_k: int = field(default=50, metadata={"help": "Top k for generation."})
63 | top_p: float = field(default=0.9, metadata={"help": "Top p for generation."})
64 | temperature: float = field(
65 | default=0.6, metadata={"help": "Temperature for generation."}
66 | )
67 | max_new_tokens: int = field(default=1024, metadata={"help": "Max response length."})
68 |
69 | # save
70 | remove_old: bool = field(
71 | default=False, metadata={"help": "Whether to remove old file."}
72 | )
73 | save_path: str = field(
74 | default="evaluation_log_probability.json",
75 | metadata={"help": "Evaluation results save path."},
76 | )
77 |
78 | def __post_init__(self):
79 | if self.column_name is None:
80 | if "tatsu-lab/alpaca_eval" in self.dataset_path:
81 | self.column_name = "instruction"
82 | if "HuggingFaceH4/ultrachat_200k" in self.dataset_path:
83 | self.column_name = "prompt"
84 | if "if_eval" in self.dataset_path:
85 | self.column_name = "prompt"
86 | if "poem_generation" in self.dataset_path:
87 | self.column_name = "instruction"
88 | if "story_generation" in self.dataset_path:
89 | self.column_name = "instruction"
90 |
91 | if self.split is None:
92 | if "tatsu-lab/alpaca_eval" in self.dataset_path:
93 | self.split = "eval"
94 | if "HuggingFaceH4/ultrachat_200k" in self.dataset_path:
95 | self.split = "test_sft"
96 | if "if_eval" in self.dataset_path:
97 | self.split = "train"
98 | if "poem_generation" in self.dataset_path:
99 | self.split = "test"
100 | if "story_generation" in self.dataset_path:
101 | self.split = "test"
102 |
103 | if self.standard_format is None:
104 | if "tatsu-lab/alpaca_eval" in self.dataset_path:
105 | self.standard_format = False
106 | if "HuggingFaceH4/ultrachat_200k" in self.dataset_path:
107 | self.standard_format = False
108 | if "if_eval" in self.dataset_path:
109 | self.standard_format = False
110 | if "poem_generation" in self.dataset_path:
111 | self.standard_format = False
112 | if "story_generation" in self.dataset_path:
113 | self.standard_format = False
114 |
115 |
116 | def get_dataset(dataset_name, split="test", from_disk=False):
117 | if from_disk:
118 | dataset = load_from_disk(dataset_name)
119 | else:
120 | if "tatsu-lab/alpaca_eval" in dataset_name:
121 | dataset = load_dataset(dataset_name, "alpaca_eval")
122 | if "if_eval" in dataset_name:
123 | dataset = []
124 | with open("./data/if_eval_data.jsonl") as f:
125 | for line in f.readlines():
126 | dataset.append(json.loads(line))
127 | dataset = Dataset.from_pandas(pd.DataFrame(dataset))
128 | return dataset
129 | else:
130 | dataset = load_dataset(dataset_name)
131 | if split in dataset:
132 | return dataset[split]
133 | else:
134 | assert "train" in dataset
135 | total_size = len(dataset["train"])
136 | eval_size = min(1000, int(total_size * 0.1))
137 | train_size = total_size - eval_size
138 | print(
139 | "There is no {} in the dataset. I set {} samples from the train split.".format(
140 | split, eval_size
141 | )
142 | )
143 | return dataset["train"].shuffle(seed=42).select(range(train_size, total_size))
144 |
145 |
146 | def save_prompts_and_answers(model_name, prompts, answers, file_path):
147 | assert len(prompts) == len(answers), "Mismatched lengths!"
148 | assert file_path.endswith(".json")
149 | data = [
150 | {
151 | "id": i,
152 | "model_name": model_name,
153 | "prompt": prompts[i],
154 | "answer": answers[i],
155 | }
156 | for i in range(len(prompts))
157 | ]
158 | if not os.path.exists(file_path):
159 | with open(file_path, "w", encoding="utf-8") as file:
160 | json.dump(data, file, indent=2, ensure_ascii=False)
161 | else:
162 | with open(file_path, "r", encoding="utf-8") as file:
163 | data = json.load(file)
164 |
165 | # Determine the next id value
166 | next_id = data[-1]["id"] + 1 if data else 0
167 |
168 | # Create new entries and append them to the data list
169 | new_entries = [
170 | {
171 | "id": next_id + i,
172 | "model_name": model_name,
173 | "prompt": prompts[i],
174 | "answer": answers[i],
175 | }
176 | for i in range(len(prompts))
177 | ]
178 | data.extend(new_entries)
179 |
180 | with open(file_path, "w", encoding="utf-8") as file:
181 | json.dump(data, file, indent=2)
182 |
183 |
184 | def main():
185 | parser = HfArgumentParser((Arguments,))
186 | (args,) = parser.parse_args_into_dataclasses()
187 | pprint(args.__dict__)
188 |
189 | training_config = {}
190 | if os.path.exists(os.path.join(args.model_name_or_path, "args.json")) and False:
191 | # f = open(os.path.join(args.model_name_or_path, "args.json"), "r")
192 | # for line in f.readlines()[:-1]:
193 | # line_dict = json.loads(line)
194 | # for key, val in line_dict.items():
195 | # training_config[key] = val
196 | training_config = json.load(
197 | open(os.path.join(args.model_name_or_path, "args.json"), "r")
198 | )
199 | if training_config["algo"] == "sft":
200 | key_parameters = {
201 | "algo": "sft",
202 | "model_name_or_path": training_config["model_name_or_path"],
203 | "dataset": training_config["data_path"],
204 | "data_max_size": training_config["max_size"],
205 | "learning_rate": training_config["learning_rate"],
206 | "num_train_epochs": (training_config["num_train_epochs"]),
207 | "max_seq_len": training_config["max_seq_len"],
208 | }
209 | elif training_config["algo"] == "dpo":
210 | key_parameters = {
211 | "algo": "dpo",
212 | "actor_model_name_or_path": training_config["actor_model_name_or_path"],
213 | "max_entropy": training_config["max_entropy"],
214 | "beta": training_config["beta"],
215 | "tau": training_config["tau"] if "tau" in training_config else None,
216 | "gamma": training_config["gamma"],
217 | "alpha": training_config["alpha"],
218 | "dataset": training_config["data_path"],
219 | "learning_rate": training_config["actor_learning_rate"],
220 | "enable_ema": (
221 | training_config["enable_ema"]
222 | if "enable_ema" in training_config
223 | else None
224 | ),
225 | "ema_coeff": (
226 | training_config["ema_coeff"]
227 | if "ema_coeff" in training_config
228 | else None
229 | ),
230 | }
231 | print("===========Your Training Key Parameters===============")
232 | pprint(key_parameters)
233 | print("Your save path: {}".format(args.save_path))
234 | print()
235 |
236 | # if input("Is the save path correct (yes or no)?:\n") != "yes":
237 | # assert 0
238 | set_seed(args.seed)
239 | if os.path.exists(args.save_path):
240 | if args.remove_old:
241 | # if (
242 | # input(
243 | # "The given save path exists a file. Do you continue (yes ot no)?:\n"
244 | # )
245 | # != "yes"
246 | # ):
247 | # assert 0
248 | os.remove(args.save_path)
249 | else:
250 | print("{} exists. Exit!".format(args.save_path))
251 | return
252 |
253 | dataset = get_dataset(args.dataset_path, args.split, args.load_from_disk)
254 | if args.max_size:
255 | dataset = dataset.select(range(0, min(len(dataset), args.max_size)))
256 |
257 | device = torch.device("cuda")
258 |
259 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
260 | if "llama-3" in tokenizer.name_or_path.lower():
261 | tokenizer.pad_token = tokenizer.decode(len(tokenizer) - 1)
262 | tokenizer.pad_token_id = len(tokenizer) - 1
263 | elif "llama-2" in tokenizer.name_or_path.lower():
264 | tokenizer.pad_token = tokenizer.unk_token
265 | tokenizer.pad_token_id = tokenizer.unk_token_id
266 |
267 | if args.use_vllm:
268 | model = vllm.LLM(
269 | model=args.model_name_or_path,
270 | tokenizer=args.tokenizer_path,
271 | tensor_parallel_size=torch.cuda.device_count(),
272 | dtype=torch.bfloat16,
273 | gpu_memory_utilization=args.vllm_gpu_memory_utilization,
274 | seed=args.seed,
275 | swap_space=16,
276 | )
277 | args.batch_size = len(dataset)
278 | else:
279 | model = AutoModelForCausalLM.from_pretrained(
280 | args.model_name_or_path,
281 | torch_dtype=torch.bfloat16,
282 | attn_implementation="flash_attention_2",
283 | )
284 | model.to(device)
285 | model.eval()
286 |
287 | # eos_token_id = [tokenizer.eos_token_id]
288 | # if "llama-3" in tokenizer.name_or_path.lower():
289 | # eos_token_id.append(tokenizer("<|eot_id|>").input_ids[-1])
290 |
291 | prompt_to_save = []
292 | ans_to_save = []
293 | count = 0
294 | for i in tqdm(range(0, len(dataset), args.batch_size)):
295 | if args.standard_format:
296 | chosen_conv = dataset[i : i + args.batch_size][args.column_name]
297 | prompt_conv = [x[:-1] for x in chosen_conv]
298 |
299 | prompt_str = tokenizer.apply_chat_template(
300 | prompt_conv, tokenize=False, add_generation_prompt=True
301 | )
302 | else:
303 | prompt = dataset[i : i + args.batch_size][args.column_name]
304 | prompt_conv = [[{"role": "user", "content": x}] for x in prompt]
305 |
306 | prompt_str = tokenizer.apply_chat_template(
307 | prompt_conv, tokenize=False, add_generation_prompt=True
308 | )
309 |
310 | tokenizer.padding_side = "left"
311 | prompt_token = tokenizer.apply_chat_template(
312 | prompt_conv,
313 | padding="longest",
314 | add_generation_prompt=True,
315 | return_dict=True,
316 | return_tensors="pt",
317 | ).to(device)
318 |
319 | prompt_length = prompt_token.input_ids.size(-1)
320 |
321 | if args.use_vllm:
322 | prompt_token_ids = [
323 | prompt_token.input_ids[
324 | j, prompt_token.attention_mask[j].bool()
325 | ].tolist()
326 | for j in range(len(prompt_conv))
327 | ]
328 |
329 | sampling_params = SamplingParams(
330 | n=args.n,
331 | temperature=args.temperature,
332 | top_p=args.top_p,
333 | top_k=args.top_k,
334 | # stop_token_ids=[
335 | # tokenizer.eos_token_id,
336 | # tokenizer("<|eot_id|>").input_ids[-1],
337 | # ],
338 | max_tokens=args.max_new_tokens,
339 | )
340 | with torch.no_grad():
341 | output_results = model.generate(
342 | prompt_token_ids=prompt_token_ids, sampling_params=sampling_params
343 | )
344 | ans_str = []
345 | for j in range(len(output_results)):
346 | # if not output_results[j].outputs[0].finish_reason == "stop":
347 | # print(output_results[j].outputs[0].finish_reason)
348 | ans_str.append(
349 | [output_results[j].outputs[k].text for k in range(args.n)]
350 | )
351 | # ans_str = [x.replace("<|eot_id|>", "") for x in ans_str]
352 |
353 | prompt_to_save.extend(prompt_str)
354 | ans_to_save.extend(ans_str)
355 | else:
356 | with torch.no_grad():
357 | outputs = model.generate(
358 | prompt_token.input_ids,
359 | attention_mask=prompt_token.attention_mask,
360 | top_k=args.top_k,
361 | top_p=args.top_p,
362 | temperature=args.temperature,
363 | max_new_tokens=args.max_new_tokens,
364 | pad_token_id=tokenizer.pad_token_id,
365 | # eos_token_id=eos_token_id,
366 | do_sample=args.do_sample,
367 | )
368 | ans_token = outputs[:, prompt_length:]
369 | ans_str = tokenizer.batch_decode(ans_token, skip_special_tokens=True)
370 | # ans_str = [x.replace("<|eot_id|>", "") for x in ans_str]
371 |
372 | prompt_to_save.extend(prompt_str)
373 | ans_to_save.extend(ans_str)
374 | count += 1
375 |
376 | print(prompt_str[0])
377 | print(ans_str[0])
378 |
379 | if count % 10 == 0:
380 | save_prompts_and_answers(
381 | args.model_name_or_path,
382 | prompt_to_save,
383 | ans_to_save,
384 | args.save_path,
385 | )
386 | prompt_to_save.clear()
387 | ans_to_save.clear()
388 |
389 | if len(prompt_to_save) > 0:
390 | save_prompts_and_answers(
391 | args.model_name_or_path,
392 | prompt_to_save,
393 | ans_to_save,
394 | args.save_path,
395 | )
396 | prompt_to_save.clear()
397 | ans_to_save.clear()
398 |
399 |
400 | if __name__ == "__main__":
401 | main()
402 |
--------------------------------------------------------------------------------
/evaluation/utils/gsm8k.py:
--------------------------------------------------------------------------------
1 | import re
2 | from fraction import Fraction
3 | import numpy as np
4 |
5 |
6 | def is_number(s):
7 | try:
8 | float(s)
9 | return True
10 | except ValueError:
11 | pass
12 | try:
13 | import unicodedata
14 |
15 | unicodedata.numeric(s)
16 | return True
17 | except (TypeError, ValueError):
18 | pass
19 | return False
20 |
21 |
22 | def extract_answer_number(completion):
23 | completion_new = completion.replace("the answer is", "The answer is")
24 | text = completion_new.split("The answer is")
25 | if len(text) > 1:
26 | extract_ans = text[-1].strip()
27 | match = re.search(r"[\-+]?\d*[\.,/]?\d+", extract_ans)
28 | if match:
29 | if "/" in match.group():
30 | denominator = match.group().split("/")[1]
31 | numerator = match.group().split("/")[0]
32 | if is_number(denominator) == True and is_number(numerator) == True:
33 | if denominator == "0":
34 | return round(float(numerator.replace(",", "")))
35 | else:
36 | frac = Fraction(match.group().replace(",", ""))
37 | num_numerator = frac.numerator
38 | num_denominator = frac.denominator
39 | return round(float(num_numerator / num_denominator))
40 | else:
41 | return None
42 | else:
43 | if float(match.group().replace(",", "")) == float("inf"):
44 | return None
45 | return round(float(match.group().replace(",", "")))
46 | else:
47 | return None
48 | else:
49 | return None
50 |
51 | def evaluation_answers(
52 | true_answers, predicted_answers, eos_tokens=["<|eot_id|>"], print_results=False
53 | ):
54 | assert len(true_answers) == len(
55 | predicted_answers
56 | ), "true answers (size={}) != predicted answers (size={})".format(
57 | len(true_answers), len(predicted_answers)
58 | )
59 | predictions = []
60 | skipped_indices = []
61 | for i in range(len(predicted_answers)):
62 | final_answer = extract_answer_number(predicted_answers[i])
63 | if final_answer is not None:
64 | predictions.append(final_answer)
65 | else:
66 | skipped_indices.append(i)
67 |
68 | references = []
69 | for i in range(len(true_answers)):
70 | if i not in skipped_indices:
71 | final_answer = extract_answer_number(true_answers[i])
72 | references.append(final_answer)
73 |
74 | assert len(predictions) == len(
75 | references
76 | ), "predictions (size={}) != references (size={})".format(
77 | len(predictions), len(references)
78 | )
79 |
80 | evaluations = []
81 | c = 0
82 | for i in range(len(true_answers)):
83 | if i in skipped_indices:
84 | evaluations.append((None, None, None))
85 | else:
86 | evaluations.append(
87 | (references[c], predictions[c], references[c] == predictions[c])
88 | )
89 | c += 1
90 | if print_results:
91 | print(
92 | "There are {}/{} matched results from the given strings. Acc: {:.4f}".format(
93 | len(predictions),
94 | len(true_answers),
95 | np.mean([x[-1] for x in evaluations if x[-1] is not None]),
96 | )
97 | )
98 | return evaluations
99 |
--------------------------------------------------------------------------------
/img/gem_vs_ce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/img/gem_vs_ce.png
--------------------------------------------------------------------------------
/img/gem_with_remax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liziniu/GEM/e5b1430979fa12bf8ab7398b2ccc71dff795bbee/img/gem_with_remax.png
--------------------------------------------------------------------------------
/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
5 | import torch
6 |
7 | from datasets import load_dataset
8 |
9 | from argparse import ArgumentParser
10 | from transformers import AutoTokenizer
11 | from multiprocessing import Pool
12 | from tqdm import tqdm
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument(
16 | "--dataset_name_or_path",
17 | type=str,
18 | default="HuggingFaceH4/ultrafeedback_binarized",
19 | )
20 | parser.add_argument(
21 | "--split",
22 | type=str,
23 | default="train",
24 | )
25 | parser.add_argument(
26 | "--start",
27 | type=int,
28 | default=0,
29 | )
30 | parser.add_argument(
31 | "--end",
32 | type=int,
33 | default=None,
34 | )
35 | parser.add_argument(
36 | "--output_file",
37 | type=str,
38 | )
39 | parser.add_argument(
40 | "--tokenizer_name_or_path",
41 | type=str,
42 | required=True
43 | )
44 | parser.add_argument("--max_seq_length", type=int, default=4096)
45 | parser.add_argument("--preprocessing_num_workers", type=int, default=64)
46 | args = parser.parse_args()
47 |
48 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
49 | print(f"load tokenizer from {args.tokenizer_name_or_path} done.")
50 | max_seq_length = args.max_seq_length
51 |
52 | input_data = load_dataset(args.dataset_name_or_path)
53 | if args.split:
54 | input_data = input_data[args.split]
55 | if args.end is None:
56 | args.end = len(input_data)
57 | input_data = input_data.select(range(args.start, args.end))
58 | print(
59 | f"load input data from {args.dataset_name_or_path} done. len(input_data): {len(input_data)}"
60 | )
61 |
62 |
63 | def encode_sft_example(example, verbose=False):
64 | """
65 | This function encodes a single example into a format that can be used for sft training.
66 | Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
67 | We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
68 | """
69 | messages = example["messages"]
70 | if len(messages) == 0:
71 | raise ValueError("messages field is empty.")
72 | if verbose:
73 | chat_messages = tokenizer.apply_chat_template(
74 | conversation=messages,
75 | tokenize=False,
76 | return_tensors="pt",
77 | padding=False,
78 | truncation=True,
79 | max_length=max_seq_length,
80 | add_generation_prompt=False,
81 | )
82 | print(f"chat_messages:\n[{chat_messages}]")
83 | input_ids = tokenizer.apply_chat_template(
84 | conversation=messages,
85 | tokenize=True,
86 | return_tensors="pt",
87 | padding=False,
88 | truncation=True,
89 | max_length=max_seq_length,
90 | add_generation_prompt=False,
91 | )
92 | labels = input_ids.clone()
93 | # mask the non-assistant part for avoiding loss
94 | for message_idx, message in enumerate(messages):
95 | if message["role"] != "assistant":
96 | # we calculate the start index of this non-assistant message
97 | if message_idx == 0:
98 | message_start_idx = 0
99 | else:
100 | message_start_idx = tokenizer.apply_chat_template(
101 | conversation=messages[
102 | :message_idx
103 | ], # here marks the end of the previous messages
104 | tokenize=True,
105 | return_tensors="pt",
106 | padding=False,
107 | truncation=True,
108 | max_length=max_seq_length,
109 | add_generation_prompt=False,
110 | ).shape[1]
111 | # next, we calculate the end index of this non-assistant message
112 | if (
113 | message_idx < len(messages) - 1
114 | and messages[message_idx + 1]["role"] == "assistant"
115 | ):
116 | # for intermediate messages that follow with an assistant message, we need to
117 | # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
118 | # (e.g., `<|assistant|>`)
119 | message_end_idx = tokenizer.apply_chat_template(
120 | conversation=messages[: message_idx + 1],
121 | tokenize=True,
122 | return_tensors="pt",
123 | padding=False,
124 | truncation=True,
125 | max_length=max_seq_length,
126 | add_generation_prompt=True,
127 | ).shape[1]
128 | else:
129 | # for the last message or the message that doesn't follow with an assistant message,
130 | # we don't need to add the assistant generation prefix
131 | message_end_idx = tokenizer.apply_chat_template(
132 | conversation=messages[: message_idx + 1],
133 | tokenize=True,
134 | return_tensors="pt",
135 | padding=False,
136 | truncation=True,
137 | max_length=max_seq_length,
138 | add_generation_prompt=False,
139 | ).shape[1]
140 | # set the label to -100 for the non-assistant part
141 | labels[:, message_start_idx:message_end_idx] = -100
142 | if max_seq_length and message_end_idx >= max_seq_length:
143 | break
144 | attention_mask = torch.ones_like(input_ids)
145 | return {
146 | "input_ids": input_ids.flatten().tolist(),
147 | "labels": labels.flatten().tolist(),
148 | "attention_mask": attention_mask.flatten().tolist(),
149 | }
150 |
151 |
152 | print(encode_sft_example(input_data[0], verbose=True))
153 |
154 | tokenized_data = []
155 | with Pool(args.preprocessing_num_workers) as p:
156 | pbar = tqdm(input_data, desc=f"tokenizing")
157 | for tokenized_example in p.imap(encode_sft_example, pbar):
158 | dump = json.dumps(tokenized_example)
159 | tokenized_data.append(dump)
160 |
161 | with open(args.output_file, "w") as fw:
162 | for dump in tokenized_data:
163 | fw.write(dump + "\n")
164 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | transformers==4.45.2
3 | datasets==2.14.6
4 | deepspeed==0.15.0
5 | flash-attn==2.6.3
6 | accelerate
7 | vllm==2.6.1
8 | sentence-transformers==3.0.1
9 | nltk
10 | fraction
11 | sacrebleu
--------------------------------------------------------------------------------
/scripts/eval/creative_writing.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export CUDA_VISIBLE_DEVICES="0"
9 |
10 | MODEL_PATH="model-path"
11 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
12 |
13 | SEED=42
14 | N=16
15 | T=1.0
16 | K=50
17 | P=0.9
18 |
19 |
20 | ############################################
21 | # poem writing
22 | ############################################
23 |
24 | DATA_PATH="./data/poem_generation"
25 |
26 | python analysis/evaluation/generate_response.py \
27 | --model_name_or_path $MODEL_PATH \
28 | --tokenizer_path $TOKENIZER_PATH \
29 | --dataset_path $DATA_PATH \
30 | --max_size 1000 \
31 | --seed $SEED \
32 | --temperature $T \
33 | --top_k $K \
34 | --top_p $P \
35 | --max_new_tokens 512 \
36 | --n $N \
37 | --use_vllm True \
38 | --do_sample True \
39 | --remove_old True \
40 | --save_path "${MODEL_PATH}/poem-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json"
41 |
42 |
43 | python analysis/evaluation/evaluation_diversity.py \
44 | --tokenizer_path $TOKENIZER_PATH \
45 | --detokenizer_path $TOKENIZER_PATH \
46 | --response_path "${MODEL_PATH}/poem-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json" \
47 | 2>&1 | tee ${MODEL_PATH}/diversity_eval-poem-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.log
48 |
49 | ############################################
50 | # story writing
51 | ############################################
52 |
53 | DATA_PATH="./data/story_generation"
54 |
55 | python evaluation/generate_response.py \
56 | --model_name_or_path $MODEL_PATH \
57 | --tokenizer_path $TOKENIZER_PATH \
58 | --dataset_path $DATA_PATH \
59 | --max_size 500 \
60 | --seed $SEED \
61 | --temperature $T \
62 | --top_k $K \
63 | --top_p $P \
64 | --max_new_tokens 512 \
65 | --n $N \
66 | --use_vllm True \
67 | --do_sample True \
68 | --remove_old True \
69 | --save_path "${MODEL_PATH}/story-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json"
70 |
71 | python evaluation/evaluation_diversity.py \
72 | --tokenizer_path $TOKENIZER_PATH \
73 | --detokenizer_path $TOKENIZER_PATH \
74 | --response_path "${MODEL_PATH}/story-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.json" \
75 | 2>&1 | tee ${MODEL_PATH}/diversity_eval-story-seed_${SEED}-n_${N}-T_${T}_K_${K}_P_${P}.log
--------------------------------------------------------------------------------
/scripts/eval/gsm8k_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export CUDA_VISIBLE_DEVICES="0"
9 |
10 | DATA_PATH="gsm8k"
11 | MODEL_PATH="model-path"
12 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
13 |
14 | T=0.0
15 | K=-1
16 | P=1.0
17 |
18 | python analysis/evaluation/evaluation_gsm8k.py \
19 | --model_name_or_path $MODEL_PATH \
20 | --tokenizer_name_or_path $TOKENIZER_PATH \
21 | --dataset_name_or_path $DATA_PATH \
22 | --batch_size 20 \
23 | --max_new_tokens 512 \
24 | --use_vllm True \
25 | --remove_old True \
26 | --temperature $T \
27 | --top_p $P \
28 | --top_k $K \
29 | --save_path "${MODEL_PATH}/gsm8k_T_${T}_K_${K}_P_${P}.json" \
30 | 2>&1 | tee ${MODEL_PATH}/gsm8k_T_${T}_K_${K}_P_${P}.log
31 |
--------------------------------------------------------------------------------
/scripts/eval/gsm8k_voting_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export CUDA_VISIBLE_DEVICES="0"
9 |
10 | DATA_PATH="gsm8k"
11 | MODEL_PATH="model-path"
12 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
13 |
14 | SEED=42
15 | N=32
16 | T=0.6
17 | K=50
18 | P=0.9
19 |
20 |
21 | python analysis/evaluation/evaluation_gsm8k_voting.py \
22 | --model_name_or_path $MODEL_PATH \
23 | --tokenizer_name_or_path $TOKENIZER_PATH \
24 | --dataset_name_or_path $DATA_PATH \
25 | --dtype bf16 \
26 | --batch_size 128 \
27 | --max_new_tokens 512 \
28 | --seed $SEED \
29 | --n $N \
30 | --temperature $T \
31 | --top_k $K \
32 | --top_p $P \
33 | --use_vllm True \
34 | --remove_old True \
35 | --save_path "${MODEL_PATH}/gsm8k_voting-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.json" \
36 | 2>&1 | tee "${MODEL_PATH}/gsm8k_evaluation-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.log"
37 |
--------------------------------------------------------------------------------
/scripts/eval/reward_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export CUDA_VISIBLE_DEVICES="0"
9 |
10 | DATA_PATH="tatsu-lab/alpaca_eval"
11 | MODEL_PATH="model-path"
12 | TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
13 | REWARD_MODEL="/sfairXC/FsfairX-LLaMA3-RM-v0.1"
14 |
15 | SEED=42
16 | T=0.6
17 | K=50
18 | P=0.9
19 | N=16
20 |
21 | python analysis/evaluation/generate_response.py \
22 | --model_name_or_path $MODEL_PATH \
23 | --tokenizer_path $TOKENIZER_PATH \
24 | --dataset_path $DATA_PATH \
25 | --max_size 1000 \
26 | --seed $SEED \
27 | --temperature $T \
28 | --top_k $K \
29 | --top_p $P \
30 | --max_new_tokens 2048 \
31 | --n $N \
32 | --use_vllm True \
33 | --save_path "${MODEL_PATH}/alpaca_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.json"
34 |
35 | python analysis/evaluation/evaluation_reward.py \
36 | --model_name_or_path $REWARD_MODEL \
37 | --batch_size 8 \
38 | --detokenizer_path $TOKENIZER_PATH \
39 | --data_path "${MODEL_PATH}/alpaca_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.json" \
40 | --save_path "${MODEL_PATH}/alpaca_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}-reward.json" \
41 | 2>&1 | tee ${MODEL_PATH}/reward_eval-seed_${SEED}-n_${N}-T_${T}-K_${K}-P_${P}.log
42 |
--------------------------------------------------------------------------------
/scripts/llama3.1/tokenize_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export FLASH_ATTENTION_DETERMINISTIC="1"
9 | export CUDA_VISIBLE_DEVICES="0"
10 |
11 | # tokenize train data
12 | python preprocess_data.py \
13 | --dataset_name_or_path "HuggingFaceH4/ultrafeedback_binarized" \
14 | --split "train_sft" \
15 | --tokenizer_name_or_path "meta-llama/Llama-3.1-8B-Instruct" \
16 | --max_seq_length 2048 \
17 | --output_file "./data/ultrafeedback_sft_train_llama3.1_tokenized.jsonl"
18 |
19 | # tokenize test data
20 | python preprocess_data.py \
21 | --dataset_name_or_path "HuggingFaceH4/ultrafeedback_binarized" \
22 | --split "test_sft" \
23 | --tokenizer_name_or_path "meta-llama/Llama-3.1-8B-Instruct" \
24 | --max_seq_length 2048 \
25 | --output_file "./data/ultrafeedback_sft_test_llama3.1_tokenized.jsonl"
--------------------------------------------------------------------------------
/scripts/llama3.1/train_ce_ultrafeedback.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export FLASH_ATTENTION_DETERMINISTIC="1"
9 |
10 | TRAIN_TOKENIZED_FILE="./data/ultrafeedback_sft_train_llama3.1_tokenized.jsonl"
11 | TEST_TOKENIZED_FILE="./data/ultrafeedback_sft_test_llama3.1_tokenized.jsonl"
12 |
13 | MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B"
14 | SEED=1234
15 |
16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"`
17 | OUTPUT_DIR="./log/sft_ce-llama3.1-8b-ultrafeedback-$TIME_STEP-$SEED"
18 |
19 | mkdir -p $OUTPUT_DIR
20 |
21 | deepspeed train.py \
22 | --deepspeed scripts/zero2.json \
23 | --seed $SEED \
24 | --model_name_or_path $MODEL_NAME_OR_PATH \
25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \
26 | --test_tokenized_file $TEST_TOKENIZED_FILE \
27 | --output_dir $OUTPUT_DIR \
28 | --per_device_train_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --evaluation_strategy "epoch" \
31 | --save_strategy "no" \
32 | --loss "ce" \
33 | --learning_rate 2e-5 \
34 | --lr_scheduler_type cosine \
35 | --warmup_ratio 0.03 \
36 | --num_train_epochs 3 \
37 | --logging_steps 10 \
38 | --report_to "tensorboard" \
39 | --gradient_checkpointing True \
40 | --overwrite_output_dir \
41 | --bf16 True \
42 | 2>&1 | tee $OUTPUT_DIR/training.log
--------------------------------------------------------------------------------
/scripts/llama3.1/train_gem_ultrafeedback.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export FLASH_ATTENTION_DETERMINISTIC="1"
9 |
10 | TRAIN_TOKENIZED_FILE="./data/ultrafeedback_sft_train_llama3.1_tokenized.jsonl"
11 | TEST_TOKENIZED_FILE="./data/ultrafeedback_sft_test_llama3.1_tokenized.jsonl"
12 |
13 | MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B"
14 | SEED=1234
15 |
16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"`
17 | OUTPUT_DIR="./log/sft_ce-llama3.1-8b-ultrafeedback-$TIME_STEP-$SEED"
18 |
19 | mkdir -p $OUTPUT_DIR
20 |
21 | deepspeed train.py \
22 | --deepspeed scripts/zero2.json \
23 | --seed $SEED \
24 | --model_name_or_path $MODEL_NAME_OR_PATH \
25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \
26 | --test_tokenized_file $TEST_TOKENIZED_FILE \
27 | --output_dir $OUTPUT_DIR \
28 | --per_device_train_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --evaluation_strategy "epoch" \
31 | --save_strategy "no" \
32 | --loss "gem" \
33 | --gem_beta 0.7 \
34 | --gem_h "logsigmoid" \
35 | --learning_rate 2e-5 \
36 | --lr_scheduler_type cosine \
37 | --warmup_ratio 0.03 \
38 | --num_train_epochs 3 \
39 | --logging_steps 10 \
40 | --report_to "tensorboard" \
41 | --gradient_checkpointing True \
42 | --overwrite_output_dir \
43 | --bf16 True \
44 | 2>&1 | tee $OUTPUT_DIR/training.log
--------------------------------------------------------------------------------
/scripts/qwen2.5/tokenize_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export FLASH_ATTENTION_DETERMINISTIC="1"
9 | export CUDA_VISIBLE_DEVICES="0"
10 |
11 | # tokenize train data
12 | python preprocess_data.py \
13 | --dataset_name_or_path "AIMO/NuminaMath-CoT" \
14 | --split "train" \
15 | --tokenizer_name_or_path "Qwen/Qwen2.5-Math-7B" \
16 | --max_seq_length 2048 \
17 | --output_file "./data/numina_sft_train_qwen2.5_tokenized.jsonl" \
18 | --start 0 \
19 | --end 20000
20 |
21 | # tokenize test data
22 | python preprocess_data.py \
23 | --dataset_name_or_path "AIMO/NuminaMath-CoT" \
24 | --split "train" \
25 | --tokenizer_name_or_path "Qwen/Qwen2.5-Math-7B" \
26 | --max_seq_length 2048 \
27 | --output_file "./data/numina_sft_test_qwen2.5_tokenized.jsonl" \
28 | --start 20000 \
29 | --end 21000
--------------------------------------------------------------------------------
/scripts/qwen2.5/train_ce_numina.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export FLASH_ATTENTION_DETERMINISTIC="1"
9 |
10 | TRAIN_TOKENIZED_FILE="./data/numina_sft_train_qwen2.5_tokenized.jsonl"
11 | TEST_TOKENIZED_FILE="./data/numina_sft_test_qwen2.5_tokenized.jsonl"
12 |
13 | MODEL_NAME_OR_PATH="Qwen/Qwen2.5-Math-7B"
14 | SEED=1234
15 |
16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"`
17 | OUTPUT_DIR="./log/sft_ce-qwen2.5_7b-numina-$TIME_STEP-$SEED"
18 |
19 | mkdir -p $OUTPUT_DIR
20 |
21 | deepspeed train.py \
22 | --deepspeed scripts/zero3.json \
23 | --seed $SEED \
24 | --model_name_or_path $MODEL_NAME_OR_PATH \
25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \
26 | --test_tokenized_file $TEST_TOKENIZED_FILE \
27 | --output_dir $OUTPUT_DIR \
28 | --per_device_train_batch_size 8 \
29 | --gradient_accumulation_steps 2 \
30 | --evaluation_strategy "epoch" \
31 | --save_strategy "no" \
32 | --loss "ce" \
33 | --learning_rate 2e-5 \
34 | --lr_scheduler_type cosine \
35 | --warmup_ratio 0.03 \
36 | --num_train_epochs 3 \
37 | --logging_steps 10 \
38 | --report_to "tensorboard" \
39 | --gradient_checkpointing True \
40 | --overwrite_output_dir \
41 | --bf16 True \
42 | 2>&1 | tee $OUTPUT_DIR/training.log
--------------------------------------------------------------------------------
/scripts/qwen2.5/train_gem_numina.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | export HF_DATASETS_OFFLINE=1
7 | export TRANSFORMERS_OFFLINE=1
8 | export FLASH_ATTENTION_DETERMINISTIC="1"
9 |
10 | TRAIN_TOKENIZED_FILE="./data/numina_sft_train_qwen2.5_tokenized.jsonl"
11 | TEST_TOKENIZED_FILE="./data/numina_sft_test_qwen2.5_tokenized.jsonl"
12 |
13 | MODEL_NAME_OR_PATH="Qwen/Qwen2.5-Math-7B"
14 | SEED=1234
15 |
16 | TIME_STEP=`date "+%Y-%m-%d-%H-%M-%S"`
17 | OUTPUT_DIR="./log/sft_gem-qwen2.5_7b-numina-$TIME_STEP-$SEED"
18 |
19 | mkdir -p $OUTPUT_DIR
20 |
21 | deepspeed train.py \
22 | --deepspeed scripts/zero3.json \
23 | --seed $SEED \
24 | --model_name_or_path $MODEL_NAME_OR_PATH \
25 | --train_tokenized_file $TRAIN_TOKENIZED_FILE \
26 | --test_tokenized_file $TEST_TOKENIZED_FILE \
27 | --output_dir $OUTPUT_DIR \
28 | --per_device_train_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --evaluation_strategy "epoch" \
31 | --save_strategy "no" \
32 | --loss "gem" \
33 | --gem_beta 0.7 \
34 | --gem_h "logsigmoid" \
35 | --learning_rate 2e-5 \
36 | --lr_scheduler_type cosine \
37 | --warmup_ratio 0.03 \
38 | --num_train_epochs 3 \
39 | --logging_steps 10 \
40 | --report_to "tensorboard" \
41 | --gradient_checkpointing True \
42 | --overwrite_output_dir \
43 | --bf16 True \
44 | 2>&1 | tee $OUTPUT_DIR/training.log
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": false,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": false,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_param_persistence_threshold": "auto",
23 | "stage3_max_live_parameters": 1e9,
24 | "stage3_max_reuse_distance": 1e9,
25 | "stage3_gather_16bit_weights_on_model_save": true
26 | }
27 | }
--------------------------------------------------------------------------------
/sft_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from transformers import Trainer
5 | from transformers.trainer import (
6 | ###
7 | _is_peft_model,
8 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
9 | is_torch_xla_available,
10 | )
11 | from typing import List, Optional, Dict
12 | from utils.gem_triton_loss import GEMLoss
13 |
14 |
15 | class SFTTrainer(Trainer):
16 |
17 | @torch.no_grad
18 | def compute_training_logs(self, logits, labels):
19 | shift_logits = logits[..., :-1, :]
20 | shift_labels = labels[..., 1:]
21 |
22 | mask = shift_labels != -100
23 | shift_logits = shift_logits[mask]
24 | shift_labels = shift_labels[mask]
25 |
26 | training_logs = {}
27 | if self.args.print_entropy:
28 | entropy = chunked_entropy_from_logits(
29 | shift_logits,
30 | batch_size=max(1, shift_logits.size(0) // 4),
31 | ).mean()
32 | training_logs["entropy"] = round(entropy.item(), 2)
33 |
34 | return training_logs
35 |
36 | def gem_loss(self, logits, labels, beta=0.7, ignore_index=-100, h="logsigmoid"):
37 | shift_logits = logits[..., :-1, :].contiguous()
38 | shift_labels = labels[..., 1:].contiguous()
39 |
40 | mask = shift_labels != -100
41 | shift_logits = shift_logits[mask]
42 | shift_labels = shift_labels[mask]
43 |
44 | with torch.no_grad():
45 | logits_on_labels = torch.gather(
46 | shift_logits, dim=-1, index=shift_labels.unsqueeze(-1)
47 | ).squeeze(-1)
48 |
49 | logits_diff = shift_logits - logits_on_labels.unsqueeze(-1)
50 | if h == "linear":
51 | weights = torch.ones_like(logits_diff)
52 | elif h == "logsigmoid":
53 | weights = F.sigmoid(0.01 * logits_diff)
54 | else:
55 | raise ValueError(h)
56 |
57 | gene_log_probs = F.log_softmax(shift_logits, dim=-1)
58 | q_probs = torch.exp(F.log_softmax(shift_logits / beta, dim=-1)).detach()
59 |
60 | real_log_probs = torch.gather(
61 | gene_log_probs, dim=-1, index=shift_labels.unsqueeze(-1)
62 | )
63 |
64 | loss = -torch.sum(
65 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1
66 | ).mean()
67 |
68 | return loss
69 |
70 | def gem_loss_triton(self, logits, labels, beta=0.7, ignore_index=-100, h="linear"):
71 | if h != "linear":
72 | print(f"[warning] only linear is supported for gem_loss_triton for now. Got {h}.")
73 |
74 | gem_loss_func = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="mean")
75 |
76 | shift_logits = logits[..., :-1, :].contiguous()
77 | shift_labels = labels[..., 1:].contiguous()
78 |
79 | mask = shift_labels != -100
80 | shift_logits = shift_logits[mask]
81 | shift_labels = shift_labels[mask]
82 |
83 | loss = gem_loss_func(shift_logits, shift_labels)
84 |
85 | return loss
86 |
87 | # copied from Transformer's trainer with
88 | def compute_loss(self, model, inputs, return_outputs=False):
89 | """
90 | How the loss is computed by Trainer. By default, all models return the loss in the first element.
91 |
92 | Subclass and override for custom behavior.
93 | """
94 | if self.label_smoother is not None and "labels" in inputs:
95 | labels = inputs.pop("labels")
96 | else:
97 | labels = None
98 | outputs = model(**inputs)
99 | # Save past state if it exists
100 | # TODO: this needs to be fixed and made cleaner later.
101 | if self.args.past_index >= 0:
102 | self._past = outputs[self.args.past_index]
103 |
104 | if labels is not None:
105 | unwrapped_model = self.accelerator.unwrap_model(model)
106 | if _is_peft_model(unwrapped_model):
107 | model_name = unwrapped_model.base_model.model._get_name()
108 | else:
109 | model_name = unwrapped_model._get_name()
110 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
111 | loss = self.label_smoother(outputs, labels, shift_labels=True)
112 | else:
113 | loss = self.label_smoother(outputs, labels)
114 | else:
115 | if isinstance(outputs, dict) and "loss" not in outputs:
116 | raise ValueError(
117 | "The model did not return a loss from the inputs, only the following keys: "
118 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
119 | )
120 | # We don't use .loss here since the model may return tuples instead of ModelOutput.
121 | if self.args.loss == "ce" or self.control.should_evaluate:
122 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
123 | else:
124 | loss = self.gem_loss(
125 | outputs.logits,
126 | inputs["labels"],
127 | beta=self.args.gem_beta,
128 | h=self.args.gem_h,
129 | )
130 |
131 | # ziniu add logs
132 | if not self.control.should_evaluate:
133 | self.training_logs = self.compute_training_logs(
134 | outputs.logits, inputs["labels"]
135 | )
136 | self.training_logs["ce_loss"] = (
137 | outputs["loss"] if isinstance(outputs, dict) else outputs[0]
138 | )
139 | self.training_logs["ce_loss"] = round(self.training_logs["ce_loss"].item(), 4)
140 |
141 | return (loss, outputs) if return_outputs else loss
142 |
143 | def _maybe_log_save_evaluate(
144 | self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval
145 | ):
146 | if (
147 | self.control.should_log
148 | and self.state.global_step > self._globalstep_last_logged
149 | ):
150 | if is_torch_xla_available():
151 | xm.mark_step()
152 |
153 | logs: Dict[str, float] = {}
154 |
155 | # all_gather + mean() to get average loss over all processes
156 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
157 |
158 | # reset tr_loss to zero
159 | tr_loss -= tr_loss
160 |
161 | logs["loss"] = round(
162 | tr_loss_scalar
163 | / (self.state.global_step - self._globalstep_last_logged),
164 | 4,
165 | )
166 | if grad_norm is not None:
167 | logs["grad_norm"] = round(
168 | (
169 | grad_norm.detach().item()
170 | if isinstance(grad_norm, torch.Tensor)
171 | else grad_norm
172 | ),
173 | 4,
174 | )
175 | logs["learning_rate"] = self._get_learning_rate()
176 | ### update logs
177 | if getattr(self, "training_logs", {}):
178 | logs.update(getattr(self, "training_logs", {}))
179 |
180 | self._total_loss_scalar += tr_loss_scalar
181 | self._globalstep_last_logged = self.state.global_step
182 | self.store_flos()
183 |
184 | self.log(logs)
185 |
186 | metrics = None
187 | if self.control.should_evaluate:
188 | metrics = self._evaluate(trial, ignore_keys_for_eval)
189 |
190 | if self.control.should_save:
191 | self._save_checkpoint(model, trial, metrics=metrics)
192 | self.control = self.callback_handler.on_save(
193 | self.args, self.state, self.control
194 | )
195 |
196 | def chunked_entropy_from_logits(chunk_logits, batch_size=None):
197 | """
198 | Compute entropy from logits in a memory-efficient manner by introducing a batch_size parameter.
199 |
200 | Args:
201 | chunk_logits (torch.Tensor): Logits tensor of shape (total_samples, num_classes).
202 | batch_size (int): Number of samples to process per batch.
203 |
204 | Returns:
205 | torch.Tensor: Entropy tensor of shape (total_samples,).
206 | """
207 | total_samples, num_classes = chunk_logits.shape
208 | entropy_list = []
209 | if batch_size is None:
210 | batch_size = total_samples
211 |
212 | # Process logits in batches
213 | for start_idx in range(0, total_samples, batch_size):
214 | end_idx = min(start_idx + batch_size, total_samples)
215 | logits_batch = chunk_logits[start_idx:end_idx] # Get a batch of logits
216 |
217 | # Compute logsumexp for the current batch
218 | logsumexp_batch = torch.logsumexp(logits_batch, dim=-1, keepdim=False) # Shape: (batch_size,)
219 | # Compute probabilities in log-space without computing softmax
220 | normalized_logits = logits_batch - logsumexp_batch.unsqueeze(-1) # Shape: (batch_size, num_classes)
221 | exp_normalized_logits = torch.exp(normalized_logits) # Shape: (batch_size, num_classes)
222 | # Compute entropy for the batch
223 | entropy_batch = logsumexp_batch - (logits_batch * exp_normalized_logits).sum(dim=-1) # Shape: (batch_size,)
224 |
225 | entropy_list.append(entropy_batch) # Store entropy for the current batch
226 |
227 | # Concatenate results from all batches
228 | if len(entropy_list) > 0:
229 | return torch.cat(entropy_list, dim=0)
230 | else:
231 | return torch.tensor(0.0)
--------------------------------------------------------------------------------
/sft_trainer_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from transformers import Trainer
5 | from transformers.trainer import (
6 | ###
7 | _is_peft_model,
8 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
9 | is_torch_xla_available,
10 | SaveStrategy
11 | )
12 | from typing import List, Optional, Dict
13 | from utils.gem_triton_loss import GEMLoss
14 |
15 |
16 | class SFTTrainer(Trainer):
17 |
18 | @torch.no_grad
19 | def compute_training_logs(self, logits, labels):
20 | shift_logits = logits[..., :-1, :]
21 | shift_labels = labels[..., 1:]
22 |
23 | mask = shift_labels != -100
24 | shift_logits = shift_logits[mask]
25 | shift_labels = shift_labels[mask]
26 |
27 | training_logs = {}
28 | if self.args.print_entropy:
29 | entropy = chunked_entropy_from_logits(
30 | shift_logits,
31 | batch_size=max(1, shift_logits.size(0) // 4),
32 | ).mean()
33 | training_logs["entropy"] = round(entropy.item(), 2)
34 |
35 | return training_logs
36 |
37 | def gem_loss(self, logits, labels, num_items_in_batch, beta=0.7, ignore_index=-100, h="logsigmoid"):
38 | shift_logits = logits[..., :-1, :].contiguous()
39 | shift_labels = labels[..., 1:].contiguous()
40 |
41 | mask = shift_labels != -100
42 | shift_logits = shift_logits[mask]
43 | shift_labels = shift_labels[mask]
44 |
45 | with torch.no_grad():
46 | logits_on_labels = torch.gather(
47 | shift_logits, dim=-1, index=shift_labels.unsqueeze(-1)
48 | ).squeeze(-1)
49 |
50 | logits_diff = shift_logits - logits_on_labels.unsqueeze(-1)
51 | if h == "linear":
52 | weights = torch.ones_like(logits_diff)
53 | elif h == "logsigmoid":
54 | weights = F.sigmoid(0.01 * logits_diff)
55 | else:
56 | raise ValueError(h)
57 |
58 | gene_log_probs = F.log_softmax(shift_logits, dim=-1)
59 | q_probs = torch.exp(F.log_softmax(shift_logits / beta, dim=-1)).detach()
60 |
61 | real_log_probs = torch.gather(
62 | gene_log_probs, dim=-1, index=shift_labels.unsqueeze(-1)
63 | )
64 |
65 | if num_items_in_batch is not None:
66 | loss = -torch.sum(
67 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1
68 | ).sum() / num_items_in_batch
69 | else:
70 | loss = -torch.sum(
71 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1
72 | ).mean()
73 |
74 | return loss
75 |
76 | def gem_loss_triton(self, logits, labels, num_items_in_batch, beta=0.7, ignore_index=-100, h="linear"):
77 | assert h == "linear", "Only linear is supported for gem_loss_triton for now."
78 |
79 | if num_items_in_batch is not None:
80 | gem_loss_func = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="none")
81 | else:
82 | gem_loss_func = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="mean")
83 |
84 | shift_logits = logits[..., :-1, :].contiguous()
85 | shift_labels = labels[..., 1:].contiguous()
86 |
87 | mask = shift_labels != -100
88 | shift_logits = shift_logits[mask]
89 | shift_labels = shift_labels[mask]
90 |
91 | loss = gem_loss_func(shift_logits, shift_labels)
92 |
93 | if num_items_in_batch is not None:
94 | loss = loss.sum() / num_items_in_batch
95 | else:
96 | loss = loss
97 | return loss
98 |
99 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
100 | """
101 | How the loss is computed by Trainer. By default, all models return the loss in the first element.
102 |
103 | Subclass and override for custom behavior.
104 | """
105 | if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
106 | labels = inputs.pop("labels")
107 | else:
108 | labels = None
109 | if self.model_accepts_loss_kwargs:
110 | loss_kwargs = {}
111 | if num_items_in_batch is not None:
112 | loss_kwargs["num_items_in_batch"] = num_items_in_batch
113 | inputs = {**inputs, **loss_kwargs}
114 | outputs = model(**inputs)
115 | # Save past state if it exists
116 | # TODO: this needs to be fixed and made cleaner later.
117 | if self.args.past_index >= 0:
118 | self._past = outputs[self.args.past_index]
119 |
120 | if labels is not None:
121 | unwrapped_model = self.accelerator.unwrap_model(model)
122 | if _is_peft_model(unwrapped_model):
123 | model_name = unwrapped_model.base_model.model._get_name()
124 | else:
125 | model_name = unwrapped_model._get_name()
126 | # User-defined compute_loss function
127 | if self.compute_loss_func is not None:
128 | loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
129 | elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
130 | loss = self.label_smoother(outputs, labels, shift_labels=True)
131 | else:
132 | loss = self.label_smoother(outputs, labels)
133 | else:
134 | if isinstance(outputs, dict) and "loss" not in outputs:
135 | raise ValueError(
136 | "The model did not return a loss from the inputs, only the following keys: "
137 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
138 | )
139 | # We don't use .loss here since the model may return tuples instead of ModelOutput.
140 | if self.args.loss == "ce" or self.control.should_evaluate:
141 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
142 | elif self.args.loss == "gem":
143 | loss = self.gem_loss(
144 | outputs.logits,
145 | inputs["labels"],
146 | num_items_in_batch=num_items_in_batch,
147 | beta=self.args.gem_beta,
148 | h=self.args.gem_h
149 | )
150 | elif self.args.loss == "gem_triton":
151 | loss = self.gem_loss_triton(
152 | outputs.logits,
153 | inputs["labels"],
154 | num_items_in_batch=num_items_in_batch,
155 | beta=self.args.gem_beta,
156 | h=self.args.gem_h
157 | )
158 |
159 | if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
160 | loss *= self.accelerator.num_processes
161 |
162 | # ziniu add logs
163 | if not self.control.should_evaluate:
164 | self.training_logs = self.compute_training_logs(
165 | outputs.logits, inputs["labels"]
166 | )
167 | self.training_logs["ce_loss"] = (
168 | outputs["loss"] if isinstance(outputs, dict) else outputs[0]
169 | )
170 | self.training_logs["ce_loss"] = round(self.training_logs["ce_loss"].item(), 4)
171 |
172 | return (loss, outputs) if return_outputs else loss
173 |
174 | def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
175 | if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
176 | if is_torch_xla_available():
177 | xm.mark_step()
178 |
179 | logs: Dict[str, float] = {}
180 |
181 | # all_gather + mean() to get average loss over all processes
182 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
183 |
184 | # reset tr_loss to zero
185 | tr_loss -= tr_loss
186 |
187 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
188 | if grad_norm is not None:
189 | logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
190 | logs["learning_rate"] = self._get_learning_rate()
191 | if getattr(self, "training_logs", None):
192 | logs.update(self.training_logs)
193 |
194 | self._total_loss_scalar += tr_loss_scalar
195 | self._globalstep_last_logged = self.state.global_step
196 | self.store_flos()
197 |
198 | self.log(logs, start_time)
199 |
200 | metrics = None
201 | if self.control.should_evaluate:
202 | metrics = self._evaluate(trial, ignore_keys_for_eval)
203 | is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
204 |
205 | if self.args.save_strategy == SaveStrategy.BEST:
206 | self.control.should_save = is_new_best_metric
207 |
208 | if self.control.should_save:
209 | self._save_checkpoint(model, trial)
210 | self.control = self.callback_handler.on_save(self.args, self.state, self.control)
211 |
212 |
213 | def chunked_entropy_from_logits(chunk_logits, batch_size=None):
214 | """
215 | Compute entropy from logits in a memory-efficient manner by introducing a batch_size parameter.
216 |
217 | Args:
218 | chunk_logits (torch.Tensor): Logits tensor of shape (total_samples, num_classes).
219 | batch_size (int): Number of samples to process per batch.
220 |
221 | Returns:
222 | torch.Tensor: Entropy tensor of shape (total_samples,).
223 | """
224 | total_samples, num_classes = chunk_logits.shape
225 | entropy_list = []
226 | if batch_size is None:
227 | batch_size = total_samples
228 |
229 | # Process logits in batches
230 | for start_idx in range(0, total_samples, batch_size):
231 | end_idx = min(start_idx + batch_size, total_samples)
232 | logits_batch = chunk_logits[start_idx:end_idx] # Get a batch of logits
233 |
234 | # Compute logsumexp for the current batch
235 | logsumexp_batch = torch.logsumexp(logits_batch, dim=-1, keepdim=False) # Shape: (batch_size,)
236 | # Compute probabilities in log-space without computing softmax
237 | normalized_logits = logits_batch - logsumexp_batch.unsqueeze(-1) # Shape: (batch_size, num_classes)
238 | exp_normalized_logits = torch.exp(normalized_logits) # Shape: (batch_size, num_classes)
239 | # Compute entropy for the batch
240 | entropy_batch = logsumexp_batch - (logits_batch * exp_normalized_logits).sum(dim=-1) # Shape: (batch_size,)
241 |
242 | entropy_list.append(entropy_batch) # Store entropy for the current batch
243 |
244 | # Concatenate results from all batches
245 | if len(entropy_list) > 0:
246 | return torch.cat(entropy_list, dim=0)
247 | else:
248 | return torch.tensor(0.0)
--------------------------------------------------------------------------------
/tests/test_gem_loss_triton.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | # Add project root to Python path
5 | import os
6 | import sys
7 | from pathlib import Path
8 | project_root = str(Path(__file__).parent.parent)
9 | if project_root not in sys.path:
10 | sys.path.insert(0, project_root)
11 |
12 | from utils.gem_triton_loss import GEMLoss
13 |
14 | class ReferenceGEMLoss(torch.nn.Module):
15 | def forward(self, logits, labels, beta=0.7, ignore_index=-100, h="linear", reduction="mean"):
16 | """Reference implementation of GEM loss"""
17 | mask = labels != ignore_index
18 | masked_logits = logits[mask]
19 | masked_labels = labels[mask]
20 |
21 | with torch.no_grad():
22 | logits_on_labels = torch.gather(
23 | masked_logits, dim=-1, index=masked_labels.unsqueeze(-1)
24 | ).squeeze(-1)
25 | logits_diff = masked_logits - logits_on_labels.unsqueeze(-1)
26 | if h == "linear":
27 | weights = torch.ones_like(logits_diff)
28 | else:
29 | raise ValueError(f"Unsupported h function: {h}")
30 |
31 | gene_log_probs = F.log_softmax(masked_logits, dim=-1)
32 | with torch.no_grad():
33 | q_probs = torch.exp(F.log_softmax(masked_logits / beta, dim=-1)).detach()
34 |
35 | real_log_probs = torch.gather(
36 | gene_log_probs, dim=-1, index=masked_labels.unsqueeze(-1)
37 | )
38 |
39 | if reduction == "mean":
40 | loss = -torch.sum(
41 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1
42 | ).mean()
43 | elif reduction == "sum":
44 | loss = -torch.sum(
45 | q_probs * weights * (real_log_probs - gene_log_probs)
46 | )
47 | else:
48 | raise ValueError(f"Unsupported reduction: {reduction}")
49 |
50 | return loss
51 |
52 | def test_gem_loss():
53 | # Set random seed for reproducibility
54 | torch.manual_seed(42)
55 |
56 | # Test parameters
57 | batch_size = 10
58 | vocab_size = 120000
59 | beta = 0.7
60 | ignore_index = -100
61 |
62 | # Create random inputs
63 | logits = torch.randn(batch_size, vocab_size, device='cuda', requires_grad=True)
64 | labels = torch.randint(0, vocab_size, (batch_size,), device='cuda')
65 | # Add some ignored indices
66 | # labels[0] = ignore_index
67 |
68 | # Create loss functions
69 | triton_loss_fn = GEMLoss(beta=beta, ignore_index=ignore_index, reduction="sum")
70 | ref_loss_fn = ReferenceGEMLoss()
71 |
72 | # Forward pass
73 | triton_loss = triton_loss_fn(logits, labels)
74 | ref_loss = ref_loss_fn(logits, labels, beta=beta, ignore_index=ignore_index, h="linear", reduction="sum")
75 |
76 | test_failed = False
77 | # Check forward pass results
78 | print("*" * 100)
79 | print("triton_loss:", triton_loss)
80 | print("ref_loss:", ref_loss)
81 | print("Forward pass difference:", torch.abs((triton_loss - ref_loss)).mean().item())
82 | try:
83 | torch.testing.assert_close(triton_loss, ref_loss, rtol=1e-4, atol=1e-4)
84 | except Exception as e:
85 | print(e)
86 | test_failed = True
87 |
88 | # Backward pass
89 | triton_logits = logits.clone().detach().requires_grad_(True)
90 | ref_logits = logits.clone().detach().requires_grad_(True)
91 |
92 | triton_loss = triton_loss_fn(triton_logits, labels)
93 | ref_loss = ref_loss_fn(ref_logits, labels, beta=beta, ignore_index=ignore_index, h="linear", reduction="sum")
94 |
95 | triton_loss.mean().backward()
96 | ref_loss.mean().backward()
97 |
98 | # Check backward pass results
99 | print("*" * 100)
100 | print("Max gradient difference:", torch.max(torch.abs(triton_logits.grad - ref_logits.grad)).item())
101 | try:
102 | torch.testing.assert_close(triton_logits.grad, ref_logits.grad, rtol=1e-4, atol=1e-4)
103 | except Exception as e:
104 | print(e)
105 | test_failed = True
106 |
107 | if test_failed:
108 | print("Test failed!")
109 | else:
110 | print("All tests passed!")
111 |
112 | if __name__ == "__main__":
113 | test_gem_loss()
114 |
--------------------------------------------------------------------------------
/tests/test_gem_loss_triton_distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.distributed as dist
6 | import torch.multiprocessing as mp
7 | from torch.nn.parallel import DistributedDataParallel as DDP
8 |
9 | # Add project root to Python path
10 | import os
11 | import sys
12 | from pathlib import Path
13 | project_root = str(Path(__file__).parent.parent)
14 | if project_root not in sys.path:
15 | sys.path.insert(0, project_root)
16 |
17 | from utils.gem_triton_loss import GEMLoss as TritonGEMLoss
18 |
19 | class ReferenceGEMLoss(torch.nn.Module):
20 | def forward(self, logits, labels, beta=0.7, ignore_index=-100, h="linear"):
21 | """Reference implementation of GEM loss"""
22 | mask = labels != ignore_index
23 | masked_logits = logits[mask]
24 | masked_labels = labels[mask]
25 |
26 | with torch.no_grad():
27 | logits_on_labels = torch.gather(
28 | masked_logits, dim=-1, index=masked_labels.unsqueeze(-1)
29 | ).squeeze(-1)
30 | logits_diff = masked_logits - logits_on_labels.unsqueeze(-1)
31 | if h == "linear":
32 | weights = torch.ones_like(logits_diff)
33 | else:
34 | raise ValueError(f"Unsupported h function: {h}")
35 |
36 | gene_log_probs = F.log_softmax(masked_logits, dim=-1)
37 | with torch.no_grad():
38 | q_probs = torch.exp(F.log_softmax(masked_logits / beta, dim=-1)).detach()
39 |
40 | real_log_probs = torch.gather(
41 | gene_log_probs, dim=-1, index=masked_labels.unsqueeze(-1)
42 | )
43 |
44 | loss = -torch.sum(
45 | q_probs * weights * (real_log_probs - gene_log_probs), dim=-1
46 | ).mean()
47 |
48 | return loss
49 |
50 | def setup(local_rank=None):
51 | """Initialize the distributed environment."""
52 | # When using torch.distributed.launch, use the environment variables it sets
53 | if local_rank is None:
54 | if 'LOCAL_RANK' in os.environ:
55 | local_rank = int(os.environ['LOCAL_RANK'])
56 | world_size = int(os.environ['WORLD_SIZE'])
57 | rank = int(os.environ['RANK'])
58 | else:
59 | raise ValueError("LOCAL_RANK not found in environment variables")
60 | else:
61 | # For mp.spawn path
62 | rank = local_rank
63 | world_size = int(os.environ.get('WORLD_SIZE', '1'))
64 | os.environ['MASTER_ADDR'] = 'localhost'
65 | os.environ['MASTER_PORT'] = '12355'
66 |
67 | # Initialize the process group
68 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
69 |
70 | # Set device for this process
71 | torch.cuda.set_device(local_rank)
72 |
73 | print(f"Rank {rank}/{world_size} initialized")
74 | return rank, world_size
75 |
76 | def cleanup():
77 | """Clean up the distributed environment."""
78 | dist.destroy_process_group()
79 |
80 | def run_test(rank=None, world_size=None):
81 | """Run the distributed GEM loss test on a single process."""
82 | # Initialize distributed environment
83 | if rank is None: # When called directly (not through mp.spawn)
84 | rank, world_size = setup()
85 | else: # When called through mp.spawn
86 | setup(rank)
87 |
88 | # Set random seed for reproducibility
89 | torch.manual_seed(42 + rank) # Different seed per rank
90 |
91 | # Test parameters
92 | batch_size = 100
93 | vocab_size = 102400
94 | beta = 0.7
95 | ignore_index = -100
96 |
97 | # Each rank handles a portion of the vocabulary
98 | local_vocab_size = vocab_size // world_size
99 |
100 | # Create random inputs for this rank
101 | logits = torch.randn(batch_size, local_vocab_size, device=rank, requires_grad=True)
102 |
103 | # All ranks have the same labels (in the range of the full vocabulary)
104 | # We use the same seed for labels to ensure consistency
105 | torch.manual_seed(42) # Same seed for labels across all ranks
106 | labels = torch.randint(0, vocab_size, (batch_size,), device=rank)
107 |
108 | # Create loss function with process group
109 | triton_loss_fn = TritonGEMLoss(beta=beta, ignore_index=ignore_index, process_group=dist.group.WORLD)
110 |
111 | # Forward pass
112 | triton_loss = triton_loss_fn(logits, labels)
113 |
114 | # Basic sanity checks
115 | assert not torch.isnan(triton_loss).any(), f"Rank {rank}: Loss contains NaN values"
116 | assert not torch.isinf(triton_loss).any(), f"Rank {rank}: Loss contains Inf values"
117 |
118 | # Backward pass
119 | triton_loss.mean().backward()
120 |
121 | # Check gradients
122 | assert not torch.isnan(logits.grad).any(), f"Rank {rank}: Gradients contain NaN values"
123 | assert not torch.isinf(logits.grad).any(), f"Rank {rank}: Gradients contain Inf values"
124 |
125 | # Gather all logits to rank 0 for verification (optional)
126 | if rank == 0:
127 | all_logits = [torch.zeros_like(logits) for _ in range(world_size)]
128 | else:
129 | all_logits = None
130 |
131 | dist.gather(logits, all_logits if rank == 0 else None, dst=0)
132 |
133 | # On rank 0, verify the loss against reference implementation
134 | if rank == 0:
135 | print(f"Distributed test on {world_size} GPUs:")
136 | print(f"Triton GEM loss: {triton_loss.mean().item()}")
137 |
138 | # Concatenate all logits to get the full vocabulary
139 | full_logits = torch.cat(all_logits, dim=1).detach().requires_grad_(True)
140 |
141 | # Compute reference loss
142 | ref_loss_fn = ReferenceGEMLoss()
143 | ref_loss = ref_loss_fn(full_logits, labels, beta=beta, ignore_index=ignore_index)
144 |
145 | print(f"Reference loss: {ref_loss.item()}")
146 | print(f"Difference: {abs(triton_loss.mean().item() - ref_loss.item())}")
147 |
148 | # Note: We don't expect exact match due to distributed computation differences
149 | print("Distributed test completed successfully!")
150 |
151 | # Wait for all processes
152 | dist.barrier()
153 |
154 | # Clean up
155 | cleanup()
156 |
157 | def main():
158 | """Main function to launch the distributed test."""
159 | # Check if CUDA is available
160 | if not torch.cuda.is_available():
161 | print("CUDA not available. Skipping distributed test.")
162 | return
163 |
164 | # Check if we're being launched by torch.distributed.launch
165 | if 'LOCAL_RANK' in os.environ:
166 | # We're being launched by torch.distributed.launch, so just run the test
167 | run_test()
168 | else:
169 | # Manual launch with mp.spawn
170 | world_size = int(os.environ.get("WORLD_SIZE", 2))
171 |
172 | # Ensure we have enough GPUs
173 | if torch.cuda.device_count() < world_size:
174 | print(f"Not enough GPUs available. Need {world_size}, found {torch.cuda.device_count()}")
175 | world_size = torch.cuda.device_count()
176 | print(f"Reducing world_size to {world_size}")
177 |
178 | if world_size < 2:
179 | print("Need at least 2 GPUs for a meaningful distributed test")
180 | return
181 |
182 | print(f"Running distributed test with world_size={world_size}")
183 |
184 | # Spawn processes
185 | mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True)
186 |
187 | if __name__ == "__main__":
188 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | """
4 | This file is modified from the huggingface example for finetuning language models
5 | [run_clm.py](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py)
6 | """
7 |
8 | import logging
9 | import os
10 |
11 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
12 | import sys
13 | from typing import Optional
14 | from functools import partial
15 | import datasets
16 | import torch
17 | import torch.distributed as dist
18 | import deepspeed
19 | from datasets import load_dataset
20 | from torch.utils.data import Dataset
21 | from dataclasses import dataclass, field
22 | from typing import Optional, List, Union
23 |
24 | import transformers
25 | from transformers import (
26 | AutoModelForCausalLM,
27 | AutoTokenizer,
28 | HfArgumentParser,
29 | DataCollatorForSeq2Seq,
30 | set_seed,
31 | )
32 | from transformers.trainer_utils import get_last_checkpoint
33 |
34 | from packaging import version
35 |
36 | if version.parse(transformers.__version__) >= version.parse("4.46.0"):
37 | from sft_trainer_v2 import SFTTrainer
38 | else:
39 | from sft_trainer import SFTTrainer
40 |
41 | logging.basicConfig(level=logging.INFO)
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | @dataclass
46 | class TrainingArguments(transformers.TrainingArguments):
47 | adam_beta2: float = field(default=0.95, metadata={"help": "Beta2 for AdamW"})
48 | loss: str = field(
49 | default="gem", metadata={"help": "Loss name", "choices": ["gem", "ce", "gem_triton"]}
50 | )
51 | gem_beta: float = field(default=0.7, metadata={"help": "Hyper-parameter in GEM. A value between 0 and 1. A value close to 1.0 makes GEM behave more like CE, while a value close to 0.0 preserves more diversity."})
52 | gem_h: str = field(
53 | default="linear", metadata={"help": "Function $h$ in GEM. The 'logsigmoid' function is more adaptive, but the difference between 'logsigmoid' and 'linear' is usually negligible.", "choices": ["logsigmoid", "linear"]}
54 | )
55 | print_entropy: bool = field(
56 | default=False, metadata={"help": "Print entropy during training"}
57 | )
58 |
59 |
60 | @dataclass
61 | class ModelArguments:
62 | model_name_or_path: str = field(
63 | metadata={
64 | "help": "Path to pretrained model or model identifier from huggingface.co/models"
65 | }
66 | )
67 | cache_dir: Optional[str] = field(
68 | default=None,
69 | metadata={
70 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
71 | },
72 | )
73 | use_flash_attn: bool = field(
74 | default=True,
75 | metadata={"help": "Overwrite the cached training and evaluation sets"},
76 | )
77 |
78 |
79 | @dataclass
80 | class DataArguments:
81 | train_tokenized_file: str = field(
82 | default=None, metadata={"help": "huggingface dataset name or local data path"}
83 | )
84 | test_tokenized_file: str = field(
85 | default=None, metadata={"help": "huggingface dataset name or local data path"}
86 | )
87 | max_train_samples: Optional[int] = field(
88 | default=None,
89 | metadata={
90 | "help": (
91 | "For debugging purposes or quicker training, truncate the number of training examples to this "
92 | "value if set."
93 | )
94 | },
95 | )
96 | max_seq_length: Optional[int] = field(
97 | default=None,
98 | metadata={
99 | "help": (
100 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
101 | )
102 | },
103 | )
104 | overwrite_cache: bool = field(
105 | default=False,
106 | metadata={"help": "Overwrite the cached training and evaluation sets"},
107 | )
108 |
109 |
110 | class CustomDataset(Dataset):
111 | def __init__(
112 | self,
113 | training_args,
114 | data_args,
115 | model_args,
116 | train_tokenized_file,
117 | ):
118 | self.training_args = training_args
119 | self.data_args = data_args
120 | self.model_args = model_args
121 |
122 | raw_datasets = load_dataset(
123 | "json",
124 | data_files=[train_tokenized_file],
125 | cache_dir=self.model_args.cache_dir,
126 | )
127 | self.data = raw_datasets["train"]
128 |
129 | if self.data_args.max_train_samples is not None:
130 | max_samples = min(len(self.data), self.data_args.max_train_samples)
131 | self.data = self.data.select(range(max_samples))
132 |
133 | def __len__(self):
134 | return len(self.data)
135 |
136 | def __getitem__(self, item):
137 | example = self.data[item]
138 | assert "input_ids" in example
139 | assert "labels" in example
140 | example = {k: torch.tensor(v, dtype=torch.long) for k, v in example.items()}
141 | return example
142 |
143 |
144 | def main():
145 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
146 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
147 | model_args, data_args, training_args = parser.parse_json_file(
148 | json_file=os.path.abspath(sys.argv[1])
149 | )
150 | else:
151 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
152 |
153 | # Setup logging
154 | logging.basicConfig(
155 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
156 | datefmt="%m/%d/%Y %H:%M:%S",
157 | handlers=[logging.StreamHandler(sys.stdout)],
158 | )
159 |
160 | if training_args.should_log:
161 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
162 | transformers.utils.logging.set_verbosity_info()
163 |
164 | log_level = training_args.get_process_log_level()
165 | logger.setLevel(log_level)
166 | datasets.utils.logging.set_verbosity(log_level)
167 | transformers.utils.logging.set_verbosity(log_level)
168 | transformers.utils.logging.enable_default_handler()
169 | transformers.utils.logging.enable_explicit_format()
170 |
171 | # Log on each process the small summary:
172 | global_rank = dist.get_rank()
173 | logger.warning(
174 | f"Process rank: {global_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
175 | )
176 | logger.info(f"Training parameters {training_args}")
177 |
178 | # Set seed before initializing model.
179 | set_seed(training_args.seed)
180 |
181 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
182 | if "llama-3" in tokenizer.name_or_path.lower() and tokenizer.pad_token is None:
183 | tokenizer.pad_token_id = len(tokenizer) - 1
184 | tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id)
185 |
186 | model = AutoModelForCausalLM.from_pretrained(
187 | model_args.model_name_or_path,
188 | torch_dtype="auto",
189 | attn_implementation=(
190 | "flash_attention_2" if model_args.use_flash_attn else "eager"
191 | ),
192 | )
193 |
194 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
195 | # on a small vocab and want a smaller embedding size, remove this test.
196 | # gather deepspeed to get "real" embedding size
197 | embeddings = model.get_input_embeddings()
198 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
199 | embedding_size = embeddings.weight.shape[0]
200 | # resize does its own gather
201 | if len(tokenizer) > embedding_size:
202 | # pad to multiple for tensor cores.
203 | logging.warning(f"len(tokenizer) > embedding_size!!! we are resizing...")
204 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
205 |
206 | # set up datasets
207 | train_dataset = CustomDataset(training_args, data_args, model_args, data_args.train_tokenized_file)
208 | if data_args.test_tokenized_file:
209 | test_dataset = CustomDataset(training_args, data_args, model_args, data_args.test_tokenized_file)
210 | else:
211 | test_dataset = None
212 |
213 | # initalize a trainer
214 | # here we use a custom trainer that moves the model to CPU when saving the checkpoint in FSDP mode
215 | # we can switch to the default trainer after moving to deepspeed (let's don't change too much for now)
216 |
217 | trainer = SFTTrainer(
218 | model=model,
219 | args=training_args,
220 | train_dataset=train_dataset,
221 | eval_dataset=test_dataset,
222 | tokenizer=tokenizer,
223 | data_collator=DataCollatorForSeq2Seq(
224 | tokenizer=tokenizer, model=model, padding="longest"
225 | ),
226 | preprocess_logits_for_metrics=None,
227 | compute_metrics=None,
228 | )
229 |
230 | # Training
231 | logger.info("*** Train ***")
232 | checkpoint = None
233 | if training_args.resume_from_checkpoint is not None:
234 | checkpoint = training_args.resume_from_checkpoint
235 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
236 | if "llama-3" in model.config.name_or_path.lower() and isinstance(model.generation_config.eos_token_id, int):
237 | model.generation_config.eos_token_id = [128001, 128009]
238 | trainer.save_model() # Saves the tokenizer too for easy upload
239 |
240 | metrics = train_result.metrics
241 | metrics["train_samples"] = len(train_dataset)
242 | trainer.log_metrics("train", metrics)
243 | trainer.save_metrics("train", metrics)
244 |
245 |
246 | if __name__ == "__main__":
247 | main()
248 |
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 | # Triton Implementation of GEM Loss
2 |
3 | This folder contains the Triton implementation of GEM loss.
4 |
5 | ## Test
6 |
7 | We have successfully tested the implementation.
8 |
9 | To run the tests, you can use the following command:
10 |
11 | ```bash
12 | python tests/test_gem_loss_triton.py
13 | python tests/test_gem_loss_triton_distributed.py
14 | ```
15 |
16 | Please contact Ziniu Li (ziniuli@link.cuhk.edu.cn) if you find any issues.
17 |
18 |
19 | ## To Do
20 |
21 | - [ ] Add the implementation of GEM with $h = logsigmoid$ (currently only $h = linear$ is supported).
22 | - [ ] Add more tests.
23 |
24 |
25 | ## Acknowledgement
26 |
27 | We thank the authors of [flash-attention](https://github.com/Dao-AILab/flash-attention) for providing the Triton implementation of CE loss, for which the GEM loss is based on.
28 |
--------------------------------------------------------------------------------
/utils/gem_triton_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Ziniu Li.
2 | # The code is modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .gem_triton_ops import gem_loss
9 |
10 |
11 | class GEMLoss(nn.Module):
12 | def __init__(
13 | self,
14 | ignore_index=-100,
15 | reduction="mean",
16 | beta=1.0,
17 | logit_scale=1.0,
18 | lse_square_scale=0.0,
19 | inplace_backward=False,
20 | process_group=None,
21 | return_z_loss=False,
22 | ):
23 | """
24 | Arguments:
25 | ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
26 | beta: float
27 | lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
28 | This is also referred to as "z-loss".
29 | inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
30 | This saves memory.
31 | process_group: if not None, we're doing Tensor Parallel: each process is responsible for
32 | one part of the vocab. The loss will be aggregated across processes.
33 | return_z_loss: bool. If True, we return the component of the loss contributed by
34 | the lse_square_scale value. This value is only for logging and does not support
35 | backprop.
36 | """
37 | super().__init__()
38 | if reduction not in ["mean", "none", "sum"]:
39 | raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
40 | self.ignore_index = ignore_index
41 | self.reduction = reduction
42 | self.beta = beta
43 | self.logit_scale = logit_scale
44 | self.lse_square_scale = lse_square_scale
45 | self.inplace_backward = inplace_backward
46 | self.process_group = process_group
47 | self.return_z_loss = return_z_loss
48 |
49 | def forward(self, input, target, precomputed_lse=None):
50 | """
51 | Arguments:
52 | input: (batch, vocab_size)
53 | target: (batch,)
54 | Returns:
55 | losses: (batch,) if reduction is 'none', else (1,), dtype float
56 | z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
57 | """
58 | assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
59 | loss, z_loss = gem_loss(
60 | input,
61 | target,
62 | precomputed_lse=precomputed_lse,
63 | beta=self.beta,
64 | logit_scale=self.logit_scale,
65 | lse_square_scale=self.lse_square_scale,
66 | ignore_index=self.ignore_index,
67 | inplace_backward=self.inplace_backward,
68 | process_group=self.process_group,
69 | )
70 | if self.reduction == "mean":
71 | loss = loss.sum() / (target != self.ignore_index).sum()
72 | elif self.reduction == "sum":
73 | loss = loss.sum()
74 | else:
75 | loss = loss
76 |
77 | if not self.return_z_loss:
78 | return loss
79 |
80 | if self.reduction == "mean":
81 | z_loss = z_loss.sum() / (target != self.ignore_index).sum()
82 | elif self.reduction == "sum":
83 | z_loss = z_loss.sum()
84 | else:
85 | z_loss = z_loss
86 |
87 | return loss, z_loss
88 |
--------------------------------------------------------------------------------
/utils/gem_triton_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Ziniu Li.
2 | # The code is modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py
3 |
4 | from typing import Tuple, Optional, Union
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | import triton
10 | import triton.language as tl
11 |
12 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
13 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
14 | # version of PyTorch. The following 2 lines are for backward compatibility with
15 | # older PyTorch.
16 | if "all_gather_into_tensor" not in dir(torch.distributed):
17 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
18 |
19 | @triton.jit
20 | def gem_fwd_kernel(
21 | loss_ptr, # data ptrs
22 | lse_ptr,
23 | z_loss_ptr,
24 | logits_ptr,
25 | labels_ptr,
26 | beta,
27 | logit_scale,
28 | lse_square_scale,
29 | ignore_index,
30 | total_classes,
31 | class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
32 | n_cols, # shapes
33 | logits_row_stride, # strides
34 | BLOCK_SIZE: tl.constexpr,
35 | # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
36 | SPLIT: tl.constexpr,
37 | PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
38 | ):
39 | # GEM Loss (h = linear)
40 | # Loss = -log p(y|x) + sum_y q(y|x) * log p(y|x)
41 | # q (y|x) = softmax (1 / beta * log p(y|x))
42 | # Note that the first term is the same as the CE loss.
43 |
44 | # Prepare for calculating lse.
45 | row_idx = tl.program_id(0)
46 | logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
47 |
48 | if not PRECOMPUTED_LSE:
49 | # Statistics for online softmax
50 | m_i = -float("inf")
51 | l_i = 0.0
52 | for col_offset in range(0, n_cols, BLOCK_SIZE):
53 | cols = col_offset + tl.arange(0, BLOCK_SIZE)
54 | logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
55 | tl.float32
56 | ) * logit_scale
57 | m_i_new = tl.maximum(m_i, tl.max(logits))
58 | l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
59 | m_i = m_i_new
60 | lse = tl.log(l_i) + m_i
61 | tl.store(lse_ptr + row_idx, lse)
62 | else:
63 | lse = tl.load(lse_ptr + row_idx)
64 | label_idx = tl.load(labels_ptr + row_idx)
65 |
66 | # Second term: q-regularized loss
67 | m_q = -float("inf") # running max
68 | s_q = 0.0 # running sum for denominator
69 | for col_offset in range(0, n_cols, BLOCK_SIZE):
70 | cols = col_offset + tl.arange(0, BLOCK_SIZE)
71 | logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
72 | tl.float32
73 | ) * logit_scale
74 | logits_scaled = logits / beta
75 |
76 | # Update running max and rescale previous sum
77 | m_q_prev = m_q
78 | m_q = tl.maximum(m_q, tl.max(logits_scaled))
79 | s_q = tl.exp(m_q_prev - m_q) * s_q
80 |
81 | # Add contribution from current block
82 | numerator = tl.exp(logits_scaled - m_q)
83 | s_q += tl.sum(tl.where(cols < n_cols, numerator, 0.0))
84 |
85 | # Second pass: compute q_loss using final m_q and s_q
86 | q_loss = 0.0
87 | for col_offset in range(0, n_cols, BLOCK_SIZE):
88 | cols = col_offset + tl.arange(0, BLOCK_SIZE)
89 | logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
90 | tl.float32
91 | ) * logit_scale
92 | logits_scaled = logits / beta
93 |
94 | # Compute q_probs using final normalization constants
95 | q_probs = tl.exp(logits_scaled - m_q) / s_q
96 | q_loss += tl.sum(tl.where(cols < n_cols, q_probs * (logits - lse), 0.0))
97 |
98 | # Compute CE loss term (same as before)
99 | label_idx = tl.load(labels_ptr + row_idx)
100 | if label_idx == ignore_index:
101 | loss = 0.0
102 | z_loss = 0.0
103 | else:
104 | label_idx -= class_start_idx
105 | if label_idx >= 0 and label_idx < n_cols:
106 | logits_label = tl.load(logits_ptr + label_idx) * logit_scale
107 | # GEM loss = CE loss + q_loss
108 | loss = (lse if not SPLIT else 0.0) - logits_label + q_loss
109 | else:
110 | loss = q_loss
111 | if not SPLIT:
112 | z_loss = lse_square_scale * lse * lse
113 | loss += z_loss
114 | else:
115 | z_loss = 0.0
116 | tl.store(loss_ptr + row_idx, loss)
117 | if not SPLIT:
118 | tl.store(z_loss_ptr + row_idx, z_loss)
119 |
120 |
121 | @triton.jit
122 | def gem_bwd_kernel(
123 | dlogits_ptr, # data ptrs
124 | dloss_ptr,
125 | logits_ptr,
126 | lse_ptr,
127 | labels_ptr,
128 | beta,
129 | logit_scale,
130 | lse_square_scale,
131 | ignore_index,
132 | total_classes,
133 | class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
134 | n_cols, # shapes
135 | logits_row_stride, # strides
136 | dlogits_row_stride,
137 | dloss_row_stride,
138 | BLOCK_SIZE: tl.constexpr,
139 | ):
140 | row_idx = tl.program_id(0)
141 | col_block_idx = tl.program_id(1)
142 | logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
143 | dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
144 | col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
145 | label_idx = tl.load(labels_ptr + row_idx)
146 | if label_idx != ignore_index:
147 | dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
148 | else:
149 | dloss = 0.0
150 | logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
151 | tl.float32
152 | ) * logit_scale
153 | lse = tl.load(lse_ptr + row_idx)
154 |
155 | # GEM: gradient of y w.r.t. logits = q(y|x) if y != label else q(y|x) - 1.
156 | # First pass: compute max and sum for proper softmax normalization
157 | m_q = -float("inf")
158 | s_q = 0.0
159 | for offset in range(0, n_cols, BLOCK_SIZE):
160 | cols = offset + tl.arange(0, BLOCK_SIZE)
161 | logits_block = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(tl.float32) * logit_scale
162 | logits_scaled = logits_block / beta
163 | m_q = tl.maximum(m_q, tl.max(logits_scaled))
164 |
165 | # Second pass: compute sum with stable numerics
166 | for offset in range(0, n_cols, BLOCK_SIZE):
167 | cols = offset + tl.arange(0, BLOCK_SIZE)
168 | logits_block = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(tl.float32) * logit_scale
169 | logits_scaled = logits_block / beta
170 | numerator = tl.exp(logits_scaled - m_q)
171 | s_q += tl.sum(tl.where(cols < n_cols, numerator, 0.0))
172 |
173 | # Final pass: compute gradients for current block
174 | col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
175 | logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32) * logit_scale
176 | logits_scaled = logits / beta
177 | probs = tl.exp(logits_scaled - m_q) / s_q
178 |
179 | probs += 2.0 * lse_square_scale * lse * probs
180 | label_idx -= class_start_idx
181 | probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
182 | tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
183 |
184 |
185 | class GEMLoss(torch.autograd.Function):
186 |
187 | @staticmethod
188 | def forward(
189 | ctx,
190 | logits,
191 | labels,
192 | precomputed_lse=None,
193 | beta=1.0,
194 | logit_scale=1.0,
195 | lse_square_scale=0.0,
196 | ignore_index=-100,
197 | inplace_backward=False,
198 | process_group=None,
199 | ):
200 | # For some reason Triton generates wrong code when labels has dtype long and its address
201 | # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
202 | if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
203 | labels = F.pad(labels, (0, 1))[..., :-1]
204 | assert labels.data_ptr() % 16 == 0
205 | assert logit_scale > 0.0
206 | n_rows, n_cols = logits.shape
207 | assert labels.shape == (n_rows,)
208 | world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
209 | total_classes = world_size * n_cols
210 | rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
211 | class_start_idx = rank * n_cols
212 | use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0
213 |
214 | if logits.stride(-1) != 1:
215 | logits = logits.contiguous()
216 | MAX_BLOCK_SIZE = 16 * 1024
217 | BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
218 | num_warps = (
219 | 4
220 | if BLOCK_SIZE < 2048
221 | else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
222 | )
223 | losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
224 | if use_precomputed_lse:
225 | assert precomputed_lse.shape == (n_rows,)
226 | lse = precomputed_lse.contiguous()
227 | else:
228 | lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
229 | z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
230 | # Need this, otherwise Triton tries to launch from cuda:0 and we get
231 | # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
232 | with torch.cuda.device(logits.device.index):
233 | gem_fwd_kernel[(n_rows,)](
234 | losses, # data ptrs
235 | lse,
236 | z_losses,
237 | logits,
238 | labels,
239 | beta,
240 | logit_scale,
241 | lse_square_scale,
242 | ignore_index,
243 | total_classes,
244 | class_start_idx,
245 | n_cols, # shapes
246 | logits.stride(0), # strides
247 | BLOCK_SIZE=BLOCK_SIZE, # constants
248 | SPLIT=world_size > 1,
249 | PRECOMPUTED_LSE=use_precomputed_lse,
250 | num_warps=num_warps,
251 | )
252 |
253 | if world_size > 1:
254 | # For GEM loss, if labels are in the vocab of this partition, losses contains
255 | # - predicted logit + q_loss for this partition, and 0 otherwise.
256 | # For labels not in the vocab of this partition, losses contains
257 | # only q_loss for this partition
258 | lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
259 | torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
260 | handle_losses = torch.distributed.all_reduce(
261 | losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
262 | )
263 | lse = torch.logsumexp(lse_allgather, dim=0)
264 | handle_losses.wait()
265 | # After the allreduce:
266 | # 1. For labels in any partition, we have the sum of all q_losses from each partition
267 | # 2. For the partition containing the label, we also have -predicted_logit
268 | # We just need to add the global lse to complete the GEM loss
269 | losses += lse
270 | if lse_square_scale != 0.0:
271 | z_losses = lse_square_scale * lse.square()
272 | z_losses.masked_fill_(labels == ignore_index, 0.0)
273 | losses += z_losses
274 | else:
275 | z_losses = torch.zeros_like(losses)
276 | losses.masked_fill_(labels == ignore_index, 0.0)
277 |
278 | ctx.save_for_backward(logits, lse, labels)
279 | ctx.mark_non_differentiable(z_losses)
280 | ctx.beta = beta
281 | ctx.logit_scale = logit_scale
282 | ctx.lse_square_scale = lse_square_scale
283 | ctx.ignore_index = ignore_index
284 | ctx.total_classes = total_classes
285 | ctx.class_start_idx = class_start_idx
286 | ctx.inplace_backward = inplace_backward
287 | return losses, z_losses
288 |
289 | @staticmethod
290 | def backward(ctx, grad_losses, grad_z_losses):
291 | del grad_z_losses # z_losses are only for logging.
292 |
293 | logits, lse, labels = ctx.saved_tensors
294 | dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
295 | n_rows, n_cols = logits.shape
296 | BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
297 | num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
298 | grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
299 | # Need this, otherwise Triton tries to launch from cuda:0 and we get
300 | # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
301 | with torch.cuda.device(logits.device.index):
302 | gem_bwd_kernel[grid](
303 | dlogits, # data ptrs
304 | grad_losses,
305 | logits,
306 | lse,
307 | labels,
308 | ctx.beta,
309 | ctx.logit_scale,
310 | ctx.lse_square_scale,
311 | ctx.ignore_index,
312 | ctx.total_classes,
313 | ctx.class_start_idx,
314 | n_cols, # shapes
315 | logits.stride(0), # strides
316 | dlogits.stride(0),
317 | grad_losses.stride(0),
318 | BLOCK_SIZE=BLOCK_SIZE, # constants
319 | num_warps=num_warps,
320 | )
321 | return dlogits, None, None, None, None, None, None, None, None, None
322 |
323 |
324 | def gem_loss(
325 | logits: torch.Tensor,
326 | labels: torch.Tensor,
327 | precomputed_lse: Optional[torch.Tensor] = None,
328 | beta: float = 1.0,
329 | logit_scale: float = 1.0,
330 | lse_square_scale: float = 0.0,
331 | ignore_index=-100,
332 | inplace_backward: bool = False,
333 | process_group=None,
334 | ) -> Tuple[torch.Tensor, torch.Tensor]:
335 | """
336 | Arguments:
337 | logits: (batch, vocab_size)
338 | labels: (batch,)
339 | beta: float
340 | logit_scale: float. Multiply logits by this scale before calculating the loss.
341 | lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
342 | This is also referred to as "z-loss".
343 | ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
344 | inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
345 | This saves memory.
346 | process_group: if not None, we're doing Tensor Parallel: each process is responsible for
347 | one part of the vocab. The loss will be aggregated across processes.
348 | Returns:
349 | losses: (batch,), float
350 | z_losses: (batch,), float
351 | """
352 | return GEMLoss.apply(
353 | logits,
354 | labels,
355 | precomputed_lse,
356 | beta,
357 | logit_scale,
358 | lse_square_scale,
359 | ignore_index,
360 | inplace_backward,
361 | process_group,
362 | )
363 |
--------------------------------------------------------------------------------