├── .gitignore
├── README.md
├── assets
├── pipeline.png
└── rewards.png
├── requirements.txt
└── src
├── generate.py
├── grpo.py
└── train_supervised.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # Temp folders
174 | data/
175 | wandb/
176 | logs/
177 | eval_results/
178 | results/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # open-r1-text2graph
2 |
3 | The goal of this project is to reporduce **DeepSeek R1** training schema particulary for **text-to-graph** information extraction.
4 |
5 | The project is build on top and is inpired by Hagging Face [Open-R1](https://github.com/huggingface/open-r1/tree/main) and [trl](https://github.com/huggingface/trl/tree/main/trl)
6 |
7 | ### Structure
8 | The project currently consists of the following components:
9 | * `src` contains the scripts to train a model using GRPO or supervised learning as well as data generation scripts.
10 | * `grpo.py` - trains a model using GRPO giving text-to-graph related reward functions;
11 | * `train_supervised.py` - trains a model in a supervised manner for text-to-graph extraction;
12 | * `generate.py` - generate thinking chains giving input text and extracted JSON;
13 |
14 | ### Pipeline
15 |
16 | 
17 |
18 | The training process consists of three major stages: **synthetic data generation, supervised training, and reinforcement learning (RL) training**. Each of these stages plays a crucial role in improving the model’s ability to perform structured information extraction.
19 |
20 | 1. **Synthetic Data Generation**
21 |
22 | To bootstrap the process, we start with **data collection**, where we gather diverse text sources relevant to our target domain. The **text-to-graph** generation step, powered by **Llama 70B** structured generation, converts unstructured text into graph-based representations. However, this step is imperfect, and therefore, selecting and augmenting data becomes essential to filter out low-quality extractions and enrich the dataset with more diverse structures.
23 |
24 | Additionally, we feed generated with structured prediction JSON data and feed them and text into **DeepSeek-R1 Llama 70B** to generate a chain of thought that can explain the extraction process.
25 |
26 | We experimented with both thinking-enabled and disabled modes and discovered that small models struggle to discover some interesting and important thinking strategies.
27 |
28 | 2. **Supervised Training**
29 |
30 | Before starting reinforcement learning and considering that we use small models additional supervised training is required to push model return data in the right format, We used only 1k examples for this purpose.
31 |
32 | 3. **Reinforcement Learning with GRPO**
33 |
34 | Supervised training alone does not fully solve the problem, especially when it comes to conditioning model outputs on predefined entity and relation types. To address this, we employ **Group Relative Policy Optimization (GRPO)** for reinforcement learning.
35 |
36 | * **Format reward** ensures that the output follows a structured format, where thinking is encapsulated in a respective tag (in the case of thinking mode).
37 | * **JSON reward** specifically validates well-formed, machine-readable JSON representations and that its structure aligns with the desirable format.
38 | * **F1 reward** evaluates the accuracy of extracted entities and relations by comparing them to ground truth graphs.
39 |
40 | Below you can see how different rewards change through training steps from one of our experiments.
41 |
42 | 
43 |
44 |
45 | ### Try Model
46 |
47 | You can try one of fine-tuned model using this framework.
48 |
49 | ```python
50 | from transformers import AutoModelForCausalLM, AutoTokenizer
51 |
52 | model_name = "Ihor/Text2Graph-R1-Qwen2.5-0.5b"
53 |
54 | model = AutoModelForCausalLM.from_pretrained(
55 | model_name,
56 | torch_dtype="auto",
57 | device_map="auto"
58 | )
59 | tokenizer = AutoTokenizer.from_pretrained(model_name)
60 |
61 | text = """Your text here..."""
62 | prompt = "Analyze this text, identify the entities, and extract meaningful relationships as per given instructions:{}"
63 | messages = [
64 | {"role": "system", "content": (
65 | "You are an assistant trained to process any text and extract named entities and relations from it. "
66 | "Your task is to analyze user-provided text, identify all unique and contextually relevant entities, and infer meaningful relationships between them"
67 | "Output the annotated data in JSON format, structured as follows:\n\n"
68 | """{"entities": [{"type": entity_type_0", "text": "entity_0", "id": 0}, "type": entity_type_1", "text": "entity_1", "id": 0}], "relations": [{"head": "entity_0", "tail": "entity_1", "type": "re_type_0"}]}"""
69 | )},
70 | {"role": "user", "content": prompt.format(text)}
71 | ]
72 | text = tokenizer.apply_chat_template(
73 | messages,
74 | tokenize=False,
75 | add_generation_prompt=True
76 | )
77 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
78 |
79 | generated_ids = model.generate(
80 | **model_inputs,
81 | max_new_tokens=512
82 | )
83 | generated_ids = [
84 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
85 | ]
86 |
87 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
88 | ```
89 |
90 | ### Blog
91 |
92 | Read more details about the project in the following [blog](https://huggingface.co/blog/Ihor/replicating-deepseek-r1-for-information-extraction).
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ingvarstep/open-r1-text2graph/ba2c3f1628f2e2cc2941b916fc313d6288cc1416/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/rewards.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ingvarstep/open-r1-text2graph/ba2c3f1628f2e2cc2941b916fc313d6288cc1416/assets/rewards.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | trl==0.14.0
2 | transformers==4.48.1
3 | accelerate==1.3.0
4 | vllm==0.7.0
5 | bitsandbytes==0.45.1
6 | datasets==3.1.0
7 | peft==0.14.0
--------------------------------------------------------------------------------
/src/generate.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import json
4 | import torch
5 | import copy
6 | import random
7 | import argparse
8 | from tqdm import tqdm
9 |
10 | from transformers import AutoTokenizer
11 | from vllm import LLM, SamplingParams
12 |
13 | ZERO_SHOT_MESSAGES = [
14 | {
15 | "role": "system",
16 | "content": (
17 | "You are an assistant trained to process any text and extract named entities and relations from it. "
18 | "Your task is to analyze user-provided text, identify all unique and contextually relevant entities, and infer meaningful relationships between them"
19 | "Additionaly you will get ready extracted results and your task is to generate thinking process as you would do not knowing the final JSON output."
20 | "You need to formulate your reasoning process and encapsulate it in tag, this is the only thing you return."
21 | ),
22 | },
23 |
24 | ]
25 |
26 | ZERO_SHOT_PROMPT = """Analyze this text and JSON output and produce your thinking"""
27 |
28 |
29 | def create_chat(prompt, text):
30 | return prompt.format(text)
31 |
32 | def generate_response(llm, chats, sampling_params):
33 | responses = llm.generate(chats, sampling_params, use_tqdm=False)
34 | return responses
35 |
36 | def process_input(example):
37 | messages = copy.deepcopy(ZERO_SHOT_MESSAGES)
38 | text = messages[-1]['content']
39 | text+="Here is the JSON output:"
40 | solution = example['solution']
41 | text+=solution
42 | messages.append({
43 | "role": 'user',
44 | "content": text
45 | })
46 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
47 | return prompt
48 |
49 | def generate_dataset(examples, llm, sampling_params, batch_size=8, max_lines=None):
50 | batch_chats = []
51 | batch_examples = []
52 | text = True
53 | final_results = []
54 | for i in tqdm(range(0, max_lines)):
55 | example = examples[i]
56 |
57 | batch_examples.append(example)
58 | chat = process_input(example)
59 | batch_chats.append(chat)
60 |
61 | if len(batch_chats)==batch_size:
62 | try:
63 | responses = generate_response(llm, batch_chats, sampling_params)
64 | except Exception as err:
65 | print(err)
66 | continue
67 |
68 | batch_results = [ {'prompt': batch_examples[j]['prompt'],
69 | 'solution': response.outputs[0].text +'\n'+batch_examples[j]['solution']}
70 | for j, response in enumerate(responses)]
71 |
72 | final_results.extend(batch_results)
73 |
74 | batch_chats = []
75 | batch_examples = []
76 |
77 | return final_results
78 |
79 |
80 | if __name__ == '__main__':
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument('--data_path', type=str, default= "data/text2graph.json")
83 | parser.add_argument('--save_path', type=str, default= "data/text2graph_with_thinking.json")
84 | parser.add_argument('--model', type=str, default= "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
85 | parser.add_argument('--quantization', type=str, default= "fp8")
86 | parser.add_argument('--max_examples', type=int, default= 500)
87 | parser.add_argument('--batch_size', type=int, default= 8)
88 | parser.add_argument('--temperature', type=float, default= 0.75)
89 | args = parser.parse_args()
90 |
91 | with open(args.data_path, 'r') as f:
92 | texts = json.load(f)
93 | random.shuffle(texts)
94 | print('Texts count: ', len(texts))
95 |
96 | llm = LLM(model=args.model,
97 | max_model_len = 8129,
98 | tensor_parallel_size=1, dtype="half",
99 | gpu_memory_utilization = 0.9, quantization = args.quantization)
100 |
101 | sampling_params = SamplingParams(temperature = args.temperature, repetition_penalty = 1.1, top_k=100, max_tokens=4096, top_p=0.8, stop="")
102 |
103 | tokenizer = AutoTokenizer.from_pretrained(args.model)
104 |
105 | results = generate_dataset(texts, llm, sampling_params,
106 | batch_size=args.batch_size, max_lines=args.max_examples)
107 |
108 | with open(args.save_path, 'w') as f:
109 | json.dump(results, f, indent=1)
--------------------------------------------------------------------------------
/src/grpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | import json
17 | import random
18 | from dataclasses import dataclass, field
19 |
20 | from datasets import load_dataset
21 |
22 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
23 |
24 |
25 | @dataclass
26 | class GRPOScriptArguments(ScriptArguments):
27 | """
28 | Script arguments for the GRPO training script.
29 |
30 | Args:
31 | reward_funcs (`list[str]`):
32 | List of reward functions. Possible values: 'accuracy', 'format'.
33 | """
34 |
35 | reward_funcs: list[str] = field(
36 | default_factory=lambda: ['format', 'json_consistency', 'json_structure', 'f1_ents', 'f1_rels'],
37 | metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
38 | )
39 |
40 | def extract_json_from_text(text):
41 | json_start = 0
42 | json_end = 0
43 | close_brace_count = 0
44 | extracted_jsons = []
45 | for idx, char in enumerate(text):
46 | if char == '{':
47 | if close_brace_count == 0:
48 | json_start = idx
49 | close_brace_count += 1
50 | elif char == '}':
51 | close_brace_count -= 1
52 | if close_brace_count == 0:
53 | json_end = idx + 1
54 | extracted_json = text[json_start:json_end]
55 | try:
56 | extracted_jsons.append(json.loads(extracted_json))
57 | except json.JSONDecodeError:
58 | pass
59 | return extracted_jsons
60 |
61 | def validate_json_structure(data):
62 | required_keys = {"entities", "relations"}
63 | entity_required_keys = {"id", "text", "type"}
64 | relation_required_keys = {"head", "tail", "type"}
65 |
66 | if not isinstance(data, dict) or not required_keys.issubset(data.keys()):
67 | return False
68 |
69 | if not isinstance(data["entities"], list):
70 | return False
71 |
72 | for entity in data["entities"]:
73 | if not isinstance(entity, dict) or not entity_required_keys.issubset(entity.keys()):
74 | return False
75 | if not isinstance(entity["id"], int) or not isinstance(entity["text"], str) or not isinstance(entity["type"], str):
76 | return False
77 |
78 | if not isinstance(data["relations"], list):
79 | return False
80 |
81 | for relation in data["relations"]:
82 | if not isinstance(relation, dict) or not relation_required_keys.issubset(relation.keys()):
83 | return False
84 | if not isinstance(relation["head"], str) or not isinstance(relation["tail"], str) or not isinstance(relation["type"], str):
85 | return False
86 |
87 | return True
88 |
89 | def compute_f1(pred_list, true_list):
90 | tp = len(set(pred_list) & set(true_list))
91 | fp = len(set(pred_list) - set(true_list))
92 | fn = len(set(true_list) - set(pred_list))
93 |
94 | # Compute precision, recall, and F1 score
95 | precision = tp / (tp + fp) if tp + fp > 0 else 0
96 | recall = tp / (tp + fn) if tp + fn > 0 else 0
97 | f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
98 | return f1
99 |
100 | def get_entities_f1_score(pred_entities, true_entities):
101 | pred_list = {f"{entity['text']}_{entity['type']}" for entity in pred_entities}
102 | true_list = {f"{entity['text']}_{entity['type']}" for entity in true_entities}
103 | f1 = compute_f1(pred_list, true_list)
104 | return f1
105 |
106 | def get_relations_f1_score(pred_relations, true_relations):
107 | pred_list = {f"{rel['head']}_{rel['tail']}_{rel['type']}" for rel in pred_relations}
108 | true_list = {f"{rel['head']}_{rel['tail']}_{rel['type']}" for rel in true_relations}
109 | f1 = compute_f1(pred_list, true_list)
110 | return f1
111 |
112 | def json_consistency_reward(completions, solution, **kwargs):
113 | contents = [completion[0]["content"] for completion in completions]
114 | rewards = []
115 | for content, sol in zip(contents, solution):
116 | extracted_jsons = extract_json_from_text(content)
117 | if len(extracted_jsons)==1:
118 | rewards.append(0.1)
119 | else:
120 | rewards.append(0.0)
121 | return rewards
122 |
123 | def json_structure_reward(completions, solution, **kwargs):
124 | contents = [completion[0]["content"] for completion in completions]
125 | rewards = []
126 | for content, sol in zip(contents, solution):
127 | extracted_jsons = extract_json_from_text(content)
128 | if len(extracted_jsons)==1:
129 | extracted_json = extracted_jsons[0]
130 | val = validate_json_structure(extracted_json)
131 | if val:
132 | rewards.append(0.1)
133 | else:
134 | rewards.append(0.0)
135 | else:
136 | rewards.append(0.0)
137 | return rewards
138 |
139 | def f1_entities_reward(completions, solution, **kwargs):
140 | contents = [completion[0]["content"] for completion in completions]
141 | rewards = []
142 | for content, sol in zip(contents, solution):
143 | extracted_jsons_pred = extract_json_from_text(content)
144 | extracted_jsons_true = extract_json_from_text(sol)
145 |
146 | if len(extracted_jsons_pred)==1 and len(extracted_jsons_true)==1:
147 | json_pred = extracted_jsons_pred[0]
148 | json_true = extracted_jsons_true[0]
149 |
150 | f1_reward = 0
151 | try:
152 | f1_reward += get_entities_f1_score(json_pred['entities'], json_true['entities'])
153 | except:
154 | pass
155 |
156 | rewards.append(f1_reward)
157 | else:
158 | rewards.append(0)
159 | return rewards
160 |
161 | def f1_relations_reward(completions, solution, **kwargs):
162 | contents = [completion[0]["content"] for completion in completions]
163 | rewards = []
164 | for content, sol in zip(contents, solution):
165 | extracted_jsons_pred = extract_json_from_text(content)
166 | extracted_jsons_true = extract_json_from_text(sol)
167 |
168 | if len(extracted_jsons_pred)==1 and len(extracted_jsons_true)==1:
169 | json_pred = extracted_jsons_pred[0]
170 | json_true = extracted_jsons_true[0]
171 |
172 | f1_reward = 0
173 | try:
174 | f1_reward += get_relations_f1_score(json_pred['relations'], json_true['relations'])
175 | except:
176 | pass
177 |
178 | rewards.append(f1_reward)
179 | else:
180 | rewards.append(0)
181 | return rewards
182 |
183 | def format_reward(completions, **kwargs):
184 | """Reward function that checks if the completion has a specific format."""
185 | completion_contents = [completion[0]["content"] for completion in completions]
186 | matches = [("" in content and "" in content) for content in completion_contents]
187 | return [1.0 if match else 0.0 for match in matches]
188 |
189 | reward_funcs_registry = {
190 | "format": format_reward,
191 | "json_consistency": json_consistency_reward,
192 | "json_structure": json_structure_reward,
193 | "f1_ents": f1_entities_reward,
194 | "f1_rels": f1_relations_reward,
195 |
196 | }
197 |
198 | def main(script_args, training_args, model_args):
199 | # Get reward functions
200 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
201 |
202 | # Load the dataset
203 | with open(script_args.dataset_name, 'r', encoding='utf-8') as f:
204 | dataset = json.load(f)
205 | random.shuffle(dataset)
206 |
207 | train_dataset = dataset[:int(len(dataset)*0.8)]
208 | test_dataset = dataset[int(len(dataset)*0.8):]
209 |
210 | print(train_dataset[0])
211 |
212 | # Initialize the GRPO trainer
213 | trainer = GRPOTrainer(
214 | model=model_args.model_name_or_path,
215 | reward_funcs=reward_funcs,
216 | args=training_args,
217 | train_dataset=train_dataset,
218 | eval_dataset=test_dataset if training_args.eval_strategy != "no" else None,
219 | peft_config=get_peft_config(model_args),
220 | )
221 |
222 | # Train and push the model to the Hub
223 | trainer.train()
224 |
225 | # Save and push to hub
226 | trainer.save_model(training_args.output_dir)
227 |
228 |
229 | if __name__ == "__main__":
230 | parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
231 | script_args, training_args, model_args = parser.parse_args_and_config()
232 | main(script_args, training_args, model_args)
233 |
--------------------------------------------------------------------------------
/src/train_supervised.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import logging
5 | import argparse
6 | from tqdm import tqdm
7 |
8 | import bitsandbytes as bnb
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | from peft import LoraConfig
13 | from trl import SFTTrainer
14 | from transformers import (
15 | AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
16 | )
17 |
18 | # Setup logging
19 | logging.basicConfig(
20 | format='%(asctime)s - %(levelname)s - %(message)s',
21 | level=logging.INFO
22 | )
23 | logger = logging.getLogger(__name__)
24 |
25 | # Define the dataset class
26 | class Text2JSONDataset(Dataset):
27 | def __init__(self, data, tokenizer, max_length=1024):
28 | self.tokenizer = tokenizer
29 | self.max_length = max_length
30 | self.dataset = data
31 |
32 | def __len__(self):
33 | return len(self.dataset)
34 |
35 | def __getitem__(self, idx):
36 | item = self.dataset[idx]
37 | chat = item['prompt']
38 |
39 | output = item['solution']
40 |
41 | chat.extend([
42 | {"role": "assistant", "content": str(output)}
43 | ])
44 |
45 | input_ = self.tokenizer.apply_chat_template(
46 | chat,
47 | tokenize=False,
48 | add_generation_prompt=False
49 | )
50 |
51 | inputs = self.tokenizer(
52 | input_, return_tensors="pt", max_length=self.max_length, truncation=True, padding='max_length'
53 | )
54 |
55 | input_ids = inputs["input_ids"].squeeze(0)
56 | attention_mask = inputs["attention_mask"].squeeze(0)
57 | labels_ids = input_ids.clone().squeeze(0)
58 |
59 | return {
60 | 'input_ids': input_ids,
61 | 'attention_mask': attention_mask,
62 | 'labels': labels_ids
63 | }
64 |
65 | # Helper function to find linear module names
66 | def find_all_linear_names(model, quantize=False):
67 | cls = bnb.nn.Linear4bit if quantize else torch.nn.Linear
68 |
69 | lora_module_names = set()
70 | for name, module in model.named_modules():
71 | if isinstance(module, cls):
72 | names = name.split('.')
73 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
74 | if 'lm_head' in lora_module_names: # needed for 16 bit
75 | lora_module_names.remove('lm_head')
76 | return list(lora_module_names)
77 |
78 | # Main function
79 | def main(args):
80 | logger.info("Starting script with configuration: %s", args)
81 |
82 | QUANTIZE = args.quantize
83 | USE_LORA = args.use_lora
84 |
85 | model_path = args.model_path
86 |
87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88 |
89 | if QUANTIZE:
90 | bnb_config = BitsAndBytesConfig(
91 | load_in_4bit=True,
92 | bnb_4bit_use_double_quant=False,
93 | bnb_4bit_quant_type="nf4",
94 | bnb_4bit_compute_dtype="float16",
95 | )
96 | else:
97 | bnb_config = None
98 |
99 | model = AutoModelForCausalLM.from_pretrained(
100 | model_path,
101 | device_map=device,
102 | # torch_dtype=torch.bfloat16,
103 | quantization_config=bnb_config,
104 | trust_remote_code=True,
105 | token=args.hf_token,
106 | # attn_implementation="flash_attention_2"
107 | )
108 |
109 | if USE_LORA:
110 | modules = find_all_linear_names(model, quantize=QUANTIZE)
111 |
112 | peft_config = LoraConfig(
113 | lora_alpha=32,
114 | lora_dropout=0.1,
115 | r=64,
116 | bias="none",
117 | task_type="CAUSAL_LM",
118 | target_modules=modules,
119 | )
120 | else:
121 | peft_config = None
122 |
123 | model.config.use_cache = False
124 | model.config.pretraining_tp = 1
125 |
126 | tokenizer = AutoTokenizer.from_pretrained(model_path)
127 | tokenizer.pad_token_id = tokenizer.eos_token_id
128 | # tokenizer.chat_template = CHAT_TEMPLATE
129 |
130 | with open(args.data_path, encoding='utf-8') as f:
131 | data = json.load(f)
132 | train_data = data[:int(len(data)*0.8)]
133 | test_data = data[int(len(data)*0.8):]
134 | train_dataset = Text2JSONDataset(train_data, tokenizer, max_length=args.max_length)
135 | test_dataset = Text2JSONDataset(test_data, tokenizer, max_length=args.max_length)
136 |
137 | logger.info("Dataset lengths - Train: %d, Test: %d", len(train_dataset), len(test_dataset))
138 |
139 | training_arguments = TrainingArguments(
140 | output_dir=args.output_dir,
141 | num_train_epochs=args.num_train_epochs,
142 | per_device_train_batch_size=args.batch_size,
143 | gradient_accumulation_steps=args.gradient_accumulation_steps,
144 | gradient_checkpointing=args.gradient_checkpointing,
145 | optim="paged_adamw_32bit",
146 | logging_steps=args.logging_steps,
147 | learning_rate=args.learning_rate,
148 | weight_decay=args.weight_decay,
149 | fp16=args.fp16,
150 | bf16=args.bf16,
151 | max_grad_norm=args.max_grad_norm,
152 | max_steps=args.max_steps,
153 | warmup_ratio=args.warmup_ratio,
154 | group_by_length=False,
155 | lr_scheduler_type=args.lr_scheduler_type,
156 | save_total_limit=3,
157 | report_to="none",
158 | eval_strategy="steps",
159 | eval_steps=args.eval_steps,
160 | save_steps=args.save_steps,
161 | )
162 |
163 | trainer = SFTTrainer(
164 | model=model,
165 | args=training_arguments,
166 | train_dataset=train_dataset,
167 | eval_dataset=test_dataset,
168 | peft_config=peft_config,
169 | tokenizer=tokenizer,
170 | )
171 |
172 | trainer.train()
173 |
174 | if peft_config is not None:
175 | logger.info("Merging LoRA weights into the base model (since use_lora=True).")
176 | trainer.model = trainer.model.merge_and_unload()
177 | trainer.model.save_pretrained(os.path.join(args.output_dir, 'merged'))
178 | tokenizer.save_pretrained(os.path.join(args.output_dir, 'merged'))
179 |
180 | if __name__ == "__main__":
181 | parser = argparse.ArgumentParser(description="Text2JSON Dataset Training Script")
182 |
183 | parser.add_argument('--model_path', type=str, required=True, help="Path to the model.")
184 | parser.add_argument('--data_path', type=str, required=True, help="Path to the training dataset.")
185 | parser.add_argument('--output_dir', type=str, required=True, help="Directory to save trained models.")
186 | parser.add_argument('--hf_token', type=str, required=False, help="Hugging Face authentication token.")
187 | parser.add_argument('--max_length', type=int, default=4096, help="Maximum sequence length.")
188 | parser.add_argument('--num_train_epochs', type=int, default=2, help="Number of training epochs.")
189 | parser.add_argument('--batch_size', type=int, default=2, help="Training batch size.")
190 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Gradient accumulation steps.")
191 | parser.add_argument('--gradient_checkpointing', action='store_true', help="Enable gradient checkpointing.")
192 | parser.add_argument('--learning_rate', type=float, default=5e-6, help="Learning rate.")
193 | parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay.")
194 | parser.add_argument('--fp16', action='store_true', help="Enable FP16 training.")
195 | parser.add_argument('--bf16', action='store_true', help="Enable BF16 training.")
196 | parser.add_argument('--max_grad_norm', type=float, default=0.9, help="Maximum gradient norm.")
197 | parser.add_argument('--max_steps', type=int, default=-1, help="Maximum training steps.")
198 | parser.add_argument('--warmup_ratio', type=float, default=0.1, help="Warmup ratio for learning rate scheduler.")
199 | parser.add_argument('--lr_scheduler_type', type=str, default="cosine", help="Learning rate scheduler type.")
200 | parser.add_argument('--logging_steps', type=int, default=1, help="Logging steps.")
201 | parser.add_argument('--eval_steps', type=int, default=10000, help="Evaluation steps.")
202 | parser.add_argument('--save_steps', type=int, default=1000, help="Save steps.")
203 | parser.add_argument('--quantize', action='store_true', help="Enable quantization.")
204 | parser.add_argument('--use_lora', action='store_true', help="Enable LoRA training.")
205 |
206 | args = parser.parse_args()
207 |
208 | main(args)
--------------------------------------------------------------------------------