├── .gitignore
├── README.md
├── assets
├── apr-paper.pdf
└── apr.png
├── src
├── __init__.py
├── countdown_utils.py
├── eval
│ ├── eval_apr.py
│ ├── eval_sosp.py
│ └── parallel_infernce_utils.py
└── utils.py
├── supervised-jax
├── README.md
├── launcher.py
├── llama3_train
│ ├── .gitignore
│ ├── README.md
│ ├── gpt2_train_script.py
│ ├── llama_train
│ │ ├── __init__.py
│ │ ├── gpt2.py
│ │ ├── llama3.py
│ │ ├── optimizer.py
│ │ ├── serve.py
│ │ ├── splash.py
│ │ └── utils.py
│ ├── llama_train_script.py
│ ├── requirements.txt
│ └── setup.py
├── scripts
│ ├── hs-v3.sh
│ ├── hsp-v3.sh
│ ├── sos-v3.sh
│ ├── training_hs-v2_llama.sh
│ ├── training_hsp-v2_llama.sh
│ ├── training_sos_llama.sh
│ ├── training_sos_llama_1ep.sh
│ ├── training_sos_llama_gpt2tok.sh
│ ├── training_sos_llama_gpt2tok_1ep.sh
│ └── training_sos_xiuyu.sh
├── to_dataset.ipynb
├── training_hsp.sh
├── training_hsp_2x.sh
├── training_run.sh
└── training_run_test.sh
└── tinyrl
├── README.md
├── configs
├── apr.yaml
├── apr_cond10.yaml
└── sosp.yaml
├── data
└── .gitkeep
├── requirements.txt
├── rollout
├── apr_utils.py
└── sos_utils.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | outputs/
3 | data/
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # UV
101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | #uv.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119 | .pdm.toml
120 | .pdm-python
121 | .pdm-build/
122 |
123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124 | __pypackages__/
125 |
126 | # Celery stuff
127 | celerybeat-schedule
128 | celerybeat.pid
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 | .dmypy.json
155 | dmypy.json
156 |
157 | # Pyre type checker
158 | .pyre/
159 |
160 | # pytype static type analyzer
161 | .pytype/
162 |
163 | # Cython debug symbols
164 | cython_debug/
165 |
166 | # PyCharm
167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169 | # and can be added to the global gitignore or merged into this file. For a more nuclear
170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171 | #.idea/
172 |
173 | # Ruff stuff:
174 | .ruff_cache/
175 |
176 | # PyPI configuration file
177 | .pypirc
178 |
179 | *_ignored*
180 | checkpoints/
181 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Learning Adaptive Parallel Reasoning
with Language Models
2 |
3 |
4 | Jiayi Pan*,
5 | Xiuyu Li*,
6 | Long Lian*,
7 | Charlie Victor Snell,
8 | Yifei Zhou,
9 | Adam Yala,
10 | Trevor Darrell,
11 | Kurt Keutzer,
12 | Alane Suhr
13 |
14 |
15 |
16 | UC Berkeley and UCSF * Equal Contribution
17 |
18 |
19 |
20 | 📃 Paper
21 | •
22 | 💻 Code
23 | •
24 | 🤗 Data & Models
25 |
26 |
27 |
28 | 
29 |
30 | **TL;DR**:
31 | We present Adaptive Parallel Reasoning (APR), a novel framework that enables language models to learn to orchestrate both serialized and parallel computations. APR trains language models to use `spawn()` and `join()` operations through end-to-end supervised training and reinforcement learning, allowing models to dynamically orchestrate their own computational workflows.
32 | APR efficiently distributes compute, reduces latency, overcomes context window limits, and achieves state‑of‑the‑art performance on complex reasoning tasks (e.g., 83.4% vs. 60.0% accuracy at 4K context on Countdown).
33 |
34 | > The full code will be released soon!
35 | ## Data Preparation
36 |
37 | ## Supervised Training
38 | We use TPU-v3-128 for supervised training with a codebase building upon [JAX_llama](https://github.com/Sea-Snell/JAX_llama).
39 |
40 | Please refer to [the instructions](supervised-jax/README.md) for more details.
41 |
42 | ## Reinforcement Learning
43 | We present TinyRL, a simple implementation of the GRPO training framework for our experiments. TinyRL is a lightweight yet performant reinforcement learning library designed to be both easy to use and extend. It integrates with [SGLang](https://github.com/sgl-project/sglang) for efficient rollout. Given the small size of the model we’re training, we haven’t implemented model parallelism, so it runs on two GPUs—one for training and one for rollout
44 |
45 | It supports asynchronous, multi-turn, multi-agent rollouts through a general `rollout_fun` interface, with the minimal assumption that your rollout mechanism relies on calling an OpenAI-compatible API endpoint.
46 | ```python
47 | def rollout_fun(server_url, prefix_list, bos_token, temperature=0.5, max_workers=32):
48 | pass
49 | ```
50 |
51 | Please refer to [the instructions](tinyrl/README.md) for more details.
52 |
53 | ## Evaluation
54 |
55 | > [!IMPORTANT]
56 | > **For evaluation, SGLang needs to be patched**.
57 | > Remove this check in `python/sglang/srt/managers/tokenizer_manager.py` in your local SGLang repo:
58 | > ```
59 | > # if (
60 | > # obj.sampling_params.get("max_new_tokens") is not None
61 | > # and obj.sampling_params.get("max_new_tokens") + input_token_num
62 | > # >= self.context_len
63 | > # ):
64 | > # raise ValueError(
65 | > # f"Requested token count exceeds the model's maximum context length "
66 | > # f"of {self.context_len} tokens. You requested a total of "
67 | > # f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
68 | > # f"tokens: {input_token_num} tokens from the input messages and "
69 | > # f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
70 | > # f"completion. Please reduce the number of tokens in the input "
71 | > # f"messages or the completion to fit within the limit."
72 | > # )
73 | > ```
74 | >
75 | > This file is located at [tokenizer_manager.py](https://github.com/sgl-project/sglang/blob/45205d88a08606d5875476fbbbc76815a5107edd/python/sglang/srt/managers/tokenizer_manager.py#L350)
76 |
77 | > [!Note]
78 | > sgl-project/sglang#3721 introduces an `--allow-auto-truncate` option that makes this patch unnecessary. Once merged, you can use that directly.
79 |
80 | ### SoS+
81 |
82 | The following command evaluates the SoS+ model on the validation set.
83 | ```bash
84 | python -m src.eval.eval_sosp --ckpt --temperature --batch_size 256 --gens 1 --output_dir --num_gpus 8 --n_samples --budget
85 | ```
86 | Where `` is the number of Best-of-N samples in inference, and `` is the budget for conditional generation (leave it empty if not using conditioned models). For instance, the following command evaluates the SoS+ model with 8 samples using a unconditioned checkpoint.
87 | ```bash
88 | python -m src.eval.eval_sosp --ckpt Parallel-Reasoning/llama-sosp --temperature 1.0 --batch_size 256 --gens 1 --output_dir results/llama-sosp/ --num_gpus 8 --n_samples 8
89 | ```
90 |
91 | ### APR
92 |
93 | First, you need to start the SGLang server for the target model. For instance:
94 | ```bash
95 | python -m sglang.launch_server --served-model-name model --model-path Parallel-Reasoning/llama-apr_cond10_grpo --port 2346 --dp-size 8
96 | ```
97 |
98 | Then the following command evaluates the APR model on the validation set.
99 | ```bash
100 | python -m src.eval.eval_apr --model_name llama-apr_cond10_grpo --ckpt Parallel-Reasoning/llama-apr_cond10_grpo --temperature 1.0 --budget 10 --use_subcall_cond
101 | ```
102 | which evaluates the APR model with a budget of 10 child threads and uses child thread count conditioning. Do not use `--budget` and `--use_subcall_cond` for unconditioned models.
103 |
104 |
105 | ## Citation
106 | If you find this work useful in your research, please consider citing:
107 |
108 | ```bibtex
109 | @article{pan2025learning,
110 | title = {Learning Adaptive Parallel Reasoning with Language Models},
111 | author = {Jiayi Pan and Xiuyu Li and Long Lian and Charlie Snell and Yifei Zhou and Adam Yala and Trevor Darrell and Kurt Keutzer and Alane Suhr},
112 | year = {2025},
113 | journal = {arXiv preprint arXiv: 2504.15466}
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/assets/apr-paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/assets/apr-paper.pdf
--------------------------------------------------------------------------------
/assets/apr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/assets/apr.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/src/__init__.py
--------------------------------------------------------------------------------
/src/eval/eval_apr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 | from concurrent.futures import ThreadPoolExecutor, as_completed
7 | from termcolor import colored
8 |
9 | from transformers import AutoTokenizer
10 | from litellm import text_completion, APIError
11 | from src.eval.parallel_infernce_utils import (
12 | get_search_result,
13 | get_main_trace_after_sub_search,
14 | get_subsearch_info,
15 | check_solution
16 | )
17 |
18 | def parse_bool(value):
19 | if value.lower() in ['true', '1', 'yes', 'y']:
20 | return True
21 | elif value.lower() in ['false', '0', 'no', 'n']:
22 | return False
23 | else:
24 | raise ValueError(f"Invalid boolean value: {value}")
25 |
26 | # Example command for sglang:
27 | # python -m sglang.launch_server --served-model-name model --model-path Parallel-Reasoning/llama-apr_cond10_grpo --port 2346 --dp-size 8
28 |
29 | # Parse command line arguments
30 | parser = argparse.ArgumentParser(description='Run APR inference experiments')
31 | parser.add_argument('--model_name', type=str, default="llama-apr",
32 | help='Model name')
33 | parser.add_argument("-d", "--data",type=str, default="data/val.json")
34 | parser.add_argument("--ckpt", type=str, help="path to checkpoint")
35 | parser.add_argument('--disable_parallel_inference', action='store_true', default=False,
36 | help='Whether to use parallel inference')
37 | parser.add_argument('--use_subcall_cond', action='store_true', default=False,
38 | help='Whether to use subcall count conditioning instead of token count')
39 | parser.add_argument('--max_workers', type=int, default=16,
40 | help='Maximum number of workers for parallel inference')
41 | parser.add_argument('--max_tokens', type=int, default=4096,
42 | help='Maximum tokens for generation')
43 | parser.add_argument('--temperature', type=float, default=1.0,
44 | help='Temperature for generation')
45 | parser.add_argument('--budget', type=int, default=None,
46 | help='Budget for generation')
47 | parser.add_argument('--output_dir', type=str,
48 | default=None,
49 | help='Directory to save results. If None, defaults to "../results/{model_name}/val_apr"')
50 |
51 | args = parser.parse_args()
52 |
53 | # Set global variables from arguments
54 | MODEL_NAME = args.model_name
55 | PARALLEL_INFERENCE = not args.disable_parallel_inference
56 | USE_SUBCALL_COND = args.use_subcall_cond
57 | MAX_WORKERS = args.max_workers
58 | TEMPERATURE = args.temperature
59 | BUDGET = args.budget
60 | SAVE_DIR = args.output_dir
61 |
62 | # Load validation data
63 | data_path = args.data
64 | with open(data_path, "r") as f:
65 | val_data = json.load(f)
66 |
67 | # Initialize tokenizer
68 | ckpt = args.ckpt
69 | tokenizer = AutoTokenizer.from_pretrained(ckpt)
70 | bos_token = tokenizer.bos_token
71 |
72 | # API configuration
73 | api_base_url = "http://127.0.0.1:2346/v1"
74 | api_key = "api_key"
75 | model_name = "model"
76 | max_tokens = args.max_tokens
77 |
78 | # Configuration for text generation
79 | ADD_ANGLE_BRACKETS = False
80 | ADD_BOS = True
81 |
82 | # BUDGET and TEMPERATURE will be handled outside in bash script
83 |
84 | def add_angle_brackets(text):
85 | lines = text.split('\n')
86 | result_lines = []
87 | for line in lines:
88 | if '>' in line and '<' not in line:
89 | line = '<' + line
90 | result_lines.append(line)
91 | return '\n'.join(result_lines)
92 |
93 | def generate(prefix, stop = [], temperature = 0.0):
94 | if ADD_BOS:
95 | prefix = bos_token + prefix
96 | result = text_completion(
97 | model=f"openai/{model_name}",
98 | prompt=prefix,
99 | api_base=api_base_url,
100 | api_key=api_key,
101 | temperature=temperature,
102 | max_tokens=max_tokens,
103 | stop=stop,
104 | )
105 | text = result['choices'][0]['text']
106 | complete_text = prefix + text
107 | complete_text = complete_text.replace(bos_token, ' ')
108 | if complete_text[0] == ' ':
109 | complete_text = complete_text[1:]
110 | if ADD_ANGLE_BRACKETS:
111 | complete_text = add_angle_brackets(complete_text)
112 | return complete_text, result
113 |
114 | def add_all_calls(trace_dict):
115 | """Recursively collect all call traces from a trace_dict."""
116 | all_calls = []
117 | all_calls += trace_dict['main_calls']
118 | for sub in trace_dict.get('sub_calls', []):
119 | for sub_trace in sub:
120 | all_calls += add_all_calls(sub_trace)
121 | return all_calls
122 |
123 | def calculate_tokens(item, ds_name="apr"):
124 | token_count = 0
125 | if 'apr' in ds_name:
126 | seqs = add_all_calls(item['trace_dict'])
127 |
128 | if len(seqs) > 1:
129 | # Find all sequences that start with "Moving to Node #0"
130 | root_seqs = [seq for seq in seqs if "Moving to Node #0\n" in seq]
131 |
132 | # Sort root sequences by length (shortest first)
133 | root_seqs = sorted(root_seqs, key=len)
134 |
135 | # Calculate total tokens without considering redundancy
136 | total_tokens = 0
137 | # Calculate total token count and track longest sequence in one pass
138 | longest_seq_tokens = 0
139 | for seq in seqs:
140 | tokens = tokenizer.encode(tokenizer.bos_token + seq + tokenizer.eos_token)
141 | total_tokens += len(tokens)
142 | longest_seq_tokens = max(longest_seq_tokens, len(tokens))
143 |
144 | # Calculate redundant tokens between root sequences
145 | redundant_tokens = 0
146 | if len(root_seqs) > 1:
147 | # Find common prefixes between each pair of sequences
148 | for i in range(len(root_seqs) - 1):
149 | j = i + 1
150 | seq1 = root_seqs[i]
151 | seq2 = root_seqs[j]
152 |
153 | # Find common prefix
154 | prefix_len = 0
155 | for k in range(min(len(seq1), len(seq2))):
156 | if seq1[k] == seq2[k]:
157 | prefix_len += 1
158 | else:
159 | break
160 |
161 | if prefix_len > 0:
162 | common_prefix = seq1[:prefix_len]
163 | # Count tokens in this prefix
164 | prefix_tokens = len(tokenizer.encode(common_prefix)) - 2 # Subtract BOS/EOS
165 | redundant_tokens += max(0, prefix_tokens)
166 |
167 | # Final token count is total minus redundant
168 | token_count = total_tokens - redundant_tokens
169 |
170 | item['longest_seq_token_count'] = longest_seq_tokens
171 | sub_calls = [seq for seq in seqs if not "Moving to Node #0\n" in seq]
172 | item['avg_seq_token_count'] = token_count / (len(sub_calls) + 1)
173 | else:
174 | tokens = tokenizer.encode(tokenizer.bos_token + seqs[0] + tokenizer.eos_token)
175 | token_count = len(tokens)
176 | item['longest_seq_token_count'] = token_count
177 | item['avg_seq_token_count'] = token_count
178 | else:
179 | seq = item['search_path']
180 | tokens = tokenizer.encode(tokenizer.bos_token + seq + tokenizer.eos_token)
181 | token_count = len(tokens)
182 |
183 | item['token_count'] = token_count
184 | return item
185 |
186 |
187 | def decode_trace(prefix, temperature):
188 | # we should never let the model generate
189 | # whenever it happens, we replace it with
190 | while True:
191 | trace = generate(prefix, stop = [""], temperature = temperature)
192 | prefix = trace[0]
193 | if trace[1].choices[0].matched_stop == "":
194 | prefix += ""
195 | else:
196 | break
197 | prefix = trace[0]
198 | if prefix.split('\n')[-1] == "":
199 | # TODO: why is this happening?
200 | prefix = prefix[:-1]
201 | return prefix
202 |
203 | if PARALLEL_INFERENCE:
204 | def batch_decode_trace(prefix_list, temperature):
205 | with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
206 | # Submit all tasks and store futures
207 | future_to_prefix = {executor.submit(decode_trace, prefix, temperature): prefix for prefix in prefix_list}
208 |
209 | # Initialize results list with the same length as prefix_list
210 | results = [None] * len(prefix_list)
211 |
212 | # As futures complete, store results in the correct order
213 | for future in as_completed(future_to_prefix):
214 | prefix = future_to_prefix[future]
215 | try:
216 | result = future.result()
217 | original_idx = prefix_list.index(prefix)
218 | results[original_idx] = result
219 | except Exception as e:
220 | print(f"Error processing prefix: {e}")
221 | original_idx = prefix_list.index(prefix)
222 | results[original_idx] = None
223 |
224 | return results
225 | else:
226 | def batch_decode_trace(prefix_list, temperature):
227 | # TODO make this parallel in the future
228 | result_list = []
229 | for prefix in prefix_list:
230 | result_list.append(decode_trace(prefix, temperature))
231 | return result_list
232 |
233 |
234 | def is_calling_subsearch(trace):
235 | return "" in trace.split('\n')[-1]
236 | def call_search(prefix, temperature, budget=None, avg_sub_call=False):
237 | try:
238 | trace_dict = {"main_calls": [], "sub_calls": []}
239 | trace = decode_trace(prefix, temperature)
240 | trace_dict["main_calls"].append(trace)
241 | while is_calling_subsearch(trace):
242 | sub_search_prefix_list, sub_search_nodes = get_subsearch_info(trace, budget, avg_sub_call)
243 | # call sub searchers
244 | # TODO: this assumes we only nest one level of sub searchers
245 | # In the future, we need to support nested sub searchers by recursion
246 | sub_search_traces = batch_decode_trace(sub_search_prefix_list, temperature)
247 | trace_dict["sub_calls"].append([])
248 | for sub_search_trace in sub_search_traces:
249 | trace_dict["sub_calls"][-1].append({"main_calls": [sub_search_trace]})
250 | sub_search_results = [get_search_result(trace) for trace in sub_search_traces]
251 | new_prefix = get_main_trace_after_sub_search(trace, sub_search_nodes, sub_search_results)
252 | trace = decode_trace(new_prefix, temperature)
253 | trace_dict["main_calls"].append(trace)
254 | return get_search_result(trace), trace_dict
255 | except APIError as e:
256 | print(f"Error at call_search: {e}")
257 | raise e
258 | except Exception as e:
259 | print(f"Error at call_search: {e}")
260 | return None, trace_dict
261 |
262 |
263 | def get_prefix(dataset, idx, is_subcall_cond=False, use_budget=None):
264 | dp = dataset[idx]
265 | sub_call_budget = None
266 | prefix = f"Moving to Node #0\nCurrent State: {dp['target']}:{dp['nums']}, Operations: []"
267 |
268 | if is_subcall_cond:
269 | # Just use the provided budget for subcall conditioning
270 | assert use_budget is not None, "Subcall conditioning requires a budget"
271 | subcall_count = use_budget
272 | prefix = f"Sub Call Budget: {subcall_count} " + prefix
273 |
274 | if is_subcall_cond:
275 | # For subcall conditioning, we don't need to set sub_call_budget
276 | sub_call_budget = None
277 | elif sub_call_budget is not None:
278 | sub_call_budget = ((sub_call_budget - 1) // 512 + 1) * 512
279 | else:
280 | sub_call_budget = dp['token_count'] if 'token_count' in dp else None
281 | return prefix, sub_call_budget
282 |
283 |
284 | def run_inference_experiment(
285 | model_name,
286 | val_data,
287 | n_tasks=1000,
288 | temperature=0.0,
289 | budget=None,
290 | use_subcall_cond=False,
291 | parallel_inference=True,
292 | max_workers=16,
293 | save_dir=None
294 | ):
295 | """
296 | Run an inference experiment with the specified configuration.
297 |
298 | Args:
299 | model_name (str): Name of the model to use
300 | val_data (list): Validation data
301 | n_tasks (int): Number of tasks to run
302 | temperature (float): Temperature for generation
303 | budget (int, optional): Token budget. If None, use default from data
304 | use_subcall_cond (bool): Whether to use subcall count conditioning
305 | parallel_inference (bool): Whether to use parallel inference
306 | max_workers (int): Maximum number of workers for parallel inference
307 | save_dir (str, optional): Directory to save results. If None, use default
308 |
309 | Returns:
310 | dict: Results dictionary with trajectories, trace_dicts, ratings, true_ratings
311 | """
312 | os.makedirs(save_dir, exist_ok=True)
313 | base_name = f"{model_name.replace('/', '_')}_{n_tasks}_0_temp_{str(temperature).replace('.','_')}"
314 | if budget is not None:
315 | base_name += f"_budget_{budget}"
316 |
317 | save_path = os.path.join(save_dir, f"{base_name}.json")
318 |
319 | # Check if results already exist
320 | if os.path.exists(save_path):
321 | print(colored(f"Results already exist at {save_path}. Loading...", "green"))
322 | with open(save_path, 'r') as f:
323 | results = json.load(f)
324 |
325 | # Calculate and print metrics
326 | true_ratings = results["true_ratings"]
327 | ratings = results["ratings"]
328 | true_success_rate = np.mean(true_ratings)
329 | success_rate = np.mean(ratings)
330 |
331 | print(f"model: {model_name}, temperature: {temperature}, budget: {budget}\n"
332 | f"success_rate: {true_success_rate:.3f}, unverified_success_rate: {success_rate:.3f}")
333 |
334 | return results
335 |
336 | # If results don't exist, run the experiment
337 | success_count = []
338 | true_success_count = []
339 | logs = []
340 |
341 | pbar = tqdm(range(n_tasks))
342 |
343 | def process_sample(i):
344 | prefix, sub_call_budget = get_prefix(
345 | dataset=val_data, idx=i,
346 | is_subcall_cond=use_subcall_cond,
347 | use_budget=budget,
348 | )
349 | out = call_search(prefix, temperature, budget=None, avg_sub_call=False)
350 | solution = out[0]
351 | true_success = check_solution(prefix, solution) if solution is not None else False
352 | return out, solution is not None, true_success
353 |
354 | if parallel_inference:
355 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
356 | futures = [executor.submit(process_sample, i) for i in range(n_tasks)]
357 |
358 | for i, future in enumerate(as_completed(futures)):
359 | out, success, true_success = future.result()
360 | logs.append(out)
361 | success_count.append(success)
362 | true_success_count.append(true_success)
363 | pbar.update(1)
364 | pbar.set_postfix({
365 | 'true_success_rate': f"{np.mean(true_success_count):.3f}",
366 | 'success_rate': f"{np.mean(success_count):.3f}"
367 | })
368 | else:
369 | for i in pbar:
370 | out, success, true_success = process_sample(i)
371 | logs.append(out)
372 | success_count.append(success)
373 | true_success_count.append(true_success)
374 | pbar.update(1)
375 | pbar.set_postfix({
376 | 'true_success_rate': f"{np.mean(true_success_count):.3f}",
377 | 'success_rate': f"{np.mean(success_count):.3f}"
378 | })
379 |
380 | # Prepare results
381 | trajectories = [x[0] if x[0] is not None else "" for x in logs]
382 | trace_dicts = [x[1] for x in logs]
383 | ratings = [1 if x else 0 for x in success_count]
384 | true_ratings = [1 if x else 0 for x in true_success_count]
385 |
386 | print(colored("Results Summary:", "cyan", attrs=["bold"]))
387 | print(colored(f"Model: ", "yellow") + colored(f"{model_name}", "green") +
388 | colored(f", Temperature: ", "yellow") + colored(f"{temperature}", "green") +
389 | colored(f", Budget: ", "yellow") + colored(f"{budget}", "green"))
390 | # print(colored(f"Unverified Success Rate: ", "yellow") + colored(f"{np.mean(success_count):.4f}", "green"))
391 | print(colored(f"Success Rate: ", "yellow") + colored(f"{np.mean(true_success_count):.4f}", "green"))
392 |
393 | results = {
394 | "trajectories": trajectories,
395 | "trace_dicts": trace_dicts,
396 | "ratings": ratings,
397 | "true_ratings": true_ratings
398 | }
399 |
400 |
401 | os.makedirs(save_dir, exist_ok=True)
402 | base_name = f"{model_name}_{n_tasks}_0_temp_{str(temperature).replace('.','_')}"
403 | if budget is not None:
404 | if use_subcall_cond:
405 | base_name += f"_subcall_budget_{budget}"
406 | else:
407 | base_name += f"_budget_{budget}"
408 |
409 | # Save results
410 | print(f"Saving results to {save_path}")
411 | with open(save_path, 'w') as f:
412 | json.dump(results, f)
413 |
414 | return results
415 |
416 |
417 | if __name__ == "__main__":
418 | print(colored("Model name: ", "cyan", attrs=["bold"]) + colored(MODEL_NAME, "yellow"))
419 | print(colored("Data path: ", "cyan", attrs=["bold"]) + colored(data_path, "yellow"))
420 | print(colored("Checkpoint: ", "cyan", attrs=["bold"]) + colored(ckpt, "yellow"))
421 |
422 | print(colored("Conditions: ", "cyan", attrs=["bold"]) +
423 | colored(f"use_subcall_cond={USE_SUBCALL_COND}", "yellow"))
424 | print(colored("Inference: ", "cyan", attrs=["bold"]) +
425 | colored(f"parallel_inference={PARALLEL_INFERENCE}, "
426 | f"max_workers={MAX_WORKERS}, "
427 | f"max_tokens={max_tokens}", "yellow"))
428 | print(colored("Generation: ", "cyan", attrs=["bold"]) +
429 | colored(f"temperature={TEMPERATURE}, budget={BUDGET}", "yellow"))
430 |
431 | # Set default save directory if not provided
432 | save_directory = SAVE_DIR if SAVE_DIR is not None else f"results/{MODEL_NAME}/val_apr"
433 | print(colored(f"save_dir: {save_directory}", "cyan"))
434 |
435 | run_inference_experiment(
436 | model_name=MODEL_NAME,
437 | val_data=val_data,
438 | temperature=TEMPERATURE,
439 | budget=BUDGET,
440 | use_subcall_cond=USE_SUBCALL_COND,
441 | parallel_inference=PARALLEL_INFERENCE,
442 | max_workers=MAX_WORKERS,
443 | save_dir=save_directory
444 | )
445 |
--------------------------------------------------------------------------------
/src/eval/eval_sosp.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import json
3 | import random
4 | import argparse
5 | import tqdm
6 | import numpy as np
7 | from transformers import AutoTokenizer
8 | import sglang as sgl
9 | from src.countdown_utils import *
10 | from src.utils import seed_everything
11 | from termcolor import colored
12 | import re
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--seed", type=int, default=4)
16 | parser.add_argument("--ckpt", type=str, help="path to checkpoint")
17 | parser.add_argument("--tknz", type=str, help="path to tokenizer")
18 | parser.add_argument("--n_samples", type=int, default=1, help="Number of samples for best-of-n evaluation")
19 | parser.add_argument("-d", "--data",type=str, default="data/val.json")
20 | parser.add_argument("--temperature", type=float, default=0.0)
21 | parser.add_argument("--batch_size", type=int, default=64)
22 | parser.add_argument("--max_tokens", type=int, default=131072)
23 | parser.add_argument("--gens", type=int, default=1)
24 | parser.add_argument("--num_gpus", type=int, default=1)
25 | parser.add_argument("--output_dir", type=str, default="results/")
26 | parser.add_argument("--skip-save-results", action="store_true", help="Skip saving results to file")
27 | parser.add_argument("--num_token_cond", action="store_true", help="Use num token condition")
28 | parser.add_argument("--budget", type=int, default=None, help="Set a fixed budget condition")
29 |
30 | # Calculate mean token count
31 | def count_tokens(text, tokenizer):
32 | cleaned_text = re.sub(r'^.*?Current State:', 'Current State:', text)
33 | # Add offset (so it matches the token count in the train/val data json)
34 | return len(tokenizer.encode(cleaned_text)) + 2
35 |
36 | def eval_ll(args):
37 | """
38 | Evaluate the model on the data using sglang
39 | """
40 | with open(args.data, "r") as json_file:
41 | raw_data = json.load(json_file)
42 |
43 | if not args.skip_save_results:
44 | output_dir = os.path.join(args.output_dir, os.path.splitext(os.path.basename(args.data))[0])
45 | os.makedirs(output_dir, exist_ok=True)
46 | base_name = f"{args.ckpt.split('outputs/')[-1].replace('/','_')}_temp_{str(args.temperature).replace('.','_')}"
47 | if args.budget is not None:
48 | base_name += f"_budget_{args.budget}"
49 | if args.n_samples > 1:
50 | base_name += f"_n_samples_{args.n_samples}"
51 | results_file = os.path.join(output_dir, f"{base_name}.json")
52 | print(f"Results file: {results_file}")
53 | if os.path.exists(results_file):
54 | print(f"Loading existing results from {results_file}")
55 | with open(results_file, 'r') as f:
56 | results = json.load(f)
57 | pred_ratings = results['ratings']
58 |
59 | # Initialize tokenizer for token counting
60 | tokenizer = AutoTokenizer.from_pretrained(args.tknz) if args.tknz else AutoTokenizer.from_pretrained(args.ckpt)
61 |
62 | token_counts = [count_tokens(pred, tokenizer) for pred in results['trajectories']]
63 | mean_token_count = sum(token_counts) / len(token_counts)
64 |
65 | print(colored("Results Summary:", "cyan", attrs=["bold"]))
66 | print(colored(f"Mean token count: ", "yellow") + colored(f"{mean_token_count:.2f}", "green"))
67 | print(colored(f"Model Accuracy: ", "yellow") + colored(f"{np.mean([r > 0 for r in pred_ratings]):.4f}", "green"))
68 | sys.exit(0)
69 |
70 | # Initialize sglang engine
71 | llm = sgl.Engine(
72 | model_path=args.ckpt,
73 | tokenizer_path=args.tknz if args.tknz else args.ckpt,
74 | allow_auto_truncate=True,
75 | # log_level='warning',
76 | # tp_size=args.num_gpus,
77 | # dp is slightly faster than tp due to the small model size; also sgl gpt2 has bug with tp
78 | dp_size=args.num_gpus,
79 | )
80 |
81 | # Prepare prompts
82 | tokenizer = AutoTokenizer.from_pretrained(args.tknz) if args.tknz else AutoTokenizer.from_pretrained(args.ckpt)
83 | if args.num_token_cond or args.budget:
84 | def get_token_budget(sample):
85 | if args.budget is not None:
86 | return args.budget
87 | # Define token budget bins (512, 1024, 1536, 2048, 2560, 3072, 3584, 4096)
88 | budget_bins = list(range(512, 4096+1, 512))
89 | token_count = sample['token_count']
90 | # Find the appropriate budget bin
91 | budget = 4096 if token_count > 4096 else \
92 | next((bin_value for bin_value in budget_bins if token_count <= bin_value), 4096)
93 | return budget
94 |
95 | test_prompts = [
96 | f"{tokenizer.bos_token}Token Budget: {get_token_budget(sample)} "
97 | f"Current State: {sample['target']}:{sample['nums']}, Operations: []"
98 | for sample in raw_data
99 | ]
100 | else:
101 | test_prompts = [
102 | f"{tokenizer.bos_token}Current State: {sample['target']}:{sample['nums']}, Operations: []"
103 | for sample in raw_data
104 | ]
105 | len_nums = [len(sample['nums']) for sample in raw_data]
106 | data = [d for d, l in zip(test_prompts, len_nums) if l == 4]
107 |
108 | # Set up sampling parameters
109 | sampling_params = {
110 | "temperature": args.temperature,
111 | "max_new_tokens": args.max_tokens,
112 | "top_k": 50,
113 | }
114 |
115 | # Process in batches
116 | batch_size = args.batch_size * args.num_gpus
117 |
118 | if args.n_samples > 1:
119 | # Initialize data structures for multiple samples
120 | all_sample_predictions = []
121 | all_sample_ratings = []
122 | all_sample_reasons = []
123 |
124 | # Run inference n_samples times
125 | for sample_idx in range(args.n_samples):
126 | print(colored(f"Running sample {sample_idx+1}/{args.n_samples}", "cyan", attrs=["bold"]))
127 | predictions = []
128 |
129 | for b in tqdm.trange(0, len(data), batch_size):
130 | batch = data[b:min(b+batch_size, len(data))]
131 |
132 | if args.gens == 1:
133 | # Generate outputs
134 | outputs = llm.generate(batch, sampling_params)
135 | # Combine prompts with generated text
136 | batch_predictions = [prompt + output['text'] for prompt, output in zip(batch, outputs)]
137 | predictions.extend(batch_predictions)
138 | else:
139 | assert args.temperature > 0.0, "Temperature must be greater than 0 for sampling"
140 | all_outputs = []
141 | all_ratings = []
142 |
143 | # Generate multiple times for each prompt
144 | for _ in range(args.gens):
145 | outputs = llm.generate(batch, sampling_params)
146 | # Combine prompts with generated text for each generation
147 | output_texts = [prompt + output['text'] for prompt, output in zip(batch, outputs)]
148 | # Get rating for each output
149 | ratings = [metric_fn(ot, mode="sft")[0] for ot in output_texts]
150 | all_ratings.append(ratings)
151 | all_outputs.append(output_texts)
152 |
153 | # Convert to numpy array for easier processing
154 | all_ratings = np.array(all_ratings)
155 | print(all_ratings)
156 | print(f"average rating", np.mean(all_ratings))
157 |
158 | # Get the best output for each prompt
159 | max_ratings = np.argmax(all_ratings, axis=0)
160 | max_rating_vals = np.max(all_ratings, axis=0)
161 | print(f"max ratings", np.mean(max_rating_vals))
162 |
163 | # Select the best outputs
164 | batch_predictions = [all_outputs[max_r][i] for i, max_r in enumerate(max_ratings)]
165 | predictions.extend(batch_predictions)
166 |
167 | # Rate outputs for this sample
168 | pred_ratings = []
169 | pred_reasons = []
170 | for i, pred in enumerate(predictions):
171 | rating, reason = metric_fn(pred, mode="sft")
172 | pred_ratings.append(rating)
173 | pred_reasons.append(reason)
174 |
175 | # Store results for this sample
176 | all_sample_predictions.append(predictions)
177 | all_sample_ratings.append(pred_ratings)
178 | all_sample_reasons.append(pred_reasons)
179 |
180 | all_sample_ratings_array = np.array(all_sample_ratings)
181 | binary_correctness = all_sample_ratings_array > 0
182 |
183 | # Print results
184 | print(colored("Results Summary:", "cyan", attrs=["bold"]))
185 | print(colored(f"Number of samples: ", "yellow") + colored(f"{args.n_samples}", "green"))
186 | print(colored(f"Individual sample accuracies: ", "yellow") +
187 | colored(f"{[np.mean(binary_correctness[i]) for i in range(args.n_samples)]}", "green"))
188 | # TODO: cons@n and pass@n
189 |
190 | # Save results
191 | if not args.skip_save_results:
192 | with open(results_file, "w") as f:
193 | json.dump({
194 | "n_samples": args.n_samples,
195 | "individual_sample_accuracies": [float(np.mean(binary_correctness[i])) for i in range(args.n_samples)],
196 | "sample_trajectories": all_sample_predictions,
197 | "sample_ratings": all_sample_ratings_array.tolist(),
198 | "sample_reasons": all_sample_reasons,
199 | }, f, indent=4)
200 |
201 | else:
202 | # Original single-sample code
203 | predictions = []
204 | pred_ratings = []
205 | pred_reasons = []
206 |
207 | for b in tqdm.trange(0, len(data), batch_size):
208 | batch = data[b:min(b+batch_size, len(data))]
209 |
210 | if args.gens == 1:
211 | # Generate outputs
212 | outputs = llm.generate(batch, sampling_params)
213 | # Combine prompts with generated text
214 | batch_predictions = [prompt + output['text'] for prompt, output in zip(batch, outputs)]
215 | predictions.extend(batch_predictions)
216 | else:
217 | assert args.temperature > 0.0, "Temperature must be greater than 0 for sampling"
218 | all_outputs = []
219 | all_ratings = []
220 |
221 | # Generate multiple times for each prompt
222 | for _ in range(args.gens):
223 | outputs = llm.generate(batch, sampling_params)
224 | # Combine prompts with generated text for each generation
225 | output_texts = [prompt + output['text'] for prompt, output in zip(batch, outputs)]
226 | # Get rating for each output
227 | ratings = [metric_fn(ot, mode="sft")[0] for ot in output_texts]
228 | all_ratings.append(ratings)
229 | all_outputs.append(output_texts)
230 |
231 | # Convert to numpy array for easier processing
232 | all_ratings = np.array(all_ratings)
233 | print(all_ratings)
234 | print(f"average rating", np.mean(all_ratings))
235 |
236 | # Get the best output for each prompt
237 | max_ratings = np.argmax(all_ratings, axis=0)
238 | max_rating_vals = np.max(all_ratings, axis=0)
239 | print(f"max ratings", np.mean(max_rating_vals))
240 |
241 | # Select the best outputs
242 | batch_predictions = [all_outputs[max_r][i] for i, max_r in enumerate(max_ratings)]
243 | predictions.extend(batch_predictions)
244 |
245 | # Rate outputs
246 | true_rating = []
247 | for i, pred in enumerate(predictions):
248 | rating, reason = metric_fn(pred, mode="sft")
249 | tr, _ = metric_fn(f"{raw_data[i]['search_path']}", mode="sft")
250 | pred_ratings.append(rating)
251 | true_rating.append(tr)
252 | pred_reasons.append(reason)
253 |
254 | token_counts = [count_tokens(pred, tokenizer) for pred in predictions]
255 | mean_token_count = sum(token_counts) / len(token_counts)
256 |
257 | # Print results
258 | pred_ratings = np.array(pred_ratings)
259 | print(colored("Results Summary:", "cyan", attrs=["bold"]))
260 | print(colored(f"Mean token count: ", "yellow") + colored(f"{mean_token_count:.2f}", "green"))
261 | print(colored(f"Model Accuracy: ", "yellow") + colored(f"{np.mean([r > 0 for r in pred_ratings]):.4f}", "green"))
262 | print(colored(f"Original Symbolic Solver Accuracy: ", "yellow") + colored(f"{np.mean([r > 0 for r in true_rating]):.4f}", "green"))
263 |
264 | # Save results
265 | if not args.skip_save_results:
266 | with open(results_file, "w") as f:
267 | json.dump({
268 | "trajectories": predictions,
269 | "ratings": pred_ratings.tolist(),
270 | "reasons": pred_reasons,
271 | "token_counts": token_counts,
272 | "mean_token_count": mean_token_count
273 | }, f, indent=4)
274 |
275 | if __name__ == "__main__":
276 | args = parser.parse_args()
277 | seed_everything(args.seed)
278 |
279 | print(colored("Evaluating model: ", "cyan", attrs=["bold"]) + colored(args.ckpt, "yellow"))
280 | print(colored("Data file: ", "cyan", attrs=["bold"]) + colored(args.data, "yellow"))
281 | print(colored("Temperature: ", "cyan", attrs=["bold"]) + colored(args.temperature, "yellow"))
282 | print(colored("Number of GPUs: ", "cyan", attrs=["bold"]) + colored(args.num_gpus, "yellow"))
283 | if args.n_samples > 1:
284 | print(colored("Number of samples (best-of-n): ", "cyan", attrs=["bold"]) + colored(args.n_samples, "yellow"))
285 | if args.num_token_cond:
286 | print(colored("Using token condition", "cyan", attrs=["bold"]))
287 | if args.budget:
288 | print(colored("Using a fixed budget: ", "cyan", attrs=["bold"]) + colored(args.budget, "yellow"))
289 | # eval
290 | eval_ll(args)
291 |
--------------------------------------------------------------------------------
/src/eval/parallel_infernce_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def check_solution(prefix: str, solution: str) -> bool:
4 | """
5 | Parses the prefix and solution to verify if the solution actually
6 | solves the puzzle of reaching `target` from the given list of numbers.
7 |
8 | :param prefix: A line like:
9 | "Moving to Node #1\nCurrent State: 62:[5, 50, 79, 27], Operations: []"
10 | :param solution: The multiline string describing the step-by-step solution.
11 | :return: True if the solution's final result matches the target and
12 | all stated operations are valid. False otherwise.
13 | """
14 | # -----------------------------------------------------------------
15 | # 1. Parse the prefix to extract target and initial numbers
16 | # -----------------------------------------------------------------
17 | # Example prefix line to parse:
18 | # Current State: 62:[5, 50, 79, 27], Operations: []
19 | #
20 | # We'll look for something matching:
21 | # Current State: :[], ...
22 | prefix_pattern = r"Current State:\s*(\d+):\[(.*?)\]"
23 | match = re.search(prefix_pattern, prefix)
24 | if not match:
25 | print("ERROR: Could not parse the prefix for target and numbers.")
26 | return False
27 |
28 | target_str, numbers_str = match.groups()
29 | target = int(target_str.strip())
30 | # Now parse something like "5, 50, 79, 27" into a list of integers
31 | if numbers_str.strip():
32 | initial_numbers = [int(x.strip()) for x in numbers_str.split(",")]
33 | else:
34 | initial_numbers = []
35 |
36 | # We'll keep track of our current working list of numbers
37 | current_numbers = initial_numbers
38 |
39 | # -----------------------------------------------------------------
40 | # 2. Parse solution to extract lines with "Exploring Operation:"
41 | # -----------------------------------------------------------------
42 | # Example lines:
43 | # Exploring Operation: 79-27=52, Resulting Numbers: [5, 50, 52]
44 | # We want to parse out: operand1=79, operator='-', operand2=27, result=52
45 | # Then parse the new list: [5, 50, 52]
46 |
47 | operation_pattern = r"Exploring Operation:\s*([\d]+)([\+\-\*/])([\d]+)=(\d+),\s*Resulting Numbers:\s*\[(.*?)\]"
48 |
49 | # We'll process the solution line-by-line
50 | # so that we can also capture the final "Goal Reached" line.
51 | lines = solution.splitlines()
52 |
53 | for line in lines:
54 | line = line.strip()
55 |
56 | # Check for "Exploring Operation"
57 | op_match = re.search(operation_pattern, line)
58 | if op_match:
59 | # Parse out the operation parts
60 | x_str, op, y_str, z_str, new_nums_str = op_match.groups()
61 | x_val = int(x_str)
62 | y_val = int(y_str)
63 | z_val = int(z_str)
64 |
65 | # Parse the new list of numbers from something like "5, 50, 52"
66 | new_numbers = []
67 | if new_nums_str.strip():
68 | new_numbers = [int(n.strip()) for n in new_nums_str.split(",")]
69 |
70 | # -------------------------------------------------------------
71 | # Verify that applying X op Y => Z to current_numbers is valid
72 | # -------------------------------------------------------------
73 | # 1. X and Y must both be present in current_numbers
74 | # 2. Remove X and Y from current_numbers
75 | # 3. Add Z
76 | # 4. The new list must match exactly the "Resulting Numbers"
77 | # 5. Also verify the arithmetic was correct (if you want to be strict)
78 | #
79 | # NOTE: we do not handle repeating values carefully here if X or Y
80 | # appear multiple times, but you can adapt as needed (e.g. remove once).
81 |
82 | temp_list = current_numbers[:]
83 |
84 | # Try removing X, Y once each
85 | try:
86 | temp_list.remove(x_val)
87 | temp_list.remove(y_val)
88 | except ValueError:
89 | print(f"ERROR: {x_val} or {y_val} not found in current_numbers {current_numbers}.")
90 | return False
91 |
92 | # Check that the stated Z matches the arithmetic operation
93 | # (If you want to skip verifying the math, remove these lines.)
94 | computed_result = None
95 | if op == '+':
96 | computed_result = x_val + y_val
97 | elif op == '-':
98 | computed_result = x_val - y_val
99 | elif op == '*':
100 | computed_result = x_val * y_val
101 | elif op == '/':
102 | # watch for zero division or non-integer division if you want
103 | # to require exact integer results
104 | if y_val == 0:
105 | print("ERROR: Division by zero encountered.")
106 | return False
107 | # For a typical "24 game" style puzzle, we allow float or integer check
108 | computed_result = x_val / y_val
109 |
110 | # Compare the stated z_val to the computed result
111 | # (if it's integer-based arithmetic, we might check int(...) or round)
112 | if computed_result is None:
113 | print("ERROR: Unknown operation encountered.")
114 | return False
115 |
116 | # If we want exact integer match (for e.g. 50/5=10):
117 | # If float is possible, we might do a small epsilon check:
118 | # e.g. if abs(computed_result - z_val) > 1e-9
119 | if computed_result != z_val:
120 | print(f"ERROR: Operation {x_val}{op}{y_val} does not equal {z_val}. Got {computed_result} instead.")
121 | return False
122 |
123 | # Now add the result to temp_list
124 | temp_list.append(z_val)
125 | # Sort if you do not care about order, or keep order if you do
126 | # and compare to new_numbers
127 | # We'll assume exact order is not critical, so let's do a sorted comparison:
128 | if sorted(temp_list) != sorted(new_numbers):
129 | print(f"ERROR: After applying {x_val}{op}{y_val}={z_val} to {current_numbers}, "
130 | f"got {sorted(temp_list)} but solution says {sorted(new_numbers)}.")
131 | return False
132 |
133 | # If we got here, it means the operation is consistent
134 | current_numbers = new_numbers
135 |
136 | # ---------------------------------------------------------
137 | # 3. Check for "Goal Reached" line
138 | # ---------------------------------------------------------
139 | # Something like: "62,62 equal: Goal Reached"
140 | # We'll check if the final single number is indeed `target`.
141 | if "Goal Reached" in line:
142 | # For a simple check, if "Goal Reached" is present,
143 | # confirm that current_numbers is [target].
144 | if len(current_numbers) == 1 and current_numbers[0] == target:
145 | return True
146 | else:
147 | print("ERROR: 'Goal Reached' declared but final numbers don't match the target.")
148 | return False
149 |
150 | # If we never saw "Goal Reached," then it's incomplete
151 | # or didn't declare success. Return False by default
152 | print("ERROR: Did not find 'Goal Reached' in solution.")
153 | return False
154 |
155 | def get_search_result(search_trace):
156 | # Given a search trace, return the result of the search
157 | # If the search is successful, return the result optimal path
158 | # If the search is unsuccessful, return None
159 | if search_trace.count("Goal Reached") >= 2:
160 | # Find all occurrences of "Goal Reached"
161 | goal_indices = [i for i in range(len(search_trace)) if search_trace.startswith("Goal Reached", i)]
162 | # Get the second to last index, this is where we begin generate
163 | # the optimal path
164 | goal_idx = goal_indices[-2]
165 | return search_trace[goal_idx:].strip()[13:]
166 | else:
167 | return None
168 |
169 | def get_subsearch_info(search_trace, budget=None, avg_sub_call=False):
170 | try:
171 | return _get_subsearch_info(search_trace, budget, avg_sub_call)
172 | except Exception as e:
173 | print(f"Error at get_subsearch_info: {e}")
174 | print(search_trace)
175 | raise e
176 |
177 | def _get_subsearch_info(search_trace, budget=None, avg_sub_call=False):
178 | # Given a search trace, return the information of the
179 | # subsearch that it wants to invoke
180 | # sub_search= {"node": "#1,1,2", "target": 39, 'nums':[2, 11], "operations": ["51-49=2", "36-25=11"]}
181 | last_line = search_trace.split("\n")[-1]
182 | assert "" in last_line, "This is not a valid subsearch trace"
183 |
184 | # --- Parse the search trace to get the generated nodes ---
185 | generated_nodes = {}
186 | # First find any "Moving to Node" lines followed by "Current State" lines
187 | lines = search_trace.split("\n")
188 | for i in range(len(lines)-1):
189 | # Moving to Node #1,1
190 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2']
191 | if "Moving to Node #" in lines[i] and "Current State:" in lines[i+1]:
192 | # Extract node id from first line like:
193 | # Moving to Node #1,1
194 | node_id = lines[i].split("Moving to Node #")[1].strip()
195 |
196 | # Extract state from second line like:
197 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2']
198 | state_line = lines[i+1]
199 | state_part = state_line.split("Current State:")[1].split("],")[0].strip()
200 | operations_part = state_line.split("Operations:")[1].strip()
201 |
202 | # Parse state like "39:[25, 36, 2]"
203 | target = int(state_part.split(":")[0])
204 | # nums = eval(state_part.split(":")[1].strip())
205 | nums = eval(state_part.split(":")[1].strip() + "]")
206 | operations = eval(operations_part)
207 |
208 | # Parse operations list
209 |
210 | generated_nodes[node_id] = {
211 | "node": f"#{node_id}",
212 | "target": target,
213 | "nums": nums,
214 | "operations": operations
215 | }
216 | for line in search_trace.split("\n"):
217 | if "Generated Node" in line:
218 | # Extract node id and info from line like:
219 | # Generated Node #1,1,2: 39:[2, 11] Operation: 36-25=11
220 | node_id = line.split(":")[0].split("#")[1]
221 | if node_id in generated_nodes:
222 | continue
223 | rest = line.split(":", 1)[1].strip()
224 | state = rest.split("Operation:")[0].strip()
225 | operation = rest.split("Operation:")[1].strip()
226 |
227 | # Parse state like "39:[2, 11]" into target and nums
228 | target = int(state.split(":")[0])
229 | nums = eval(state.split(":")[1].strip())
230 |
231 | parent_node_id = ",".join(node_id.split(",")[:-1])
232 | parent_node = generated_nodes[parent_node_id]
233 | new_operations = parent_node["operations"] + [operation]
234 |
235 | generated_nodes[node_id] = {
236 | "node": f"#{node_id}",
237 | "target": target,
238 | "nums": nums,
239 | "operations": new_operations
240 | }
241 | # then we construct the sub_searches
242 | sub_search_nodes = []
243 | # Split on and take the last chunk
244 | last_chunk = search_trace.split("\n")[-1]
245 | # Split that chunk on and take first part
246 | sub_search_section = last_chunk.split("\n")[0]
247 |
248 | for line in sub_search_section.split("\n"):
249 | if " Moving to Node #1,1,2
252 | node_id = line.split("Moving to Node #")[1].strip()
253 | sub_search_nodes.append(generated_nodes[node_id])
254 |
255 | if avg_sub_call:
256 | assert budget is not None, "Budget must be provided if avg_sub_call is True"
257 | budget = budget / len(sub_search_nodes)
258 | budget = int((budget - 1) // 512 + 1) * 512
259 |
260 | def construct_sub_search_prefix(node):
261 | # exmaple
262 | # "Moving to Node #1,1,0\nCurrent State: 39:[36, 50], Operations: ['51-49=2', '25*2=50']
263 | prefix = f"Moving to Node {node['node']}\nCurrent State: {node['target']}:[{', '.join(map(str, node['nums']))}], Operations: {node['operations']}"
264 | if budget is not None:
265 | prefix = f"Token Budget: {budget} {prefix}"
266 | return prefix
267 | sub_search_prefix_list = [construct_sub_search_prefix(node) for node in sub_search_nodes]
268 | return sub_search_prefix_list, sub_search_nodes
269 |
270 | def get_main_trace_after_sub_search(main_trace, sub_search_nodes, sub_search_result_list):
271 | last_line = main_trace.split("\n")[-1]
272 | assert "" in last_line, "This is not a valid subsearch trace"
273 | sub_searches = []
274 | # Split on and take the last chunk
275 | last_chunk = main_trace.split("\n")[-1]
276 | # Split that chunk on and take first part
277 | sub_search_section = last_chunk.split("\n")[0]
278 | main_trace_after_sub_search = "\n".join(main_trace.split("\n")[:-1])
279 | assert main_trace_after_sub_search in main_trace
280 | main_trace_after_sub_search += "\n"
281 | for i, (this_node, this_result) in enumerate(zip(sub_search_nodes, sub_search_result_list)):
282 | if this_result is None:
283 | #
284 | main_trace_after_sub_search += f"\n"
285 | else:
286 | # \nMoving to Node #1,2,0\nCurrent State: 39:[51, 12], Operations: ['49-36=13', '25-13=12']\nExploring Operation: 51-12=39, Resulting Numbers: [39]\n39,39 equal: Goal Reached\n
287 | main_trace_after_sub_search += f"\n"
288 | main_trace_after_sub_search += this_result + "\n"
289 | main_trace_after_sub_search += "\n"
290 | return main_trace_after_sub_search
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def seed_everything(seed: int):
8 | random.seed(seed)
9 | os.environ['PYTHONHASHSEED'] = str(seed)
10 | np.random.seed(seed)
11 | torch.manual_seed(seed)
12 | torch.cuda.manual_seed(seed)
13 | torch.backends.cudnn.deterministic = True
14 | torch.backends.cudnn.benchmark = True
--------------------------------------------------------------------------------
/supervised-jax/README.md:
--------------------------------------------------------------------------------
1 | # Example Supervised Training Instructions
2 | > Fork from [JAX_llama](https://github.com/Sea-Snell/JAX_llama)
3 |
4 | First do the following:
5 |
6 | To orchestrate the TPU pod
7 | ```
8 | pip install git+https://github.com/Sea-Snell/tpu_pod_launcher.git@main
9 | ```
10 |
11 | There is a config in training_run.sh and a launcher defined in `launcher.py`.
12 |
13 | The launcher basically uses a tool to run the installation and training script on all hosts in the pod at once.
14 |
15 | You will probably want to edit the launcher's `available_tpus` to reflect the TPUs you have access to.
16 |
17 | You will also need a google cloud token in a json file somewhere, which has write permissions to buckets. You should change the path in `line 51` in the launcher to point to this file.
18 |
19 | You may also want to modify the ssh info and copy path in `lines 56-69` in the launcher.
20 |
21 | Make sure you set the API keys correctly at the top of the `training_run.sh` script. And also edit the bucket paths as needed.
22 |
23 | The data should be stored in a json list of strings:
24 |
25 | ```
26 | [
27 | "seq1",
28 | "seq2",
29 | ...
30 | ]
31 | ```
32 |
33 | Use the `to_dataset.ipynb` notebook to convert the dataset to the correct format and upload to GCS.
34 |
35 |
36 | To install all dependencies on the TPU hosts run:
37 |
38 | ```
39 | python launcher.py setup --project=jiayi-128-eu
40 | -2
41 | ```
42 |
43 | You only need to do this once for each TPU pod.
44 |
45 | where you_tpu_name refers to the name of the TPU in the `available_tpus` list in the launcher.
46 |
47 | To launch the training run:
48 |
49 | ```
50 | conda activate GPML
51 | cd /home/jiayipan/code/25SP/LM-Parallel/JAX-Train
52 | python launcher.py launch training_sos-split-digit-v2.sh --project=jiayi-128-eu
53 | python launcher.py launch training_hsp_2x.sh --project=jiayi-128-eu
54 | python launcher.py launch training_run.sh --project=jiayi-128-eu
55 | python launcher.py launch scripts/training_sos_llama_10ep_v2.sh --project=jiayi-128-eu
56 | python launcher.py launch scripts/hsp-v3.sh --project=jiayi-128-eu
57 | python launcher.py launch scripts/hs-v3.sh --project=jiayi-128-eu-2
58 | python launcher.py launch scripts/sos-v3.sh --project=jiayi-64-eu
59 |
60 | This will: 1) copy the latest version of `llama3_train` to the TPUs; 2) stop anything running on the TPUs; 3) run the training script on the TPUs.
61 |
62 | To print the output of the training run, you can run:
63 |
64 | ```
65 | python launcher.py check --project=your_tpu_name
66 | python launcher.py check --project=jiayi-128-eu
67 | ```
68 |
69 | To terminate an ongoing training run, you can run:
70 |
71 | ```
72 | python launcher.py stop --project=your_tpu_name
73 | ```
74 |
75 | The 3 mesh dimensions in the config currently correspond to (replica,fsdp,tensor). We can also in add a sequence parallel dimension without much difficulty if needed.
76 |
77 | ### Test Model
78 | Quickly test the model with FLAX/JAX
79 | https://colab.research.google.com/drive/1X7ElvcwrAk5nt_dkAsZUZo6IOp4nsW3L?usp=sharing
80 |
81 | ### Export To PyTorch
82 | You can use this script to easily export the model to pytorch (huggingface compatible).
83 |
84 | https://colab.research.google.com/drive/1XD3bJ1PQuHKwSO2cB0NCc6BkyJoptF19?usp=sharing
85 |
86 |
--------------------------------------------------------------------------------
/supervised-jax/launcher.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tpu_pod_launcher import TPUPodClient, TPUPodProject, create_cli
3 |
4 | SETUP_SCRIPT = """\
5 | cd ~/
6 | # install basics
7 | apt-get update -q \
8 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \
9 | apt-utils \
10 | curl \
11 | git \
12 | vim \
13 | wget \
14 | tmux \
15 | redis-server \
16 | && apt-get clean \
17 | && rm -rf /var/lib/apt/lists/*
18 |
19 | # install miniforge
20 | rm -rf ~/Miniconda3-py39_4.12.0-Linux-x86_64.sh
21 | wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -P ~/
22 | bash ~/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b
23 |
24 | # install dependencies
25 | source ~/miniconda3/bin/activate
26 | conda init bash
27 | conda create -n llama3_train python=3.10 -y
28 | conda activate llama3_train
29 | cd ~/llama3_train
30 | python -m pip install -e .
31 | pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
32 | python -m pip install tyro flax scalax transformers gcsfs optax wandb
33 | pip install -U transformers==4.47.1 flax==0.10.2
34 |
35 | # clean up
36 | cd ~/
37 | rm -rf ~/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh
38 | """.strip()
39 |
40 | CHECK_DEVICES = r"""
41 | source ~/miniconda3/bin/activate llama3_train
42 | python -c "import jax; print(jax.devices())"
43 | """.strip()
44 |
45 | def check_devices(project: TPUPodProject, verbose: bool=False):
46 | project.ssh(CHECK_DEVICES, verbose=verbose)
47 |
48 | def setup(project: TPUPodProject, verbose: bool=False):
49 | project.copy(verbose=verbose)
50 | project.ssh(SETUP_SCRIPT, verbose=verbose)
51 | project.ssh('mkdir ~/.config/', verbose=verbose)
52 | project.ssh('mkdir ~/.config/gcloud/', verbose=verbose)
53 | project.scp('/home/jiayipan/code/25SP/TPU-Train/civic-boulder-204700-3052e43e8c80.json', '~/.config/gcloud/', verbose=verbose)
54 |
55 | def debug(project: TPUPodProject, verbose: bool=False):
56 | import IPython; IPython.embed()
57 |
58 | def create_project(tpu_name: str, zone: str) -> TPUPodProject:
59 | return TPUPodProject(
60 | client=TPUPodClient(
61 | tpu_project='civic-boulder-204700',
62 | tpu_zone=zone,
63 | user='jiayipan',
64 | key_path='/home/jiayipan/.ssh/id_rsa',
65 | ),
66 | tpu_name=tpu_name,
67 | copy_dirs=[('/home/jiayipan/code/25SP/LM-Parallel/JAX-Train/llama3_train/', '~/llama3_train/')],
68 | working_dir='~/llama3_train/',
69 | copy_excludes=['.git', '__pycache__', '*.pkl', '*.json', '*.jsonl', '*.ipynb'],
70 | kill_commands=['pkill -9 python'],
71 | )
72 |
73 | if __name__ == "__main__":
74 | launch_config_path = os.path.join(os.path.dirname(__file__), 'launch_config.json')
75 |
76 | available_tpus = [
77 | ('jiayi-64-eu', 'europe-west4-a'), # v3-64
78 | ('jiayi-128-eu', 'europe-west4-a'), # v3-128
79 | ('jiayi-128-eu-2', 'europe-west4-a'), # v3-128
80 | ]
81 |
82 | tpu_projects = {name: create_project(name, zone) for name, zone in available_tpus}
83 |
84 | create_cli(
85 | projects=tpu_projects,
86 | setup=setup,
87 | custom_commands={'debug': debug, 'check_devices': check_devices},
88 | launch_config_path=launch_config_path,
89 | )
90 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | __pycache__/
3 | .DS_Store
4 | wandb/
5 | *.pyc
6 | dump.rdb
7 | *.egg-info/
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/README.md:
--------------------------------------------------------------------------------
1 | # LLaMA-3 Train
2 |
3 | ## Installation
4 |
5 | ```
6 | python -m pip install -e .
7 | ```
8 |
9 | Also see [here](https://github.com/google/jax#instructions) to install whatever version of jax you need for your accelerator.
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/gpt2_train_script.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Any, Optional
2 | import json
3 | from functools import partial
4 | import os
5 | import collections
6 | import tempfile
7 |
8 | import tyro
9 | import jax
10 | from scalax.sharding import MeshShardingHelper, TreePathShardingRule
11 | from flax.training.train_state import TrainState
12 | import numpy as np
13 | import itertools
14 | import pickle as pkl
15 | from transformers import AutoTokenizer
16 | from tqdm.auto import tqdm
17 | import optax
18 | from jax.sharding import PartitionSpec as PS
19 |
20 | from llama_train.utils import (
21 | load_checkpoint, open_with_bucket, save_checkpoint,
22 | get_float_dtype_by_name, get_weight_decay_mask,
23 | cross_entropy_loss_and_accuracy, global_norm,
24 | average_metrics, WandbLogger, delete_with_bucket,
25 | jax_distributed_initalize, jax_distributed_barrier
26 | )
27 | from llama_train.gpt2 import (
28 | GPT2Config, FlaxGPT2LMHeadModel,
29 | )
30 | from llama_train.optimizer import load_adamw_optimizer, load_palm_optimizer
31 |
32 | def process_pretrain_example(
33 | seq: str,
34 | max_length: int,
35 | tokenizer: AutoTokenizer,
36 | ):
37 | tokenization = tokenizer(
38 | [tokenizer.bos_token+seq+tokenizer.eos_token],
39 | padding='max_length',
40 | truncation=True,
41 | max_length=max_length+1,
42 | return_tensors='np',
43 | )
44 |
45 | input_ids = tokenization.input_ids[:, :-1]
46 | target_ids = tokenization.input_ids[:, 1:]
47 | attention_mask = tokenization.attention_mask[:, :-1]
48 | position_ids = np.maximum(np.cumsum(attention_mask, axis=-1) - 1, 0)
49 | loss_masks = attention_mask
50 |
51 | batch_items = dict(
52 | input_ids=input_ids,
53 | attention_mask=attention_mask,
54 | position_ids=position_ids,
55 | target_ids=target_ids,
56 | loss_masks=loss_masks,
57 | )
58 | return batch_items
59 |
60 |
61 | def checkpointer(
62 | path: str,
63 | train_state: Any,
64 | config: Any,
65 | gather_fns: Any,
66 | metadata: Any=None,
67 | save_optimizer_state: bool=False,
68 | save_float_dtype: str='bf16',
69 | active=True,
70 | ):
71 | if not path.startswith('gcs://'):
72 | os.makedirs(path, exist_ok=True)
73 | if save_optimizer_state:
74 | checkpoint_state = train_state
75 | if not active:
76 | checkpoint_path = '/dev/null'
77 | else:
78 | checkpoint_path = os.path.join(path, 'train_state.msgpack')
79 | checkpoint_gather_fns = gather_fns
80 | else:
81 | checkpoint_state = train_state.params
82 | if not active:
83 | checkpoint_path = '/dev/null'
84 | else:
85 | checkpoint_path = os.path.join(path, 'params.msgpack')
86 | checkpoint_gather_fns = gather_fns.params
87 | metadata_path = os.path.join(path, 'metadata.pkl')
88 | config_path = os.path.join(path, 'config.json')
89 |
90 | save_checkpoint(
91 | checkpoint_state,
92 | checkpoint_path,
93 | gather_fns=checkpoint_gather_fns,
94 | float_dtype=save_float_dtype,
95 | )
96 | if active:
97 | with open_with_bucket(metadata_path, 'wb') as f:
98 | pkl.dump(metadata, f)
99 | with open_with_bucket(config_path, 'w') as f:
100 | json.dump(config, f)
101 |
102 | def main(
103 | load_model: str,
104 | train_data_path: str,
105 | eval_data_path: str,
106 | output_dir: Optional[str],
107 | sharding: str,
108 | num_train_steps: int,
109 | max_length: int,
110 | bsize: int,
111 | log_freq: int,
112 | num_eval_steps: int,
113 | save_model_freq: int,
114 | wandb_project: str,
115 | param_dtype: str='fp32',
116 | activation_dtype: str='fp32',
117 | optim_config: str='adamw:{}',
118 | logger_config: str='{}',
119 | checkpointer_config: str='{}',
120 | model_config_override: str='{}',
121 | inputs_tokenizer_override: str='{}',
122 | outputs_tokenizer_override: str='{}',
123 | jax_distributed_initalize_config: str='{}',
124 | save_initial_checkpoint: bool=False,
125 | log_initial_step: bool=True,
126 | max_checkpoints: Optional[int]=None,
127 | eval_bsize: Optional[int]=None,
128 | physical_axis_splitting: bool=False,
129 | shuffle_train_data: bool=True,
130 | hf_repo_id: str='LM-Parallel/sample',
131 | ):
132 | args_dict = dict(locals())
133 | print(args_dict)
134 | sharding: List[int] = list(map(lambda x: int(x.strip()), sharding.split(',')))
135 | if eval_bsize is None:
136 | eval_bsize = bsize
137 |
138 | param_dtype = get_float_dtype_by_name(param_dtype)
139 | activation_dtype = get_float_dtype_by_name(activation_dtype)
140 |
141 | logger_config: Dict[str, Any] = json.loads(logger_config)
142 | checkpointer_config: Dict[str, Any] = json.loads(checkpointer_config)
143 | model_config_override: Dict[str, Any] = json.loads(model_config_override)
144 | inputs_tokenizer_override: Dict[str, Any] = json.loads(inputs_tokenizer_override)
145 | outputs_tokenizer_override: Dict[str, Any] = json.loads(outputs_tokenizer_override)
146 | jax_distributed_initalize_config: Dict[str, Any] = json.loads(jax_distributed_initalize_config)
147 |
148 | jax_distributed_initalize(**jax_distributed_initalize_config)
149 | jax_distributed_barrier()
150 |
151 | if optim_config.startswith('adamw:'):
152 | optim_config = json.loads(optim_config[len('adamw:'):])
153 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple()))
154 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1)
155 | optimizer, optimizer_info = load_adamw_optimizer(**optim_config)
156 | elif optim_config.startswith('palm:'):
157 | optim_config = json.loads(optim_config[len('palm:'):])
158 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple()))
159 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1)
160 | optimizer, optimizer_info = load_palm_optimizer(**optim_config)
161 | else:
162 | raise ValueError(f'Unknown optimizer config: {optim_config}')
163 | if grad_accum_steps > 1:
164 | optimizer = optax.MultiSteps(
165 | optimizer,
166 | grad_accum_steps,
167 | )
168 |
169 | mesh = MeshShardingHelper(sharding, ['replica', 'fsdp', 'tensor'], mesh_axis_splitting=physical_axis_splitting) # Create a 3D mesh with data, fsdp, and model parallelism axes
170 | with mesh.get_context():
171 | print('mesh:', mesh.mesh)
172 | print('loading model ...')
173 |
174 | if load_model.startswith('paths:'):
175 | model_paths = json.loads(load_model[len('paths:'):])
176 | if not ('remove_dict_prefix' in model_paths):
177 | model_paths['remove_dict_prefix'] = None
178 | else:
179 | raise ValueError(f'Unknown model info type: {load_model}')
180 |
181 | config_is_temp = False
182 | if 'config' in model_paths and model_paths['config'].startswith('gcs://'):
183 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False)
184 | with open_with_bucket(model_paths['config'], 'rb') as f:
185 | temp_file.write(f.read())
186 | temp_file.close()
187 | model_paths['config'] = temp_file.name
188 | config_is_temp = True
189 |
190 | if 'config' in model_paths:
191 | config = GPT2Config.from_pretrained(model_paths['config'], **model_config_override)
192 | # elif 'default_config_name' in model_paths:
193 | # config = GPT2Config(**GPT2_STANDARD_CONFIGS[model_paths['default_config_name']], **model_config_override)
194 | else:
195 | config = GPT2Config(**model_config_override)
196 |
197 | if config_is_temp:
198 | os.remove(model_paths['config'])
199 |
200 | model = FlaxGPT2LMHeadModel(config, dtype=activation_dtype, _do_init=False, param_dtype=param_dtype, input_shape=(bsize, 1024))
201 | # TODO: embedding dim is hardcoded to 1024, it's
202 |
203 | tokenizer_is_temp = False
204 | if model_paths['tokenizer'].startswith('gcs://'):
205 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False)
206 | with open_with_bucket(model_paths['tokenizer'], 'rb') as f:
207 | temp_file.write(f.read())
208 | temp_file.close()
209 | model_paths['tokenizer'] = temp_file.name
210 | tokenizer_is_temp = True
211 |
212 | tokenizer_kwargs = dict(
213 | truncation_side='right',
214 | padding_side='right',
215 | )
216 | tokenizer_kwargs.update(outputs_tokenizer_override)
217 | tokenizer = AutoTokenizer.from_pretrained(model_paths['tokenizer'], **tokenizer_kwargs)
218 | tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
219 |
220 | if tokenizer_is_temp:
221 | os.remove(model_paths['tokenizer'])
222 |
223 | sharding_rules = TreePathShardingRule(*config.get_partition_rules())
224 |
225 | @partial(
226 | mesh.sjit,
227 | in_shardings=(sharding_rules,),
228 | out_shardings=sharding_rules,
229 | )
230 | def create_train_state_from_params(params):
231 | return TrainState.create(params=params, tx=optimizer, apply_fn=None)
232 |
233 | @partial(
234 | mesh.sjit,
235 | in_shardings=(PS(),),
236 | out_shardings=sharding_rules,
237 | )
238 | def init_fn(rng):
239 | params = model.init_weights(rng, (bsize, 1024))
240 | return create_train_state_from_params(params)
241 |
242 | train_state_shape = jax.eval_shape(lambda: init_fn(jax.random.PRNGKey(0)))
243 | shard_train_state_fns, gather_train_state_fns = mesh.make_shard_and_gather_fns(train_state_shape, sharding_rules)
244 |
245 | if 'params' in model_paths:
246 | train_state = create_train_state_from_params(load_checkpoint(
247 | model_paths['params'],
248 | shard_fns=shard_train_state_fns.params,
249 | remove_dict_prefix=model_paths['remove_dict_prefix'],
250 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape.params),
251 | ))
252 | elif 'train_state' in model_paths:
253 | train_state = load_checkpoint(
254 | model_paths['train_state'],
255 | shard_fns=shard_train_state_fns,
256 | remove_dict_prefix=model_paths['remove_dict_prefix'],
257 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape),
258 | )
259 | else:
260 | train_state = init_fn(jax.random.PRNGKey(0))
261 |
262 | print('model loaded.')
263 |
264 | @partial(
265 | mesh.sjit,
266 | in_shardings=(sharding_rules, PS(),PS()),
267 | out_shardings=(sharding_rules, PS()),
268 | args_sharding_constraint=(sharding_rules, None, PS(('replica', 'fsdp'))),
269 | donate_argnums=(0,),
270 | )
271 | def train_step(train_state, rng, batch):
272 | def loss_and_accuracy(params):
273 | logits = model(
274 | input_ids=batch['input_ids'],
275 | attention_mask=batch['attention_mask'],
276 | position_ids=batch['position_ids'],
277 | params=params,
278 | dropout_rng=rng,
279 | train=True,
280 | ).logits
281 | return cross_entropy_loss_and_accuracy(
282 | logits, batch['target_ids'], batch['loss_masks'],
283 | )
284 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
285 | (loss, accuracy), grads = grad_fn(train_state.params)
286 | train_state = train_state.apply_gradients(grads=grads)
287 | metrics = dict(
288 | loss=loss,
289 | accuracy=accuracy,
290 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
291 | gradient_norm=global_norm(grads),
292 | param_norm=global_norm(train_state.params),
293 | )
294 | return train_state, metrics
295 |
296 | @partial(
297 | mesh.sjit,
298 | in_shardings=(sharding_rules, PS()),
299 | out_shardings=PS(),
300 | args_sharding_constraint=(sharding_rules, PS(('replica', 'fsdp'))),
301 | )
302 | def eval_step(params, batch):
303 | logits = model(
304 | input_ids=batch['input_ids'],
305 | attention_mask=batch['attention_mask'],
306 | position_ids=batch['position_ids'],
307 | params=params,
308 | train=False,
309 | ).logits
310 | loss, accuracy = cross_entropy_loss_and_accuracy(
311 | logits, batch['target_ids'], batch['loss_masks'],
312 | )
313 | metrics = dict(
314 | eval_loss=loss,
315 | eval_accuracy=accuracy,
316 | )
317 | return metrics
318 |
319 | print('loading data ...')
320 | if train_data_path.endswith('jsonl'):
321 | train_examples = []
322 | with open_with_bucket(train_data_path, 'r') as f:
323 | for line in f:
324 | train_examples.append(json.loads(line.strip()))
325 | else:
326 | with open_with_bucket(train_data_path, 'r') as f:
327 | train_examples = json.load(f)
328 | if eval_data_path.endswith('jsonl'):
329 | eval_examples = []
330 | with open_with_bucket(eval_data_path, 'r') as f:
331 | for line in f:
332 | eval_examples.append(json.loads(line.strip()))
333 | else:
334 | with open_with_bucket(eval_data_path, 'r') as f:
335 | eval_examples = json.load(f)
336 | print('done.')
337 |
338 | def data_iterable(data_items, rng, bsize, shuffle=True, loop=True):
339 | while True:
340 | with jax.default_device(jax.devices('cpu')[0]):
341 | idxs = []
342 | for _ in range((bsize + (len(data_items) - 1)) // len(data_items)):
343 | if shuffle:
344 | rng, subrng = jax.random.split(rng)
345 | curr_idxs = jax.random.permutation(subrng, np.arange(len(data_items)))
346 | idxs.extend(curr_idxs.tolist())
347 | else:
348 | curr_idxs = np.arange(len(data_items))
349 | idxs.extend(curr_idxs.tolist())
350 | idxs = np.asarray(idxs)
351 | for batch_idx in range(len(idxs) // bsize):
352 | batch_idxs = idxs[batch_idx*bsize:(batch_idx+1)*bsize]
353 | batch_examples = [data_items[idx] for idx in batch_idxs]
354 | processed_batch_examples = []
355 | for example in batch_examples:
356 | processed_batch_examples.append(process_pretrain_example(
357 | example,
358 | max_length,
359 | tokenizer,
360 | ))
361 | batch = dict()
362 | for key in processed_batch_examples[0]:
363 | batch[key] = np.concatenate([example[key] for example in processed_batch_examples], axis=0)
364 | yield batch
365 | if not loop:
366 | break
367 |
368 | if 'enable' not in logger_config:
369 | logger_config['enable'] = (jax.process_index() == 0)
370 | if 'config_to_log' in logger_config:
371 | logger_config['config_to_log'].update(args_dict)
372 | else:
373 | logger_config['config_to_log'] = args_dict
374 | logger = WandbLogger(
375 | wandb_project,
376 | output_dir=output_dir,
377 | **logger_config,
378 | )
379 |
380 | checkpoint_queue = collections.deque()
381 |
382 | def _save_checkpoint(
383 | train_state,
384 | step,
385 | ):
386 | old_step = None
387 | if (max_checkpoints is not None) and (len(checkpoint_queue) >= max_checkpoints):
388 | old_step = checkpoint_queue.popleft()
389 | if logger.can_save():
390 | print(f'saving checkpoint at step {step} ...')
391 | # delete old checkpoint if max checkpoints is reached
392 | if old_step is not None:
393 | old_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{old_step}')
394 | delete_with_bucket(old_path, recursive=True)
395 |
396 | metadata = dict(
397 | step=step,
398 | args_dict=args_dict,
399 | )
400 |
401 | checkpointer(
402 | path=os.path.join(logger.output_dir, 'checkpoints', f'step_{step}'),
403 | train_state=train_state,
404 | config=config.to_dict(),
405 | gather_fns=gather_train_state_fns,
406 | metadata=metadata,
407 | active=logger.can_save(),
408 | **checkpointer_config,
409 | )
410 |
411 | checkpoint_queue.append(step)
412 |
413 | if logger.can_save():
414 | print('saved.')
415 |
416 | if save_initial_checkpoint:
417 | _save_checkpoint(train_state, 0)
418 |
419 | rng = jax.random.PRNGKey(0)
420 |
421 | rng, eval_iterable_rng = jax.random.split(rng)
422 | rng, subrng = jax.random.split(rng)
423 | train_iterable = data_iterable(train_examples, subrng, bsize, shuffle=shuffle_train_data, loop=True)
424 | for step, train_batch in tqdm(itertools.islice(enumerate(train_iterable), num_train_steps), total=num_train_steps):
425 | rng, subrng = jax.random.split(rng)
426 | train_state, metrics = train_step(train_state, subrng, train_batch)
427 | if log_freq > 0 and ((step+1) % log_freq == 0 or (log_initial_step and step == 0)):
428 | if num_eval_steps > 0:
429 | eval_metric_list = []
430 | eval_iterable = data_iterable(eval_examples, eval_iterable_rng, eval_bsize, shuffle=True, loop=False)
431 | for eval_batch in itertools.islice(eval_iterable, num_eval_steps):
432 | eval_metric_list.append(eval_step(train_state.params, eval_batch))
433 | metrics.update(average_metrics(jax.device_get(eval_metric_list)))
434 | log_metrics = {"step": step+1}
435 | log_metrics.update(metrics)
436 | log_metrics = jax.device_get(log_metrics)
437 | logger.log(log_metrics)
438 | print(log_metrics)
439 |
440 | if save_model_freq > 0 and (step+1) % save_model_freq == 0:
441 | _save_checkpoint(train_state, step+1)
442 |
443 | if save_model_freq > 0 and (num_train_steps not in checkpoint_queue):
444 | _save_checkpoint(train_state, num_train_steps)
445 |
446 | jax_distributed_barrier()
447 | logger.finish()
448 | jax_distributed_barrier()
449 |
450 | # Only have the first worker push to hub to avoid conflicts
451 | if jax.process_index() == 0 and logger.can_save():
452 | import shutil
453 | print("First worker copying final checkpoint to hub...")
454 |
455 | # Create temp directory for checkpoint
456 | temp_dir = tempfile.mkdtemp()
457 | final_ckpt_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{num_train_steps}')
458 |
459 | # Copy checkpoint files to temp dir
460 | if final_ckpt_path.startswith('gcs://'):
461 | with open_with_bucket(os.path.join(final_ckpt_path, 'params.msgpack'), 'rb') as f:
462 | with open(os.path.join(temp_dir, 'params.msgpack'), 'wb') as f_out:
463 | f_out.write(f.read())
464 | with open_with_bucket(os.path.join(final_ckpt_path, 'config.json'), 'rb') as f:
465 | with open(os.path.join(temp_dir, 'config.json'), 'wb') as f_out:
466 | f_out.write(f.read())
467 | else:
468 | shutil.copy2(os.path.join(final_ckpt_path, 'params.msgpack'), temp_dir)
469 | shutil.copy2(os.path.join(final_ckpt_path, 'config.json'), temp_dir)
470 |
471 | # Push to hub
472 | try:
473 | from huggingface_hub import HfApi
474 | api = HfApi()
475 | repo_type = "model"
476 | repo_name = hf_repo_id
477 |
478 | api.create_repo(repo_name, repo_type=repo_type, private=False, exist_ok=True)
479 | api.upload_folder(
480 | folder_path=temp_dir,
481 | repo_id=repo_name,
482 | repo_type=repo_type
483 | )
484 | print("Successfully pushed checkpoint to hub")
485 | except Exception as e:
486 | print(f"Error pushing to hub: {e}")
487 | finally:
488 | # Cleanup temp directory
489 | shutil.rmtree(temp_dir)
490 | if __name__ == "__main__":
491 | tyro.cli(main)
492 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/llama_train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/supervised-jax/llama3_train/llama_train/__init__.py
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/llama_train/optimizer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Tuple, NamedTuple, Dict, Any, Optional, Callable
3 | import optax
4 | import jax.numpy as jnp
5 | import jax
6 |
7 | # Adapted from: https://github.com/young-geng/EasyLM/blob/main/EasyLM/optimizers.py
8 |
9 | def warmup_linear_decay_schedule(
10 | init_value: float,
11 | peak_value: float,
12 | warmup_steps: int,
13 | decay_steps: int,
14 | end_value: float = 0.0,
15 | ) -> optax.Schedule:
16 | """Linear warmup followed by linear decay.
17 |
18 | Args:
19 | init_value: Initial value for the scalar to be annealed.
20 | peak_value: Peak value for scalar to be annealed at end of warmup.
21 | warmup_steps: Positive integer, the length of the linear warmup.
22 | decay_steps: Positive integer, the total length of the schedule. Note that
23 | this includes the warmup time, so the number of steps during which cosine
24 | annealing is applied is `decay_steps - warmup_steps`.
25 | end_value: End value of the scalar to be annealed.
26 | Returns:
27 | schedule: A function that maps step counts to values.
28 | """
29 | schedules = [
30 | optax.linear_schedule(
31 | init_value=init_value,
32 | end_value=peak_value,
33 | transition_steps=warmup_steps),
34 | optax.linear_schedule(
35 | init_value=peak_value,
36 | end_value=end_value,
37 | transition_steps=decay_steps - warmup_steps)]
38 | return optax.join_schedules(schedules, [warmup_steps])
39 |
40 | schedule_by_name = dict(
41 | cos=optax.warmup_cosine_decay_schedule,
42 | linear=warmup_linear_decay_schedule,
43 | )
44 |
45 | def load_adamw_optimizer(
46 | init_lr: float=0.0,
47 | end_lr: float=3e-5,
48 | lr: float=3e-4,
49 | lr_warmup_steps: int=3000,
50 | lr_decay_steps: int=300000,
51 | b1: float=0.9,
52 | b2: float=0.95,
53 | clip_gradient: float=1.0,
54 | weight_decay: float=0.1,
55 | bf16_momentum: bool=False,
56 | multiply_by_parameter_scale: bool=False,
57 | weight_decay_mask: Optional[Callable]=None,
58 | schedule: str='cos',
59 | ) -> Tuple[optax.GradientTransformation, Dict[str, Any]]:
60 | learning_rate_schedule = schedule_by_name[schedule](
61 | init_value=init_lr,
62 | peak_value=lr,
63 | warmup_steps=lr_warmup_steps,
64 | decay_steps=lr_decay_steps,
65 | end_value=end_lr,
66 | )
67 |
68 | optimizer_info = dict(
69 | learning_rate_schedule=learning_rate_schedule,
70 | )
71 |
72 | if multiply_by_parameter_scale:
73 | optimizer = optax.chain(
74 | optax.clip_by_global_norm(clip_gradient),
75 | optax.adafactor(
76 | learning_rate=learning_rate_schedule,
77 | multiply_by_parameter_scale=True,
78 | momentum=b1,
79 | decay_rate=b2,
80 | factored=False,
81 | clipping_threshold=None,
82 | dtype_momentum=jnp.bfloat16 if bf16_momentum else jnp.float32,
83 | ),
84 | optax_add_scheduled_weight_decay(
85 | lambda step: -learning_rate_schedule(step) * weight_decay,
86 | weight_decay_mask,
87 | ),
88 | )
89 | else:
90 | optimizer = optax.chain(
91 | optax.clip_by_global_norm(clip_gradient),
92 | optax.adamw(
93 | learning_rate=learning_rate_schedule,
94 | weight_decay=weight_decay,
95 | b1=b1,
96 | b2=b2,
97 | mask=weight_decay_mask,
98 | mu_dtype=jnp.bfloat16 if bf16_momentum else jnp.float32,
99 | ),
100 | )
101 |
102 | return optimizer, optimizer_info
103 |
104 | def load_palm_optimizer(
105 | lr: float=0.01,
106 | lr_warmup_steps: int=10000,
107 | b1: float=0.9,
108 | b2: float=0.99,
109 | clip_gradient: float=1.0,
110 | weight_decay: float=1e-4,
111 | bf16_momentum: bool=False,
112 | weight_decay_mask: Optional[Callable]=None,
113 | ) -> Tuple[optax.GradientTransformation, Dict[str, Any]]:
114 | def learning_rate_schedule(step):
115 | multiplier = lr / 0.01
116 | return multiplier / jnp.sqrt(jnp.maximum(step, lr_warmup_steps))
117 |
118 | def weight_decay_schedule(step):
119 | multiplier = weight_decay / 1e-4
120 | return -multiplier * jnp.square(learning_rate_schedule(step))
121 |
122 | optimizer_info = dict(
123 | learning_rate_schedule=learning_rate_schedule,
124 | weight_decay_schedule=weight_decay_schedule,
125 | )
126 |
127 | optimizer = optax.chain(
128 | optax.clip_by_global_norm(clip_gradient),
129 | optax.adafactor(
130 | learning_rate=learning_rate_schedule,
131 | multiply_by_parameter_scale=True,
132 | momentum=b1,
133 | decay_rate=b2,
134 | factored=False,
135 | clipping_threshold=None,
136 | dtype_momentum=jnp.bfloat16 if bf16_momentum else jnp.float32,
137 | ),
138 | optax_add_scheduled_weight_decay(
139 | weight_decay_schedule, weight_decay_mask,
140 | ),
141 | )
142 |
143 | return optimizer, optimizer_info
144 |
145 | class OptaxScheduledWeightDecayState(NamedTuple):
146 | count: jnp.ndarray
147 |
148 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
149 | """ Apply weight decay with schedule. """
150 |
151 | def init_fn(params):
152 | del params
153 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
154 |
155 | def update_fn(updates, state, params):
156 | if params is None:
157 | raise ValueError('Params cannot be None for weight decay!')
158 |
159 | weight_decay = schedule_fn(state.count)
160 | updates = jax.tree_util.tree_map(
161 | lambda g, p: g + weight_decay * p, updates, params
162 | )
163 | return updates, OptaxScheduledWeightDecayState(
164 | count=optax.safe_int32_increment(state.count)
165 | )
166 |
167 | if mask is not None:
168 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
169 | return optax.GradientTransformation(init_fn, update_fn)
170 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/llama_train/serve.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, Any, Optional
2 | import redis
3 | import time
4 | import pickle as pkl
5 | import multiprocessing as mp
6 | from functools import partial
7 | import six
8 | import json
9 | from collections import OrderedDict
10 | # config for server
11 |
12 | class Config:
13 | redis_host = 'localhost'
14 | redis_port = 6379
15 | redis_db = 0
16 | client_refresh_delay = 0.01
17 | self_indicator = '__self__'
18 | init_message = '__init_message__'
19 | sse_channel_prefix = "__sse_channel__"
20 | sse_exit_type = "__EXIT__"
21 |
22 | """
23 | =====
24 | Below is the code for running a class on a seperate process.
25 | Each call to a method on the class is executed in a queue.
26 | You want to do this when serving models to process 1 request at a time.
27 | =====
28 | """
29 |
30 | def serve_class(cls):
31 | cache_cls = pkl.dumps(cls)
32 |
33 | class WrappedModel:
34 | def __init__(self, *args, **kwargs):
35 | self.r = redis.Redis(host=Config.redis_host, port=Config.redis_port, db=Config.redis_db)
36 | self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
37 |
38 | def __getattribute__(self, name):
39 | return partial(build_method(name, super().__getattribute__('r'), super().__getattribute__('Q')), self)
40 |
41 | def __call__(self, *args, **kwargs):
42 | return build_method('__call__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, *args, **kwargs)
43 |
44 | def __getitem__(self, key):
45 | return build_method('__getitem__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, key)
46 |
47 | def __setitem__(self, key, value):
48 | return build_method('__setitem__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, key, value)
49 |
50 | def __contains__(self, key):
51 | return build_method('__contains__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, key)
52 |
53 | def __len__(self):
54 | return build_method('__len__', super().__getattribute__('r'), super().__getattribute__('Q'))(self)
55 |
56 | return WrappedModel
57 |
58 | def build_method(method, r, Q):
59 | def call_method(self, *args, **kwargs):
60 | request_id = int(r.incr('request_id_counter'))
61 | Q.put((request_id, method, args, kwargs,))
62 | while not r.exists(f'result_{request_id}'):
63 | time.sleep(Config.client_refresh_delay)
64 | result = pkl.loads(r.get(f'result_{request_id}'))
65 | r.delete(f'result_{request_id}')
66 | if result == Config.self_indicator:
67 | return self
68 | return result
69 | return call_method
70 |
71 | def server_process(Q, cls_pkl, args, kwargs):
72 | r = redis.Redis(host=Config.redis_host, port=Config.redis_port, db=Config.redis_db)
73 | model = pkl.loads(cls_pkl)(*args, **kwargs)
74 | while True:
75 | try:
76 | request_id, method, args, kwargs = Q.get()
77 | if method == Config.init_message:
78 | r.set(f'result_{request_id}', pkl.dumps(method))
79 | continue
80 | result = getattr(model, method)(*args, **kwargs)
81 | if isinstance(result, Generator):
82 | result = tuple(result)
83 | if result == model:
84 | result = Config.self_indicator
85 | r.set(f'result_{request_id}', pkl.dumps(result))
86 | except EOFError:
87 | return
88 | except Exception as e:
89 | raise Exception
90 |
91 | def initalize_server(self, r, cls_pkl, args, kwargs):
92 | Q = mp.Manager().Queue()
93 | p = mp.Process(target=server_process, args=(Q, cls_pkl, args, kwargs))
94 | p.start()
95 | build_method(Config.init_message, r, Q)(self)
96 | return Q
97 |
98 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/llama_train/splash.py:
--------------------------------------------------------------------------------
1 | # adapted from: https://github.com/stanford-crfm/levanter/blob/main/src/levanter/models/attention.py
2 | from typing import Optional
3 | import jax.numpy as jnp
4 | import warnings
5 | import jax
6 | from jax.experimental.shard_map import shard_map
7 | from jax.sharding import PartitionSpec as PS
8 | from scalax.sharding import MeshShardingHelper
9 | import functools
10 |
11 | # CF https://github.com/google/maxtext/blob/db31dd4b0b686bca4cd7cf940917ec372faa183a/MaxText/layers/attentions.py#L179
12 | def _tpu_splash_attention(
13 | query: jnp.ndarray,
14 | key: jnp.ndarray,
15 | value: jnp.ndarray,
16 | attention_mask: jnp.ndarray,
17 | dropout: float = 0.0,
18 | *,
19 | attention_dtype: Optional[jnp.dtype] = None,
20 | block_size: Optional[int] = None,
21 | ) -> Optional[jnp.ndarray]:
22 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask
23 |
24 | # Splash attention requires BHSD format
25 | # We need to reshape the input to match this format
26 | if dropout != 0.0:
27 | raise NotImplementedError("Splash attention does not support dropout")
28 |
29 | if attention_dtype is not None and attention_dtype != jnp.float32:
30 | warnings.warn("Splash attention only supports float32. Switching to float32.")
31 |
32 | attention_dtype = jnp.float32
33 |
34 | B, Sq, Hq, D = query.shape
35 | Bk, Sk, Hk, Dk = key.shape
36 |
37 | # pre-divide q_ by sqrt(d) to match the reference implementation
38 | query = query / jnp.sqrt(D)
39 |
40 | # number
41 | if Sk % 128 != 0:
42 | raise NotImplementedError(f"Splash attention requires KPos to be a multiple of 128, got {Sk}")
43 |
44 | if block_size is not None and block_size % 128 != 0:
45 | raise NotImplementedError(f"Splash attention requires block_size to be a multiple of 128, got {block_size}")
46 |
47 | # TODO: must Dk == Dv?
48 | if key.shape != value.shape:
49 | raise ValueError("k and v must have the same axes")
50 |
51 | # TODO: this isn't really necessary on TPU?
52 | if B != Bk:
53 | raise ValueError(f"Batch axes must be the same for q, k, and v: {B} != {Bk}")
54 |
55 | if D != Dk:
56 | raise ValueError(f"Embedding axes must be the same for q, k, and v: {D} != {Dk}")
57 |
58 | # MaxText uses a block size of 512
59 | block_size = block_size or 512
60 |
61 | # copied from MaxText
62 | @functools.partial(
63 | shard_map,
64 | mesh=MeshShardingHelper.get_global_mesh(),
65 | in_specs=(
66 | PS(("dp", "fsdp"), "mp", None, None),
67 | PS(("dp", "fsdp"), "mp", None, None),
68 | PS(("dp", "fsdp"), "mp", None, None),
69 | PS(("dp", "fsdp"), None),
70 | ),
71 | out_specs=PS(("dp", "fsdp"), "mp", None, None),
72 | check_rep=False,
73 | )
74 | def wrap_flash_attention(q, k, v, attention_mask):
75 | block_sizes = splash_attention_kernel.BlockSizes(
76 | block_q=min(block_size, Sq),
77 | block_kv_compute=min(block_size, Sk),
78 | block_kv=min(block_size, Sk),
79 | block_q_dkv=min(block_size, Sq),
80 | block_kv_dkv=min(block_size, Sk),
81 | block_kv_dkv_compute=min(block_size, Sq),
82 | block_q_dq=min(block_size, Sq),
83 | block_kv_dq=min(block_size, Sq),
84 | )
85 |
86 | segment_ids = splash_attention_kernel.SegmentIds(
87 | q=attention_mask,
88 | kv=attention_mask,
89 | )
90 |
91 | kernel_mask = splash_attention_mask.MultiHeadMask(
92 | [splash_attention_mask.CausalMask((Sq, Sq)) for _ in range(Hq)],
93 | )
94 |
95 | # copied from MaxText
96 | splash_kernel = splash_attention_kernel.make_splash_mha(
97 | mask=kernel_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
98 | )
99 |
100 | q = q.astype(attention_dtype)
101 | k = k.astype(attention_dtype)
102 | v = v.astype(attention_dtype)
103 | return jax.vmap(splash_kernel)(q, k, v, segment_ids=segment_ids)
104 |
105 | query = query.transpose(0, 2, 1, 3)
106 | key = key.transpose(0, 2, 1, 3)
107 | value = value.transpose(0, 2, 1, 3)
108 | attn_output = wrap_flash_attention(query, key, value, attention_mask)
109 | attn_output = attn_output.transpose(0, 2, 1, 3)
110 | return attn_output
111 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/llama_train/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any, List
2 | import os
3 | import gcsfs
4 | import jax
5 | from flax.serialization import from_bytes, to_state_dict, from_state_dict, to_bytes
6 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
7 | import msgpack
8 | import jax.numpy as jnp
9 | from functools import partial
10 | import re
11 | from optax import softmax_cross_entropy_with_integer_labels
12 | import wandb
13 | import uuid
14 | from socket import gethostname
15 | import tempfile
16 | from jax import lax
17 |
18 | GCLOUD_TOKEN_PATH = os.environ.get('GCLOUD_TOKEN_PATH', None)
19 | GCLOUD_PROJECT = os.environ.get('GCLOUD_PROJECT', None)
20 |
21 | def open_with_bucket(
22 | path: Any,
23 | mode: str="rb",
24 | gcloud_project: Optional[str]=None,
25 | gcloud_token: Optional[Any]=None,
26 | **kwargs,
27 | ):
28 | # backup to env vars if None
29 | if gcloud_project is None:
30 | gcloud_project = GCLOUD_PROJECT
31 | if gcloud_token is None:
32 | gcloud_token = GCLOUD_TOKEN_PATH
33 | # load from google cloud storage if starts with "gcs://"
34 | if path.startswith('gcs://'):
35 | f = gcsfs.GCSFileSystem(project=gcloud_project, token=gcloud_token).open(path[len('gcs://'):], mode=mode, **kwargs)
36 | else:
37 | f = open(path, mode=mode, **kwargs)
38 | return f
39 |
40 | def delete_with_bucket(
41 | path: str,
42 | recursive: bool=True,
43 | gcloud_project: Optional[str]=None,
44 | gcloud_token: Optional[Any]=None,
45 | ) -> None:
46 | # backup to env vars if None
47 | if gcloud_project is None:
48 | gcloud_project = GCLOUD_PROJECT
49 | if gcloud_token is None:
50 | gcloud_token = GCLOUD_TOKEN_PATH
51 | # delete from google cloud storage if starts with "gcs://"
52 | if path.startswith('gcs://'):
53 | path = path[len('gcs://'):]
54 | gcsfs.GCSFileSystem(project=gcloud_project, token=gcloud_token).rm(path, recursive=recursive)
55 | else:
56 | os.system(f"rm -{'r' if recursive else ''}f {path}")
57 |
58 | def get_gradient_checkpoint_policy(name):
59 | return {
60 | 'everything_saveable': jax.checkpoint_policies.everything_saveable,
61 | 'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
62 | 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
63 | 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
64 | }[name]
65 |
66 | def get_float_dtype_by_name(dtype):
67 | return {
68 | 'bf16': jnp.bfloat16,
69 | 'bfloat16': jnp.bfloat16,
70 | 'fp16': jnp.float16,
71 | 'float16': jnp.float16,
72 | 'fp32': jnp.float32,
73 | 'float32': jnp.float32,
74 | 'fp64': jnp.float64,
75 | 'float64': jnp.float64,
76 | }[dtype]
77 |
78 |
79 | def float_tensor_to_dtype(tensor, dtype):
80 | if dtype is None or dtype == '':
81 | return tensor
82 | if isinstance(dtype, str):
83 | dtype = get_float_dtype_by_name(dtype)
84 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
85 | if getattr(tensor, 'dtype', None) in float_dtypes:
86 | tensor = tensor.astype(dtype)
87 | return tensor
88 |
89 |
90 | def float_to_dtype(tree, dtype):
91 | return jax.tree_util.tree_map(
92 | partial(float_tensor_to_dtype, dtype=dtype), tree
93 | )
94 |
95 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None, convert_to_dtypes=None):
96 | if shard_fns is not None:
97 | shard_fns = flatten_dict(
98 | to_state_dict(shard_fns)
99 | )
100 | if convert_to_dtypes is not None:
101 | convert_to_dtypes = flatten_dict(
102 | to_state_dict(convert_to_dtypes)
103 | )
104 | if remove_dict_prefix is not None:
105 | remove_dict_prefix = tuple(remove_dict_prefix)
106 | flattend_train_state = {}
107 | with open_with_bucket(path, 'rb') as fin:
108 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS
109 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
110 | for key, value in unpacker:
111 | key = tuple(key)
112 | if remove_dict_prefix is not None:
113 | if key[:len(remove_dict_prefix)] == remove_dict_prefix:
114 | key = key[len(remove_dict_prefix):]
115 | else:
116 | continue
117 |
118 | tensor = from_bytes(None, value)
119 | if convert_to_dtypes is not None:
120 | tensor = float_tensor_to_dtype(tensor, convert_to_dtypes[key])
121 | if shard_fns is not None:
122 | tensor = shard_fns[key](tensor)
123 | flattend_train_state[key] = tensor
124 |
125 | if target is not None:
126 | flattened_target = flatten_dict(
127 | to_state_dict(target), keep_empty_nodes=True
128 | )
129 | for key, value in flattened_target.items():
130 | if key not in flattend_train_state and value == empty_node:
131 | flattend_train_state[key] = value
132 |
133 | train_state = unflatten_dict(flattend_train_state)
134 | if target is None:
135 | return train_state
136 |
137 | return from_state_dict(target, train_state)
138 |
139 | def save_checkpoint(train_state, path, gather_fns=None, float_dtype=None):
140 | train_state = to_state_dict(train_state)
141 | packer = msgpack.Packer()
142 | flattend_train_state = flatten_dict(train_state)
143 | if gather_fns is not None:
144 | gather_fns = flatten_dict(to_state_dict(gather_fns))
145 |
146 | with open_with_bucket(path, "wb") as fout:
147 | for key, value in flattend_train_state.items():
148 | if gather_fns is not None:
149 | value = gather_fns[key](value)
150 | value = float_tensor_to_dtype(value, float_dtype)
151 | fout.write(packer.pack((key, to_bytes(value))))
152 |
153 | def tree_path_to_string(path, sep=None):
154 | keys = []
155 | for key in path:
156 | if isinstance(key, jax.tree_util.SequenceKey):
157 | keys.append(str(key.idx))
158 | elif isinstance(key, jax.tree_util.DictKey):
159 | keys.append(str(key.key))
160 | elif isinstance(key, jax.tree_util.GetAttrKey):
161 | keys.append(str(key.name))
162 | elif isinstance(key, jax.tree_util.FlattenedIndexKey):
163 | keys.append(str(key.key))
164 | else:
165 | keys.append(str(key))
166 | if sep is None:
167 | return tuple(keys)
168 | return sep.join(keys)
169 |
170 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
171 | """ An extended version of jax.tree_util.tree_map, where the mapped function
172 | f takes both the name (path) and the tree leaf as input.
173 | """
174 | return jax.tree_util.tree_map_with_path(
175 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
176 | tree, *rest,
177 | is_leaf=is_leaf
178 | )
179 |
180 | def get_weight_decay_mask(exclusions):
181 | """ Return a weight decay mask function that computes the pytree masks
182 | according to the given exclusion rules.
183 | """
184 | def decay(name, _):
185 | for rule in exclusions:
186 | if re.search(rule, name) is not None:
187 | return False
188 | return True
189 |
190 | def weight_decay_mask(params):
191 | return named_tree_map(decay, params, sep='/')
192 |
193 | return weight_decay_mask
194 |
195 | def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
196 | if valid is None:
197 | valid = jnp.ones(tokens.shape[:2])
198 | valid = valid.astype(jnp.float32)
199 | logits = logits.astype(jnp.float32) # for numerical stability
200 | token_loss = softmax_cross_entropy_with_integer_labels(logits, tokens)
201 | loss = jnp.mean(token_loss, where=valid > 0.0)
202 | accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == tokens, where=valid > 0.0)
203 | return loss, accuracy
204 |
205 | def global_norm(tree):
206 | """ Return the global L2 norm of a pytree. """
207 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
208 | flattened, _ = jax.flatten_util.ravel_pytree(squared)
209 | return jnp.sqrt(jnp.sum(flattened))
210 |
211 | def average_metrics(metrics):
212 | return jax.tree_map(
213 | lambda *args: jnp.mean(jnp.stack(args)),
214 | *metrics
215 | )
216 |
217 | def flatten_config_dict(config, prefix=None):
218 | output = {}
219 | for key, val in config.items():
220 | if isinstance(val, dict):
221 | output.update(flatten_config_dict(val, prefix=key))
222 | else:
223 | if prefix is not None:
224 | output["{}.{}".format(prefix, key)] = val
225 | else:
226 | output[key] = val
227 | return output
228 |
229 | # logger adapted from mlxu: https://github.com/young-geng/mlxu/blob/main/mlxu/logging.py
230 |
231 | class WandbLogger:
232 | def __init__(
233 | self,
234 | project: str,
235 | output_dir: Optional[str]=None,
236 | config_to_log: Optional[Any]=None,
237 | enable: bool=True,
238 | online: bool=False,
239 | prefix: Optional[str]=None,
240 | experiment_id: Optional[str]=None,
241 | wandb_dir: Optional[str]=None,
242 | notes: Optional[str]=None,
243 | entity: Optional[str]=None,
244 | prefix_to_id: bool=False,
245 | ):
246 | self.enable = enable
247 | self.notes = notes
248 | self.entity = entity
249 | self.project = project
250 | self.online = online
251 | self.experiment_id = experiment_id
252 |
253 | if self.experiment_id is None:
254 | self.experiment_id = f'{uuid.uuid4().hex}-{uuid.uuid1().hex}'
255 | if prefix is not None:
256 | if prefix_to_id:
257 | self.experiment_id = f"{prefix}--{self.experiment_id}"
258 | else:
259 | self.project = f"{prefix}--{self.project}"
260 |
261 | self.wandb_dir = wandb_dir
262 | self.output_dir = output_dir
263 | if self.enable:
264 | if self.output_dir is not None:
265 | self.output_dir = os.path.join(self.output_dir, self.experiment_id)
266 | if not self.output_dir.startswith('gcs://'):
267 | os.makedirs(self.output_dir, exist_ok=True)
268 | if self.wandb_dir is None:
269 | if (self.output_dir is not None) and (not self.output_dir.startswith('gcs://')):
270 | self.wandb_dir = self.output_dir
271 | else:
272 | assert not self.wandb_dir.startswith('gcs://')
273 | self.wandb_dir = os.path.join(self.wandb_dir, self.experiment_id)
274 | os.makedirs(self.wandb_dir, exist_ok=True)
275 |
276 | if config_to_log is not None:
277 | self.config_to_log = flatten_config_dict(config_to_log)
278 | if "hostname" not in self.config_to_log:
279 | self.config_to_log["hostname"] = gethostname()
280 | if "experiment_id" not in self.config_to_log:
281 | self.config_to_log["experiment_id"] = self.experiment_id
282 | if "logger_output_dir" not in self.config_to_log:
283 | self.config_to_log["logger_output_dir"] = self.output_dir
284 | if "wandb_dir" not in self.config_to_log:
285 | self.config_to_log["wandb_dir"] = self.wandb_dir
286 | else:
287 | self.config_to_log = None
288 |
289 | if self.enable:
290 | self.run = wandb.init(
291 | reinit=True,
292 | config=self.config_to_log,
293 | project=self.project,
294 | dir=self.wandb_dir,
295 | id=self.experiment_id,
296 | notes=self.notes,
297 | entity=self.entity,
298 | settings=wandb.Settings(
299 | start_method="thread",
300 | _disable_stats=True,
301 | ),
302 | mode="online" if self.online else "offline",
303 | )
304 | else:
305 | self.run = None
306 |
307 | def log(self, *args, **kwargs):
308 | if self.enable:
309 | self.run.log(*args, **kwargs)
310 |
311 | def finish(self):
312 | if self.enable:
313 | wandb.finish()
314 |
315 | def can_save(self) -> bool:
316 | return self.enable and (self.output_dir is not None)
317 |
318 | def jax_distributed_initalize(
319 | initialize_jax_distributed: bool=False,
320 | local_device_ids: Optional[List[int]]=None,
321 | coordinator_address: Optional[str]=None,
322 | num_processes: Optional[int]=None,
323 | process_id: Optional[int]=None,
324 | ):
325 | if initialize_jax_distributed:
326 | if local_device_ids is not None:
327 | local_device_ids = [int(x) for x in local_device_ids.split(',')]
328 | else:
329 | local_device_ids = None
330 |
331 | jax.distributed.initialize(
332 | coordinator_address=coordinator_address,
333 | num_processes=num_processes,
334 | process_id=process_id,
335 | local_device_ids=local_device_ids,
336 | )
337 |
338 | def jax_distributed_barrier():
339 | # Dummy function that all processes run
340 | def computation(x):
341 | result = x * x
342 | return result
343 |
344 | @partial(jax.pmap, axis_name='i')
345 | def sync_barrier(x):
346 | # Perform a trivial collective operation, acting as a barrier
347 | c = lax.psum(x, axis_name='i')
348 | return computation(x) + computation(c)
349 |
350 | # Dummy input
351 | x = jnp.ones((jax.local_device_count(),))
352 |
353 | # Run the barrier + computation
354 | results = sync_barrier(x)
355 |
356 | jax.block_until_ready(results)
357 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/llama_train_script.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Any, Optional
2 | import json
3 | from functools import partial
4 | import os
5 | import collections
6 | import tempfile
7 |
8 | import tyro
9 | import jax
10 | from scalax.sharding import MeshShardingHelper, TreePathShardingRule
11 | from flax.training.train_state import TrainState
12 | import numpy as np
13 | import itertools
14 | import pickle as pkl
15 | from transformers import AutoTokenizer
16 | from tqdm.auto import tqdm
17 | import optax
18 | from jax.sharding import PartitionSpec as PS
19 |
20 | from llama_train.utils import (
21 | load_checkpoint, open_with_bucket, save_checkpoint,
22 | get_float_dtype_by_name, get_weight_decay_mask,
23 | cross_entropy_loss_and_accuracy, global_norm,
24 | average_metrics, WandbLogger, delete_with_bucket,
25 | jax_distributed_initalize, jax_distributed_barrier
26 | )
27 | from llama_train.llama3 import (
28 | LLaMAConfig, LLAMA_STANDARD_CONFIGS,
29 | FlaxLLaMAForCausalLM, download_openllama_easylm,
30 | )
31 | from llama_train.optimizer import load_adamw_optimizer, load_palm_optimizer
32 |
33 | def process_pretrain_example(
34 | seq: str,
35 | max_length: int,
36 | tokenizer: AutoTokenizer,
37 | ):
38 | tokenization = tokenizer(
39 | [tokenizer.bos_token+seq+tokenizer.eos_token],
40 | padding='max_length',
41 | truncation=True,
42 | max_length=max_length+1,
43 | return_tensors='np',
44 | )
45 |
46 | input_ids = tokenization.input_ids[:, :-1]
47 | target_ids = tokenization.input_ids[:, 1:]
48 | attention_mask = tokenization.attention_mask[:, :-1]
49 | position_ids = np.maximum(np.cumsum(attention_mask, axis=-1) - 1, 0)
50 | loss_masks = attention_mask
51 |
52 | batch_items = dict(
53 | input_ids=input_ids,
54 | attention_mask=attention_mask,
55 | position_ids=position_ids,
56 | target_ids=target_ids,
57 | loss_masks=loss_masks,
58 | )
59 | return batch_items
60 |
61 |
62 | def checkpointer(
63 | path: str,
64 | train_state: Any,
65 | config: Any,
66 | gather_fns: Any,
67 | metadata: Any=None,
68 | save_optimizer_state: bool=False,
69 | save_float_dtype: str='bf16',
70 | active=True,
71 | ):
72 | if not path.startswith('gcs://'):
73 | os.makedirs(path, exist_ok=True)
74 | if save_optimizer_state:
75 | checkpoint_state = train_state
76 | if not active:
77 | checkpoint_path = '/dev/null'
78 | else:
79 | checkpoint_path = os.path.join(path, 'train_state.msgpack')
80 | checkpoint_gather_fns = gather_fns
81 | else:
82 | checkpoint_state = train_state.params
83 | if not active:
84 | checkpoint_path = '/dev/null'
85 | else:
86 | checkpoint_path = os.path.join(path, 'params.msgpack')
87 | checkpoint_gather_fns = gather_fns.params
88 | metadata_path = os.path.join(path, 'metadata.pkl')
89 | config_path = os.path.join(path, 'config.json')
90 |
91 | save_checkpoint(
92 | checkpoint_state,
93 | checkpoint_path,
94 | gather_fns=checkpoint_gather_fns,
95 | float_dtype=save_float_dtype,
96 | )
97 | if active:
98 | with open_with_bucket(metadata_path, 'wb') as f:
99 | pkl.dump(metadata, f)
100 | with open_with_bucket(config_path, 'w') as f:
101 | json.dump(config, f)
102 |
103 | def main(
104 | load_model: str,
105 | train_data_path: str,
106 | eval_data_path: str,
107 | output_dir: Optional[str],
108 | sharding: str,
109 | num_train_steps: int,
110 | max_length: int,
111 | bsize: int,
112 | log_freq: int,
113 | num_eval_steps: int,
114 | save_model_freq: int,
115 | wandb_project: str,
116 | param_dtype: str='fp32',
117 | activation_dtype: str='fp32',
118 | optim_config: str='adamw:{}',
119 | logger_config: str='{}',
120 | checkpointer_config: str='{}',
121 | model_config_override: str='{}',
122 | inputs_tokenizer_override: str='{}',
123 | outputs_tokenizer_override: str='{}',
124 | jax_distributed_initalize_config: str='{}',
125 | save_initial_checkpoint: bool=False,
126 | log_initial_step: bool=True,
127 | max_checkpoints: Optional[int]=None,
128 | eval_bsize: Optional[int]=None,
129 | physical_axis_splitting: bool=False,
130 | shuffle_train_data: bool=True,
131 | hf_repo_id: str='LM-Parallel/sample',
132 | ):
133 | args_dict = dict(locals())
134 | print(args_dict)
135 | sharding: List[int] = list(map(lambda x: int(x.strip()), sharding.split(',')))
136 | if eval_bsize is None:
137 | eval_bsize = bsize
138 |
139 | param_dtype = get_float_dtype_by_name(param_dtype)
140 | activation_dtype = get_float_dtype_by_name(activation_dtype)
141 |
142 | logger_config: Dict[str, Any] = json.loads(logger_config)
143 | checkpointer_config: Dict[str, Any] = json.loads(checkpointer_config)
144 | model_config_override: Dict[str, Any] = json.loads(model_config_override)
145 | inputs_tokenizer_override: Dict[str, Any] = json.loads(inputs_tokenizer_override)
146 | outputs_tokenizer_override: Dict[str, Any] = json.loads(outputs_tokenizer_override)
147 | jax_distributed_initalize_config: Dict[str, Any] = json.loads(jax_distributed_initalize_config)
148 |
149 | jax_distributed_initalize(**jax_distributed_initalize_config)
150 | jax_distributed_barrier()
151 |
152 | if optim_config.startswith('adamw:'):
153 | optim_config = json.loads(optim_config[len('adamw:'):])
154 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple()))
155 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1)
156 | optimizer, optimizer_info = load_adamw_optimizer(**optim_config)
157 | elif optim_config.startswith('palm:'):
158 | optim_config = json.loads(optim_config[len('palm:'):])
159 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple()))
160 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1)
161 | optimizer, optimizer_info = load_palm_optimizer(**optim_config)
162 | else:
163 | raise ValueError(f'Unknown optimizer config: {optim_config}')
164 | if grad_accum_steps > 1:
165 | optimizer = optax.MultiSteps(
166 | optimizer,
167 | grad_accum_steps,
168 | )
169 |
170 | mesh = MeshShardingHelper(sharding, ['dp', 'fsdp', 'mp'], mesh_axis_splitting=physical_axis_splitting) # Create a 3D mesh with data, fsdp, and model parallelism axes
171 | with mesh.get_context():
172 | print('mesh:', mesh.mesh)
173 | print('loading model ...')
174 |
175 | if load_model.startswith('paths:'):
176 | model_paths = json.loads(load_model[len('paths:'):])
177 | if not ('remove_dict_prefix' in model_paths):
178 | model_paths['remove_dict_prefix'] = None
179 | else:
180 | raise ValueError(f'Unknown model info type: {load_model}')
181 |
182 | config_is_temp = False
183 | if 'config' in model_paths and model_paths['config'].startswith('gcs://'):
184 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False)
185 | with open_with_bucket(model_paths['config'], 'rb') as f:
186 | temp_file.write(f.read())
187 | temp_file.close()
188 | model_paths['config'] = temp_file.name
189 | config_is_temp = True
190 |
191 | if 'config' in model_paths:
192 | config = LLaMAConfig.from_pretrained(model_paths['config'], **model_config_override)
193 | elif 'default_config_name' in model_paths:
194 | config = LLaMAConfig(**LLAMA_STANDARD_CONFIGS[model_paths['default_config_name']], **model_config_override)
195 | else:
196 | config = LLaMAConfig(**model_config_override)
197 |
198 | if config_is_temp:
199 | os.remove(model_paths['config'])
200 |
201 | model = FlaxLLaMAForCausalLM(config, dtype=activation_dtype, _do_init=False, param_dtype=param_dtype, input_shape=(bsize, 1024))
202 | # TODO: embedding dim is hardcoded to 1024, it's
203 |
204 | tokenizer_is_temp = False
205 | if model_paths['tokenizer'].startswith('gcs://'):
206 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False)
207 | with open_with_bucket(model_paths['tokenizer'], 'rb') as f:
208 | temp_file.write(f.read())
209 | temp_file.close()
210 | model_paths['tokenizer'] = temp_file.name
211 | tokenizer_is_temp = True
212 |
213 | tokenizer_kwargs = dict(
214 | truncation_side='right',
215 | padding_side='right',
216 | )
217 | tokenizer_kwargs.update(outputs_tokenizer_override)
218 | tokenizer = AutoTokenizer.from_pretrained(model_paths['tokenizer'], **tokenizer_kwargs)
219 | tokenizer.add_special_tokens({'pad_token': tokenizer.convert_ids_to_tokens(config.pad_token_id)})
220 |
221 | if tokenizer_is_temp:
222 | os.remove(model_paths['tokenizer'])
223 |
224 | sharding_rules = TreePathShardingRule(*config.get_partition_rules())
225 |
226 | @partial(
227 | mesh.sjit,
228 | in_shardings=(sharding_rules,),
229 | out_shardings=sharding_rules,
230 | )
231 | def create_train_state_from_params(params):
232 | return TrainState.create(params=params, tx=optimizer, apply_fn=None)
233 |
234 | @partial(
235 | mesh.sjit,
236 | in_shardings=(PS(),),
237 | out_shardings=sharding_rules,
238 | )
239 | def init_fn(rng):
240 | params = model.init_weights(rng, (bsize, 1024))
241 | return create_train_state_from_params(params)
242 |
243 | train_state_shape = jax.eval_shape(lambda: init_fn(jax.random.PRNGKey(0)))
244 | shard_train_state_fns, gather_train_state_fns = mesh.make_shard_and_gather_fns(train_state_shape, sharding_rules)
245 |
246 | if 'params' in model_paths:
247 | train_state = create_train_state_from_params(load_checkpoint(
248 | model_paths['params'],
249 | shard_fns=shard_train_state_fns.params,
250 | remove_dict_prefix=model_paths['remove_dict_prefix'],
251 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape.params),
252 | ))
253 | elif 'train_state' in model_paths:
254 | train_state = load_checkpoint(
255 | model_paths['train_state'],
256 | shard_fns=shard_train_state_fns,
257 | remove_dict_prefix=model_paths['remove_dict_prefix'],
258 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape),
259 | )
260 | else:
261 | train_state = init_fn(jax.random.PRNGKey(0))
262 |
263 | print(model)
264 | print('model loaded.')
265 |
266 | @partial(
267 | mesh.sjit,
268 | in_shardings=(sharding_rules, PS(),PS()),
269 | out_shardings=(sharding_rules, PS()),
270 | args_sharding_constraint=(sharding_rules, None, PS(('dp', 'fsdp'))),
271 | donate_argnums=(0,),
272 | )
273 | def train_step(train_state, rng, batch):
274 | def loss_and_accuracy(params):
275 | logits = model(
276 | input_ids=batch['input_ids'],
277 | attention_mask=batch['attention_mask'],
278 | position_ids=batch['position_ids'],
279 | params=params,
280 | dropout_rng=rng,
281 | train=True,
282 | ).logits
283 | return cross_entropy_loss_and_accuracy(
284 | logits, batch['target_ids'], batch['loss_masks'],
285 | )
286 | # print("start training...")
287 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
288 | (loss, accuracy), grads = grad_fn(train_state.params)
289 | # print(f"loss: {loss}, accuracy: {accuracy}")
290 | train_state = train_state.apply_gradients(grads=grads)
291 | # print("gradients applied.")
292 | metrics = dict(
293 | loss=loss,
294 | accuracy=accuracy,
295 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
296 | gradient_norm=global_norm(grads),
297 | param_norm=global_norm(train_state.params),
298 | )
299 | return train_state, metrics
300 |
301 | @partial(
302 | mesh.sjit,
303 | in_shardings=(sharding_rules, PS()),
304 | out_shardings=PS(),
305 | args_sharding_constraint=(sharding_rules, PS(('dp', 'fsdp'))),
306 | )
307 | def eval_step(params, batch):
308 | logits = model(
309 | input_ids=batch['input_ids'],
310 | attention_mask=batch['attention_mask'],
311 | position_ids=batch['position_ids'],
312 | params=params,
313 | train=False,
314 | ).logits
315 | loss, accuracy = cross_entropy_loss_and_accuracy(
316 | logits, batch['target_ids'], batch['loss_masks'],
317 | )
318 | metrics = dict(
319 | eval_loss=loss,
320 | eval_accuracy=accuracy,
321 | )
322 | return metrics
323 |
324 | print('loading data ...')
325 | if train_data_path.endswith('jsonl'):
326 | train_examples = []
327 | with open_with_bucket(train_data_path, 'r') as f:
328 | for line in f:
329 | train_examples.append(json.loads(line.strip()))
330 | else:
331 | with open_with_bucket(train_data_path, 'r') as f:
332 | train_examples = json.load(f)
333 | if eval_data_path.endswith('jsonl'):
334 | eval_examples = []
335 | with open_with_bucket(eval_data_path, 'r') as f:
336 | for line in f:
337 | eval_examples.append(json.loads(line.strip()))
338 | else:
339 | with open_with_bucket(eval_data_path, 'r') as f:
340 | eval_examples = json.load(f)
341 | print('done.')
342 |
343 | def data_iterable(data_items, rng, bsize, shuffle=True, loop=True):
344 | while True:
345 | with jax.default_device(jax.devices('cpu')[0]):
346 | idxs = []
347 | for _ in range((bsize + (len(data_items) - 1)) // len(data_items)):
348 | if shuffle:
349 | rng, subrng = jax.random.split(rng)
350 | curr_idxs = jax.random.permutation(subrng, np.arange(len(data_items)))
351 | idxs.extend(curr_idxs.tolist())
352 | else:
353 | curr_idxs = np.arange(len(data_items))
354 | idxs.extend(curr_idxs.tolist())
355 | idxs = np.asarray(idxs)
356 | for batch_idx in range(len(idxs) // bsize):
357 | batch_idxs = idxs[batch_idx*bsize:(batch_idx+1)*bsize]
358 | batch_examples = [data_items[idx] for idx in batch_idxs]
359 | processed_batch_examples = []
360 | for example in batch_examples:
361 | processed_batch_examples.append(process_pretrain_example(
362 | example,
363 | max_length,
364 | tokenizer,
365 | ))
366 | batch = dict()
367 | for key in processed_batch_examples[0]:
368 | batch[key] = np.concatenate([example[key] for example in processed_batch_examples], axis=0)
369 | yield batch
370 | if not loop:
371 | break
372 |
373 | if 'enable' not in logger_config:
374 | logger_config['enable'] = (jax.process_index() == 0)
375 | if 'config_to_log' in logger_config:
376 | logger_config['config_to_log'].update(args_dict)
377 | else:
378 | logger_config['config_to_log'] = args_dict
379 | logger = WandbLogger(
380 | wandb_project,
381 | output_dir=output_dir,
382 | **logger_config,
383 | )
384 | print('wandb logger initialized.')
385 |
386 | checkpoint_queue = collections.deque()
387 |
388 | def _save_checkpoint(
389 | train_state,
390 | step,
391 | ):
392 | old_step = None
393 | if (max_checkpoints is not None) and (len(checkpoint_queue) >= max_checkpoints):
394 | old_step = checkpoint_queue.popleft()
395 | if logger.can_save():
396 | print(f'saving checkpoint at step {step} ...')
397 | # delete old checkpoint if max checkpoints is reached
398 | if old_step is not None:
399 | old_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{old_step}')
400 | delete_with_bucket(old_path, recursive=True)
401 |
402 | metadata = dict(
403 | step=step,
404 | args_dict=args_dict,
405 | )
406 |
407 | checkpointer(
408 | path=os.path.join(logger.output_dir, 'checkpoints', f'step_{step}'),
409 | train_state=train_state,
410 | config=config.to_dict(),
411 | gather_fns=gather_train_state_fns,
412 | metadata=metadata,
413 | active=logger.can_save(),
414 | **checkpointer_config,
415 | )
416 |
417 | checkpoint_queue.append(step)
418 |
419 | if logger.can_save():
420 | print('saved.')
421 |
422 | if save_initial_checkpoint:
423 | _save_checkpoint(train_state, 0)
424 |
425 | rng = jax.random.PRNGKey(0)
426 |
427 | rng, eval_iterable_rng = jax.random.split(rng)
428 | rng, subrng = jax.random.split(rng)
429 | train_iterable = data_iterable(train_examples, subrng, bsize, shuffle=shuffle_train_data, loop=True)
430 | for step, train_batch in tqdm(itertools.islice(enumerate(train_iterable), num_train_steps), total=num_train_steps):
431 | rng, subrng = jax.random.split(rng)
432 | train_state, metrics = train_step(train_state, subrng, train_batch)
433 | # print(f"step {step} metrics: {metrics}")
434 | if log_freq > 0 and ((step+1) % log_freq == 0 or (log_initial_step and step == 0)):
435 | if num_eval_steps > 0:
436 | eval_metric_list = []
437 | eval_iterable = data_iterable(eval_examples, eval_iterable_rng, eval_bsize, shuffle=True, loop=False)
438 | for eval_batch in itertools.islice(eval_iterable, num_eval_steps):
439 | eval_metric_list.append(eval_step(train_state.params, eval_batch))
440 | metrics.update(average_metrics(jax.device_get(eval_metric_list)))
441 | log_metrics = {"step": step+1}
442 | log_metrics.update(metrics)
443 | log_metrics = jax.device_get(log_metrics)
444 | logger.log(log_metrics)
445 | print(log_metrics)
446 |
447 | if save_model_freq > 0 and (step+1) % save_model_freq == 0:
448 | _save_checkpoint(train_state, step+1)
449 |
450 | if save_model_freq > 0 and (num_train_steps not in checkpoint_queue):
451 | _save_checkpoint(train_state, num_train_steps)
452 |
453 | jax_distributed_barrier()
454 | logger.finish()
455 | jax_distributed_barrier()
456 |
457 | # Only have the first worker push to hub to avoid conflicts
458 | if jax.process_index() == 0 and logger.can_save():
459 | import shutil
460 | print("First worker copying final checkpoint to hub...")
461 |
462 | # Create temp directory for checkpoint
463 | temp_dir = tempfile.mkdtemp()
464 | final_ckpt_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{num_train_steps}')
465 |
466 | # Copy checkpoint files to temp dir
467 | if final_ckpt_path.startswith('gcs://'):
468 | with open_with_bucket(os.path.join(final_ckpt_path, 'params.msgpack'), 'rb') as f:
469 | with open(os.path.join(temp_dir, 'params.msgpack'), 'wb') as f_out:
470 | f_out.write(f.read())
471 | with open_with_bucket(os.path.join(final_ckpt_path, 'config.json'), 'rb') as f:
472 | with open(os.path.join(temp_dir, 'config.json'), 'wb') as f_out:
473 | f_out.write(f.read())
474 | else:
475 | shutil.copy2(os.path.join(final_ckpt_path, 'params.msgpack'), temp_dir)
476 | shutil.copy2(os.path.join(final_ckpt_path, 'config.json'), temp_dir)
477 |
478 | # Push to hub
479 | try:
480 | from huggingface_hub import HfApi
481 | api = HfApi()
482 | repo_type = "model"
483 | repo_name = hf_repo_id
484 |
485 | api.create_repo(repo_name, repo_type=repo_type, private=False, exist_ok=True)
486 | api.upload_folder(
487 | folder_path=temp_dir,
488 | repo_id=repo_name,
489 | repo_type=repo_type
490 | )
491 | print("Successfully pushed checkpoint to hub")
492 | except Exception as e:
493 | print(f"Error pushing to hub: {e}")
494 | finally:
495 | # Cleanup temp directory
496 | shutil.rmtree(temp_dir)
497 | if __name__ == "__main__":
498 | tyro.cli(main)
499 |
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/requirements.txt:
--------------------------------------------------------------------------------
1 | gcsfs==2023.10.0
2 | jax==0.4.31
3 | transformers==4.47.1
4 | flax==0.10.2
5 | sentencepiece==0.1.99
6 | wget==3.2
7 | jaxtyping==0.2.23
8 | scalax @ git+https://github.com/young-geng/scalax.git@main
9 | tyro==0.8.11
10 | tqdm==4.66.1
11 | wandb==0.16.1
12 | einops==0.8.0
13 | numpy<2.0.0
14 | redis==4.3.4
15 | Flask==3.0.3
16 | flask-cors==5.0.0
--------------------------------------------------------------------------------
/supervised-jax/llama3_train/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import find_packages, setup
3 |
4 |
5 | def read_requirements_file(filename):
6 | req_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
7 | filename)
8 | with open(req_file_path) as f:
9 | return [line.strip() for line in f if line.strip() != '']
10 |
11 |
12 | setup(
13 | name='llama3_train',
14 | version='1.0.0',
15 | description='LLaMA-3 Train.',
16 | url='https://github.com/Sea-Snell/llama3_train',
17 | author='Charlie Snell',
18 | packages=find_packages(),
19 | install_requires=read_requirements_file('requirements.txt'),
20 | license='LICENCE',
21 | )
--------------------------------------------------------------------------------
/supervised-jax/scripts/hs-v3.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-200m-hs-v2"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
25 | \"default_config_name\": \"200m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/v3/train_hs.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/v3/val_hs.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hs/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 1,
67 | \"eos_token_id\": 2,
68 | \"pad_token_id\": 0,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-hsp-v3"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/hsp-v3.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-200m-hsp-v2"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
25 | \"default_config_name\": \"200m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/v3/train_hsp.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/v3/val_hsp.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hsp/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 1,
67 | \"eos_token_id\": 2,
68 | \"pad_token_id\": 0,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-hsp-v3"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/sos-v3.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-200m-sos-v2"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
25 | \"default_config_name\": \"200m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/v3/train_sos.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/v3/val_sos.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 1,
67 | \"eos_token_id\": 2,
68 | \"pad_token_id\": 0,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-sos-v3"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_hs-v2_llama.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-200m-hs-v2"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=8000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
25 | \"default_config_name\": \"200m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/hs-v2/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/hs-v2/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 1,
67 | \"eos_token_id\": 2,
68 | \"pad_token_id\": 0,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-hs-v2-8k-step"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_hsp-v2_llama.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-200m-hsp-v2"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=8000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
25 | \"default_config_name\": \"200m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/hsp-v2/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/hsp-v2/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 1,
67 | \"eos_token_id\": 2,
68 | \"pad_token_id\": 0,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-hsp-v2-8k-step"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_sos_llama.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-300m-standard-500k-bs256-valid"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
25 | \"default_config_name\": \"300m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 1,
67 | \"eos_token_id\": 2,
68 | \"pad_token_id\": 0,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-300m-standard-valid-5e-5-bs256"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_sos_llama_1ep.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-300m-standard-500k-1ep-bs128-valid"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | pip install wget
19 | # pip install optax wandb
20 | source ~/miniconda3/bin/activate llama3_train
21 |
22 | TRAIN_STEPS=3900
23 | python llama_train_script.py \
24 | --load_model="paths:{
25 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\",
26 | \"default_config_name\": \"300m\"
27 | }" \
28 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
29 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
30 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
31 | --sharding="-1,1,1" \
32 | --num_train_steps=$TRAIN_STEPS \
33 | --max_length=4096 \
34 | --bsize=128 \
35 | --log_freq=100 \
36 | --num_eval_steps=500 \
37 | --save_model_freq=100000000 \
38 | --wandb_project="sos" \
39 | --param_dtype="fp32" \
40 | --activation_dtype="fp32" \
41 | --optim_config="adamw:{
42 | \"init_lr\": 5e-6,
43 | \"end_lr\": 5e-7,
44 | \"lr\": 5e-5,
45 | \"lr_warmup_steps\": 1,
46 | \"lr_decay_steps\": $TRAIN_STEPS,
47 | \"b1\": 0.9,
48 | \"b2\": 0.999,
49 | \"clip_gradient\": 1.0,
50 | \"weight_decay\": 0.01,
51 | \"bf16_momentum\": false,
52 | \"multiply_by_parameter_scale\": false,
53 | \"weight_decay_exclusions\": [],
54 | \"schedule\": \"cos\",
55 | \"grad_accum_steps\": 1
56 | }" \
57 | --logger_config="{
58 | \"online\": true,
59 | \"prefix\": \"$RUN_NAME\",
60 | \"prefix_to_id\": true
61 | }" \
62 | --checkpointer_config="{
63 | \"save_optimizer_state\": false,
64 | \"save_float_dtype\": \"bf16\"
65 | }" \
66 | --model_config_override="{
67 | \"bos_token_id\": 1,
68 | \"eos_token_id\": 2,
69 | \"pad_token_id\": 0,
70 | \"remat_block\": \"nothing_saveable\"
71 | }" \
72 | --eval_bsize=512 \
73 | --no_shuffle_train_data \
74 | --hf_repo_id="LM-Parallel/llama-300m-standard-valid-5e-5-bs128-1ep"
75 | )
76 |
77 |
78 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_sos_llama_gpt2tok.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-300m-gpt2tok-standard-500k-bs256-valid"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python llama_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"gpt2\",
25 | \"default_config_name\": \"300m\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 50256,
67 | \"eos_token_id\": 50256,
68 | \"pad_token_id\": 50256,
69 | \"remat_block\": \"nothing_saveable\"
70 | }" \
71 | --eval_bsize=512 \
72 | --no_shuffle_train_data \
73 | --hf_repo_id="LM-Parallel/llama-300m-gpt2tok-standard-valid-5e-5"
74 | )
75 |
76 |
77 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_sos_llama_gpt2tok_1ep.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="llama-300m-gpt2tok-standard-500k-1ep-bs128-valid"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2"
17 | cd ~/llama3_train
18 | pip install wget
19 | # pip install optax wandb
20 | source ~/miniconda3/bin/activate llama3_train
21 |
22 | TRAIN_STEPS=3900
23 | python llama_train_script.py \
24 | --load_model="paths:{
25 | \"tokenizer\": \"gpt2\",
26 | \"default_config_name\": \"300m\"
27 | }" \
28 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
29 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
30 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
31 | --sharding="-1,1,1" \
32 | --num_train_steps=$TRAIN_STEPS \
33 | --max_length=4096 \
34 | --bsize=128 \
35 | --log_freq=100 \
36 | --num_eval_steps=500 \
37 | --save_model_freq=100000000 \
38 | --wandb_project="sos" \
39 | --param_dtype="fp32" \
40 | --activation_dtype="fp32" \
41 | --optim_config="adamw:{
42 | \"init_lr\": 5e-6,
43 | \"end_lr\": 5e-7,
44 | \"lr\": 5e-5,
45 | \"lr_warmup_steps\": 1,
46 | \"lr_decay_steps\": $TRAIN_STEPS,
47 | \"b1\": 0.9,
48 | \"b2\": 0.999,
49 | \"clip_gradient\": 1.0,
50 | \"weight_decay\": 0.01,
51 | \"bf16_momentum\": false,
52 | \"multiply_by_parameter_scale\": false,
53 | \"weight_decay_exclusions\": [],
54 | \"schedule\": \"cos\",
55 | \"grad_accum_steps\": 1
56 | }" \
57 | --logger_config="{
58 | \"online\": true,
59 | \"prefix\": \"$RUN_NAME\",
60 | \"prefix_to_id\": true
61 | }" \
62 | --checkpointer_config="{
63 | \"save_optimizer_state\": false,
64 | \"save_float_dtype\": \"bf16\"
65 | }" \
66 | --model_config_override="{
67 | \"bos_token_id\": 50256,
68 | \"eos_token_id\": 50256,
69 | \"pad_token_id\": 50256,
70 | \"remat_block\": \"nothing_saveable\"
71 | }" \
72 | --eval_bsize=512 \
73 | --no_shuffle_train_data \
74 | --hf_repo_id="LM-Parallel/llama-300m-gpt2tok-standard-valid-5e-5-bs128-1ep"
75 | )
76 |
77 |
78 |
--------------------------------------------------------------------------------
/supervised-jax/scripts/training_sos_xiuyu.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="gpt2-standard-500k-bs128-valid"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=39000
22 | python gpt2_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"gpt2\",
25 | \"config\": \"LM-Parallel/jax-reference-gp2-s\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=128 \
34 | --log_freq=100 \
35 | --num_eval_steps=500 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 1e-5,
42 | \"end_lr\": 0,
43 | \"lr\": 1e-5,
44 | \"lr_warmup_steps\": 0,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 50256,
67 | \"eos_token_id\": 50256,
68 | \"pad_token_id\": 50256,
69 | \"remat_block\": \"nothing_saveable\",
70 | \"n_positions\": 4096
71 | }" \
72 | --eval_bsize=512 \
73 | --no_shuffle_train_data \
74 | --hf_repo_id="LM-Parallel/gpt2-standard-valid-1e-5"
75 | )
76 |
77 |
78 |
--------------------------------------------------------------------------------
/supervised-jax/to_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import json, os\n",
10 | "from google.cloud import storage\n",
11 | "import os\n",
12 | "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"CREDENTIALS_PATH\" # TODO: replace with your own credentials path\n",
13 | "# Upload to GCS\n",
14 | "bucket_name = \"BUCKET_NAME\" # TODO: replace with your own bucket name\n",
15 | "local_dir = \"LOCAL_DIR\" # TODO: replace with your own local directory\n",
16 | "prefix = \"PREFIX\" # TODO: replace with your own prefix"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 5,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "def process_to_gcp(file_path, name):\n",
26 | " # load \n",
27 | " with open(file_path, 'r') as f:\n",
28 | " data = json.load(f)\n",
29 | "\n",
30 | " # get all seqs\n",
31 | " if \"hsp\" not in name:\n",
32 | " seqs = [d['search_path'] for d in data]\n",
33 | " else:\n",
34 | " def add_all_calls(trace_dict):\n",
35 | " all_calls = []\n",
36 | " all_calls += trace_dict['main_calls']\n",
37 | " for sub in trace_dict['sub_calls']:\n",
38 | " for sub_trace in sub:\n",
39 | " all_calls += add_all_calls(sub_trace)\n",
40 | " return all_calls\n",
41 | " seqs = []\n",
42 | " for dp in data:\n",
43 | " seqs += add_all_calls(dp['trace_dict'])\n",
44 | " print(f\"name: {len(seqs)}\")\n",
45 | "\n",
46 | " # local save\n",
47 | " os.makedirs(local_dir, exist_ok=True)\n",
48 | " local_path = os.path.join(local_dir, f\"{name}.json\")\n",
49 | " with open(local_path, \"w\") as f:\n",
50 | " json.dump(seqs, f)\n",
51 | "\n",
52 | " # send to gcp\n",
53 | " storage_client = storage.Client()\n",
54 | " bucket = storage_client.bucket(bucket_name)\n",
55 | " # # Upload train file\n",
56 | " train_blob = bucket.blob(prefix+f\"{name}.json\")\n",
57 | " train_blob.upload_from_filename(local_path)"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 6,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "todos = [\n",
67 | " [\"./train_apr.json\", \"train_apr\"],\n",
68 | " # TODO: add your own data here\n",
69 | "]"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "from tqdm import tqdm\n",
79 | "for file_path, name in tqdm(todos):\n",
80 | " process_to_gcp(file_path, name)"
81 | ]
82 | }
83 | ],
84 | "metadata": {
85 | "kernelspec": {
86 | "display_name": "GPML",
87 | "language": "python",
88 | "name": "python3"
89 | },
90 | "language_info": {
91 | "codemirror_mode": {
92 | "name": "ipython",
93 | "version": 3
94 | },
95 | "file_extension": ".py",
96 | "mimetype": "text/x-python",
97 | "name": "python",
98 | "nbconvert_exporter": "python",
99 | "pygments_lexer": "ipython3",
100 | "version": "3.10.14"
101 | }
102 | },
103 | "nbformat": 4,
104 | "nbformat_minor": 2
105 | }
106 |
--------------------------------------------------------------------------------
/supervised-jax/training_hsp.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
13 | export RUN_NAME="hsp_v0_500k"
14 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
15 | export GCLOUD_PROJECT="civic-boulder-204700"
16 | export HF_TOKEN="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"
17 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6"
18 | cd ~/llama3_train
19 | # pip install optax wandb
20 | source ~/miniconda3/bin/activate llama3_train
21 |
22 | TRAIN_STEPS=19000
23 | python gpt2_train_script.py \
24 | --load_model="paths:{
25 | \"tokenizer\": \"gpt2\",
26 | \"config\": \"LM-Parallel/jax-reference-gp2-s\"
27 | }" \
28 | --train_data_path="gcs://jiayi-eu/data/hsp_v0_500k/train.json" \
29 | --eval_data_path="gcs://jiayi-eu/data/hsp_v0_500k/test.json" \
30 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hsp/" \
31 | --sharding="-1,1,1" \
32 | --num_train_steps=$TRAIN_STEPS \
33 | --max_length=4096 \
34 | --bsize=256 \
35 | --log_freq=512 \
36 | --num_eval_steps=512 \
37 | --save_model_freq=100000000 \
38 | --wandb_project="sos" \
39 | --param_dtype="fp32" \
40 | --activation_dtype="fp32" \
41 | --optim_config="adamw:{
42 | \"init_lr\": 5e-6,
43 | \"end_lr\": 5e-7,
44 | \"lr\": 5e-5,
45 | \"lr_warmup_steps\": 1,
46 | \"lr_decay_steps\": $TRAIN_STEPS,
47 | \"b1\": 0.9,
48 | \"b2\": 0.999,
49 | \"clip_gradient\": 100.0,
50 | \"weight_decay\": 0.01,
51 | \"bf16_momentum\": false,
52 | \"multiply_by_parameter_scale\": false,
53 | \"weight_decay_exclusions\": [],
54 | \"schedule\": \"cos\",
55 | \"grad_accum_steps\": 1
56 | }" \
57 | --logger_config="{
58 | \"online\": true,
59 | \"prefix\": \"$RUN_NAME\",
60 | \"prefix_to_id\": true
61 | }" \
62 | --checkpointer_config="{
63 | \"save_optimizer_state\": false,
64 | \"save_float_dtype\": \"bf16\"
65 | }" \
66 | --model_config_override="{
67 | \"bos_token_id\": 50256,
68 | \"eos_token_id\": 50256,
69 | \"pad_token_id\": 50256,
70 | \"remat_block\": \"nothing_saveable\",
71 | \"resid_pdrop\": 0.00,
72 | \"embd_pdrop\": 0.00,
73 | \"attn_pdrop\": 0.00,
74 | \"n_positions\": 4096
75 | }" \
76 | --eval_bsize=512 \
77 | --no_shuffle_train_data \
78 | --hf_repo_id="LM-Parallel/hsp-v0"
79 | )
80 |
81 |
82 |
--------------------------------------------------------------------------------
/supervised-jax/training_hsp_2x.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="hsp_v0_500k"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python gpt2_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"gpt2\",
25 | \"config\": \"LM-Parallel/jax-reference-gpt2-2x-s\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/hsp_v0_500k/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/hsp_v0_500k/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hsp/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=512 \
35 | --num_eval_steps=512 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 100.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 50256,
67 | \"eos_token_id\": 50256,
68 | \"pad_token_id\": 50256,
69 | \"remat_block\": \"nothing_saveable\",
70 | \"resid_pdrop\": 0.00,
71 | \"embd_pdrop\": 0.00,
72 | \"attn_pdrop\": 0.00,
73 | \"n_positions\": 4096
74 | }" \
75 | --eval_bsize=512 \
76 | --no_shuffle_train_data \
77 | --hf_repo_id="LM-Parallel/hsp-v0-2x"
78 | )
79 |
80 |
81 |
--------------------------------------------------------------------------------
/supervised-jax/training_run.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="gpt2_sos_jan12"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=19000
22 | python gpt2_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"gpt2\",
25 | \"config\": \"LM-Parallel/jax-reference-gp2-s\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=256 \
34 | --log_freq=512 \
35 | --num_eval_steps=512 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 5e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.999,
48 | \"clip_gradient\": 100.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 50256,
67 | \"eos_token_id\": 50256,
68 | \"pad_token_id\": 50256,
69 | \"remat_block\": \"nothing_saveable\",
70 | \"resid_pdrop\": 0.1,
71 | \"embd_pdrop\": 0.00,
72 | \"attn_pdrop\": 0.00,
73 | \"n_positions\": 4096
74 | }" \
75 | --eval_bsize=512 \
76 | --no_shuffle_train_data \
77 | --hf_repo_id="LM-Parallel/standard-ref-5e-5-rdrop01"
78 | )
79 |
80 |
81 |
--------------------------------------------------------------------------------
/supervised-jax/training_run_test.sh:
--------------------------------------------------------------------------------
1 | # 12/9/24
2 |
3 | # test gpt2 training, this will train a randomly initialized model
4 |
5 | # charlie-pod
6 | # \"tokenizer\": \"gpt2\",
7 | # \"config\": \"gpt2\"
8 | # gs://
9 |
10 | (
11 | source ~/miniconda3/bin/activate llama3_train
12 | export RUN_NAME="gpt2_sos_jan15"
13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json"
14 | export GCLOUD_PROJECT="civic-boulder-204700"
15 | export HF_TOKEN="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"
16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6"
17 | cd ~/llama3_train
18 | # pip install optax wandb
19 | source ~/miniconda3/bin/activate llama3_train
20 |
21 | TRAIN_STEPS=5000
22 | python gpt2_train_script.py \
23 | --load_model="paths:{
24 | \"tokenizer\": \"gpt2\",
25 | \"config\": \"LM-Parallel/jax-reference-gp2-s\"
26 | }" \
27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \
28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \
29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \
30 | --sharding="-1,1,1" \
31 | --num_train_steps=$TRAIN_STEPS \
32 | --max_length=4096 \
33 | --bsize=128 \
34 | --log_freq=512 \
35 | --num_eval_steps=512 \
36 | --save_model_freq=100000000 \
37 | --wandb_project="sos" \
38 | --param_dtype="fp32" \
39 | --activation_dtype="fp32" \
40 | --optim_config="adamw:{
41 | \"init_lr\": 5e-6,
42 | \"end_lr\": 5e-7,
43 | \"lr\": 1e-5,
44 | \"lr_warmup_steps\": 1,
45 | \"lr_decay_steps\": $TRAIN_STEPS,
46 | \"b1\": 0.9,
47 | \"b2\": 0.95,
48 | \"clip_gradient\": 1.0,
49 | \"weight_decay\": 0.01,
50 | \"bf16_momentum\": false,
51 | \"multiply_by_parameter_scale\": false,
52 | \"weight_decay_exclusions\": [],
53 | \"schedule\": \"cos\",
54 | \"grad_accum_steps\": 1
55 | }" \
56 | --logger_config="{
57 | \"online\": true,
58 | \"prefix\": \"$RUN_NAME\",
59 | \"prefix_to_id\": true
60 | }" \
61 | --checkpointer_config="{
62 | \"save_optimizer_state\": false,
63 | \"save_float_dtype\": \"bf16\"
64 | }" \
65 | --model_config_override="{
66 | \"bos_token_id\": 50256,
67 | \"eos_token_id\": 50256,
68 | \"pad_token_id\": 50256,
69 | \"remat_block\": \"nothing_saveable\",
70 | \"resid_pdrop\": 0.00,
71 | \"embd_pdrop\": 0.00,
72 | \"attn_pdrop\": 0.00,
73 | \"n_positions\": 4096
74 | }" \
75 | --eval_bsize=512 \
76 | --no_shuffle_train_data \
77 | --hf_repo_id="LM-Parallel/standard-ref-test"
78 | )
79 |
80 |
81 |
--------------------------------------------------------------------------------
/tinyrl/README.md:
--------------------------------------------------------------------------------
1 | # Installation
2 | ```
3 | conda create -n grpo python=3.10
4 | # install sgl
5 | pip install --upgrade pip
6 | pip install sgl-kernel --force-reinstall --no-deps
7 | pip install "sglang[all]>=0.4.3.post1" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
8 |
9 | # other utils
10 | pip install hydra-core omegaconf wandb
11 | ```
12 |
13 | ## For APR, SGLang needs to be patched
14 | Remove this check in `python/sglang/srt/managers/tokenizer_manager.py` in your local SGLang repo:
15 | ```
16 | # if (
17 | # obj.sampling_params.get("max_new_tokens") is not None
18 | # and obj.sampling_params.get("max_new_tokens") + input_token_num
19 | # >= self.context_len
20 | # ):
21 | # raise ValueError(
22 | # f"Requested token count exceeds the model's maximum context length "
23 | # f"of {self.context_len} tokens. You requested a total of "
24 | # f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
25 | # f"tokens: {input_token_num} tokens from the input messages and "
26 | # f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
27 | # f"completion. Please reduce the number of tokens in the input "
28 | # f"messages or the completion to fit within the limit."
29 | # )
30 | ```
31 |
32 | This file is at https://github.com/sgl-project/sglang/blob/45205d88a08606d5875476fbbbc76815a5107edd/python/sglang/srt/managers/tokenizer_manager.py#L350
33 |
34 | # Data Preparation
35 | Please put `sosp_train_prefix.json`, `sosp_val_prefix.json`, `apr_train_beam10_subbeam15_prefix.json`, and `apr_val_prefix.json` in the `data` folder.
36 |
37 | You can download them from [https://huggingface.co/datasets/Parallel-Reasoning/apr_rl_data](https://huggingface.co/datasets/Parallel-Reasoning/apr_rl_data).
38 |
39 | # Run
40 | Each run requires two GPUs: one for model training and one for serving with SGLang.
41 |
42 | ## RL on APR without subcall condition
43 | ```
44 | export CUDA_VISIBLE_DEVICES=0,1
45 | python trainer.py --config-name apr
46 | ```
47 |
48 | Reference checkpoint: [https://huggingface.co/Parallel-Reasoning/apr_grpo](https://huggingface.co/Parallel-Reasoning/apr_grpo)
49 |
50 | ## RL on APR with subcall condition (condition set to 10)
51 | ```
52 | export CUDA_VISIBLE_DEVICES=0,1
53 | python trainer.py --config-name apr_cond10
54 | ```
55 |
56 | Reference checkpoint: [https://huggingface.co/Parallel-Reasoning/apr_cond10_grpo](https://huggingface.co/Parallel-Reasoning/apr_cond10_grpo)
57 |
58 | ## RL on SOS+
59 | ```
60 | export CUDA_VISIBLE_DEVICES=0,1
61 | python trainer.py --config-name sosp
62 | ```
63 |
64 | Reference checkpoint: [https://huggingface.co/Parallel-Reasoning/sosp_grpo](https://huggingface.co/Parallel-Reasoning/sosp_grpo)
65 |
66 | You can set your own config files and specify different configs with `--config-name `.
67 |
--------------------------------------------------------------------------------
/tinyrl/configs/apr.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/hydra_logging: disabled
4 | - override hydra/job_logging: disabled
5 |
6 | # Data configuration
7 | data:
8 | train_data_path: data/apr_train_beam10_subbeam15_prefix.json
9 | val_data_path: data/apr_val_prefix.json
10 | train_batch_size: 64
11 | eval_batch_size: 128
12 |
13 | # Model configuration
14 | model:
15 | name: "Parallel-Reasoning/llama-apr-beam10_subbeam15"
16 | server_url: "http://localhost:0"
17 | master_port: 0
18 | mem_fraction_static: 0.6
19 | gradient_checkpointing: true # Enable gradient checkpointing to save memory
20 |
21 | # Rollout configuration
22 | rollout:
23 | mode: "apr"
24 | sample_temperature: 1.0
25 | eval_temperature: 0.0
26 | group_size: 5
27 | condition_prefix: null
28 |
29 | # Training configuration
30 | training:
31 | learning_rate: 1e-5
32 | ppo_clip_ratio: 0.2
33 | num_steps: 150
34 | inner_steps: 2
35 | kl_beta: 0.001
36 | grad_clip: 1.0
37 | # This grad_accum_chunk_size is used for gradient accumulation
38 | # It is used to avoid OOM when the batch size is too large
39 | grad_accum_chunk_size: 4
40 | log_probs_chunk_size: 16
41 | validate_per_steps: 50
42 | # Number of validation samples to use (null for full validation set), set to 16 for testing (skipping validation), set to a larger number for actual validation
43 | val_samples: null
44 | output_dir: checkpoints/apr
45 | save_steps: 10
46 |
47 | # Logging configuration
48 | logging:
49 | verbose: false
50 | wandb_project: "tinyrl-apr"
51 | wandb_name: "apr_grpo"
52 | log_interval: 1
53 | use_current_time: true
54 |
55 | hydra:
56 | run:
57 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
58 | sweep:
59 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
60 | subdir: ${hydra.job.num}
61 |
--------------------------------------------------------------------------------
/tinyrl/configs/apr_cond10.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/hydra_logging: disabled
4 | - override hydra/job_logging: disabled
5 |
6 | # Data configuration
7 | data:
8 | train_data_path: data/apr_train_beam10_subbeam15_prefix.json
9 | val_data_path: data/apr_val_prefix.json
10 | train_batch_size: 64
11 | eval_batch_size: 128
12 |
13 | # Model configuration
14 | model:
15 | name: "Parallel-Reasoning/llama-apr-beam10_subbeam15_num_subcall_cond"
16 | server_url: "http://localhost:0"
17 | master_port: 0
18 | mem_fraction_static: 0.6
19 | gradient_checkpointing: true # Enable gradient checkpointing to save memory
20 |
21 | # Rollout configuration
22 | rollout:
23 | mode: "apr"
24 | sample_temperature: 1.0
25 | eval_temperature: 1.0
26 | group_size: 5
27 | # condition_prefix: null
28 | condition_prefix: "Sub Call Budget: 10 "
29 |
30 | # Training configuration
31 | training:
32 | learning_rate: 1e-5
33 | ppo_clip_ratio: 0.2
34 | num_steps: 150
35 | inner_steps: 2
36 | kl_beta: 0.001
37 | grad_clip: 1.0
38 | # This grad_accum_chunk_size is used for gradient accumulation
39 | # It is used to avoid OOM when the batch size is too large
40 | grad_accum_chunk_size: 4
41 | log_probs_chunk_size: 16
42 | validate_per_steps: 50
43 | # Number of validation samples to use (null for full validation set), set to 16 for testing, set to a larger number for actual validation
44 | val_samples: 16
45 | output_dir: checkpoints/apr
46 | save_steps: 10
47 |
48 | # Logging configuration
49 | logging:
50 | verbose: false
51 | wandb_project: "tinyrl-apr"
52 | wandb_name: "apr_cond10_grpo"
53 | log_interval: 1
54 | use_current_time: true
55 |
56 | hydra:
57 | run:
58 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
59 | sweep:
60 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
61 | subdir: ${hydra.job.num}
--------------------------------------------------------------------------------
/tinyrl/configs/sosp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/hydra_logging: disabled
4 | - override hydra/job_logging: disabled
5 |
6 | # Data configuration
7 | data:
8 | train_data_path: data/sosp_train_prefix.json # To be specified by user
9 | val_data_path: data/sosp_val_prefix.json # To be specified by user
10 | train_batch_size: 64
11 | eval_batch_size: 128
12 |
13 | # Model configuration
14 | model:
15 | name: "Parallel-Reasoning/llama-sosp"
16 | server_url: "http://localhost:29215"
17 | master_port: 45516
18 | mem_fraction_static: 0.9
19 | gradient_checkpointing: true # Enable gradient checkpointing to save memory
20 |
21 | # Rollout configuration
22 | rollout:
23 | mode: "sos"
24 | sample_temperature: 1.0
25 | eval_temperature: 0.5
26 | group_size: 5
27 |
28 | # Training configuration
29 | training:
30 | learning_rate: 1e-5
31 | ppo_clip_ratio: 0.2
32 | num_steps: 150
33 | inner_steps: 2
34 | kl_beta: 0.01
35 | grad_clip: 1.0
36 | # This grad_accum_chunk_size is used for gradient accumulation
37 | # It is used to avoid OOM when the batch size is too large
38 | grad_accum_chunk_size: 4
39 | log_probs_chunk_size: 16
40 | validate_per_steps: 25
41 | # Number of validation samples to use (null for full validation set), set to 16 for testing, set to a larger number for actual validation
42 | val_samples: null
43 | output_dir: checkpoints/sosp
44 | save_steps: 50
45 |
46 | # Logging configuration
47 | logging:
48 | verbose: false
49 | wandb_project: "tinyrl"
50 | wandb_name: "sosp_grpo"
51 | log_interval: 1
52 | use_current_time: true
53 |
54 | hydra:
55 | run:
56 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
57 | sweep:
58 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
59 | subdir: ${hydra.job.num}
60 |
--------------------------------------------------------------------------------
/tinyrl/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/tinyrl/data/.gitkeep
--------------------------------------------------------------------------------
/tinyrl/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | sglang
3 | wandb
4 | hydra-core
5 | termcolor
6 |
--------------------------------------------------------------------------------
/tinyrl/rollout/apr_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | export CUDA_VISIBLE_DEVICES=7
3 | python -m sglang.launch_server --model-path LM-Parallel/llama-apr-v3 --host localhost --served-model-name model
4 | """
5 | import re
6 | from transformers import AutoTokenizer
7 | from litellm import text_completion
8 | from concurrent.futures import ThreadPoolExecutor, as_completed
9 | import sglang
10 |
11 | # Initialize tokenizer
12 | MODEL_NAME = "LM-Parallel/llama-apr-v3"
13 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14 | print(f"Tokenizer loaded for {MODEL_NAME}")
15 |
16 | VERBOSE = False
17 |
18 | def check_solution(prefix: str, solution: str) -> bool:
19 | """
20 | Parses the prefix and solution to verify if the solution actually
21 | solves the puzzle of reaching `target` from the given list of numbers.
22 |
23 | :param prefix: A line like:
24 | "Moving to Node #1\nCurrent State: 62:[5, 50, 79, 27], Operations: []"
25 | :param solution: The multiline string describing the step-by-step solution.
26 | :return: True if the solution's final result matches the target and
27 | all stated operations are valid. False otherwise.
28 | """
29 | # -----------------------------------------------------------------
30 | # 1. Parse the prefix to extract target and initial numbers
31 | # -----------------------------------------------------------------
32 | # Example prefix line to parse:
33 | # Current State: 62:[5, 50, 79, 27], Operations: []
34 | #
35 | # We'll look for something matching:
36 | # Current State: :[], ...
37 | prefix_pattern = r"Current State:\s*(\d+):\[(.*?)\]"
38 | match = re.search(prefix_pattern, prefix)
39 | if not match:
40 | if VERBOSE:
41 | print("ERROR: Could not parse the prefix for target and numbers.")
42 | return False
43 |
44 | target_str, numbers_str = match.groups()
45 | target = int(target_str.strip())
46 | # Now parse something like "5, 50, 79, 27" into a list of integers
47 | if numbers_str.strip():
48 | initial_numbers = [int(x.strip()) for x in numbers_str.split(",")]
49 | else:
50 | initial_numbers = []
51 |
52 | # We'll keep track of our current working list of numbers
53 | current_numbers = initial_numbers
54 |
55 | # -----------------------------------------------------------------
56 | # 2. Parse solution to extract lines with "Exploring Operation:"
57 | # -----------------------------------------------------------------
58 | # Example lines:
59 | # Exploring Operation: 79-27=52, Resulting Numbers: [5, 50, 52]
60 | # We want to parse out: operand1=79, operator='-', operand2=27, result=52
61 | # Then parse the new list: [5, 50, 52]
62 |
63 | operation_pattern = r"Exploring Operation:\s*([\d]+)([\+\-\*/])([\d]+)=(\d+),\s*Resulting Numbers:\s*\[(.*?)\]"
64 |
65 | # We'll process the solution line-by-line
66 | # so that we can also capture the final "Goal Reached" line.
67 | lines = solution.splitlines()
68 |
69 | for line in lines:
70 | line = line.strip()
71 |
72 | # Check for "Exploring Operation"
73 | op_match = re.search(operation_pattern, line)
74 | if op_match:
75 | # Parse out the operation parts
76 | x_str, op, y_str, z_str, new_nums_str = op_match.groups()
77 | x_val = int(x_str)
78 | y_val = int(y_str)
79 | z_val = int(z_str)
80 |
81 | # Parse the new list of numbers from something like "5, 50, 52"
82 | new_numbers = []
83 | if new_nums_str.strip():
84 | new_numbers = [int(n.strip()) for n in new_nums_str.split(",")]
85 |
86 | # -------------------------------------------------------------
87 | # Verify that applying X op Y => Z to current_numbers is valid
88 | # -------------------------------------------------------------
89 | # 1. X and Y must both be present in current_numbers
90 | # 2. Remove X and Y from current_numbers
91 | # 3. Add Z
92 | # 4. The new list must match exactly the "Resulting Numbers"
93 | # 5. Also verify the arithmetic was correct (if you want to be strict)
94 | #
95 | # NOTE: we do not handle repeating values carefully here if X or Y
96 | # appear multiple times, but you can adapt as needed (e.g. remove once).
97 |
98 | temp_list = current_numbers[:]
99 |
100 | # Try removing X, Y once each
101 | try:
102 | temp_list.remove(x_val)
103 | temp_list.remove(y_val)
104 | except ValueError:
105 | if VERBOSE:
106 | print(f"ERROR: {x_val} or {y_val} not found in current_numbers {current_numbers}.")
107 | return False
108 |
109 | # Check that the stated Z matches the arithmetic operation
110 | # (If you want to skip verifying the math, remove these lines.)
111 | computed_result = None
112 | if op == '+':
113 | computed_result = x_val + y_val
114 | elif op == '-':
115 | computed_result = x_val - y_val
116 | elif op == '*':
117 | computed_result = x_val * y_val
118 | elif op == '/':
119 | # watch for zero division or non-integer division if you want
120 | # to require exact integer results
121 | if y_val == 0:
122 | if VERBOSE:
123 | print("ERROR: Division by zero encountered.")
124 | return False
125 | # For a typical "24 game" style puzzle, we allow float or integer check
126 | computed_result = x_val / y_val
127 | # If the puzzle requires integer arithmetic only, check remainder:
128 | # if x_val % y_val != 0:
129 | # print("ERROR: Non-integer division result.")
130 | # return False
131 |
132 | # Compare the stated z_val to the computed result
133 | # (if it's integer-based arithmetic, we might check int(...) or round)
134 | if computed_result is None:
135 | if VERBOSE:
136 | print("ERROR: Unknown operation encountered.")
137 | return False
138 |
139 | # If we want exact integer match (for e.g. 50/5=10):
140 | # If float is possible, we might do a small epsilon check:
141 | # e.g. if abs(computed_result - z_val) > 1e-9
142 | if computed_result != z_val:
143 | if VERBOSE:
144 | print(f"ERROR: Operation {x_val}{op}{y_val} does not equal {z_val}. Got {computed_result} instead.")
145 | return False
146 |
147 | # Now add the result to temp_list
148 | temp_list.append(z_val)
149 | # Sort if you do not care about order, or keep order if you do
150 | # and compare to new_numbers
151 | # We'll assume exact order is not critical, so let's do a sorted comparison:
152 | if sorted(temp_list) != sorted(new_numbers):
153 | if VERBOSE:
154 | print(f"ERROR: After applying {x_val}{op}{y_val}={z_val} to {current_numbers}, "
155 | f"got {sorted(temp_list)} but solution says {sorted(new_numbers)}.")
156 | return False
157 |
158 | # If we got here, it means the operation is consistent
159 | current_numbers = new_numbers
160 |
161 | # ---------------------------------------------------------
162 | # 3. Check for "Goal Reached" line
163 | # ---------------------------------------------------------
164 | # Something like: "62,62 equal: Goal Reached"
165 | # We'll check if the final single number is indeed `target`.
166 | if "Goal Reached" in line:
167 | # For a simple check, if "Goal Reached" is present,
168 | # confirm that current_numbers is [target].
169 | if len(current_numbers) == 1 and current_numbers[0] == target:
170 | return True
171 | else:
172 | if VERBOSE:
173 | print("ERROR: 'Goal Reached' declared but final numbers don't match the target.")
174 | return False
175 |
176 | # If we never saw "Goal Reached," then it's incomplete
177 | # or didn't declare success. Return False by default
178 | if VERBOSE:
179 | print("ERROR: Did not find 'Goal Reached' in solution.")
180 | return False
181 |
182 | def get_search_result(search_trace):
183 | # Given a search trace, return the result of the search
184 | # If the search is successful, return the result optimal path
185 | # If the search is unsuccessful, return None
186 | if search_trace.count("Goal Reached") >= 2:
187 | # Find all occurrences of "Goal Reached"
188 | goal_indices = [i for i in range(len(search_trace)) if search_trace.startswith("Goal Reached", i)]
189 | # Get the second to last index, this is where we begin generate
190 | # the optimal path
191 | goal_idx = goal_indices[-2]
192 | return search_trace[goal_idx:].strip()[13:]
193 | else:
194 | return None
195 |
196 | def get_subsearch_info(search_trace):
197 | try:
198 | return _get_subsearch_info(search_trace)
199 | except Exception as e:
200 | print(f"Error at get_subsearch_info: {e}")
201 | print(search_trace)
202 | raise e
203 |
204 | def _get_subsearch_info(search_trace):
205 | # Given a search trace, return the information of the
206 | # subsearch that it wants to invoke
207 | # sub_search= {"node": "#1,1,2", "target": 39, 'nums':[2, 11], "operations": ["51-49=2", "36-25=11"]}
208 | last_line = search_trace.split("\n")[-1]
209 | assert "" in last_line, "This is not a valid subsearch trace"
210 |
211 | # --- Parse the search trace to get the generated nodes ---
212 | generated_nodes = {}
213 | # First find any "Moving to Node" lines followed by "Current State" lines
214 | lines = search_trace.split("\n")
215 | for i in range(len(lines)-1):
216 | # Moving to Node #1,1
217 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2']
218 | if "Moving to Node #" in lines[i] and "Current State:" in lines[i+1]:
219 | # Extract node id from first line like:
220 | # Moving to Node #1,1
221 | node_id = lines[i].split("Moving to Node #")[1].strip()
222 |
223 | # Extract state from second line like:
224 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2']
225 | state_line = lines[i+1]
226 | state_part = state_line.split("Current State:")[1].split("],")[0].strip()
227 | operations_part = state_line.split("Operations:")[1].strip()
228 |
229 | # Parse state like "39:[25, 36, 2]"
230 | target = int(state_part.split(":")[0])
231 | # nums = eval(state_part.split(":")[1].strip())
232 | nums = eval(state_part.split(":")[1].strip() + "]")
233 | operations = eval(operations_part)
234 |
235 | # Parse operations list
236 |
237 | generated_nodes[node_id] = {
238 | "node": f"#{node_id}",
239 | "target": target,
240 | "nums": nums,
241 | "operations": operations
242 | }
243 | for line in search_trace.split("\n"):
244 | if "Generated Node" in line:
245 | # Extract node id and info from line like:
246 | # Generated Node #1,1,2: 39:[2, 11] Operation: 36-25=11
247 | node_id = line.split(":")[0].split("#")[1]
248 | if node_id in generated_nodes:
249 | continue
250 | rest = line.split(":", 1)[1].strip()
251 | state = rest.split("Operation:")[0].strip()
252 | operation = rest.split("Operation:")[1].strip()
253 |
254 | # Parse state like "39:[2, 11]" into target and nums
255 | target = int(state.split(":")[0])
256 | nums = eval(state.split(":")[1].strip())
257 |
258 | parent_node_id = ",".join(node_id.split(",")[:-1])
259 | parent_node = generated_nodes[parent_node_id]
260 | new_operations = parent_node["operations"] + [operation]
261 |
262 | generated_nodes[node_id] = {
263 | "node": f"#{node_id}",
264 | "target": target,
265 | "nums": nums,
266 | "operations": new_operations
267 | }
268 | # then we construct the sub_searches
269 | sub_search_nodes = []
270 | # Split on and take the last chunk
271 | last_chunk = search_trace.split("\n")[-1]
272 | # Split that chunk on and take first part
273 | sub_search_section = last_chunk.split("\n")[0]
274 |
275 | for line in sub_search_section.split("\n"):
276 | if " Moving to Node #1,1,2
279 | node_id = line.split("Moving to Node #")[1].strip()
280 | sub_search_nodes.append(generated_nodes[node_id])
281 |
282 | def construct_sub_search_prefix(node):
283 | # exmaple
284 | # "Moving to Node #1,1,0\nCurrent State: 39:[36, 50], Operations: ['51-49=2', '25*2=50']
285 | return f"Moving to Node {node['node']}\nCurrent State: {node['target']}:[{', '.join(map(str, node['nums']))}], Operations: {node['operations']}"
286 | sub_search_prefix_list = [construct_sub_search_prefix(node) for node in sub_search_nodes]
287 | return sub_search_prefix_list, sub_search_nodes
288 |
289 | def get_main_trace_after_sub_search(main_trace, sub_search_nodes, sub_search_result_list):
290 | last_line = main_trace.split("\n")[-1]
291 | assert "" in last_line, "This is not a valid subsearch trace"
292 | sub_searches = []
293 | # Split on and take the last chunk
294 | last_chunk = main_trace.split("\n")[-1]
295 | # Split that chunk on and take first part
296 | sub_search_section = last_chunk.split("\n")[0]
297 | main_trace_after_sub_search = "\n".join(main_trace.split("\n")[:-1])
298 | assert main_trace_after_sub_search in main_trace
299 | main_trace_after_sub_search += "\n"
300 | for i, (this_node, this_result) in enumerate(zip(sub_search_nodes, sub_search_result_list)):
301 | if this_result is None:
302 | #
303 | main_trace_after_sub_search += f"\n"
304 | else:
305 | # \nMoving to Node #1,2,0\nCurrent State: 39:[51, 12], Operations: ['49-36=13', '25-13=12']\nExploring Operation: 51-12=39, Resulting Numbers: [39]\n39,39 equal: Goal Reached\n
306 | main_trace_after_sub_search += f"\n"
307 | main_trace_after_sub_search += this_result + "\n"
308 | main_trace_after_sub_search += "\n"
309 | return main_trace_after_sub_search
310 |
311 |
312 |
313 | def add_angle_brackets(text):
314 | lines = text.split('\n')
315 | result_lines = []
316 | for line in lines:
317 | if '>' in line and '<' not in line:
318 | line = '<' + line
319 | result_lines.append(line)
320 | return '\n'.join(result_lines)
321 |
322 | def generate(prefix, tokenizer, api_base_url, temperature=0.5, stop=[]):
323 | """Generate text using the model API"""
324 | bos_token = tokenizer.bos_token
325 | prefix = bos_token + prefix
326 |
327 | result = text_completion(
328 | model="openai/model",
329 | prompt=prefix,
330 | api_base=api_base_url,
331 | api_key="api_key",
332 | temperature=temperature,
333 | max_tokens=4096,
334 | stop=stop,
335 | )
336 |
337 | text = result['choices'][0]['text']
338 | complete_text = prefix + text
339 | complete_text = complete_text.replace(bos_token, ' ')
340 | if complete_text[0] == ' ':
341 | complete_text = complete_text[1:]
342 | return complete_text, result
343 |
344 | def decode_trace(prefix, tokenizer, api_base_url, temperature=0.5):
345 | """Decode a single trace"""
346 | while True:
347 | trace = generate(prefix, tokenizer, api_base_url, temperature=temperature, stop=[""])
348 | llm_call_info = {
349 | "prefix": prefix, # Store the original prefix
350 | "output": trace[1]["choices"][0]["text"]
351 | }
352 | # Store the prefix and output in a dictionary
353 | prefix = trace[0]
354 | if trace[1].choices[0].matched_stop == "":
355 | prefix += ""
356 | else:
357 | break
358 | prefix = trace[0]
359 | if prefix.split('\n')[-1] == "":
360 | prefix = prefix[:-1]
361 | return prefix, llm_call_info
362 |
363 | def batch_decode_trace(prefix_list, tokenizer, api_base_url, temperature=0.5, max_workers=16):
364 | """Decode multiple traces in parallel"""
365 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
366 | future_to_prefix = {
367 | executor.submit(decode_trace, prefix, tokenizer, api_base_url, temperature): prefix
368 | for prefix in prefix_list
369 | }
370 |
371 | results = [None] * len(prefix_list)
372 | llm_calls = [None] * len(prefix_list)
373 |
374 | for future in as_completed(future_to_prefix):
375 | prefix = future_to_prefix[future]
376 | try:
377 | result, llm_call_info = future.result()
378 | original_idx = prefix_list.index(prefix)
379 | results[original_idx] = result
380 | llm_calls[original_idx] = llm_call_info
381 | except Exception as e:
382 | print(f"Error processing prefix: {e}")
383 | original_idx = prefix_list.index(prefix)
384 | results[original_idx] = None
385 | llm_calls[original_idx] = None
386 |
387 | return results, llm_calls
388 |
389 | def is_calling_subsearch(trace):
390 | return "" in trace.split('\n')[-1]
391 |
392 | def call_search(prefix, api_base_url, temperature=0.5):
393 | """
394 | Main search function that processes a given prefix
395 |
396 | Args:
397 | prefix (str): The initial state and goal
398 | api_base_url (str): The URL for the model API
399 | temperature (float): Temperature parameter for text generation
400 |
401 | Returns:
402 | tuple: (solution, trace_dict, success, true_success)
403 | """
404 | condition_prefix = prefix.split("\n")[0].split("Moving to Node")[0]
405 | try:
406 | trace_dict = {"main_calls": [], "sub_calls": [], "llm_calls": []}
407 | trace, llm_call_info = decode_trace(prefix, tokenizer, api_base_url, temperature=temperature)
408 | trace_dict["main_calls"].append(trace)
409 | trace_dict["llm_calls"].append(llm_call_info)
410 |
411 | while is_calling_subsearch(trace):
412 | sub_search_prefix_list, sub_search_nodes = get_subsearch_info(trace)
413 | for idx, sub_search_prefix in enumerate(sub_search_prefix_list):
414 | if condition_prefix.startswith("Sub Call Budget: "):
415 | # we don't need to add the condition prefix to the sub search prefix if the condition prefix starts with "Sub Call Budget: "
416 | sub_search_prefix_list[idx] = sub_search_prefix
417 | elif condition_prefix.startswith("Token Budget: "):
418 | # we need to add the condition prefix to the sub search prefix if the condition prefix starts with "Token Budget: "
419 | sub_search_prefix_list[idx] = condition_prefix + sub_search_prefix
420 | elif condition_prefix:
421 | raise ValueError(f"Unknown condition prefix: {condition_prefix}")
422 | sub_search_traces, sub_search_llm_calls = batch_decode_trace(sub_search_prefix_list, tokenizer, api_base_url, temperature=temperature)
423 |
424 | trace_dict["sub_calls"].append([])
425 | for sub_search_trace, sub_search_llm_call in zip(sub_search_traces, sub_search_llm_calls):
426 | trace_dict["sub_calls"][-1].append({
427 | "main_calls": [sub_search_trace],
428 | "llm_calls": [sub_search_llm_call]
429 | })
430 |
431 | sub_search_results = [get_search_result(trace) for trace in sub_search_traces]
432 | new_prefix = get_main_trace_after_sub_search(trace, sub_search_nodes, sub_search_results)
433 | trace, llm_call_info = decode_trace(new_prefix, tokenizer, api_base_url, temperature=temperature)
434 | trace_dict["main_calls"].append(trace)
435 | trace_dict["llm_calls"].append(llm_call_info)
436 |
437 | solution = get_search_result(trace)
438 | success = solution is not None
439 | true_success = check_solution(prefix, solution) if success else False
440 |
441 | return solution, trace_dict, success, true_success
442 |
443 | except Exception as e:
444 | print(f"Error in call_search: {e}")
445 | return None, trace_dict, False, False
446 |
447 | def process_single_prefix(prefix, server_url, bos_token, temperature=0.5):
448 | """Helper function to process a single prefix for parallel execution"""
449 | solution, trace_dict, success, true_success = call_search(prefix, server_url, temperature)
450 | seqs = []
451 | for lm_call in trace_dict['llm_calls']:
452 | seqs.append(lm_call)
453 | for sub_calls in trace_dict['sub_calls']:
454 | for sub_call in sub_calls:
455 | for lm_call in sub_call['llm_calls']:
456 | seqs.append(lm_call)
457 | return {
458 | "seqs": seqs,
459 | "is_correct": true_success
460 | }
461 |
462 | def rollout_apr(server_url, prefix_list, bos_token, temperature=0.5, max_workers=32, condition_prefix=""
463 | ):
464 | """
465 | Parallel implementation of rollout function using ThreadPoolExecutor
466 | """
467 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
468 | # Create futures with index tracking
469 | futures = []
470 | for idx, prefix in enumerate(prefix_list):
471 | future = executor.submit(
472 | process_single_prefix,
473 | condition_prefix + prefix,
474 | server_url,
475 | bos_token,
476 | temperature
477 | )
478 | futures.append((idx, future))
479 |
480 | # Initialize results list with correct size
481 | results = [None] * len(prefix_list)
482 |
483 | # Collect results as they complete
484 | for idx, future in futures:
485 | try:
486 | result = future.result()
487 | results[idx] = result
488 | except Exception as e:
489 | print(f"Error processing prefix at index {idx}: {e}")
490 | results[idx] = {
491 | "seqs": [],
492 | "is_correct": False
493 | }
494 | # if the result is None, something's wrong in the decoding process
495 | # we use a dummy result to avoid breaking the loop
496 | for idx in range(len(results)):
497 | if not results[idx]:
498 | print(f"Error processing prefix {prefix_list[idx]} at index {idx}")
499 | results[idx] = {
500 | "seqs": [],
501 | "is_correct": False
502 | }
503 | for result_idx in range(len(results)):
504 | for seq_id in reversed(range(len(results[result_idx]['seqs']))):
505 | if not results[result_idx]['seqs'][seq_id]:
506 | # if one seq is None, we remove that seq
507 | results[result_idx]['seqs'].pop(seq_id)
508 | print(f"Removing seq {seq_id} from result {result_idx} since it's None")
509 | return results
--------------------------------------------------------------------------------
/tinyrl/rollout/sos_utils.py:
--------------------------------------------------------------------------------
1 | from litellm import text_completion
2 | import re
3 | def _parse_trajectory(search_path):
4 | # Find the first occurrence of "Current State" and trim everything before it
5 | start_idx = search_path.find("Current State")
6 | if start_idx == -1:
7 | return "Invalid input: Cannot find the initial state."
8 | search_path = search_path[start_idx:]
9 |
10 | # Extracting the target and initial numbers from the first line
11 | first_line = search_path.strip().split('\n')[0]
12 |
13 | # if mode == "dt":
14 | # first_line = first_line.split("->")[1]
15 | target_nums_match = re.match(r"Current State: (\d+):\[(.*?)\]", first_line)
16 | if not target_nums_match:
17 | return "Invalid input: Cannot find the initial state in the first line."
18 |
19 | target, nums = int(target_nums_match.group(1)), [int(n) for n in target_nums_match.group(2).split(", ")]
20 |
21 | # Extract the operations from the line that claims the goal is reached.
22 | goal_lines = re.finditer(r"\d+,\d+ equal: Goal Reached", search_path)
23 | goal_lines = list(goal_lines)
24 | if not goal_lines:
25 | return "No goal reached statement found."
26 |
27 | goal_line = goal_lines[-1]
28 | # get the last operation line before the goal reached statement
29 | operations = re.findall(r"Exploring Operation: (.*?=\d+), Resulting Numbers: \[(.*?)\]",
30 | search_path[:goal_line.start()])
31 | if not operations:
32 | return "No operations found leading to the goal."
33 |
34 | final_operation = operations[-1][0]
35 | try:
36 | predicted_result = int(final_operation.split('=')[1])
37 | except:
38 | print("couldnt parse last op", final_operation)
39 | return "Couldnt parse last op"
40 | if predicted_result != target:
41 | return "Invalid path: Final operation does not result in target."
42 |
43 | # get the last current state, operations before the goal reached statement, and extract the operations
44 | try:
45 | core_path = search_path[:goal_line.start()].split("Goal Reached\n")[1]
46 | except:
47 | print("invalid, no summarized answer")
48 | return "Invalid path: no summarized answer."
49 | operation_list = re.findall(r"Current State: \d+:\[.*?\], Operations: \[(.*?)\]", core_path)[
50 | -1].split(', ')
51 | operation_list = [op.replace("'", "") for op in operation_list]
52 | operation_list += [final_operation]
53 |
54 | # Verify each operation and keep track of the numbers involved
55 | available_numbers = nums
56 | for operation in operation_list:
57 | # Verify the operation
58 | try:
59 | left, right = operation.split('=')
60 | except:
61 | return f"Could not split operation into lhs, rhs"
62 | try:
63 | if eval(left) != int(right):
64 | return f"Invalid operation: {operation}"
65 | except Exception as e:
66 | return f"Error in evaluating operation {operation}: {e}"
67 | # get the numbers involved
68 | used_numbers = re.findall(r"\d+", left)
69 | for n in used_numbers:
70 | if int(n) not in available_numbers:
71 | return f"Invalid operation: {operation}, number {n} not available in {available_numbers}"
72 |
73 | available_numbers = [n for n in available_numbers if n not in used_numbers]
74 | available_numbers.append(int(right))
75 |
76 | return "Valid path."
77 |
78 | def is_correct(search_path):
79 | try:
80 | return _parse_trajectory(search_path) == "Valid path."
81 | except Exception as e:
82 | print(f"Error in is_correct: {e}, treating as incorrect")
83 | return False
84 |
85 |
86 | def rollout_sos(server_url, prefix_list, bos_token, temperature=0.5, condition_prefix=None):
87 | if condition_prefix is not None:
88 | assert not prefix_list[0].startswith("Token Budget: "), f"Condition prefix already in the prefix: {prefix_list[0]}, please use data without condition prefix"
89 | all_prefixes = [bos_token + condition_prefix + prefix for prefix in prefix_list]
90 | else:
91 | all_prefixes = [bos_token + prefix for prefix in prefix_list]
92 |
93 | outputs = text_completion(
94 | model="openai/model",
95 | prompt=all_prefixes,
96 | api_base=server_url,
97 | api_key="api_key",
98 | temperature=temperature,
99 | max_tokens=4000,
100 | )
101 | return_dict = []
102 | for output, prefix in zip(outputs["choices"], all_prefixes):
103 | whole_text = prefix + output["text"]
104 | return_dict.append({
105 | 'seqs': [
106 | {
107 | "prefix": prefix,
108 | "output": output["text"],
109 | }
110 | ],
111 | "is_correct": is_correct(whole_text)
112 | })
113 | return return_dict
114 |
--------------------------------------------------------------------------------
/tinyrl/utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import time
3 | import requests
4 | from typing import Optional
5 | from sglang.srt.utils import kill_process_tree
6 |
7 | def popen_launch_server(
8 | model: str,
9 | base_url: str,
10 | timeout: float,
11 | model_name: str = "model",
12 | api_key: Optional[str] = None,
13 | other_args: list[str] = (),
14 | env: Optional[dict] = None,
15 | return_stdout_stderr: Optional[tuple] = None,
16 | ):
17 | _, host, port = base_url.split(":")
18 | host = host[2:]
19 |
20 | command = [
21 | "python3",
22 | "-m",
23 | "sglang.launch_server",
24 | "--model-path",
25 | model,
26 | "--host",
27 | host,
28 | "--port",
29 | port,
30 | "--served-model-name",
31 | model_name,
32 | *other_args,
33 | ]
34 |
35 | if api_key:
36 | command += ["--api-key", api_key]
37 |
38 | if return_stdout_stderr:
39 | process = subprocess.Popen(
40 | command,
41 | stdout=return_stdout_stderr[0],
42 | stderr=return_stdout_stderr[1],
43 | env=env,
44 | text=True,
45 | )
46 | else:
47 | process = subprocess.Popen(
48 | command,
49 | stdout=subprocess.DEVNULL,
50 | stderr=subprocess.DEVNULL,
51 | env=env
52 | )
53 |
54 | start_time = time.time()
55 | with requests.Session() as session:
56 | while time.time() - start_time < timeout:
57 | try:
58 | headers = {
59 | "Content-Type": "application/json; charset=utf-8",
60 | "Authorization": f"Bearer {api_key}",
61 | }
62 | response = session.get(
63 | f"{base_url}/health_generate",
64 | headers=headers,
65 | )
66 | if response.status_code == 200:
67 | return process
68 | except requests.RequestException:
69 | pass
70 | time.sleep(10)
71 | raise TimeoutError("Server failed to start within the timeout period.")
72 |
73 | def terminate_process(process):
74 | kill_process_tree(process.pid)
75 |
--------------------------------------------------------------------------------