├── .gitignore ├── README.md ├── assets ├── pipeline.png └── rewards.png ├── requirements.txt └── src ├── generate.py ├── grpo.py └── train_supervised.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # Temp folders 174 | data/ 175 | wandb/ 176 | logs/ 177 | eval_results/ 178 | results/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # open-r1-text2graph 2 | 3 | The goal of this project is to reporduce **DeepSeek R1** training schema particulary for **text-to-graph** information extraction. 4 | 5 | The project is build on top and is inpired by Hagging Face [Open-R1](https://github.com/huggingface/open-r1/tree/main) and [trl](https://github.com/huggingface/trl/tree/main/trl) 6 | 7 | ### Structure 8 | The project currently consists of the following components: 9 | * `src` contains the scripts to train a model using GRPO or supervised learning as well as data generation scripts. 10 | * `grpo.py` - trains a model using GRPO giving text-to-graph related reward functions; 11 | * `train_supervised.py` - trains a model in a supervised manner for text-to-graph extraction; 12 | * `generate.py` - generate thinking chains giving input text and extracted JSON; 13 | 14 | ### Pipeline 15 | 16 | ![image/png](assets/pipeline.png) 17 | 18 | The training process consists of three major stages: **synthetic data generation, supervised training, and reinforcement learning (RL) training**. Each of these stages plays a crucial role in improving the model’s ability to perform structured information extraction. 19 | 20 | 1. **Synthetic Data Generation** 21 | 22 | To bootstrap the process, we start with **data collection**, where we gather diverse text sources relevant to our target domain. The **text-to-graph** generation step, powered by **Llama 70B** structured generation, converts unstructured text into graph-based representations. However, this step is imperfect, and therefore, selecting and augmenting data becomes essential to filter out low-quality extractions and enrich the dataset with more diverse structures. 23 | 24 | Additionally, we feed generated with structured prediction JSON data and feed them and text into **DeepSeek-R1 Llama 70B** to generate a chain of thought that can explain the extraction process. 25 | 26 | We experimented with both thinking-enabled and disabled modes and discovered that small models struggle to discover some interesting and important thinking strategies. 27 | 28 | 2. **Supervised Training** 29 | 30 | Before starting reinforcement learning and considering that we use small models additional supervised training is required to push model return data in the right format, We used only 1k examples for this purpose. 31 | 32 | 3. **Reinforcement Learning with GRPO** 33 | 34 | Supervised training alone does not fully solve the problem, especially when it comes to conditioning model outputs on predefined entity and relation types. To address this, we employ **Group Relative Policy Optimization (GRPO)** for reinforcement learning. 35 | 36 | * **Format reward** ensures that the output follows a structured format, where thinking is encapsulated in a respective tag (in the case of thinking mode). 37 | * **JSON reward** specifically validates well-formed, machine-readable JSON representations and that its structure aligns with the desirable format. 38 | * **F1 reward** evaluates the accuracy of extracted entities and relations by comparing them to ground truth graphs. 39 | 40 | Below you can see how different rewards change through training steps from one of our experiments. 41 | 42 | ![image/png](assets/rewards.png) 43 | 44 | 45 | ### Try Model 46 | 47 | You can try one of fine-tuned model using this framework. 48 | 49 | ```python 50 | from transformers import AutoModelForCausalLM, AutoTokenizer 51 | 52 | model_name = "Ihor/Text2Graph-R1-Qwen2.5-0.5b" 53 | 54 | model = AutoModelForCausalLM.from_pretrained( 55 | model_name, 56 | torch_dtype="auto", 57 | device_map="auto" 58 | ) 59 | tokenizer = AutoTokenizer.from_pretrained(model_name) 60 | 61 | text = """Your text here...""" 62 | prompt = "Analyze this text, identify the entities, and extract meaningful relationships as per given instructions:{}" 63 | messages = [ 64 | {"role": "system", "content": ( 65 | "You are an assistant trained to process any text and extract named entities and relations from it. " 66 | "Your task is to analyze user-provided text, identify all unique and contextually relevant entities, and infer meaningful relationships between them" 67 | "Output the annotated data in JSON format, structured as follows:\n\n" 68 | """{"entities": [{"type": entity_type_0", "text": "entity_0", "id": 0}, "type": entity_type_1", "text": "entity_1", "id": 0}], "relations": [{"head": "entity_0", "tail": "entity_1", "type": "re_type_0"}]}""" 69 | )}, 70 | {"role": "user", "content": prompt.format(text)} 71 | ] 72 | text = tokenizer.apply_chat_template( 73 | messages, 74 | tokenize=False, 75 | add_generation_prompt=True 76 | ) 77 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) 78 | 79 | generated_ids = model.generate( 80 | **model_inputs, 81 | max_new_tokens=512 82 | ) 83 | generated_ids = [ 84 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 85 | ] 86 | 87 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 88 | ``` 89 | 90 | ### Blog 91 | 92 | Read more details about the project in the following [blog](https://huggingface.co/blog/Ihor/replicating-deepseek-r1-for-information-extraction). -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ingvarstep/open-r1-text2graph/ba2c3f1628f2e2cc2941b916fc313d6288cc1416/assets/pipeline.png -------------------------------------------------------------------------------- /assets/rewards.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ingvarstep/open-r1-text2graph/ba2c3f1628f2e2cc2941b916fc313d6288cc1416/assets/rewards.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trl==0.14.0 2 | transformers==4.48.1 3 | accelerate==1.3.0 4 | vllm==0.7.0 5 | bitsandbytes==0.45.1 6 | datasets==3.1.0 7 | peft==0.14.0 -------------------------------------------------------------------------------- /src/generate.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import torch 5 | import copy 6 | import random 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | from transformers import AutoTokenizer 11 | from vllm import LLM, SamplingParams 12 | 13 | ZERO_SHOT_MESSAGES = [ 14 | { 15 | "role": "system", 16 | "content": ( 17 | "You are an assistant trained to process any text and extract named entities and relations from it. " 18 | "Your task is to analyze user-provided text, identify all unique and contextually relevant entities, and infer meaningful relationships between them" 19 | "Additionaly you will get ready extracted results and your task is to generate thinking process as you would do not knowing the final JSON output." 20 | "You need to formulate your reasoning process and encapsulate it in tag, this is the only thing you return." 21 | ), 22 | }, 23 | 24 | ] 25 | 26 | ZERO_SHOT_PROMPT = """Analyze this text and JSON output and produce your thinking""" 27 | 28 | 29 | def create_chat(prompt, text): 30 | return prompt.format(text) 31 | 32 | def generate_response(llm, chats, sampling_params): 33 | responses = llm.generate(chats, sampling_params, use_tqdm=False) 34 | return responses 35 | 36 | def process_input(example): 37 | messages = copy.deepcopy(ZERO_SHOT_MESSAGES) 38 | text = messages[-1]['content'] 39 | text+="Here is the JSON output:" 40 | solution = example['solution'] 41 | text+=solution 42 | messages.append({ 43 | "role": 'user', 44 | "content": text 45 | }) 46 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 47 | return prompt 48 | 49 | def generate_dataset(examples, llm, sampling_params, batch_size=8, max_lines=None): 50 | batch_chats = [] 51 | batch_examples = [] 52 | text = True 53 | final_results = [] 54 | for i in tqdm(range(0, max_lines)): 55 | example = examples[i] 56 | 57 | batch_examples.append(example) 58 | chat = process_input(example) 59 | batch_chats.append(chat) 60 | 61 | if len(batch_chats)==batch_size: 62 | try: 63 | responses = generate_response(llm, batch_chats, sampling_params) 64 | except Exception as err: 65 | print(err) 66 | continue 67 | 68 | batch_results = [ {'prompt': batch_examples[j]['prompt'], 69 | 'solution': response.outputs[0].text +'\n'+batch_examples[j]['solution']} 70 | for j, response in enumerate(responses)] 71 | 72 | final_results.extend(batch_results) 73 | 74 | batch_chats = [] 75 | batch_examples = [] 76 | 77 | return final_results 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--data_path', type=str, default= "data/text2graph.json") 83 | parser.add_argument('--save_path', type=str, default= "data/text2graph_with_thinking.json") 84 | parser.add_argument('--model', type=str, default= "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") 85 | parser.add_argument('--quantization', type=str, default= "fp8") 86 | parser.add_argument('--max_examples', type=int, default= 500) 87 | parser.add_argument('--batch_size', type=int, default= 8) 88 | parser.add_argument('--temperature', type=float, default= 0.75) 89 | args = parser.parse_args() 90 | 91 | with open(args.data_path, 'r') as f: 92 | texts = json.load(f) 93 | random.shuffle(texts) 94 | print('Texts count: ', len(texts)) 95 | 96 | llm = LLM(model=args.model, 97 | max_model_len = 8129, 98 | tensor_parallel_size=1, dtype="half", 99 | gpu_memory_utilization = 0.9, quantization = args.quantization) 100 | 101 | sampling_params = SamplingParams(temperature = args.temperature, repetition_penalty = 1.1, top_k=100, max_tokens=4096, top_p=0.8, stop="") 102 | 103 | tokenizer = AutoTokenizer.from_pretrained(args.model) 104 | 105 | results = generate_dataset(texts, llm, sampling_params, 106 | batch_size=args.batch_size, max_lines=args.max_examples) 107 | 108 | with open(args.save_path, 'w') as f: 109 | json.dump(results, f, indent=1) -------------------------------------------------------------------------------- /src/grpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import json 17 | import random 18 | from dataclasses import dataclass, field 19 | 20 | from datasets import load_dataset 21 | 22 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config 23 | 24 | 25 | @dataclass 26 | class GRPOScriptArguments(ScriptArguments): 27 | """ 28 | Script arguments for the GRPO training script. 29 | 30 | Args: 31 | reward_funcs (`list[str]`): 32 | List of reward functions. Possible values: 'accuracy', 'format'. 33 | """ 34 | 35 | reward_funcs: list[str] = field( 36 | default_factory=lambda: ['format', 'json_consistency', 'json_structure', 'f1_ents', 'f1_rels'], 37 | metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, 38 | ) 39 | 40 | def extract_json_from_text(text): 41 | json_start = 0 42 | json_end = 0 43 | close_brace_count = 0 44 | extracted_jsons = [] 45 | for idx, char in enumerate(text): 46 | if char == '{': 47 | if close_brace_count == 0: 48 | json_start = idx 49 | close_brace_count += 1 50 | elif char == '}': 51 | close_brace_count -= 1 52 | if close_brace_count == 0: 53 | json_end = idx + 1 54 | extracted_json = text[json_start:json_end] 55 | try: 56 | extracted_jsons.append(json.loads(extracted_json)) 57 | except json.JSONDecodeError: 58 | pass 59 | return extracted_jsons 60 | 61 | def validate_json_structure(data): 62 | required_keys = {"entities", "relations"} 63 | entity_required_keys = {"id", "text", "type"} 64 | relation_required_keys = {"head", "tail", "type"} 65 | 66 | if not isinstance(data, dict) or not required_keys.issubset(data.keys()): 67 | return False 68 | 69 | if not isinstance(data["entities"], list): 70 | return False 71 | 72 | for entity in data["entities"]: 73 | if not isinstance(entity, dict) or not entity_required_keys.issubset(entity.keys()): 74 | return False 75 | if not isinstance(entity["id"], int) or not isinstance(entity["text"], str) or not isinstance(entity["type"], str): 76 | return False 77 | 78 | if not isinstance(data["relations"], list): 79 | return False 80 | 81 | for relation in data["relations"]: 82 | if not isinstance(relation, dict) or not relation_required_keys.issubset(relation.keys()): 83 | return False 84 | if not isinstance(relation["head"], str) or not isinstance(relation["tail"], str) or not isinstance(relation["type"], str): 85 | return False 86 | 87 | return True 88 | 89 | def compute_f1(pred_list, true_list): 90 | tp = len(set(pred_list) & set(true_list)) 91 | fp = len(set(pred_list) - set(true_list)) 92 | fn = len(set(true_list) - set(pred_list)) 93 | 94 | # Compute precision, recall, and F1 score 95 | precision = tp / (tp + fp) if tp + fp > 0 else 0 96 | recall = tp / (tp + fn) if tp + fn > 0 else 0 97 | f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 98 | return f1 99 | 100 | def get_entities_f1_score(pred_entities, true_entities): 101 | pred_list = {f"{entity['text']}_{entity['type']}" for entity in pred_entities} 102 | true_list = {f"{entity['text']}_{entity['type']}" for entity in true_entities} 103 | f1 = compute_f1(pred_list, true_list) 104 | return f1 105 | 106 | def get_relations_f1_score(pred_relations, true_relations): 107 | pred_list = {f"{rel['head']}_{rel['tail']}_{rel['type']}" for rel in pred_relations} 108 | true_list = {f"{rel['head']}_{rel['tail']}_{rel['type']}" for rel in true_relations} 109 | f1 = compute_f1(pred_list, true_list) 110 | return f1 111 | 112 | def json_consistency_reward(completions, solution, **kwargs): 113 | contents = [completion[0]["content"] for completion in completions] 114 | rewards = [] 115 | for content, sol in zip(contents, solution): 116 | extracted_jsons = extract_json_from_text(content) 117 | if len(extracted_jsons)==1: 118 | rewards.append(0.1) 119 | else: 120 | rewards.append(0.0) 121 | return rewards 122 | 123 | def json_structure_reward(completions, solution, **kwargs): 124 | contents = [completion[0]["content"] for completion in completions] 125 | rewards = [] 126 | for content, sol in zip(contents, solution): 127 | extracted_jsons = extract_json_from_text(content) 128 | if len(extracted_jsons)==1: 129 | extracted_json = extracted_jsons[0] 130 | val = validate_json_structure(extracted_json) 131 | if val: 132 | rewards.append(0.1) 133 | else: 134 | rewards.append(0.0) 135 | else: 136 | rewards.append(0.0) 137 | return rewards 138 | 139 | def f1_entities_reward(completions, solution, **kwargs): 140 | contents = [completion[0]["content"] for completion in completions] 141 | rewards = [] 142 | for content, sol in zip(contents, solution): 143 | extracted_jsons_pred = extract_json_from_text(content) 144 | extracted_jsons_true = extract_json_from_text(sol) 145 | 146 | if len(extracted_jsons_pred)==1 and len(extracted_jsons_true)==1: 147 | json_pred = extracted_jsons_pred[0] 148 | json_true = extracted_jsons_true[0] 149 | 150 | f1_reward = 0 151 | try: 152 | f1_reward += get_entities_f1_score(json_pred['entities'], json_true['entities']) 153 | except: 154 | pass 155 | 156 | rewards.append(f1_reward) 157 | else: 158 | rewards.append(0) 159 | return rewards 160 | 161 | def f1_relations_reward(completions, solution, **kwargs): 162 | contents = [completion[0]["content"] for completion in completions] 163 | rewards = [] 164 | for content, sol in zip(contents, solution): 165 | extracted_jsons_pred = extract_json_from_text(content) 166 | extracted_jsons_true = extract_json_from_text(sol) 167 | 168 | if len(extracted_jsons_pred)==1 and len(extracted_jsons_true)==1: 169 | json_pred = extracted_jsons_pred[0] 170 | json_true = extracted_jsons_true[0] 171 | 172 | f1_reward = 0 173 | try: 174 | f1_reward += get_relations_f1_score(json_pred['relations'], json_true['relations']) 175 | except: 176 | pass 177 | 178 | rewards.append(f1_reward) 179 | else: 180 | rewards.append(0) 181 | return rewards 182 | 183 | def format_reward(completions, **kwargs): 184 | """Reward function that checks if the completion has a specific format.""" 185 | completion_contents = [completion[0]["content"] for completion in completions] 186 | matches = [("" in content and "" in content) for content in completion_contents] 187 | return [1.0 if match else 0.0 for match in matches] 188 | 189 | reward_funcs_registry = { 190 | "format": format_reward, 191 | "json_consistency": json_consistency_reward, 192 | "json_structure": json_structure_reward, 193 | "f1_ents": f1_entities_reward, 194 | "f1_rels": f1_relations_reward, 195 | 196 | } 197 | 198 | def main(script_args, training_args, model_args): 199 | # Get reward functions 200 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] 201 | 202 | # Load the dataset 203 | with open(script_args.dataset_name, 'r', encoding='utf-8') as f: 204 | dataset = json.load(f) 205 | random.shuffle(dataset) 206 | 207 | train_dataset = dataset[:int(len(dataset)*0.8)] 208 | test_dataset = dataset[int(len(dataset)*0.8):] 209 | 210 | print(train_dataset[0]) 211 | 212 | # Initialize the GRPO trainer 213 | trainer = GRPOTrainer( 214 | model=model_args.model_name_or_path, 215 | reward_funcs=reward_funcs, 216 | args=training_args, 217 | train_dataset=train_dataset, 218 | eval_dataset=test_dataset if training_args.eval_strategy != "no" else None, 219 | peft_config=get_peft_config(model_args), 220 | ) 221 | 222 | # Train and push the model to the Hub 223 | trainer.train() 224 | 225 | # Save and push to hub 226 | trainer.save_model(training_args.output_dir) 227 | 228 | 229 | if __name__ == "__main__": 230 | parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) 231 | script_args, training_args, model_args = parser.parse_args_and_config() 232 | main(script_args, training_args, model_args) 233 | -------------------------------------------------------------------------------- /src/train_supervised.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import logging 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | import bitsandbytes as bnb 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | from peft import LoraConfig 13 | from trl import SFTTrainer 14 | from transformers import ( 15 | AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig 16 | ) 17 | 18 | # Setup logging 19 | logging.basicConfig( 20 | format='%(asctime)s - %(levelname)s - %(message)s', 21 | level=logging.INFO 22 | ) 23 | logger = logging.getLogger(__name__) 24 | 25 | # Define the dataset class 26 | class Text2JSONDataset(Dataset): 27 | def __init__(self, data, tokenizer, max_length=1024): 28 | self.tokenizer = tokenizer 29 | self.max_length = max_length 30 | self.dataset = data 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, idx): 36 | item = self.dataset[idx] 37 | chat = item['prompt'] 38 | 39 | output = item['solution'] 40 | 41 | chat.extend([ 42 | {"role": "assistant", "content": str(output)} 43 | ]) 44 | 45 | input_ = self.tokenizer.apply_chat_template( 46 | chat, 47 | tokenize=False, 48 | add_generation_prompt=False 49 | ) 50 | 51 | inputs = self.tokenizer( 52 | input_, return_tensors="pt", max_length=self.max_length, truncation=True, padding='max_length' 53 | ) 54 | 55 | input_ids = inputs["input_ids"].squeeze(0) 56 | attention_mask = inputs["attention_mask"].squeeze(0) 57 | labels_ids = input_ids.clone().squeeze(0) 58 | 59 | return { 60 | 'input_ids': input_ids, 61 | 'attention_mask': attention_mask, 62 | 'labels': labels_ids 63 | } 64 | 65 | # Helper function to find linear module names 66 | def find_all_linear_names(model, quantize=False): 67 | cls = bnb.nn.Linear4bit if quantize else torch.nn.Linear 68 | 69 | lora_module_names = set() 70 | for name, module in model.named_modules(): 71 | if isinstance(module, cls): 72 | names = name.split('.') 73 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 74 | if 'lm_head' in lora_module_names: # needed for 16 bit 75 | lora_module_names.remove('lm_head') 76 | return list(lora_module_names) 77 | 78 | # Main function 79 | def main(args): 80 | logger.info("Starting script with configuration: %s", args) 81 | 82 | QUANTIZE = args.quantize 83 | USE_LORA = args.use_lora 84 | 85 | model_path = args.model_path 86 | 87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | 89 | if QUANTIZE: 90 | bnb_config = BitsAndBytesConfig( 91 | load_in_4bit=True, 92 | bnb_4bit_use_double_quant=False, 93 | bnb_4bit_quant_type="nf4", 94 | bnb_4bit_compute_dtype="float16", 95 | ) 96 | else: 97 | bnb_config = None 98 | 99 | model = AutoModelForCausalLM.from_pretrained( 100 | model_path, 101 | device_map=device, 102 | # torch_dtype=torch.bfloat16, 103 | quantization_config=bnb_config, 104 | trust_remote_code=True, 105 | token=args.hf_token, 106 | # attn_implementation="flash_attention_2" 107 | ) 108 | 109 | if USE_LORA: 110 | modules = find_all_linear_names(model, quantize=QUANTIZE) 111 | 112 | peft_config = LoraConfig( 113 | lora_alpha=32, 114 | lora_dropout=0.1, 115 | r=64, 116 | bias="none", 117 | task_type="CAUSAL_LM", 118 | target_modules=modules, 119 | ) 120 | else: 121 | peft_config = None 122 | 123 | model.config.use_cache = False 124 | model.config.pretraining_tp = 1 125 | 126 | tokenizer = AutoTokenizer.from_pretrained(model_path) 127 | tokenizer.pad_token_id = tokenizer.eos_token_id 128 | # tokenizer.chat_template = CHAT_TEMPLATE 129 | 130 | with open(args.data_path, encoding='utf-8') as f: 131 | data = json.load(f) 132 | train_data = data[:int(len(data)*0.8)] 133 | test_data = data[int(len(data)*0.8):] 134 | train_dataset = Text2JSONDataset(train_data, tokenizer, max_length=args.max_length) 135 | test_dataset = Text2JSONDataset(test_data, tokenizer, max_length=args.max_length) 136 | 137 | logger.info("Dataset lengths - Train: %d, Test: %d", len(train_dataset), len(test_dataset)) 138 | 139 | training_arguments = TrainingArguments( 140 | output_dir=args.output_dir, 141 | num_train_epochs=args.num_train_epochs, 142 | per_device_train_batch_size=args.batch_size, 143 | gradient_accumulation_steps=args.gradient_accumulation_steps, 144 | gradient_checkpointing=args.gradient_checkpointing, 145 | optim="paged_adamw_32bit", 146 | logging_steps=args.logging_steps, 147 | learning_rate=args.learning_rate, 148 | weight_decay=args.weight_decay, 149 | fp16=args.fp16, 150 | bf16=args.bf16, 151 | max_grad_norm=args.max_grad_norm, 152 | max_steps=args.max_steps, 153 | warmup_ratio=args.warmup_ratio, 154 | group_by_length=False, 155 | lr_scheduler_type=args.lr_scheduler_type, 156 | save_total_limit=3, 157 | report_to="none", 158 | eval_strategy="steps", 159 | eval_steps=args.eval_steps, 160 | save_steps=args.save_steps, 161 | ) 162 | 163 | trainer = SFTTrainer( 164 | model=model, 165 | args=training_arguments, 166 | train_dataset=train_dataset, 167 | eval_dataset=test_dataset, 168 | peft_config=peft_config, 169 | tokenizer=tokenizer, 170 | ) 171 | 172 | trainer.train() 173 | 174 | if peft_config is not None: 175 | logger.info("Merging LoRA weights into the base model (since use_lora=True).") 176 | trainer.model = trainer.model.merge_and_unload() 177 | trainer.model.save_pretrained(os.path.join(args.output_dir, 'merged')) 178 | tokenizer.save_pretrained(os.path.join(args.output_dir, 'merged')) 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser(description="Text2JSON Dataset Training Script") 182 | 183 | parser.add_argument('--model_path', type=str, required=True, help="Path to the model.") 184 | parser.add_argument('--data_path', type=str, required=True, help="Path to the training dataset.") 185 | parser.add_argument('--output_dir', type=str, required=True, help="Directory to save trained models.") 186 | parser.add_argument('--hf_token', type=str, required=False, help="Hugging Face authentication token.") 187 | parser.add_argument('--max_length', type=int, default=4096, help="Maximum sequence length.") 188 | parser.add_argument('--num_train_epochs', type=int, default=2, help="Number of training epochs.") 189 | parser.add_argument('--batch_size', type=int, default=2, help="Training batch size.") 190 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Gradient accumulation steps.") 191 | parser.add_argument('--gradient_checkpointing', action='store_true', help="Enable gradient checkpointing.") 192 | parser.add_argument('--learning_rate', type=float, default=5e-6, help="Learning rate.") 193 | parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay.") 194 | parser.add_argument('--fp16', action='store_true', help="Enable FP16 training.") 195 | parser.add_argument('--bf16', action='store_true', help="Enable BF16 training.") 196 | parser.add_argument('--max_grad_norm', type=float, default=0.9, help="Maximum gradient norm.") 197 | parser.add_argument('--max_steps', type=int, default=-1, help="Maximum training steps.") 198 | parser.add_argument('--warmup_ratio', type=float, default=0.1, help="Warmup ratio for learning rate scheduler.") 199 | parser.add_argument('--lr_scheduler_type', type=str, default="cosine", help="Learning rate scheduler type.") 200 | parser.add_argument('--logging_steps', type=int, default=1, help="Logging steps.") 201 | parser.add_argument('--eval_steps', type=int, default=10000, help="Evaluation steps.") 202 | parser.add_argument('--save_steps', type=int, default=1000, help="Save steps.") 203 | parser.add_argument('--quantize', action='store_true', help="Enable quantization.") 204 | parser.add_argument('--use_lora', action='store_true', help="Enable LoRA training.") 205 | 206 | args = parser.parse_args() 207 | 208 | main(args) --------------------------------------------------------------------------------