├── .gitignore ├── README.md ├── assets ├── apr-paper.pdf └── apr.png ├── src ├── __init__.py ├── countdown_utils.py ├── eval │ ├── eval_apr.py │ ├── eval_sosp.py │ └── parallel_infernce_utils.py └── utils.py ├── supervised-jax ├── README.md ├── launcher.py ├── llama3_train │ ├── .gitignore │ ├── README.md │ ├── gpt2_train_script.py │ ├── llama_train │ │ ├── __init__.py │ │ ├── gpt2.py │ │ ├── llama3.py │ │ ├── optimizer.py │ │ ├── serve.py │ │ ├── splash.py │ │ └── utils.py │ ├── llama_train_script.py │ ├── requirements.txt │ └── setup.py ├── scripts │ ├── hs-v3.sh │ ├── hsp-v3.sh │ ├── sos-v3.sh │ ├── training_hs-v2_llama.sh │ ├── training_hsp-v2_llama.sh │ ├── training_sos_llama.sh │ ├── training_sos_llama_1ep.sh │ ├── training_sos_llama_gpt2tok.sh │ ├── training_sos_llama_gpt2tok_1ep.sh │ └── training_sos_xiuyu.sh ├── to_dataset.ipynb ├── training_hsp.sh ├── training_hsp_2x.sh ├── training_run.sh └── training_run_test.sh └── tinyrl ├── README.md ├── configs ├── apr.yaml ├── apr_cond10.yaml └── sosp.yaml ├── data └── .gitkeep ├── requirements.txt ├── rollout ├── apr_utils.py └── sos_utils.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | outputs/ 3 | data/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # Ruff stuff: 174 | .ruff_cache/ 175 | 176 | # PyPI configuration file 177 | .pypirc 178 | 179 | *_ignored* 180 | checkpoints/ 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Learning Adaptive Parallel Reasoning
with Language Models

2 | 3 |

4 | Jiayi Pan*, 5 | Xiuyu Li*, 6 | Long Lian*, 7 | Charlie Victor Snell, 8 | Yifei Zhou,
9 | Adam Yala, 10 | Trevor Darrell, 11 | Kurt Keutzer, 12 | Alane Suhr 13 |

14 | 15 |

16 | UC Berkeley and UCSF    * Equal Contribution 17 |

18 | 19 |

20 | 📃 Paper 21 | • 22 | 💻 Code 23 | • 24 | 🤗 Data & Models 25 |

26 | 27 | 28 | ![APR](./assets/apr.png) 29 | 30 | **TL;DR**: 31 | We present Adaptive Parallel Reasoning (APR), a novel framework that enables language models to learn to orchestrate both serialized and parallel computations. APR trains language models to use `spawn()` and `join()` operations through end-to-end supervised training and reinforcement learning, allowing models to dynamically orchestrate their own computational workflows. 32 | APR efficiently distributes compute, reduces latency, overcomes context window limits, and achieves state‑of‑the‑art performance on complex reasoning tasks (e.g., 83.4% vs. 60.0% accuracy at 4K context on Countdown). 33 | 34 | > The full code will be released soon! 35 | ## Data Preparation 36 | 37 | ## Supervised Training 38 | We use TPU-v3-128 for supervised training with a codebase building upon [JAX_llama](https://github.com/Sea-Snell/JAX_llama). 39 | 40 | Please refer to [the instructions](supervised-jax/README.md) for more details. 41 | 42 | ## Reinforcement Learning 43 | We present TinyRL, a simple implementation of the GRPO training framework for our experiments. TinyRL is a lightweight yet performant reinforcement learning library designed to be both easy to use and extend. It integrates with [SGLang](https://github.com/sgl-project/sglang) for efficient rollout. Given the small size of the model we’re training, we haven’t implemented model parallelism, so it runs on two GPUs—one for training and one for rollout 44 | 45 | It supports asynchronous, multi-turn, multi-agent rollouts through a general `rollout_fun` interface, with the minimal assumption that your rollout mechanism relies on calling an OpenAI-compatible API endpoint. 46 | ```python 47 | def rollout_fun(server_url, prefix_list, bos_token, temperature=0.5, max_workers=32): 48 | pass 49 | ``` 50 | 51 | Please refer to [the instructions](tinyrl/README.md) for more details. 52 | 53 | ## Evaluation 54 | 55 | > [!IMPORTANT] 56 | > **For evaluation, SGLang needs to be patched**. 57 | > Remove this check in `python/sglang/srt/managers/tokenizer_manager.py` in your local SGLang repo: 58 | > ``` 59 | > # if ( 60 | > # obj.sampling_params.get("max_new_tokens") is not None 61 | > # and obj.sampling_params.get("max_new_tokens") + input_token_num 62 | > # >= self.context_len 63 | > # ): 64 | > # raise ValueError( 65 | > # f"Requested token count exceeds the model's maximum context length " 66 | > # f"of {self.context_len} tokens. You requested a total of " 67 | > # f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " 68 | > # f"tokens: {input_token_num} tokens from the input messages and " 69 | > # f"{obj.sampling_params.get('max_new_tokens')} tokens for the " 70 | > # f"completion. Please reduce the number of tokens in the input " 71 | > # f"messages or the completion to fit within the limit." 72 | > # ) 73 | > ``` 74 | > 75 | > This file is located at [tokenizer_manager.py](https://github.com/sgl-project/sglang/blob/45205d88a08606d5875476fbbbc76815a5107edd/python/sglang/srt/managers/tokenizer_manager.py#L350) 76 | 77 | > [!Note] 78 | > sgl-project/sglang#3721 introduces an `--allow-auto-truncate` option that makes this patch unnecessary. Once merged, you can use that directly. 79 | 80 | ### SoS+ 81 | 82 | The following command evaluates the SoS+ model on the validation set. 83 | ```bash 84 | python -m src.eval.eval_sosp --ckpt --temperature --batch_size 256 --gens 1 --output_dir --num_gpus 8 --n_samples --budget 85 | ``` 86 | Where `` is the number of Best-of-N samples in inference, and `` is the budget for conditional generation (leave it empty if not using conditioned models). For instance, the following command evaluates the SoS+ model with 8 samples using a unconditioned checkpoint. 87 | ```bash 88 | python -m src.eval.eval_sosp --ckpt Parallel-Reasoning/llama-sosp --temperature 1.0 --batch_size 256 --gens 1 --output_dir results/llama-sosp/ --num_gpus 8 --n_samples 8 89 | ``` 90 | 91 | ### APR 92 | 93 | First, you need to start the SGLang server for the target model. For instance: 94 | ```bash 95 | python -m sglang.launch_server --served-model-name model --model-path Parallel-Reasoning/llama-apr_cond10_grpo --port 2346 --dp-size 8 96 | ``` 97 | 98 | Then the following command evaluates the APR model on the validation set. 99 | ```bash 100 | python -m src.eval.eval_apr --model_name llama-apr_cond10_grpo --ckpt Parallel-Reasoning/llama-apr_cond10_grpo --temperature 1.0 --budget 10 --use_subcall_cond 101 | ``` 102 | which evaluates the APR model with a budget of 10 child threads and uses child thread count conditioning. Do not use `--budget` and `--use_subcall_cond` for unconditioned models. 103 | 104 | 105 | ## Citation 106 | If you find this work useful in your research, please consider citing: 107 | 108 | ```bibtex 109 | @article{pan2025learning, 110 | title = {Learning Adaptive Parallel Reasoning with Language Models}, 111 | author = {Jiayi Pan and Xiuyu Li and Long Lian and Charlie Snell and Yifei Zhou and Adam Yala and Trevor Darrell and Kurt Keutzer and Alane Suhr}, 112 | year = {2025}, 113 | journal = {arXiv preprint arXiv: 2504.15466} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /assets/apr-paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/assets/apr-paper.pdf -------------------------------------------------------------------------------- /assets/apr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/assets/apr.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/src/__init__.py -------------------------------------------------------------------------------- /src/eval/eval_apr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | from termcolor import colored 8 | 9 | from transformers import AutoTokenizer 10 | from litellm import text_completion, APIError 11 | from src.eval.parallel_infernce_utils import ( 12 | get_search_result, 13 | get_main_trace_after_sub_search, 14 | get_subsearch_info, 15 | check_solution 16 | ) 17 | 18 | def parse_bool(value): 19 | if value.lower() in ['true', '1', 'yes', 'y']: 20 | return True 21 | elif value.lower() in ['false', '0', 'no', 'n']: 22 | return False 23 | else: 24 | raise ValueError(f"Invalid boolean value: {value}") 25 | 26 | # Example command for sglang: 27 | # python -m sglang.launch_server --served-model-name model --model-path Parallel-Reasoning/llama-apr_cond10_grpo --port 2346 --dp-size 8 28 | 29 | # Parse command line arguments 30 | parser = argparse.ArgumentParser(description='Run APR inference experiments') 31 | parser.add_argument('--model_name', type=str, default="llama-apr", 32 | help='Model name') 33 | parser.add_argument("-d", "--data",type=str, default="data/val.json") 34 | parser.add_argument("--ckpt", type=str, help="path to checkpoint") 35 | parser.add_argument('--disable_parallel_inference', action='store_true', default=False, 36 | help='Whether to use parallel inference') 37 | parser.add_argument('--use_subcall_cond', action='store_true', default=False, 38 | help='Whether to use subcall count conditioning instead of token count') 39 | parser.add_argument('--max_workers', type=int, default=16, 40 | help='Maximum number of workers for parallel inference') 41 | parser.add_argument('--max_tokens', type=int, default=4096, 42 | help='Maximum tokens for generation') 43 | parser.add_argument('--temperature', type=float, default=1.0, 44 | help='Temperature for generation') 45 | parser.add_argument('--budget', type=int, default=None, 46 | help='Budget for generation') 47 | parser.add_argument('--output_dir', type=str, 48 | default=None, 49 | help='Directory to save results. If None, defaults to "../results/{model_name}/val_apr"') 50 | 51 | args = parser.parse_args() 52 | 53 | # Set global variables from arguments 54 | MODEL_NAME = args.model_name 55 | PARALLEL_INFERENCE = not args.disable_parallel_inference 56 | USE_SUBCALL_COND = args.use_subcall_cond 57 | MAX_WORKERS = args.max_workers 58 | TEMPERATURE = args.temperature 59 | BUDGET = args.budget 60 | SAVE_DIR = args.output_dir 61 | 62 | # Load validation data 63 | data_path = args.data 64 | with open(data_path, "r") as f: 65 | val_data = json.load(f) 66 | 67 | # Initialize tokenizer 68 | ckpt = args.ckpt 69 | tokenizer = AutoTokenizer.from_pretrained(ckpt) 70 | bos_token = tokenizer.bos_token 71 | 72 | # API configuration 73 | api_base_url = "http://127.0.0.1:2346/v1" 74 | api_key = "api_key" 75 | model_name = "model" 76 | max_tokens = args.max_tokens 77 | 78 | # Configuration for text generation 79 | ADD_ANGLE_BRACKETS = False 80 | ADD_BOS = True 81 | 82 | # BUDGET and TEMPERATURE will be handled outside in bash script 83 | 84 | def add_angle_brackets(text): 85 | lines = text.split('\n') 86 | result_lines = [] 87 | for line in lines: 88 | if '>' in line and '<' not in line: 89 | line = '<' + line 90 | result_lines.append(line) 91 | return '\n'.join(result_lines) 92 | 93 | def generate(prefix, stop = [], temperature = 0.0): 94 | if ADD_BOS: 95 | prefix = bos_token + prefix 96 | result = text_completion( 97 | model=f"openai/{model_name}", 98 | prompt=prefix, 99 | api_base=api_base_url, 100 | api_key=api_key, 101 | temperature=temperature, 102 | max_tokens=max_tokens, 103 | stop=stop, 104 | ) 105 | text = result['choices'][0]['text'] 106 | complete_text = prefix + text 107 | complete_text = complete_text.replace(bos_token, ' ') 108 | if complete_text[0] == ' ': 109 | complete_text = complete_text[1:] 110 | if ADD_ANGLE_BRACKETS: 111 | complete_text = add_angle_brackets(complete_text) 112 | return complete_text, result 113 | 114 | def add_all_calls(trace_dict): 115 | """Recursively collect all call traces from a trace_dict.""" 116 | all_calls = [] 117 | all_calls += trace_dict['main_calls'] 118 | for sub in trace_dict.get('sub_calls', []): 119 | for sub_trace in sub: 120 | all_calls += add_all_calls(sub_trace) 121 | return all_calls 122 | 123 | def calculate_tokens(item, ds_name="apr"): 124 | token_count = 0 125 | if 'apr' in ds_name: 126 | seqs = add_all_calls(item['trace_dict']) 127 | 128 | if len(seqs) > 1: 129 | # Find all sequences that start with "Moving to Node #0" 130 | root_seqs = [seq for seq in seqs if "Moving to Node #0\n" in seq] 131 | 132 | # Sort root sequences by length (shortest first) 133 | root_seqs = sorted(root_seqs, key=len) 134 | 135 | # Calculate total tokens without considering redundancy 136 | total_tokens = 0 137 | # Calculate total token count and track longest sequence in one pass 138 | longest_seq_tokens = 0 139 | for seq in seqs: 140 | tokens = tokenizer.encode(tokenizer.bos_token + seq + tokenizer.eos_token) 141 | total_tokens += len(tokens) 142 | longest_seq_tokens = max(longest_seq_tokens, len(tokens)) 143 | 144 | # Calculate redundant tokens between root sequences 145 | redundant_tokens = 0 146 | if len(root_seqs) > 1: 147 | # Find common prefixes between each pair of sequences 148 | for i in range(len(root_seqs) - 1): 149 | j = i + 1 150 | seq1 = root_seqs[i] 151 | seq2 = root_seqs[j] 152 | 153 | # Find common prefix 154 | prefix_len = 0 155 | for k in range(min(len(seq1), len(seq2))): 156 | if seq1[k] == seq2[k]: 157 | prefix_len += 1 158 | else: 159 | break 160 | 161 | if prefix_len > 0: 162 | common_prefix = seq1[:prefix_len] 163 | # Count tokens in this prefix 164 | prefix_tokens = len(tokenizer.encode(common_prefix)) - 2 # Subtract BOS/EOS 165 | redundant_tokens += max(0, prefix_tokens) 166 | 167 | # Final token count is total minus redundant 168 | token_count = total_tokens - redundant_tokens 169 | 170 | item['longest_seq_token_count'] = longest_seq_tokens 171 | sub_calls = [seq for seq in seqs if not "Moving to Node #0\n" in seq] 172 | item['avg_seq_token_count'] = token_count / (len(sub_calls) + 1) 173 | else: 174 | tokens = tokenizer.encode(tokenizer.bos_token + seqs[0] + tokenizer.eos_token) 175 | token_count = len(tokens) 176 | item['longest_seq_token_count'] = token_count 177 | item['avg_seq_token_count'] = token_count 178 | else: 179 | seq = item['search_path'] 180 | tokens = tokenizer.encode(tokenizer.bos_token + seq + tokenizer.eos_token) 181 | token_count = len(tokens) 182 | 183 | item['token_count'] = token_count 184 | return item 185 | 186 | 187 | def decode_trace(prefix, temperature): 188 | # we should never let the model generate 189 | # whenever it happens, we replace it with 190 | while True: 191 | trace = generate(prefix, stop = [""], temperature = temperature) 192 | prefix = trace[0] 193 | if trace[1].choices[0].matched_stop == "": 194 | prefix += "" 195 | else: 196 | break 197 | prefix = trace[0] 198 | if prefix.split('\n')[-1] == "": 199 | # TODO: why is this happening? 200 | prefix = prefix[:-1] 201 | return prefix 202 | 203 | if PARALLEL_INFERENCE: 204 | def batch_decode_trace(prefix_list, temperature): 205 | with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: 206 | # Submit all tasks and store futures 207 | future_to_prefix = {executor.submit(decode_trace, prefix, temperature): prefix for prefix in prefix_list} 208 | 209 | # Initialize results list with the same length as prefix_list 210 | results = [None] * len(prefix_list) 211 | 212 | # As futures complete, store results in the correct order 213 | for future in as_completed(future_to_prefix): 214 | prefix = future_to_prefix[future] 215 | try: 216 | result = future.result() 217 | original_idx = prefix_list.index(prefix) 218 | results[original_idx] = result 219 | except Exception as e: 220 | print(f"Error processing prefix: {e}") 221 | original_idx = prefix_list.index(prefix) 222 | results[original_idx] = None 223 | 224 | return results 225 | else: 226 | def batch_decode_trace(prefix_list, temperature): 227 | # TODO make this parallel in the future 228 | result_list = [] 229 | for prefix in prefix_list: 230 | result_list.append(decode_trace(prefix, temperature)) 231 | return result_list 232 | 233 | 234 | def is_calling_subsearch(trace): 235 | return "" in trace.split('\n')[-1] 236 | def call_search(prefix, temperature, budget=None, avg_sub_call=False): 237 | try: 238 | trace_dict = {"main_calls": [], "sub_calls": []} 239 | trace = decode_trace(prefix, temperature) 240 | trace_dict["main_calls"].append(trace) 241 | while is_calling_subsearch(trace): 242 | sub_search_prefix_list, sub_search_nodes = get_subsearch_info(trace, budget, avg_sub_call) 243 | # call sub searchers 244 | # TODO: this assumes we only nest one level of sub searchers 245 | # In the future, we need to support nested sub searchers by recursion 246 | sub_search_traces = batch_decode_trace(sub_search_prefix_list, temperature) 247 | trace_dict["sub_calls"].append([]) 248 | for sub_search_trace in sub_search_traces: 249 | trace_dict["sub_calls"][-1].append({"main_calls": [sub_search_trace]}) 250 | sub_search_results = [get_search_result(trace) for trace in sub_search_traces] 251 | new_prefix = get_main_trace_after_sub_search(trace, sub_search_nodes, sub_search_results) 252 | trace = decode_trace(new_prefix, temperature) 253 | trace_dict["main_calls"].append(trace) 254 | return get_search_result(trace), trace_dict 255 | except APIError as e: 256 | print(f"Error at call_search: {e}") 257 | raise e 258 | except Exception as e: 259 | print(f"Error at call_search: {e}") 260 | return None, trace_dict 261 | 262 | 263 | def get_prefix(dataset, idx, is_subcall_cond=False, use_budget=None): 264 | dp = dataset[idx] 265 | sub_call_budget = None 266 | prefix = f"Moving to Node #0\nCurrent State: {dp['target']}:{dp['nums']}, Operations: []" 267 | 268 | if is_subcall_cond: 269 | # Just use the provided budget for subcall conditioning 270 | assert use_budget is not None, "Subcall conditioning requires a budget" 271 | subcall_count = use_budget 272 | prefix = f"Sub Call Budget: {subcall_count} " + prefix 273 | 274 | if is_subcall_cond: 275 | # For subcall conditioning, we don't need to set sub_call_budget 276 | sub_call_budget = None 277 | elif sub_call_budget is not None: 278 | sub_call_budget = ((sub_call_budget - 1) // 512 + 1) * 512 279 | else: 280 | sub_call_budget = dp['token_count'] if 'token_count' in dp else None 281 | return prefix, sub_call_budget 282 | 283 | 284 | def run_inference_experiment( 285 | model_name, 286 | val_data, 287 | n_tasks=1000, 288 | temperature=0.0, 289 | budget=None, 290 | use_subcall_cond=False, 291 | parallel_inference=True, 292 | max_workers=16, 293 | save_dir=None 294 | ): 295 | """ 296 | Run an inference experiment with the specified configuration. 297 | 298 | Args: 299 | model_name (str): Name of the model to use 300 | val_data (list): Validation data 301 | n_tasks (int): Number of tasks to run 302 | temperature (float): Temperature for generation 303 | budget (int, optional): Token budget. If None, use default from data 304 | use_subcall_cond (bool): Whether to use subcall count conditioning 305 | parallel_inference (bool): Whether to use parallel inference 306 | max_workers (int): Maximum number of workers for parallel inference 307 | save_dir (str, optional): Directory to save results. If None, use default 308 | 309 | Returns: 310 | dict: Results dictionary with trajectories, trace_dicts, ratings, true_ratings 311 | """ 312 | os.makedirs(save_dir, exist_ok=True) 313 | base_name = f"{model_name.replace('/', '_')}_{n_tasks}_0_temp_{str(temperature).replace('.','_')}" 314 | if budget is not None: 315 | base_name += f"_budget_{budget}" 316 | 317 | save_path = os.path.join(save_dir, f"{base_name}.json") 318 | 319 | # Check if results already exist 320 | if os.path.exists(save_path): 321 | print(colored(f"Results already exist at {save_path}. Loading...", "green")) 322 | with open(save_path, 'r') as f: 323 | results = json.load(f) 324 | 325 | # Calculate and print metrics 326 | true_ratings = results["true_ratings"] 327 | ratings = results["ratings"] 328 | true_success_rate = np.mean(true_ratings) 329 | success_rate = np.mean(ratings) 330 | 331 | print(f"model: {model_name}, temperature: {temperature}, budget: {budget}\n" 332 | f"success_rate: {true_success_rate:.3f}, unverified_success_rate: {success_rate:.3f}") 333 | 334 | return results 335 | 336 | # If results don't exist, run the experiment 337 | success_count = [] 338 | true_success_count = [] 339 | logs = [] 340 | 341 | pbar = tqdm(range(n_tasks)) 342 | 343 | def process_sample(i): 344 | prefix, sub_call_budget = get_prefix( 345 | dataset=val_data, idx=i, 346 | is_subcall_cond=use_subcall_cond, 347 | use_budget=budget, 348 | ) 349 | out = call_search(prefix, temperature, budget=None, avg_sub_call=False) 350 | solution = out[0] 351 | true_success = check_solution(prefix, solution) if solution is not None else False 352 | return out, solution is not None, true_success 353 | 354 | if parallel_inference: 355 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 356 | futures = [executor.submit(process_sample, i) for i in range(n_tasks)] 357 | 358 | for i, future in enumerate(as_completed(futures)): 359 | out, success, true_success = future.result() 360 | logs.append(out) 361 | success_count.append(success) 362 | true_success_count.append(true_success) 363 | pbar.update(1) 364 | pbar.set_postfix({ 365 | 'true_success_rate': f"{np.mean(true_success_count):.3f}", 366 | 'success_rate': f"{np.mean(success_count):.3f}" 367 | }) 368 | else: 369 | for i in pbar: 370 | out, success, true_success = process_sample(i) 371 | logs.append(out) 372 | success_count.append(success) 373 | true_success_count.append(true_success) 374 | pbar.update(1) 375 | pbar.set_postfix({ 376 | 'true_success_rate': f"{np.mean(true_success_count):.3f}", 377 | 'success_rate': f"{np.mean(success_count):.3f}" 378 | }) 379 | 380 | # Prepare results 381 | trajectories = [x[0] if x[0] is not None else "" for x in logs] 382 | trace_dicts = [x[1] for x in logs] 383 | ratings = [1 if x else 0 for x in success_count] 384 | true_ratings = [1 if x else 0 for x in true_success_count] 385 | 386 | print(colored("Results Summary:", "cyan", attrs=["bold"])) 387 | print(colored(f"Model: ", "yellow") + colored(f"{model_name}", "green") + 388 | colored(f", Temperature: ", "yellow") + colored(f"{temperature}", "green") + 389 | colored(f", Budget: ", "yellow") + colored(f"{budget}", "green")) 390 | # print(colored(f"Unverified Success Rate: ", "yellow") + colored(f"{np.mean(success_count):.4f}", "green")) 391 | print(colored(f"Success Rate: ", "yellow") + colored(f"{np.mean(true_success_count):.4f}", "green")) 392 | 393 | results = { 394 | "trajectories": trajectories, 395 | "trace_dicts": trace_dicts, 396 | "ratings": ratings, 397 | "true_ratings": true_ratings 398 | } 399 | 400 | 401 | os.makedirs(save_dir, exist_ok=True) 402 | base_name = f"{model_name}_{n_tasks}_0_temp_{str(temperature).replace('.','_')}" 403 | if budget is not None: 404 | if use_subcall_cond: 405 | base_name += f"_subcall_budget_{budget}" 406 | else: 407 | base_name += f"_budget_{budget}" 408 | 409 | # Save results 410 | print(f"Saving results to {save_path}") 411 | with open(save_path, 'w') as f: 412 | json.dump(results, f) 413 | 414 | return results 415 | 416 | 417 | if __name__ == "__main__": 418 | print(colored("Model name: ", "cyan", attrs=["bold"]) + colored(MODEL_NAME, "yellow")) 419 | print(colored("Data path: ", "cyan", attrs=["bold"]) + colored(data_path, "yellow")) 420 | print(colored("Checkpoint: ", "cyan", attrs=["bold"]) + colored(ckpt, "yellow")) 421 | 422 | print(colored("Conditions: ", "cyan", attrs=["bold"]) + 423 | colored(f"use_subcall_cond={USE_SUBCALL_COND}", "yellow")) 424 | print(colored("Inference: ", "cyan", attrs=["bold"]) + 425 | colored(f"parallel_inference={PARALLEL_INFERENCE}, " 426 | f"max_workers={MAX_WORKERS}, " 427 | f"max_tokens={max_tokens}", "yellow")) 428 | print(colored("Generation: ", "cyan", attrs=["bold"]) + 429 | colored(f"temperature={TEMPERATURE}, budget={BUDGET}", "yellow")) 430 | 431 | # Set default save directory if not provided 432 | save_directory = SAVE_DIR if SAVE_DIR is not None else f"results/{MODEL_NAME}/val_apr" 433 | print(colored(f"save_dir: {save_directory}", "cyan")) 434 | 435 | run_inference_experiment( 436 | model_name=MODEL_NAME, 437 | val_data=val_data, 438 | temperature=TEMPERATURE, 439 | budget=BUDGET, 440 | use_subcall_cond=USE_SUBCALL_COND, 441 | parallel_inference=PARALLEL_INFERENCE, 442 | max_workers=MAX_WORKERS, 443 | save_dir=save_directory 444 | ) 445 | -------------------------------------------------------------------------------- /src/eval/eval_sosp.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json 3 | import random 4 | import argparse 5 | import tqdm 6 | import numpy as np 7 | from transformers import AutoTokenizer 8 | import sglang as sgl 9 | from src.countdown_utils import * 10 | from src.utils import seed_everything 11 | from termcolor import colored 12 | import re 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--seed", type=int, default=4) 16 | parser.add_argument("--ckpt", type=str, help="path to checkpoint") 17 | parser.add_argument("--tknz", type=str, help="path to tokenizer") 18 | parser.add_argument("--n_samples", type=int, default=1, help="Number of samples for best-of-n evaluation") 19 | parser.add_argument("-d", "--data",type=str, default="data/val.json") 20 | parser.add_argument("--temperature", type=float, default=0.0) 21 | parser.add_argument("--batch_size", type=int, default=64) 22 | parser.add_argument("--max_tokens", type=int, default=131072) 23 | parser.add_argument("--gens", type=int, default=1) 24 | parser.add_argument("--num_gpus", type=int, default=1) 25 | parser.add_argument("--output_dir", type=str, default="results/") 26 | parser.add_argument("--skip-save-results", action="store_true", help="Skip saving results to file") 27 | parser.add_argument("--num_token_cond", action="store_true", help="Use num token condition") 28 | parser.add_argument("--budget", type=int, default=None, help="Set a fixed budget condition") 29 | 30 | # Calculate mean token count 31 | def count_tokens(text, tokenizer): 32 | cleaned_text = re.sub(r'^.*?Current State:', 'Current State:', text) 33 | # Add offset (so it matches the token count in the train/val data json) 34 | return len(tokenizer.encode(cleaned_text)) + 2 35 | 36 | def eval_ll(args): 37 | """ 38 | Evaluate the model on the data using sglang 39 | """ 40 | with open(args.data, "r") as json_file: 41 | raw_data = json.load(json_file) 42 | 43 | if not args.skip_save_results: 44 | output_dir = os.path.join(args.output_dir, os.path.splitext(os.path.basename(args.data))[0]) 45 | os.makedirs(output_dir, exist_ok=True) 46 | base_name = f"{args.ckpt.split('outputs/')[-1].replace('/','_')}_temp_{str(args.temperature).replace('.','_')}" 47 | if args.budget is not None: 48 | base_name += f"_budget_{args.budget}" 49 | if args.n_samples > 1: 50 | base_name += f"_n_samples_{args.n_samples}" 51 | results_file = os.path.join(output_dir, f"{base_name}.json") 52 | print(f"Results file: {results_file}") 53 | if os.path.exists(results_file): 54 | print(f"Loading existing results from {results_file}") 55 | with open(results_file, 'r') as f: 56 | results = json.load(f) 57 | pred_ratings = results['ratings'] 58 | 59 | # Initialize tokenizer for token counting 60 | tokenizer = AutoTokenizer.from_pretrained(args.tknz) if args.tknz else AutoTokenizer.from_pretrained(args.ckpt) 61 | 62 | token_counts = [count_tokens(pred, tokenizer) for pred in results['trajectories']] 63 | mean_token_count = sum(token_counts) / len(token_counts) 64 | 65 | print(colored("Results Summary:", "cyan", attrs=["bold"])) 66 | print(colored(f"Mean token count: ", "yellow") + colored(f"{mean_token_count:.2f}", "green")) 67 | print(colored(f"Model Accuracy: ", "yellow") + colored(f"{np.mean([r > 0 for r in pred_ratings]):.4f}", "green")) 68 | sys.exit(0) 69 | 70 | # Initialize sglang engine 71 | llm = sgl.Engine( 72 | model_path=args.ckpt, 73 | tokenizer_path=args.tknz if args.tknz else args.ckpt, 74 | allow_auto_truncate=True, 75 | # log_level='warning', 76 | # tp_size=args.num_gpus, 77 | # dp is slightly faster than tp due to the small model size; also sgl gpt2 has bug with tp 78 | dp_size=args.num_gpus, 79 | ) 80 | 81 | # Prepare prompts 82 | tokenizer = AutoTokenizer.from_pretrained(args.tknz) if args.tknz else AutoTokenizer.from_pretrained(args.ckpt) 83 | if args.num_token_cond or args.budget: 84 | def get_token_budget(sample): 85 | if args.budget is not None: 86 | return args.budget 87 | # Define token budget bins (512, 1024, 1536, 2048, 2560, 3072, 3584, 4096) 88 | budget_bins = list(range(512, 4096+1, 512)) 89 | token_count = sample['token_count'] 90 | # Find the appropriate budget bin 91 | budget = 4096 if token_count > 4096 else \ 92 | next((bin_value for bin_value in budget_bins if token_count <= bin_value), 4096) 93 | return budget 94 | 95 | test_prompts = [ 96 | f"{tokenizer.bos_token}Token Budget: {get_token_budget(sample)} " 97 | f"Current State: {sample['target']}:{sample['nums']}, Operations: []" 98 | for sample in raw_data 99 | ] 100 | else: 101 | test_prompts = [ 102 | f"{tokenizer.bos_token}Current State: {sample['target']}:{sample['nums']}, Operations: []" 103 | for sample in raw_data 104 | ] 105 | len_nums = [len(sample['nums']) for sample in raw_data] 106 | data = [d for d, l in zip(test_prompts, len_nums) if l == 4] 107 | 108 | # Set up sampling parameters 109 | sampling_params = { 110 | "temperature": args.temperature, 111 | "max_new_tokens": args.max_tokens, 112 | "top_k": 50, 113 | } 114 | 115 | # Process in batches 116 | batch_size = args.batch_size * args.num_gpus 117 | 118 | if args.n_samples > 1: 119 | # Initialize data structures for multiple samples 120 | all_sample_predictions = [] 121 | all_sample_ratings = [] 122 | all_sample_reasons = [] 123 | 124 | # Run inference n_samples times 125 | for sample_idx in range(args.n_samples): 126 | print(colored(f"Running sample {sample_idx+1}/{args.n_samples}", "cyan", attrs=["bold"])) 127 | predictions = [] 128 | 129 | for b in tqdm.trange(0, len(data), batch_size): 130 | batch = data[b:min(b+batch_size, len(data))] 131 | 132 | if args.gens == 1: 133 | # Generate outputs 134 | outputs = llm.generate(batch, sampling_params) 135 | # Combine prompts with generated text 136 | batch_predictions = [prompt + output['text'] for prompt, output in zip(batch, outputs)] 137 | predictions.extend(batch_predictions) 138 | else: 139 | assert args.temperature > 0.0, "Temperature must be greater than 0 for sampling" 140 | all_outputs = [] 141 | all_ratings = [] 142 | 143 | # Generate multiple times for each prompt 144 | for _ in range(args.gens): 145 | outputs = llm.generate(batch, sampling_params) 146 | # Combine prompts with generated text for each generation 147 | output_texts = [prompt + output['text'] for prompt, output in zip(batch, outputs)] 148 | # Get rating for each output 149 | ratings = [metric_fn(ot, mode="sft")[0] for ot in output_texts] 150 | all_ratings.append(ratings) 151 | all_outputs.append(output_texts) 152 | 153 | # Convert to numpy array for easier processing 154 | all_ratings = np.array(all_ratings) 155 | print(all_ratings) 156 | print(f"average rating", np.mean(all_ratings)) 157 | 158 | # Get the best output for each prompt 159 | max_ratings = np.argmax(all_ratings, axis=0) 160 | max_rating_vals = np.max(all_ratings, axis=0) 161 | print(f"max ratings", np.mean(max_rating_vals)) 162 | 163 | # Select the best outputs 164 | batch_predictions = [all_outputs[max_r][i] for i, max_r in enumerate(max_ratings)] 165 | predictions.extend(batch_predictions) 166 | 167 | # Rate outputs for this sample 168 | pred_ratings = [] 169 | pred_reasons = [] 170 | for i, pred in enumerate(predictions): 171 | rating, reason = metric_fn(pred, mode="sft") 172 | pred_ratings.append(rating) 173 | pred_reasons.append(reason) 174 | 175 | # Store results for this sample 176 | all_sample_predictions.append(predictions) 177 | all_sample_ratings.append(pred_ratings) 178 | all_sample_reasons.append(pred_reasons) 179 | 180 | all_sample_ratings_array = np.array(all_sample_ratings) 181 | binary_correctness = all_sample_ratings_array > 0 182 | 183 | # Print results 184 | print(colored("Results Summary:", "cyan", attrs=["bold"])) 185 | print(colored(f"Number of samples: ", "yellow") + colored(f"{args.n_samples}", "green")) 186 | print(colored(f"Individual sample accuracies: ", "yellow") + 187 | colored(f"{[np.mean(binary_correctness[i]) for i in range(args.n_samples)]}", "green")) 188 | # TODO: cons@n and pass@n 189 | 190 | # Save results 191 | if not args.skip_save_results: 192 | with open(results_file, "w") as f: 193 | json.dump({ 194 | "n_samples": args.n_samples, 195 | "individual_sample_accuracies": [float(np.mean(binary_correctness[i])) for i in range(args.n_samples)], 196 | "sample_trajectories": all_sample_predictions, 197 | "sample_ratings": all_sample_ratings_array.tolist(), 198 | "sample_reasons": all_sample_reasons, 199 | }, f, indent=4) 200 | 201 | else: 202 | # Original single-sample code 203 | predictions = [] 204 | pred_ratings = [] 205 | pred_reasons = [] 206 | 207 | for b in tqdm.trange(0, len(data), batch_size): 208 | batch = data[b:min(b+batch_size, len(data))] 209 | 210 | if args.gens == 1: 211 | # Generate outputs 212 | outputs = llm.generate(batch, sampling_params) 213 | # Combine prompts with generated text 214 | batch_predictions = [prompt + output['text'] for prompt, output in zip(batch, outputs)] 215 | predictions.extend(batch_predictions) 216 | else: 217 | assert args.temperature > 0.0, "Temperature must be greater than 0 for sampling" 218 | all_outputs = [] 219 | all_ratings = [] 220 | 221 | # Generate multiple times for each prompt 222 | for _ in range(args.gens): 223 | outputs = llm.generate(batch, sampling_params) 224 | # Combine prompts with generated text for each generation 225 | output_texts = [prompt + output['text'] for prompt, output in zip(batch, outputs)] 226 | # Get rating for each output 227 | ratings = [metric_fn(ot, mode="sft")[0] for ot in output_texts] 228 | all_ratings.append(ratings) 229 | all_outputs.append(output_texts) 230 | 231 | # Convert to numpy array for easier processing 232 | all_ratings = np.array(all_ratings) 233 | print(all_ratings) 234 | print(f"average rating", np.mean(all_ratings)) 235 | 236 | # Get the best output for each prompt 237 | max_ratings = np.argmax(all_ratings, axis=0) 238 | max_rating_vals = np.max(all_ratings, axis=0) 239 | print(f"max ratings", np.mean(max_rating_vals)) 240 | 241 | # Select the best outputs 242 | batch_predictions = [all_outputs[max_r][i] for i, max_r in enumerate(max_ratings)] 243 | predictions.extend(batch_predictions) 244 | 245 | # Rate outputs 246 | true_rating = [] 247 | for i, pred in enumerate(predictions): 248 | rating, reason = metric_fn(pred, mode="sft") 249 | tr, _ = metric_fn(f"{raw_data[i]['search_path']}", mode="sft") 250 | pred_ratings.append(rating) 251 | true_rating.append(tr) 252 | pred_reasons.append(reason) 253 | 254 | token_counts = [count_tokens(pred, tokenizer) for pred in predictions] 255 | mean_token_count = sum(token_counts) / len(token_counts) 256 | 257 | # Print results 258 | pred_ratings = np.array(pred_ratings) 259 | print(colored("Results Summary:", "cyan", attrs=["bold"])) 260 | print(colored(f"Mean token count: ", "yellow") + colored(f"{mean_token_count:.2f}", "green")) 261 | print(colored(f"Model Accuracy: ", "yellow") + colored(f"{np.mean([r > 0 for r in pred_ratings]):.4f}", "green")) 262 | print(colored(f"Original Symbolic Solver Accuracy: ", "yellow") + colored(f"{np.mean([r > 0 for r in true_rating]):.4f}", "green")) 263 | 264 | # Save results 265 | if not args.skip_save_results: 266 | with open(results_file, "w") as f: 267 | json.dump({ 268 | "trajectories": predictions, 269 | "ratings": pred_ratings.tolist(), 270 | "reasons": pred_reasons, 271 | "token_counts": token_counts, 272 | "mean_token_count": mean_token_count 273 | }, f, indent=4) 274 | 275 | if __name__ == "__main__": 276 | args = parser.parse_args() 277 | seed_everything(args.seed) 278 | 279 | print(colored("Evaluating model: ", "cyan", attrs=["bold"]) + colored(args.ckpt, "yellow")) 280 | print(colored("Data file: ", "cyan", attrs=["bold"]) + colored(args.data, "yellow")) 281 | print(colored("Temperature: ", "cyan", attrs=["bold"]) + colored(args.temperature, "yellow")) 282 | print(colored("Number of GPUs: ", "cyan", attrs=["bold"]) + colored(args.num_gpus, "yellow")) 283 | if args.n_samples > 1: 284 | print(colored("Number of samples (best-of-n): ", "cyan", attrs=["bold"]) + colored(args.n_samples, "yellow")) 285 | if args.num_token_cond: 286 | print(colored("Using token condition", "cyan", attrs=["bold"])) 287 | if args.budget: 288 | print(colored("Using a fixed budget: ", "cyan", attrs=["bold"]) + colored(args.budget, "yellow")) 289 | # eval 290 | eval_ll(args) 291 | -------------------------------------------------------------------------------- /src/eval/parallel_infernce_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def check_solution(prefix: str, solution: str) -> bool: 4 | """ 5 | Parses the prefix and solution to verify if the solution actually 6 | solves the puzzle of reaching `target` from the given list of numbers. 7 | 8 | :param prefix: A line like: 9 | "Moving to Node #1\nCurrent State: 62:[5, 50, 79, 27], Operations: []" 10 | :param solution: The multiline string describing the step-by-step solution. 11 | :return: True if the solution's final result matches the target and 12 | all stated operations are valid. False otherwise. 13 | """ 14 | # ----------------------------------------------------------------- 15 | # 1. Parse the prefix to extract target and initial numbers 16 | # ----------------------------------------------------------------- 17 | # Example prefix line to parse: 18 | # Current State: 62:[5, 50, 79, 27], Operations: [] 19 | # 20 | # We'll look for something matching: 21 | # Current State: :[], ... 22 | prefix_pattern = r"Current State:\s*(\d+):\[(.*?)\]" 23 | match = re.search(prefix_pattern, prefix) 24 | if not match: 25 | print("ERROR: Could not parse the prefix for target and numbers.") 26 | return False 27 | 28 | target_str, numbers_str = match.groups() 29 | target = int(target_str.strip()) 30 | # Now parse something like "5, 50, 79, 27" into a list of integers 31 | if numbers_str.strip(): 32 | initial_numbers = [int(x.strip()) for x in numbers_str.split(",")] 33 | else: 34 | initial_numbers = [] 35 | 36 | # We'll keep track of our current working list of numbers 37 | current_numbers = initial_numbers 38 | 39 | # ----------------------------------------------------------------- 40 | # 2. Parse solution to extract lines with "Exploring Operation:" 41 | # ----------------------------------------------------------------- 42 | # Example lines: 43 | # Exploring Operation: 79-27=52, Resulting Numbers: [5, 50, 52] 44 | # We want to parse out: operand1=79, operator='-', operand2=27, result=52 45 | # Then parse the new list: [5, 50, 52] 46 | 47 | operation_pattern = r"Exploring Operation:\s*([\d]+)([\+\-\*/])([\d]+)=(\d+),\s*Resulting Numbers:\s*\[(.*?)\]" 48 | 49 | # We'll process the solution line-by-line 50 | # so that we can also capture the final "Goal Reached" line. 51 | lines = solution.splitlines() 52 | 53 | for line in lines: 54 | line = line.strip() 55 | 56 | # Check for "Exploring Operation" 57 | op_match = re.search(operation_pattern, line) 58 | if op_match: 59 | # Parse out the operation parts 60 | x_str, op, y_str, z_str, new_nums_str = op_match.groups() 61 | x_val = int(x_str) 62 | y_val = int(y_str) 63 | z_val = int(z_str) 64 | 65 | # Parse the new list of numbers from something like "5, 50, 52" 66 | new_numbers = [] 67 | if new_nums_str.strip(): 68 | new_numbers = [int(n.strip()) for n in new_nums_str.split(",")] 69 | 70 | # ------------------------------------------------------------- 71 | # Verify that applying X op Y => Z to current_numbers is valid 72 | # ------------------------------------------------------------- 73 | # 1. X and Y must both be present in current_numbers 74 | # 2. Remove X and Y from current_numbers 75 | # 3. Add Z 76 | # 4. The new list must match exactly the "Resulting Numbers" 77 | # 5. Also verify the arithmetic was correct (if you want to be strict) 78 | # 79 | # NOTE: we do not handle repeating values carefully here if X or Y 80 | # appear multiple times, but you can adapt as needed (e.g. remove once). 81 | 82 | temp_list = current_numbers[:] 83 | 84 | # Try removing X, Y once each 85 | try: 86 | temp_list.remove(x_val) 87 | temp_list.remove(y_val) 88 | except ValueError: 89 | print(f"ERROR: {x_val} or {y_val} not found in current_numbers {current_numbers}.") 90 | return False 91 | 92 | # Check that the stated Z matches the arithmetic operation 93 | # (If you want to skip verifying the math, remove these lines.) 94 | computed_result = None 95 | if op == '+': 96 | computed_result = x_val + y_val 97 | elif op == '-': 98 | computed_result = x_val - y_val 99 | elif op == '*': 100 | computed_result = x_val * y_val 101 | elif op == '/': 102 | # watch for zero division or non-integer division if you want 103 | # to require exact integer results 104 | if y_val == 0: 105 | print("ERROR: Division by zero encountered.") 106 | return False 107 | # For a typical "24 game" style puzzle, we allow float or integer check 108 | computed_result = x_val / y_val 109 | 110 | # Compare the stated z_val to the computed result 111 | # (if it's integer-based arithmetic, we might check int(...) or round) 112 | if computed_result is None: 113 | print("ERROR: Unknown operation encountered.") 114 | return False 115 | 116 | # If we want exact integer match (for e.g. 50/5=10): 117 | # If float is possible, we might do a small epsilon check: 118 | # e.g. if abs(computed_result - z_val) > 1e-9 119 | if computed_result != z_val: 120 | print(f"ERROR: Operation {x_val}{op}{y_val} does not equal {z_val}. Got {computed_result} instead.") 121 | return False 122 | 123 | # Now add the result to temp_list 124 | temp_list.append(z_val) 125 | # Sort if you do not care about order, or keep order if you do 126 | # and compare to new_numbers 127 | # We'll assume exact order is not critical, so let's do a sorted comparison: 128 | if sorted(temp_list) != sorted(new_numbers): 129 | print(f"ERROR: After applying {x_val}{op}{y_val}={z_val} to {current_numbers}, " 130 | f"got {sorted(temp_list)} but solution says {sorted(new_numbers)}.") 131 | return False 132 | 133 | # If we got here, it means the operation is consistent 134 | current_numbers = new_numbers 135 | 136 | # --------------------------------------------------------- 137 | # 3. Check for "Goal Reached" line 138 | # --------------------------------------------------------- 139 | # Something like: "62,62 equal: Goal Reached" 140 | # We'll check if the final single number is indeed `target`. 141 | if "Goal Reached" in line: 142 | # For a simple check, if "Goal Reached" is present, 143 | # confirm that current_numbers is [target]. 144 | if len(current_numbers) == 1 and current_numbers[0] == target: 145 | return True 146 | else: 147 | print("ERROR: 'Goal Reached' declared but final numbers don't match the target.") 148 | return False 149 | 150 | # If we never saw "Goal Reached," then it's incomplete 151 | # or didn't declare success. Return False by default 152 | print("ERROR: Did not find 'Goal Reached' in solution.") 153 | return False 154 | 155 | def get_search_result(search_trace): 156 | # Given a search trace, return the result of the search 157 | # If the search is successful, return the result optimal path 158 | # If the search is unsuccessful, return None 159 | if search_trace.count("Goal Reached") >= 2: 160 | # Find all occurrences of "Goal Reached" 161 | goal_indices = [i for i in range(len(search_trace)) if search_trace.startswith("Goal Reached", i)] 162 | # Get the second to last index, this is where we begin generate 163 | # the optimal path 164 | goal_idx = goal_indices[-2] 165 | return search_trace[goal_idx:].strip()[13:] 166 | else: 167 | return None 168 | 169 | def get_subsearch_info(search_trace, budget=None, avg_sub_call=False): 170 | try: 171 | return _get_subsearch_info(search_trace, budget, avg_sub_call) 172 | except Exception as e: 173 | print(f"Error at get_subsearch_info: {e}") 174 | print(search_trace) 175 | raise e 176 | 177 | def _get_subsearch_info(search_trace, budget=None, avg_sub_call=False): 178 | # Given a search trace, return the information of the 179 | # subsearch that it wants to invoke 180 | # sub_search= {"node": "#1,1,2", "target": 39, 'nums':[2, 11], "operations": ["51-49=2", "36-25=11"]} 181 | last_line = search_trace.split("\n")[-1] 182 | assert "" in last_line, "This is not a valid subsearch trace" 183 | 184 | # --- Parse the search trace to get the generated nodes --- 185 | generated_nodes = {} 186 | # First find any "Moving to Node" lines followed by "Current State" lines 187 | lines = search_trace.split("\n") 188 | for i in range(len(lines)-1): 189 | # Moving to Node #1,1 190 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2'] 191 | if "Moving to Node #" in lines[i] and "Current State:" in lines[i+1]: 192 | # Extract node id from first line like: 193 | # Moving to Node #1,1 194 | node_id = lines[i].split("Moving to Node #")[1].strip() 195 | 196 | # Extract state from second line like: 197 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2'] 198 | state_line = lines[i+1] 199 | state_part = state_line.split("Current State:")[1].split("],")[0].strip() 200 | operations_part = state_line.split("Operations:")[1].strip() 201 | 202 | # Parse state like "39:[25, 36, 2]" 203 | target = int(state_part.split(":")[0]) 204 | # nums = eval(state_part.split(":")[1].strip()) 205 | nums = eval(state_part.split(":")[1].strip() + "]") 206 | operations = eval(operations_part) 207 | 208 | # Parse operations list 209 | 210 | generated_nodes[node_id] = { 211 | "node": f"#{node_id}", 212 | "target": target, 213 | "nums": nums, 214 | "operations": operations 215 | } 216 | for line in search_trace.split("\n"): 217 | if "Generated Node" in line: 218 | # Extract node id and info from line like: 219 | # Generated Node #1,1,2: 39:[2, 11] Operation: 36-25=11 220 | node_id = line.split(":")[0].split("#")[1] 221 | if node_id in generated_nodes: 222 | continue 223 | rest = line.split(":", 1)[1].strip() 224 | state = rest.split("Operation:")[0].strip() 225 | operation = rest.split("Operation:")[1].strip() 226 | 227 | # Parse state like "39:[2, 11]" into target and nums 228 | target = int(state.split(":")[0]) 229 | nums = eval(state.split(":")[1].strip()) 230 | 231 | parent_node_id = ",".join(node_id.split(",")[:-1]) 232 | parent_node = generated_nodes[parent_node_id] 233 | new_operations = parent_node["operations"] + [operation] 234 | 235 | generated_nodes[node_id] = { 236 | "node": f"#{node_id}", 237 | "target": target, 238 | "nums": nums, 239 | "operations": new_operations 240 | } 241 | # then we construct the sub_searches 242 | sub_search_nodes = [] 243 | # Split on and take the last chunk 244 | last_chunk = search_trace.split("\n")[-1] 245 | # Split that chunk on and take first part 246 | sub_search_section = last_chunk.split("\n")[0] 247 | 248 | for line in sub_search_section.split("\n"): 249 | if " Moving to Node #1,1,2 252 | node_id = line.split("Moving to Node #")[1].strip() 253 | sub_search_nodes.append(generated_nodes[node_id]) 254 | 255 | if avg_sub_call: 256 | assert budget is not None, "Budget must be provided if avg_sub_call is True" 257 | budget = budget / len(sub_search_nodes) 258 | budget = int((budget - 1) // 512 + 1) * 512 259 | 260 | def construct_sub_search_prefix(node): 261 | # exmaple 262 | # "Moving to Node #1,1,0\nCurrent State: 39:[36, 50], Operations: ['51-49=2', '25*2=50'] 263 | prefix = f"Moving to Node {node['node']}\nCurrent State: {node['target']}:[{', '.join(map(str, node['nums']))}], Operations: {node['operations']}" 264 | if budget is not None: 265 | prefix = f"Token Budget: {budget} {prefix}" 266 | return prefix 267 | sub_search_prefix_list = [construct_sub_search_prefix(node) for node in sub_search_nodes] 268 | return sub_search_prefix_list, sub_search_nodes 269 | 270 | def get_main_trace_after_sub_search(main_trace, sub_search_nodes, sub_search_result_list): 271 | last_line = main_trace.split("\n")[-1] 272 | assert "" in last_line, "This is not a valid subsearch trace" 273 | sub_searches = [] 274 | # Split on and take the last chunk 275 | last_chunk = main_trace.split("\n")[-1] 276 | # Split that chunk on and take first part 277 | sub_search_section = last_chunk.split("\n")[0] 278 | main_trace_after_sub_search = "\n".join(main_trace.split("\n")[:-1]) 279 | assert main_trace_after_sub_search in main_trace 280 | main_trace_after_sub_search += "\n" 281 | for i, (this_node, this_result) in enumerate(zip(sub_search_nodes, sub_search_result_list)): 282 | if this_result is None: 283 | # 284 | main_trace_after_sub_search += f"\n" 285 | else: 286 | # \nMoving to Node #1,2,0\nCurrent State: 39:[51, 12], Operations: ['49-36=13', '25-13=12']\nExploring Operation: 51-12=39, Resulting Numbers: [39]\n39,39 equal: Goal Reached\n 287 | main_trace_after_sub_search += f"\n" 288 | main_trace_after_sub_search += this_result + "\n" 289 | main_trace_after_sub_search += "\n" 290 | return main_trace_after_sub_search -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def seed_everything(seed: int): 8 | random.seed(seed) 9 | os.environ['PYTHONHASHSEED'] = str(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.benchmark = True -------------------------------------------------------------------------------- /supervised-jax/README.md: -------------------------------------------------------------------------------- 1 | # Example Supervised Training Instructions 2 | > Fork from [JAX_llama](https://github.com/Sea-Snell/JAX_llama) 3 | 4 | First do the following: 5 | 6 | To orchestrate the TPU pod 7 | ``` 8 | pip install git+https://github.com/Sea-Snell/tpu_pod_launcher.git@main 9 | ``` 10 | 11 | There is a config in training_run.sh and a launcher defined in `launcher.py`. 12 | 13 | The launcher basically uses a tool to run the installation and training script on all hosts in the pod at once. 14 | 15 | You will probably want to edit the launcher's `available_tpus` to reflect the TPUs you have access to. 16 | 17 | You will also need a google cloud token in a json file somewhere, which has write permissions to buckets. You should change the path in `line 51` in the launcher to point to this file. 18 | 19 | You may also want to modify the ssh info and copy path in `lines 56-69` in the launcher. 20 | 21 | Make sure you set the API keys correctly at the top of the `training_run.sh` script. And also edit the bucket paths as needed. 22 | 23 | The data should be stored in a json list of strings: 24 | 25 | ``` 26 | [ 27 | "seq1", 28 | "seq2", 29 | ... 30 | ] 31 | ``` 32 | 33 | Use the `to_dataset.ipynb` notebook to convert the dataset to the correct format and upload to GCS. 34 | 35 | 36 | To install all dependencies on the TPU hosts run: 37 | 38 | ``` 39 | python launcher.py setup --project=jiayi-128-eu 40 | -2 41 | ``` 42 | 43 | You only need to do this once for each TPU pod. 44 | 45 | where you_tpu_name refers to the name of the TPU in the `available_tpus` list in the launcher. 46 | 47 | To launch the training run: 48 | 49 | ``` 50 | conda activate GPML 51 | cd /home/jiayipan/code/25SP/LM-Parallel/JAX-Train 52 | python launcher.py launch training_sos-split-digit-v2.sh --project=jiayi-128-eu 53 | python launcher.py launch training_hsp_2x.sh --project=jiayi-128-eu 54 | python launcher.py launch training_run.sh --project=jiayi-128-eu 55 | python launcher.py launch scripts/training_sos_llama_10ep_v2.sh --project=jiayi-128-eu 56 | python launcher.py launch scripts/hsp-v3.sh --project=jiayi-128-eu 57 | python launcher.py launch scripts/hs-v3.sh --project=jiayi-128-eu-2 58 | python launcher.py launch scripts/sos-v3.sh --project=jiayi-64-eu 59 | 60 | This will: 1) copy the latest version of `llama3_train` to the TPUs; 2) stop anything running on the TPUs; 3) run the training script on the TPUs. 61 | 62 | To print the output of the training run, you can run: 63 | 64 | ``` 65 | python launcher.py check --project=your_tpu_name 66 | python launcher.py check --project=jiayi-128-eu 67 | ``` 68 | 69 | To terminate an ongoing training run, you can run: 70 | 71 | ``` 72 | python launcher.py stop --project=your_tpu_name 73 | ``` 74 | 75 | The 3 mesh dimensions in the config currently correspond to (replica,fsdp,tensor). We can also in add a sequence parallel dimension without much difficulty if needed. 76 | 77 | ### Test Model 78 | Quickly test the model with FLAX/JAX 79 | https://colab.research.google.com/drive/1X7ElvcwrAk5nt_dkAsZUZo6IOp4nsW3L?usp=sharing 80 | 81 | ### Export To PyTorch 82 | You can use this script to easily export the model to pytorch (huggingface compatible). 83 | 84 | https://colab.research.google.com/drive/1XD3bJ1PQuHKwSO2cB0NCc6BkyJoptF19?usp=sharing 85 | 86 | -------------------------------------------------------------------------------- /supervised-jax/launcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tpu_pod_launcher import TPUPodClient, TPUPodProject, create_cli 3 | 4 | SETUP_SCRIPT = """\ 5 | cd ~/ 6 | # install basics 7 | apt-get update -q \ 8 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 9 | apt-utils \ 10 | curl \ 11 | git \ 12 | vim \ 13 | wget \ 14 | tmux \ 15 | redis-server \ 16 | && apt-get clean \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | # install miniforge 20 | rm -rf ~/Miniconda3-py39_4.12.0-Linux-x86_64.sh 21 | wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -P ~/ 22 | bash ~/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b 23 | 24 | # install dependencies 25 | source ~/miniconda3/bin/activate 26 | conda init bash 27 | conda create -n llama3_train python=3.10 -y 28 | conda activate llama3_train 29 | cd ~/llama3_train 30 | python -m pip install -e . 31 | pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 32 | python -m pip install tyro flax scalax transformers gcsfs optax wandb 33 | pip install -U transformers==4.47.1 flax==0.10.2 34 | 35 | # clean up 36 | cd ~/ 37 | rm -rf ~/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh 38 | """.strip() 39 | 40 | CHECK_DEVICES = r""" 41 | source ~/miniconda3/bin/activate llama3_train 42 | python -c "import jax; print(jax.devices())" 43 | """.strip() 44 | 45 | def check_devices(project: TPUPodProject, verbose: bool=False): 46 | project.ssh(CHECK_DEVICES, verbose=verbose) 47 | 48 | def setup(project: TPUPodProject, verbose: bool=False): 49 | project.copy(verbose=verbose) 50 | project.ssh(SETUP_SCRIPT, verbose=verbose) 51 | project.ssh('mkdir ~/.config/', verbose=verbose) 52 | project.ssh('mkdir ~/.config/gcloud/', verbose=verbose) 53 | project.scp('/home/jiayipan/code/25SP/TPU-Train/civic-boulder-204700-3052e43e8c80.json', '~/.config/gcloud/', verbose=verbose) 54 | 55 | def debug(project: TPUPodProject, verbose: bool=False): 56 | import IPython; IPython.embed() 57 | 58 | def create_project(tpu_name: str, zone: str) -> TPUPodProject: 59 | return TPUPodProject( 60 | client=TPUPodClient( 61 | tpu_project='civic-boulder-204700', 62 | tpu_zone=zone, 63 | user='jiayipan', 64 | key_path='/home/jiayipan/.ssh/id_rsa', 65 | ), 66 | tpu_name=tpu_name, 67 | copy_dirs=[('/home/jiayipan/code/25SP/LM-Parallel/JAX-Train/llama3_train/', '~/llama3_train/')], 68 | working_dir='~/llama3_train/', 69 | copy_excludes=['.git', '__pycache__', '*.pkl', '*.json', '*.jsonl', '*.ipynb'], 70 | kill_commands=['pkill -9 python'], 71 | ) 72 | 73 | if __name__ == "__main__": 74 | launch_config_path = os.path.join(os.path.dirname(__file__), 'launch_config.json') 75 | 76 | available_tpus = [ 77 | ('jiayi-64-eu', 'europe-west4-a'), # v3-64 78 | ('jiayi-128-eu', 'europe-west4-a'), # v3-128 79 | ('jiayi-128-eu-2', 'europe-west4-a'), # v3-128 80 | ] 81 | 82 | tpu_projects = {name: create_project(name, zone) for name, zone in available_tpus} 83 | 84 | create_cli( 85 | projects=tpu_projects, 86 | setup=setup, 87 | custom_commands={'debug': debug, 'check_devices': check_devices}, 88 | launch_config_path=launch_config_path, 89 | ) 90 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__/ 3 | .DS_Store 4 | wandb/ 5 | *.pyc 6 | dump.rdb 7 | *.egg-info/ -------------------------------------------------------------------------------- /supervised-jax/llama3_train/README.md: -------------------------------------------------------------------------------- 1 | # LLaMA-3 Train 2 | 3 | ## Installation 4 | 5 | ``` 6 | python -m pip install -e . 7 | ``` 8 | 9 | Also see [here](https://github.com/google/jax#instructions) to install whatever version of jax you need for your accelerator. -------------------------------------------------------------------------------- /supervised-jax/llama3_train/gpt2_train_script.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | import json 3 | from functools import partial 4 | import os 5 | import collections 6 | import tempfile 7 | 8 | import tyro 9 | import jax 10 | from scalax.sharding import MeshShardingHelper, TreePathShardingRule 11 | from flax.training.train_state import TrainState 12 | import numpy as np 13 | import itertools 14 | import pickle as pkl 15 | from transformers import AutoTokenizer 16 | from tqdm.auto import tqdm 17 | import optax 18 | from jax.sharding import PartitionSpec as PS 19 | 20 | from llama_train.utils import ( 21 | load_checkpoint, open_with_bucket, save_checkpoint, 22 | get_float_dtype_by_name, get_weight_decay_mask, 23 | cross_entropy_loss_and_accuracy, global_norm, 24 | average_metrics, WandbLogger, delete_with_bucket, 25 | jax_distributed_initalize, jax_distributed_barrier 26 | ) 27 | from llama_train.gpt2 import ( 28 | GPT2Config, FlaxGPT2LMHeadModel, 29 | ) 30 | from llama_train.optimizer import load_adamw_optimizer, load_palm_optimizer 31 | 32 | def process_pretrain_example( 33 | seq: str, 34 | max_length: int, 35 | tokenizer: AutoTokenizer, 36 | ): 37 | tokenization = tokenizer( 38 | [tokenizer.bos_token+seq+tokenizer.eos_token], 39 | padding='max_length', 40 | truncation=True, 41 | max_length=max_length+1, 42 | return_tensors='np', 43 | ) 44 | 45 | input_ids = tokenization.input_ids[:, :-1] 46 | target_ids = tokenization.input_ids[:, 1:] 47 | attention_mask = tokenization.attention_mask[:, :-1] 48 | position_ids = np.maximum(np.cumsum(attention_mask, axis=-1) - 1, 0) 49 | loss_masks = attention_mask 50 | 51 | batch_items = dict( 52 | input_ids=input_ids, 53 | attention_mask=attention_mask, 54 | position_ids=position_ids, 55 | target_ids=target_ids, 56 | loss_masks=loss_masks, 57 | ) 58 | return batch_items 59 | 60 | 61 | def checkpointer( 62 | path: str, 63 | train_state: Any, 64 | config: Any, 65 | gather_fns: Any, 66 | metadata: Any=None, 67 | save_optimizer_state: bool=False, 68 | save_float_dtype: str='bf16', 69 | active=True, 70 | ): 71 | if not path.startswith('gcs://'): 72 | os.makedirs(path, exist_ok=True) 73 | if save_optimizer_state: 74 | checkpoint_state = train_state 75 | if not active: 76 | checkpoint_path = '/dev/null' 77 | else: 78 | checkpoint_path = os.path.join(path, 'train_state.msgpack') 79 | checkpoint_gather_fns = gather_fns 80 | else: 81 | checkpoint_state = train_state.params 82 | if not active: 83 | checkpoint_path = '/dev/null' 84 | else: 85 | checkpoint_path = os.path.join(path, 'params.msgpack') 86 | checkpoint_gather_fns = gather_fns.params 87 | metadata_path = os.path.join(path, 'metadata.pkl') 88 | config_path = os.path.join(path, 'config.json') 89 | 90 | save_checkpoint( 91 | checkpoint_state, 92 | checkpoint_path, 93 | gather_fns=checkpoint_gather_fns, 94 | float_dtype=save_float_dtype, 95 | ) 96 | if active: 97 | with open_with_bucket(metadata_path, 'wb') as f: 98 | pkl.dump(metadata, f) 99 | with open_with_bucket(config_path, 'w') as f: 100 | json.dump(config, f) 101 | 102 | def main( 103 | load_model: str, 104 | train_data_path: str, 105 | eval_data_path: str, 106 | output_dir: Optional[str], 107 | sharding: str, 108 | num_train_steps: int, 109 | max_length: int, 110 | bsize: int, 111 | log_freq: int, 112 | num_eval_steps: int, 113 | save_model_freq: int, 114 | wandb_project: str, 115 | param_dtype: str='fp32', 116 | activation_dtype: str='fp32', 117 | optim_config: str='adamw:{}', 118 | logger_config: str='{}', 119 | checkpointer_config: str='{}', 120 | model_config_override: str='{}', 121 | inputs_tokenizer_override: str='{}', 122 | outputs_tokenizer_override: str='{}', 123 | jax_distributed_initalize_config: str='{}', 124 | save_initial_checkpoint: bool=False, 125 | log_initial_step: bool=True, 126 | max_checkpoints: Optional[int]=None, 127 | eval_bsize: Optional[int]=None, 128 | physical_axis_splitting: bool=False, 129 | shuffle_train_data: bool=True, 130 | hf_repo_id: str='LM-Parallel/sample', 131 | ): 132 | args_dict = dict(locals()) 133 | print(args_dict) 134 | sharding: List[int] = list(map(lambda x: int(x.strip()), sharding.split(','))) 135 | if eval_bsize is None: 136 | eval_bsize = bsize 137 | 138 | param_dtype = get_float_dtype_by_name(param_dtype) 139 | activation_dtype = get_float_dtype_by_name(activation_dtype) 140 | 141 | logger_config: Dict[str, Any] = json.loads(logger_config) 142 | checkpointer_config: Dict[str, Any] = json.loads(checkpointer_config) 143 | model_config_override: Dict[str, Any] = json.loads(model_config_override) 144 | inputs_tokenizer_override: Dict[str, Any] = json.loads(inputs_tokenizer_override) 145 | outputs_tokenizer_override: Dict[str, Any] = json.loads(outputs_tokenizer_override) 146 | jax_distributed_initalize_config: Dict[str, Any] = json.loads(jax_distributed_initalize_config) 147 | 148 | jax_distributed_initalize(**jax_distributed_initalize_config) 149 | jax_distributed_barrier() 150 | 151 | if optim_config.startswith('adamw:'): 152 | optim_config = json.loads(optim_config[len('adamw:'):]) 153 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple())) 154 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1) 155 | optimizer, optimizer_info = load_adamw_optimizer(**optim_config) 156 | elif optim_config.startswith('palm:'): 157 | optim_config = json.loads(optim_config[len('palm:'):]) 158 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple())) 159 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1) 160 | optimizer, optimizer_info = load_palm_optimizer(**optim_config) 161 | else: 162 | raise ValueError(f'Unknown optimizer config: {optim_config}') 163 | if grad_accum_steps > 1: 164 | optimizer = optax.MultiSteps( 165 | optimizer, 166 | grad_accum_steps, 167 | ) 168 | 169 | mesh = MeshShardingHelper(sharding, ['replica', 'fsdp', 'tensor'], mesh_axis_splitting=physical_axis_splitting) # Create a 3D mesh with data, fsdp, and model parallelism axes 170 | with mesh.get_context(): 171 | print('mesh:', mesh.mesh) 172 | print('loading model ...') 173 | 174 | if load_model.startswith('paths:'): 175 | model_paths = json.loads(load_model[len('paths:'):]) 176 | if not ('remove_dict_prefix' in model_paths): 177 | model_paths['remove_dict_prefix'] = None 178 | else: 179 | raise ValueError(f'Unknown model info type: {load_model}') 180 | 181 | config_is_temp = False 182 | if 'config' in model_paths and model_paths['config'].startswith('gcs://'): 183 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False) 184 | with open_with_bucket(model_paths['config'], 'rb') as f: 185 | temp_file.write(f.read()) 186 | temp_file.close() 187 | model_paths['config'] = temp_file.name 188 | config_is_temp = True 189 | 190 | if 'config' in model_paths: 191 | config = GPT2Config.from_pretrained(model_paths['config'], **model_config_override) 192 | # elif 'default_config_name' in model_paths: 193 | # config = GPT2Config(**GPT2_STANDARD_CONFIGS[model_paths['default_config_name']], **model_config_override) 194 | else: 195 | config = GPT2Config(**model_config_override) 196 | 197 | if config_is_temp: 198 | os.remove(model_paths['config']) 199 | 200 | model = FlaxGPT2LMHeadModel(config, dtype=activation_dtype, _do_init=False, param_dtype=param_dtype, input_shape=(bsize, 1024)) 201 | # TODO: embedding dim is hardcoded to 1024, it's 202 | 203 | tokenizer_is_temp = False 204 | if model_paths['tokenizer'].startswith('gcs://'): 205 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False) 206 | with open_with_bucket(model_paths['tokenizer'], 'rb') as f: 207 | temp_file.write(f.read()) 208 | temp_file.close() 209 | model_paths['tokenizer'] = temp_file.name 210 | tokenizer_is_temp = True 211 | 212 | tokenizer_kwargs = dict( 213 | truncation_side='right', 214 | padding_side='right', 215 | ) 216 | tokenizer_kwargs.update(outputs_tokenizer_override) 217 | tokenizer = AutoTokenizer.from_pretrained(model_paths['tokenizer'], **tokenizer_kwargs) 218 | tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) 219 | 220 | if tokenizer_is_temp: 221 | os.remove(model_paths['tokenizer']) 222 | 223 | sharding_rules = TreePathShardingRule(*config.get_partition_rules()) 224 | 225 | @partial( 226 | mesh.sjit, 227 | in_shardings=(sharding_rules,), 228 | out_shardings=sharding_rules, 229 | ) 230 | def create_train_state_from_params(params): 231 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 232 | 233 | @partial( 234 | mesh.sjit, 235 | in_shardings=(PS(),), 236 | out_shardings=sharding_rules, 237 | ) 238 | def init_fn(rng): 239 | params = model.init_weights(rng, (bsize, 1024)) 240 | return create_train_state_from_params(params) 241 | 242 | train_state_shape = jax.eval_shape(lambda: init_fn(jax.random.PRNGKey(0))) 243 | shard_train_state_fns, gather_train_state_fns = mesh.make_shard_and_gather_fns(train_state_shape, sharding_rules) 244 | 245 | if 'params' in model_paths: 246 | train_state = create_train_state_from_params(load_checkpoint( 247 | model_paths['params'], 248 | shard_fns=shard_train_state_fns.params, 249 | remove_dict_prefix=model_paths['remove_dict_prefix'], 250 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape.params), 251 | )) 252 | elif 'train_state' in model_paths: 253 | train_state = load_checkpoint( 254 | model_paths['train_state'], 255 | shard_fns=shard_train_state_fns, 256 | remove_dict_prefix=model_paths['remove_dict_prefix'], 257 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape), 258 | ) 259 | else: 260 | train_state = init_fn(jax.random.PRNGKey(0)) 261 | 262 | print('model loaded.') 263 | 264 | @partial( 265 | mesh.sjit, 266 | in_shardings=(sharding_rules, PS(),PS()), 267 | out_shardings=(sharding_rules, PS()), 268 | args_sharding_constraint=(sharding_rules, None, PS(('replica', 'fsdp'))), 269 | donate_argnums=(0,), 270 | ) 271 | def train_step(train_state, rng, batch): 272 | def loss_and_accuracy(params): 273 | logits = model( 274 | input_ids=batch['input_ids'], 275 | attention_mask=batch['attention_mask'], 276 | position_ids=batch['position_ids'], 277 | params=params, 278 | dropout_rng=rng, 279 | train=True, 280 | ).logits 281 | return cross_entropy_loss_and_accuracy( 282 | logits, batch['target_ids'], batch['loss_masks'], 283 | ) 284 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) 285 | (loss, accuracy), grads = grad_fn(train_state.params) 286 | train_state = train_state.apply_gradients(grads=grads) 287 | metrics = dict( 288 | loss=loss, 289 | accuracy=accuracy, 290 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), 291 | gradient_norm=global_norm(grads), 292 | param_norm=global_norm(train_state.params), 293 | ) 294 | return train_state, metrics 295 | 296 | @partial( 297 | mesh.sjit, 298 | in_shardings=(sharding_rules, PS()), 299 | out_shardings=PS(), 300 | args_sharding_constraint=(sharding_rules, PS(('replica', 'fsdp'))), 301 | ) 302 | def eval_step(params, batch): 303 | logits = model( 304 | input_ids=batch['input_ids'], 305 | attention_mask=batch['attention_mask'], 306 | position_ids=batch['position_ids'], 307 | params=params, 308 | train=False, 309 | ).logits 310 | loss, accuracy = cross_entropy_loss_and_accuracy( 311 | logits, batch['target_ids'], batch['loss_masks'], 312 | ) 313 | metrics = dict( 314 | eval_loss=loss, 315 | eval_accuracy=accuracy, 316 | ) 317 | return metrics 318 | 319 | print('loading data ...') 320 | if train_data_path.endswith('jsonl'): 321 | train_examples = [] 322 | with open_with_bucket(train_data_path, 'r') as f: 323 | for line in f: 324 | train_examples.append(json.loads(line.strip())) 325 | else: 326 | with open_with_bucket(train_data_path, 'r') as f: 327 | train_examples = json.load(f) 328 | if eval_data_path.endswith('jsonl'): 329 | eval_examples = [] 330 | with open_with_bucket(eval_data_path, 'r') as f: 331 | for line in f: 332 | eval_examples.append(json.loads(line.strip())) 333 | else: 334 | with open_with_bucket(eval_data_path, 'r') as f: 335 | eval_examples = json.load(f) 336 | print('done.') 337 | 338 | def data_iterable(data_items, rng, bsize, shuffle=True, loop=True): 339 | while True: 340 | with jax.default_device(jax.devices('cpu')[0]): 341 | idxs = [] 342 | for _ in range((bsize + (len(data_items) - 1)) // len(data_items)): 343 | if shuffle: 344 | rng, subrng = jax.random.split(rng) 345 | curr_idxs = jax.random.permutation(subrng, np.arange(len(data_items))) 346 | idxs.extend(curr_idxs.tolist()) 347 | else: 348 | curr_idxs = np.arange(len(data_items)) 349 | idxs.extend(curr_idxs.tolist()) 350 | idxs = np.asarray(idxs) 351 | for batch_idx in range(len(idxs) // bsize): 352 | batch_idxs = idxs[batch_idx*bsize:(batch_idx+1)*bsize] 353 | batch_examples = [data_items[idx] for idx in batch_idxs] 354 | processed_batch_examples = [] 355 | for example in batch_examples: 356 | processed_batch_examples.append(process_pretrain_example( 357 | example, 358 | max_length, 359 | tokenizer, 360 | )) 361 | batch = dict() 362 | for key in processed_batch_examples[0]: 363 | batch[key] = np.concatenate([example[key] for example in processed_batch_examples], axis=0) 364 | yield batch 365 | if not loop: 366 | break 367 | 368 | if 'enable' not in logger_config: 369 | logger_config['enable'] = (jax.process_index() == 0) 370 | if 'config_to_log' in logger_config: 371 | logger_config['config_to_log'].update(args_dict) 372 | else: 373 | logger_config['config_to_log'] = args_dict 374 | logger = WandbLogger( 375 | wandb_project, 376 | output_dir=output_dir, 377 | **logger_config, 378 | ) 379 | 380 | checkpoint_queue = collections.deque() 381 | 382 | def _save_checkpoint( 383 | train_state, 384 | step, 385 | ): 386 | old_step = None 387 | if (max_checkpoints is not None) and (len(checkpoint_queue) >= max_checkpoints): 388 | old_step = checkpoint_queue.popleft() 389 | if logger.can_save(): 390 | print(f'saving checkpoint at step {step} ...') 391 | # delete old checkpoint if max checkpoints is reached 392 | if old_step is not None: 393 | old_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{old_step}') 394 | delete_with_bucket(old_path, recursive=True) 395 | 396 | metadata = dict( 397 | step=step, 398 | args_dict=args_dict, 399 | ) 400 | 401 | checkpointer( 402 | path=os.path.join(logger.output_dir, 'checkpoints', f'step_{step}'), 403 | train_state=train_state, 404 | config=config.to_dict(), 405 | gather_fns=gather_train_state_fns, 406 | metadata=metadata, 407 | active=logger.can_save(), 408 | **checkpointer_config, 409 | ) 410 | 411 | checkpoint_queue.append(step) 412 | 413 | if logger.can_save(): 414 | print('saved.') 415 | 416 | if save_initial_checkpoint: 417 | _save_checkpoint(train_state, 0) 418 | 419 | rng = jax.random.PRNGKey(0) 420 | 421 | rng, eval_iterable_rng = jax.random.split(rng) 422 | rng, subrng = jax.random.split(rng) 423 | train_iterable = data_iterable(train_examples, subrng, bsize, shuffle=shuffle_train_data, loop=True) 424 | for step, train_batch in tqdm(itertools.islice(enumerate(train_iterable), num_train_steps), total=num_train_steps): 425 | rng, subrng = jax.random.split(rng) 426 | train_state, metrics = train_step(train_state, subrng, train_batch) 427 | if log_freq > 0 and ((step+1) % log_freq == 0 or (log_initial_step and step == 0)): 428 | if num_eval_steps > 0: 429 | eval_metric_list = [] 430 | eval_iterable = data_iterable(eval_examples, eval_iterable_rng, eval_bsize, shuffle=True, loop=False) 431 | for eval_batch in itertools.islice(eval_iterable, num_eval_steps): 432 | eval_metric_list.append(eval_step(train_state.params, eval_batch)) 433 | metrics.update(average_metrics(jax.device_get(eval_metric_list))) 434 | log_metrics = {"step": step+1} 435 | log_metrics.update(metrics) 436 | log_metrics = jax.device_get(log_metrics) 437 | logger.log(log_metrics) 438 | print(log_metrics) 439 | 440 | if save_model_freq > 0 and (step+1) % save_model_freq == 0: 441 | _save_checkpoint(train_state, step+1) 442 | 443 | if save_model_freq > 0 and (num_train_steps not in checkpoint_queue): 444 | _save_checkpoint(train_state, num_train_steps) 445 | 446 | jax_distributed_barrier() 447 | logger.finish() 448 | jax_distributed_barrier() 449 | 450 | # Only have the first worker push to hub to avoid conflicts 451 | if jax.process_index() == 0 and logger.can_save(): 452 | import shutil 453 | print("First worker copying final checkpoint to hub...") 454 | 455 | # Create temp directory for checkpoint 456 | temp_dir = tempfile.mkdtemp() 457 | final_ckpt_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{num_train_steps}') 458 | 459 | # Copy checkpoint files to temp dir 460 | if final_ckpt_path.startswith('gcs://'): 461 | with open_with_bucket(os.path.join(final_ckpt_path, 'params.msgpack'), 'rb') as f: 462 | with open(os.path.join(temp_dir, 'params.msgpack'), 'wb') as f_out: 463 | f_out.write(f.read()) 464 | with open_with_bucket(os.path.join(final_ckpt_path, 'config.json'), 'rb') as f: 465 | with open(os.path.join(temp_dir, 'config.json'), 'wb') as f_out: 466 | f_out.write(f.read()) 467 | else: 468 | shutil.copy2(os.path.join(final_ckpt_path, 'params.msgpack'), temp_dir) 469 | shutil.copy2(os.path.join(final_ckpt_path, 'config.json'), temp_dir) 470 | 471 | # Push to hub 472 | try: 473 | from huggingface_hub import HfApi 474 | api = HfApi() 475 | repo_type = "model" 476 | repo_name = hf_repo_id 477 | 478 | api.create_repo(repo_name, repo_type=repo_type, private=False, exist_ok=True) 479 | api.upload_folder( 480 | folder_path=temp_dir, 481 | repo_id=repo_name, 482 | repo_type=repo_type 483 | ) 484 | print("Successfully pushed checkpoint to hub") 485 | except Exception as e: 486 | print(f"Error pushing to hub: {e}") 487 | finally: 488 | # Cleanup temp directory 489 | shutil.rmtree(temp_dir) 490 | if __name__ == "__main__": 491 | tyro.cli(main) 492 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/llama_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/supervised-jax/llama3_train/llama_train/__init__.py -------------------------------------------------------------------------------- /supervised-jax/llama3_train/llama_train/optimizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple, NamedTuple, Dict, Any, Optional, Callable 3 | import optax 4 | import jax.numpy as jnp 5 | import jax 6 | 7 | # Adapted from: https://github.com/young-geng/EasyLM/blob/main/EasyLM/optimizers.py 8 | 9 | def warmup_linear_decay_schedule( 10 | init_value: float, 11 | peak_value: float, 12 | warmup_steps: int, 13 | decay_steps: int, 14 | end_value: float = 0.0, 15 | ) -> optax.Schedule: 16 | """Linear warmup followed by linear decay. 17 | 18 | Args: 19 | init_value: Initial value for the scalar to be annealed. 20 | peak_value: Peak value for scalar to be annealed at end of warmup. 21 | warmup_steps: Positive integer, the length of the linear warmup. 22 | decay_steps: Positive integer, the total length of the schedule. Note that 23 | this includes the warmup time, so the number of steps during which cosine 24 | annealing is applied is `decay_steps - warmup_steps`. 25 | end_value: End value of the scalar to be annealed. 26 | Returns: 27 | schedule: A function that maps step counts to values. 28 | """ 29 | schedules = [ 30 | optax.linear_schedule( 31 | init_value=init_value, 32 | end_value=peak_value, 33 | transition_steps=warmup_steps), 34 | optax.linear_schedule( 35 | init_value=peak_value, 36 | end_value=end_value, 37 | transition_steps=decay_steps - warmup_steps)] 38 | return optax.join_schedules(schedules, [warmup_steps]) 39 | 40 | schedule_by_name = dict( 41 | cos=optax.warmup_cosine_decay_schedule, 42 | linear=warmup_linear_decay_schedule, 43 | ) 44 | 45 | def load_adamw_optimizer( 46 | init_lr: float=0.0, 47 | end_lr: float=3e-5, 48 | lr: float=3e-4, 49 | lr_warmup_steps: int=3000, 50 | lr_decay_steps: int=300000, 51 | b1: float=0.9, 52 | b2: float=0.95, 53 | clip_gradient: float=1.0, 54 | weight_decay: float=0.1, 55 | bf16_momentum: bool=False, 56 | multiply_by_parameter_scale: bool=False, 57 | weight_decay_mask: Optional[Callable]=None, 58 | schedule: str='cos', 59 | ) -> Tuple[optax.GradientTransformation, Dict[str, Any]]: 60 | learning_rate_schedule = schedule_by_name[schedule]( 61 | init_value=init_lr, 62 | peak_value=lr, 63 | warmup_steps=lr_warmup_steps, 64 | decay_steps=lr_decay_steps, 65 | end_value=end_lr, 66 | ) 67 | 68 | optimizer_info = dict( 69 | learning_rate_schedule=learning_rate_schedule, 70 | ) 71 | 72 | if multiply_by_parameter_scale: 73 | optimizer = optax.chain( 74 | optax.clip_by_global_norm(clip_gradient), 75 | optax.adafactor( 76 | learning_rate=learning_rate_schedule, 77 | multiply_by_parameter_scale=True, 78 | momentum=b1, 79 | decay_rate=b2, 80 | factored=False, 81 | clipping_threshold=None, 82 | dtype_momentum=jnp.bfloat16 if bf16_momentum else jnp.float32, 83 | ), 84 | optax_add_scheduled_weight_decay( 85 | lambda step: -learning_rate_schedule(step) * weight_decay, 86 | weight_decay_mask, 87 | ), 88 | ) 89 | else: 90 | optimizer = optax.chain( 91 | optax.clip_by_global_norm(clip_gradient), 92 | optax.adamw( 93 | learning_rate=learning_rate_schedule, 94 | weight_decay=weight_decay, 95 | b1=b1, 96 | b2=b2, 97 | mask=weight_decay_mask, 98 | mu_dtype=jnp.bfloat16 if bf16_momentum else jnp.float32, 99 | ), 100 | ) 101 | 102 | return optimizer, optimizer_info 103 | 104 | def load_palm_optimizer( 105 | lr: float=0.01, 106 | lr_warmup_steps: int=10000, 107 | b1: float=0.9, 108 | b2: float=0.99, 109 | clip_gradient: float=1.0, 110 | weight_decay: float=1e-4, 111 | bf16_momentum: bool=False, 112 | weight_decay_mask: Optional[Callable]=None, 113 | ) -> Tuple[optax.GradientTransformation, Dict[str, Any]]: 114 | def learning_rate_schedule(step): 115 | multiplier = lr / 0.01 116 | return multiplier / jnp.sqrt(jnp.maximum(step, lr_warmup_steps)) 117 | 118 | def weight_decay_schedule(step): 119 | multiplier = weight_decay / 1e-4 120 | return -multiplier * jnp.square(learning_rate_schedule(step)) 121 | 122 | optimizer_info = dict( 123 | learning_rate_schedule=learning_rate_schedule, 124 | weight_decay_schedule=weight_decay_schedule, 125 | ) 126 | 127 | optimizer = optax.chain( 128 | optax.clip_by_global_norm(clip_gradient), 129 | optax.adafactor( 130 | learning_rate=learning_rate_schedule, 131 | multiply_by_parameter_scale=True, 132 | momentum=b1, 133 | decay_rate=b2, 134 | factored=False, 135 | clipping_threshold=None, 136 | dtype_momentum=jnp.bfloat16 if bf16_momentum else jnp.float32, 137 | ), 138 | optax_add_scheduled_weight_decay( 139 | weight_decay_schedule, weight_decay_mask, 140 | ), 141 | ) 142 | 143 | return optimizer, optimizer_info 144 | 145 | class OptaxScheduledWeightDecayState(NamedTuple): 146 | count: jnp.ndarray 147 | 148 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None): 149 | """ Apply weight decay with schedule. """ 150 | 151 | def init_fn(params): 152 | del params 153 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32)) 154 | 155 | def update_fn(updates, state, params): 156 | if params is None: 157 | raise ValueError('Params cannot be None for weight decay!') 158 | 159 | weight_decay = schedule_fn(state.count) 160 | updates = jax.tree_util.tree_map( 161 | lambda g, p: g + weight_decay * p, updates, params 162 | ) 163 | return updates, OptaxScheduledWeightDecayState( 164 | count=optax.safe_int32_increment(state.count) 165 | ) 166 | 167 | if mask is not None: 168 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask) 169 | return optax.GradientTransformation(init_fn, update_fn) 170 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/llama_train/serve.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, Any, Optional 2 | import redis 3 | import time 4 | import pickle as pkl 5 | import multiprocessing as mp 6 | from functools import partial 7 | import six 8 | import json 9 | from collections import OrderedDict 10 | # config for server 11 | 12 | class Config: 13 | redis_host = 'localhost' 14 | redis_port = 6379 15 | redis_db = 0 16 | client_refresh_delay = 0.01 17 | self_indicator = '__self__' 18 | init_message = '__init_message__' 19 | sse_channel_prefix = "__sse_channel__" 20 | sse_exit_type = "__EXIT__" 21 | 22 | """ 23 | ===== 24 | Below is the code for running a class on a seperate process. 25 | Each call to a method on the class is executed in a queue. 26 | You want to do this when serving models to process 1 request at a time. 27 | ===== 28 | """ 29 | 30 | def serve_class(cls): 31 | cache_cls = pkl.dumps(cls) 32 | 33 | class WrappedModel: 34 | def __init__(self, *args, **kwargs): 35 | self.r = redis.Redis(host=Config.redis_host, port=Config.redis_port, db=Config.redis_db) 36 | self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs) 37 | 38 | def __getattribute__(self, name): 39 | return partial(build_method(name, super().__getattribute__('r'), super().__getattribute__('Q')), self) 40 | 41 | def __call__(self, *args, **kwargs): 42 | return build_method('__call__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, *args, **kwargs) 43 | 44 | def __getitem__(self, key): 45 | return build_method('__getitem__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, key) 46 | 47 | def __setitem__(self, key, value): 48 | return build_method('__setitem__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, key, value) 49 | 50 | def __contains__(self, key): 51 | return build_method('__contains__', super().__getattribute__('r'), super().__getattribute__('Q'))(self, key) 52 | 53 | def __len__(self): 54 | return build_method('__len__', super().__getattribute__('r'), super().__getattribute__('Q'))(self) 55 | 56 | return WrappedModel 57 | 58 | def build_method(method, r, Q): 59 | def call_method(self, *args, **kwargs): 60 | request_id = int(r.incr('request_id_counter')) 61 | Q.put((request_id, method, args, kwargs,)) 62 | while not r.exists(f'result_{request_id}'): 63 | time.sleep(Config.client_refresh_delay) 64 | result = pkl.loads(r.get(f'result_{request_id}')) 65 | r.delete(f'result_{request_id}') 66 | if result == Config.self_indicator: 67 | return self 68 | return result 69 | return call_method 70 | 71 | def server_process(Q, cls_pkl, args, kwargs): 72 | r = redis.Redis(host=Config.redis_host, port=Config.redis_port, db=Config.redis_db) 73 | model = pkl.loads(cls_pkl)(*args, **kwargs) 74 | while True: 75 | try: 76 | request_id, method, args, kwargs = Q.get() 77 | if method == Config.init_message: 78 | r.set(f'result_{request_id}', pkl.dumps(method)) 79 | continue 80 | result = getattr(model, method)(*args, **kwargs) 81 | if isinstance(result, Generator): 82 | result = tuple(result) 83 | if result == model: 84 | result = Config.self_indicator 85 | r.set(f'result_{request_id}', pkl.dumps(result)) 86 | except EOFError: 87 | return 88 | except Exception as e: 89 | raise Exception 90 | 91 | def initalize_server(self, r, cls_pkl, args, kwargs): 92 | Q = mp.Manager().Queue() 93 | p = mp.Process(target=server_process, args=(Q, cls_pkl, args, kwargs)) 94 | p.start() 95 | build_method(Config.init_message, r, Q)(self) 96 | return Q 97 | 98 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/llama_train/splash.py: -------------------------------------------------------------------------------- 1 | # adapted from: https://github.com/stanford-crfm/levanter/blob/main/src/levanter/models/attention.py 2 | from typing import Optional 3 | import jax.numpy as jnp 4 | import warnings 5 | import jax 6 | from jax.experimental.shard_map import shard_map 7 | from jax.sharding import PartitionSpec as PS 8 | from scalax.sharding import MeshShardingHelper 9 | import functools 10 | 11 | # CF https://github.com/google/maxtext/blob/db31dd4b0b686bca4cd7cf940917ec372faa183a/MaxText/layers/attentions.py#L179 12 | def _tpu_splash_attention( 13 | query: jnp.ndarray, 14 | key: jnp.ndarray, 15 | value: jnp.ndarray, 16 | attention_mask: jnp.ndarray, 17 | dropout: float = 0.0, 18 | *, 19 | attention_dtype: Optional[jnp.dtype] = None, 20 | block_size: Optional[int] = None, 21 | ) -> Optional[jnp.ndarray]: 22 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask 23 | 24 | # Splash attention requires BHSD format 25 | # We need to reshape the input to match this format 26 | if dropout != 0.0: 27 | raise NotImplementedError("Splash attention does not support dropout") 28 | 29 | if attention_dtype is not None and attention_dtype != jnp.float32: 30 | warnings.warn("Splash attention only supports float32. Switching to float32.") 31 | 32 | attention_dtype = jnp.float32 33 | 34 | B, Sq, Hq, D = query.shape 35 | Bk, Sk, Hk, Dk = key.shape 36 | 37 | # pre-divide q_ by sqrt(d) to match the reference implementation 38 | query = query / jnp.sqrt(D) 39 | 40 | # number 41 | if Sk % 128 != 0: 42 | raise NotImplementedError(f"Splash attention requires KPos to be a multiple of 128, got {Sk}") 43 | 44 | if block_size is not None and block_size % 128 != 0: 45 | raise NotImplementedError(f"Splash attention requires block_size to be a multiple of 128, got {block_size}") 46 | 47 | # TODO: must Dk == Dv? 48 | if key.shape != value.shape: 49 | raise ValueError("k and v must have the same axes") 50 | 51 | # TODO: this isn't really necessary on TPU? 52 | if B != Bk: 53 | raise ValueError(f"Batch axes must be the same for q, k, and v: {B} != {Bk}") 54 | 55 | if D != Dk: 56 | raise ValueError(f"Embedding axes must be the same for q, k, and v: {D} != {Dk}") 57 | 58 | # MaxText uses a block size of 512 59 | block_size = block_size or 512 60 | 61 | # copied from MaxText 62 | @functools.partial( 63 | shard_map, 64 | mesh=MeshShardingHelper.get_global_mesh(), 65 | in_specs=( 66 | PS(("dp", "fsdp"), "mp", None, None), 67 | PS(("dp", "fsdp"), "mp", None, None), 68 | PS(("dp", "fsdp"), "mp", None, None), 69 | PS(("dp", "fsdp"), None), 70 | ), 71 | out_specs=PS(("dp", "fsdp"), "mp", None, None), 72 | check_rep=False, 73 | ) 74 | def wrap_flash_attention(q, k, v, attention_mask): 75 | block_sizes = splash_attention_kernel.BlockSizes( 76 | block_q=min(block_size, Sq), 77 | block_kv_compute=min(block_size, Sk), 78 | block_kv=min(block_size, Sk), 79 | block_q_dkv=min(block_size, Sq), 80 | block_kv_dkv=min(block_size, Sk), 81 | block_kv_dkv_compute=min(block_size, Sq), 82 | block_q_dq=min(block_size, Sq), 83 | block_kv_dq=min(block_size, Sq), 84 | ) 85 | 86 | segment_ids = splash_attention_kernel.SegmentIds( 87 | q=attention_mask, 88 | kv=attention_mask, 89 | ) 90 | 91 | kernel_mask = splash_attention_mask.MultiHeadMask( 92 | [splash_attention_mask.CausalMask((Sq, Sq)) for _ in range(Hq)], 93 | ) 94 | 95 | # copied from MaxText 96 | splash_kernel = splash_attention_kernel.make_splash_mha( 97 | mask=kernel_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes 98 | ) 99 | 100 | q = q.astype(attention_dtype) 101 | k = k.astype(attention_dtype) 102 | v = v.astype(attention_dtype) 103 | return jax.vmap(splash_kernel)(q, k, v, segment_ids=segment_ids) 104 | 105 | query = query.transpose(0, 2, 1, 3) 106 | key = key.transpose(0, 2, 1, 3) 107 | value = value.transpose(0, 2, 1, 3) 108 | attn_output = wrap_flash_attention(query, key, value, attention_mask) 109 | attn_output = attn_output.transpose(0, 2, 1, 3) 110 | return attn_output 111 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/llama_train/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, List 2 | import os 3 | import gcsfs 4 | import jax 5 | from flax.serialization import from_bytes, to_state_dict, from_state_dict, to_bytes 6 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node 7 | import msgpack 8 | import jax.numpy as jnp 9 | from functools import partial 10 | import re 11 | from optax import softmax_cross_entropy_with_integer_labels 12 | import wandb 13 | import uuid 14 | from socket import gethostname 15 | import tempfile 16 | from jax import lax 17 | 18 | GCLOUD_TOKEN_PATH = os.environ.get('GCLOUD_TOKEN_PATH', None) 19 | GCLOUD_PROJECT = os.environ.get('GCLOUD_PROJECT', None) 20 | 21 | def open_with_bucket( 22 | path: Any, 23 | mode: str="rb", 24 | gcloud_project: Optional[str]=None, 25 | gcloud_token: Optional[Any]=None, 26 | **kwargs, 27 | ): 28 | # backup to env vars if None 29 | if gcloud_project is None: 30 | gcloud_project = GCLOUD_PROJECT 31 | if gcloud_token is None: 32 | gcloud_token = GCLOUD_TOKEN_PATH 33 | # load from google cloud storage if starts with "gcs://" 34 | if path.startswith('gcs://'): 35 | f = gcsfs.GCSFileSystem(project=gcloud_project, token=gcloud_token).open(path[len('gcs://'):], mode=mode, **kwargs) 36 | else: 37 | f = open(path, mode=mode, **kwargs) 38 | return f 39 | 40 | def delete_with_bucket( 41 | path: str, 42 | recursive: bool=True, 43 | gcloud_project: Optional[str]=None, 44 | gcloud_token: Optional[Any]=None, 45 | ) -> None: 46 | # backup to env vars if None 47 | if gcloud_project is None: 48 | gcloud_project = GCLOUD_PROJECT 49 | if gcloud_token is None: 50 | gcloud_token = GCLOUD_TOKEN_PATH 51 | # delete from google cloud storage if starts with "gcs://" 52 | if path.startswith('gcs://'): 53 | path = path[len('gcs://'):] 54 | gcsfs.GCSFileSystem(project=gcloud_project, token=gcloud_token).rm(path, recursive=recursive) 55 | else: 56 | os.system(f"rm -{'r' if recursive else ''}f {path}") 57 | 58 | def get_gradient_checkpoint_policy(name): 59 | return { 60 | 'everything_saveable': jax.checkpoint_policies.everything_saveable, 61 | 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, 62 | 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, 63 | 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, 64 | }[name] 65 | 66 | def get_float_dtype_by_name(dtype): 67 | return { 68 | 'bf16': jnp.bfloat16, 69 | 'bfloat16': jnp.bfloat16, 70 | 'fp16': jnp.float16, 71 | 'float16': jnp.float16, 72 | 'fp32': jnp.float32, 73 | 'float32': jnp.float32, 74 | 'fp64': jnp.float64, 75 | 'float64': jnp.float64, 76 | }[dtype] 77 | 78 | 79 | def float_tensor_to_dtype(tensor, dtype): 80 | if dtype is None or dtype == '': 81 | return tensor 82 | if isinstance(dtype, str): 83 | dtype = get_float_dtype_by_name(dtype) 84 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 85 | if getattr(tensor, 'dtype', None) in float_dtypes: 86 | tensor = tensor.astype(dtype) 87 | return tensor 88 | 89 | 90 | def float_to_dtype(tree, dtype): 91 | return jax.tree_util.tree_map( 92 | partial(float_tensor_to_dtype, dtype=dtype), tree 93 | ) 94 | 95 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None, convert_to_dtypes=None): 96 | if shard_fns is not None: 97 | shard_fns = flatten_dict( 98 | to_state_dict(shard_fns) 99 | ) 100 | if convert_to_dtypes is not None: 101 | convert_to_dtypes = flatten_dict( 102 | to_state_dict(convert_to_dtypes) 103 | ) 104 | if remove_dict_prefix is not None: 105 | remove_dict_prefix = tuple(remove_dict_prefix) 106 | flattend_train_state = {} 107 | with open_with_bucket(path, 'rb') as fin: 108 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS 109 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0) 110 | for key, value in unpacker: 111 | key = tuple(key) 112 | if remove_dict_prefix is not None: 113 | if key[:len(remove_dict_prefix)] == remove_dict_prefix: 114 | key = key[len(remove_dict_prefix):] 115 | else: 116 | continue 117 | 118 | tensor = from_bytes(None, value) 119 | if convert_to_dtypes is not None: 120 | tensor = float_tensor_to_dtype(tensor, convert_to_dtypes[key]) 121 | if shard_fns is not None: 122 | tensor = shard_fns[key](tensor) 123 | flattend_train_state[key] = tensor 124 | 125 | if target is not None: 126 | flattened_target = flatten_dict( 127 | to_state_dict(target), keep_empty_nodes=True 128 | ) 129 | for key, value in flattened_target.items(): 130 | if key not in flattend_train_state and value == empty_node: 131 | flattend_train_state[key] = value 132 | 133 | train_state = unflatten_dict(flattend_train_state) 134 | if target is None: 135 | return train_state 136 | 137 | return from_state_dict(target, train_state) 138 | 139 | def save_checkpoint(train_state, path, gather_fns=None, float_dtype=None): 140 | train_state = to_state_dict(train_state) 141 | packer = msgpack.Packer() 142 | flattend_train_state = flatten_dict(train_state) 143 | if gather_fns is not None: 144 | gather_fns = flatten_dict(to_state_dict(gather_fns)) 145 | 146 | with open_with_bucket(path, "wb") as fout: 147 | for key, value in flattend_train_state.items(): 148 | if gather_fns is not None: 149 | value = gather_fns[key](value) 150 | value = float_tensor_to_dtype(value, float_dtype) 151 | fout.write(packer.pack((key, to_bytes(value)))) 152 | 153 | def tree_path_to_string(path, sep=None): 154 | keys = [] 155 | for key in path: 156 | if isinstance(key, jax.tree_util.SequenceKey): 157 | keys.append(str(key.idx)) 158 | elif isinstance(key, jax.tree_util.DictKey): 159 | keys.append(str(key.key)) 160 | elif isinstance(key, jax.tree_util.GetAttrKey): 161 | keys.append(str(key.name)) 162 | elif isinstance(key, jax.tree_util.FlattenedIndexKey): 163 | keys.append(str(key.key)) 164 | else: 165 | keys.append(str(key)) 166 | if sep is None: 167 | return tuple(keys) 168 | return sep.join(keys) 169 | 170 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): 171 | """ An extended version of jax.tree_util.tree_map, where the mapped function 172 | f takes both the name (path) and the tree leaf as input. 173 | """ 174 | return jax.tree_util.tree_map_with_path( 175 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), 176 | tree, *rest, 177 | is_leaf=is_leaf 178 | ) 179 | 180 | def get_weight_decay_mask(exclusions): 181 | """ Return a weight decay mask function that computes the pytree masks 182 | according to the given exclusion rules. 183 | """ 184 | def decay(name, _): 185 | for rule in exclusions: 186 | if re.search(rule, name) is not None: 187 | return False 188 | return True 189 | 190 | def weight_decay_mask(params): 191 | return named_tree_map(decay, params, sep='/') 192 | 193 | return weight_decay_mask 194 | 195 | def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): 196 | if valid is None: 197 | valid = jnp.ones(tokens.shape[:2]) 198 | valid = valid.astype(jnp.float32) 199 | logits = logits.astype(jnp.float32) # for numerical stability 200 | token_loss = softmax_cross_entropy_with_integer_labels(logits, tokens) 201 | loss = jnp.mean(token_loss, where=valid > 0.0) 202 | accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == tokens, where=valid > 0.0) 203 | return loss, accuracy 204 | 205 | def global_norm(tree): 206 | """ Return the global L2 norm of a pytree. """ 207 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) 208 | flattened, _ = jax.flatten_util.ravel_pytree(squared) 209 | return jnp.sqrt(jnp.sum(flattened)) 210 | 211 | def average_metrics(metrics): 212 | return jax.tree_map( 213 | lambda *args: jnp.mean(jnp.stack(args)), 214 | *metrics 215 | ) 216 | 217 | def flatten_config_dict(config, prefix=None): 218 | output = {} 219 | for key, val in config.items(): 220 | if isinstance(val, dict): 221 | output.update(flatten_config_dict(val, prefix=key)) 222 | else: 223 | if prefix is not None: 224 | output["{}.{}".format(prefix, key)] = val 225 | else: 226 | output[key] = val 227 | return output 228 | 229 | # logger adapted from mlxu: https://github.com/young-geng/mlxu/blob/main/mlxu/logging.py 230 | 231 | class WandbLogger: 232 | def __init__( 233 | self, 234 | project: str, 235 | output_dir: Optional[str]=None, 236 | config_to_log: Optional[Any]=None, 237 | enable: bool=True, 238 | online: bool=False, 239 | prefix: Optional[str]=None, 240 | experiment_id: Optional[str]=None, 241 | wandb_dir: Optional[str]=None, 242 | notes: Optional[str]=None, 243 | entity: Optional[str]=None, 244 | prefix_to_id: bool=False, 245 | ): 246 | self.enable = enable 247 | self.notes = notes 248 | self.entity = entity 249 | self.project = project 250 | self.online = online 251 | self.experiment_id = experiment_id 252 | 253 | if self.experiment_id is None: 254 | self.experiment_id = f'{uuid.uuid4().hex}-{uuid.uuid1().hex}' 255 | if prefix is not None: 256 | if prefix_to_id: 257 | self.experiment_id = f"{prefix}--{self.experiment_id}" 258 | else: 259 | self.project = f"{prefix}--{self.project}" 260 | 261 | self.wandb_dir = wandb_dir 262 | self.output_dir = output_dir 263 | if self.enable: 264 | if self.output_dir is not None: 265 | self.output_dir = os.path.join(self.output_dir, self.experiment_id) 266 | if not self.output_dir.startswith('gcs://'): 267 | os.makedirs(self.output_dir, exist_ok=True) 268 | if self.wandb_dir is None: 269 | if (self.output_dir is not None) and (not self.output_dir.startswith('gcs://')): 270 | self.wandb_dir = self.output_dir 271 | else: 272 | assert not self.wandb_dir.startswith('gcs://') 273 | self.wandb_dir = os.path.join(self.wandb_dir, self.experiment_id) 274 | os.makedirs(self.wandb_dir, exist_ok=True) 275 | 276 | if config_to_log is not None: 277 | self.config_to_log = flatten_config_dict(config_to_log) 278 | if "hostname" not in self.config_to_log: 279 | self.config_to_log["hostname"] = gethostname() 280 | if "experiment_id" not in self.config_to_log: 281 | self.config_to_log["experiment_id"] = self.experiment_id 282 | if "logger_output_dir" not in self.config_to_log: 283 | self.config_to_log["logger_output_dir"] = self.output_dir 284 | if "wandb_dir" not in self.config_to_log: 285 | self.config_to_log["wandb_dir"] = self.wandb_dir 286 | else: 287 | self.config_to_log = None 288 | 289 | if self.enable: 290 | self.run = wandb.init( 291 | reinit=True, 292 | config=self.config_to_log, 293 | project=self.project, 294 | dir=self.wandb_dir, 295 | id=self.experiment_id, 296 | notes=self.notes, 297 | entity=self.entity, 298 | settings=wandb.Settings( 299 | start_method="thread", 300 | _disable_stats=True, 301 | ), 302 | mode="online" if self.online else "offline", 303 | ) 304 | else: 305 | self.run = None 306 | 307 | def log(self, *args, **kwargs): 308 | if self.enable: 309 | self.run.log(*args, **kwargs) 310 | 311 | def finish(self): 312 | if self.enable: 313 | wandb.finish() 314 | 315 | def can_save(self) -> bool: 316 | return self.enable and (self.output_dir is not None) 317 | 318 | def jax_distributed_initalize( 319 | initialize_jax_distributed: bool=False, 320 | local_device_ids: Optional[List[int]]=None, 321 | coordinator_address: Optional[str]=None, 322 | num_processes: Optional[int]=None, 323 | process_id: Optional[int]=None, 324 | ): 325 | if initialize_jax_distributed: 326 | if local_device_ids is not None: 327 | local_device_ids = [int(x) for x in local_device_ids.split(',')] 328 | else: 329 | local_device_ids = None 330 | 331 | jax.distributed.initialize( 332 | coordinator_address=coordinator_address, 333 | num_processes=num_processes, 334 | process_id=process_id, 335 | local_device_ids=local_device_ids, 336 | ) 337 | 338 | def jax_distributed_barrier(): 339 | # Dummy function that all processes run 340 | def computation(x): 341 | result = x * x 342 | return result 343 | 344 | @partial(jax.pmap, axis_name='i') 345 | def sync_barrier(x): 346 | # Perform a trivial collective operation, acting as a barrier 347 | c = lax.psum(x, axis_name='i') 348 | return computation(x) + computation(c) 349 | 350 | # Dummy input 351 | x = jnp.ones((jax.local_device_count(),)) 352 | 353 | # Run the barrier + computation 354 | results = sync_barrier(x) 355 | 356 | jax.block_until_ready(results) 357 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/llama_train_script.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | import json 3 | from functools import partial 4 | import os 5 | import collections 6 | import tempfile 7 | 8 | import tyro 9 | import jax 10 | from scalax.sharding import MeshShardingHelper, TreePathShardingRule 11 | from flax.training.train_state import TrainState 12 | import numpy as np 13 | import itertools 14 | import pickle as pkl 15 | from transformers import AutoTokenizer 16 | from tqdm.auto import tqdm 17 | import optax 18 | from jax.sharding import PartitionSpec as PS 19 | 20 | from llama_train.utils import ( 21 | load_checkpoint, open_with_bucket, save_checkpoint, 22 | get_float_dtype_by_name, get_weight_decay_mask, 23 | cross_entropy_loss_and_accuracy, global_norm, 24 | average_metrics, WandbLogger, delete_with_bucket, 25 | jax_distributed_initalize, jax_distributed_barrier 26 | ) 27 | from llama_train.llama3 import ( 28 | LLaMAConfig, LLAMA_STANDARD_CONFIGS, 29 | FlaxLLaMAForCausalLM, download_openllama_easylm, 30 | ) 31 | from llama_train.optimizer import load_adamw_optimizer, load_palm_optimizer 32 | 33 | def process_pretrain_example( 34 | seq: str, 35 | max_length: int, 36 | tokenizer: AutoTokenizer, 37 | ): 38 | tokenization = tokenizer( 39 | [tokenizer.bos_token+seq+tokenizer.eos_token], 40 | padding='max_length', 41 | truncation=True, 42 | max_length=max_length+1, 43 | return_tensors='np', 44 | ) 45 | 46 | input_ids = tokenization.input_ids[:, :-1] 47 | target_ids = tokenization.input_ids[:, 1:] 48 | attention_mask = tokenization.attention_mask[:, :-1] 49 | position_ids = np.maximum(np.cumsum(attention_mask, axis=-1) - 1, 0) 50 | loss_masks = attention_mask 51 | 52 | batch_items = dict( 53 | input_ids=input_ids, 54 | attention_mask=attention_mask, 55 | position_ids=position_ids, 56 | target_ids=target_ids, 57 | loss_masks=loss_masks, 58 | ) 59 | return batch_items 60 | 61 | 62 | def checkpointer( 63 | path: str, 64 | train_state: Any, 65 | config: Any, 66 | gather_fns: Any, 67 | metadata: Any=None, 68 | save_optimizer_state: bool=False, 69 | save_float_dtype: str='bf16', 70 | active=True, 71 | ): 72 | if not path.startswith('gcs://'): 73 | os.makedirs(path, exist_ok=True) 74 | if save_optimizer_state: 75 | checkpoint_state = train_state 76 | if not active: 77 | checkpoint_path = '/dev/null' 78 | else: 79 | checkpoint_path = os.path.join(path, 'train_state.msgpack') 80 | checkpoint_gather_fns = gather_fns 81 | else: 82 | checkpoint_state = train_state.params 83 | if not active: 84 | checkpoint_path = '/dev/null' 85 | else: 86 | checkpoint_path = os.path.join(path, 'params.msgpack') 87 | checkpoint_gather_fns = gather_fns.params 88 | metadata_path = os.path.join(path, 'metadata.pkl') 89 | config_path = os.path.join(path, 'config.json') 90 | 91 | save_checkpoint( 92 | checkpoint_state, 93 | checkpoint_path, 94 | gather_fns=checkpoint_gather_fns, 95 | float_dtype=save_float_dtype, 96 | ) 97 | if active: 98 | with open_with_bucket(metadata_path, 'wb') as f: 99 | pkl.dump(metadata, f) 100 | with open_with_bucket(config_path, 'w') as f: 101 | json.dump(config, f) 102 | 103 | def main( 104 | load_model: str, 105 | train_data_path: str, 106 | eval_data_path: str, 107 | output_dir: Optional[str], 108 | sharding: str, 109 | num_train_steps: int, 110 | max_length: int, 111 | bsize: int, 112 | log_freq: int, 113 | num_eval_steps: int, 114 | save_model_freq: int, 115 | wandb_project: str, 116 | param_dtype: str='fp32', 117 | activation_dtype: str='fp32', 118 | optim_config: str='adamw:{}', 119 | logger_config: str='{}', 120 | checkpointer_config: str='{}', 121 | model_config_override: str='{}', 122 | inputs_tokenizer_override: str='{}', 123 | outputs_tokenizer_override: str='{}', 124 | jax_distributed_initalize_config: str='{}', 125 | save_initial_checkpoint: bool=False, 126 | log_initial_step: bool=True, 127 | max_checkpoints: Optional[int]=None, 128 | eval_bsize: Optional[int]=None, 129 | physical_axis_splitting: bool=False, 130 | shuffle_train_data: bool=True, 131 | hf_repo_id: str='LM-Parallel/sample', 132 | ): 133 | args_dict = dict(locals()) 134 | print(args_dict) 135 | sharding: List[int] = list(map(lambda x: int(x.strip()), sharding.split(','))) 136 | if eval_bsize is None: 137 | eval_bsize = bsize 138 | 139 | param_dtype = get_float_dtype_by_name(param_dtype) 140 | activation_dtype = get_float_dtype_by_name(activation_dtype) 141 | 142 | logger_config: Dict[str, Any] = json.loads(logger_config) 143 | checkpointer_config: Dict[str, Any] = json.loads(checkpointer_config) 144 | model_config_override: Dict[str, Any] = json.loads(model_config_override) 145 | inputs_tokenizer_override: Dict[str, Any] = json.loads(inputs_tokenizer_override) 146 | outputs_tokenizer_override: Dict[str, Any] = json.loads(outputs_tokenizer_override) 147 | jax_distributed_initalize_config: Dict[str, Any] = json.loads(jax_distributed_initalize_config) 148 | 149 | jax_distributed_initalize(**jax_distributed_initalize_config) 150 | jax_distributed_barrier() 151 | 152 | if optim_config.startswith('adamw:'): 153 | optim_config = json.loads(optim_config[len('adamw:'):]) 154 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple())) 155 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1) 156 | optimizer, optimizer_info = load_adamw_optimizer(**optim_config) 157 | elif optim_config.startswith('palm:'): 158 | optim_config = json.loads(optim_config[len('palm:'):]) 159 | optim_config['weight_decay_mask'] = get_weight_decay_mask(optim_config.pop('weight_decay_exclusions', tuple())) 160 | grad_accum_steps = optim_config.pop('grad_accum_steps', 1) 161 | optimizer, optimizer_info = load_palm_optimizer(**optim_config) 162 | else: 163 | raise ValueError(f'Unknown optimizer config: {optim_config}') 164 | if grad_accum_steps > 1: 165 | optimizer = optax.MultiSteps( 166 | optimizer, 167 | grad_accum_steps, 168 | ) 169 | 170 | mesh = MeshShardingHelper(sharding, ['dp', 'fsdp', 'mp'], mesh_axis_splitting=physical_axis_splitting) # Create a 3D mesh with data, fsdp, and model parallelism axes 171 | with mesh.get_context(): 172 | print('mesh:', mesh.mesh) 173 | print('loading model ...') 174 | 175 | if load_model.startswith('paths:'): 176 | model_paths = json.loads(load_model[len('paths:'):]) 177 | if not ('remove_dict_prefix' in model_paths): 178 | model_paths['remove_dict_prefix'] = None 179 | else: 180 | raise ValueError(f'Unknown model info type: {load_model}') 181 | 182 | config_is_temp = False 183 | if 'config' in model_paths and model_paths['config'].startswith('gcs://'): 184 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False) 185 | with open_with_bucket(model_paths['config'], 'rb') as f: 186 | temp_file.write(f.read()) 187 | temp_file.close() 188 | model_paths['config'] = temp_file.name 189 | config_is_temp = True 190 | 191 | if 'config' in model_paths: 192 | config = LLaMAConfig.from_pretrained(model_paths['config'], **model_config_override) 193 | elif 'default_config_name' in model_paths: 194 | config = LLaMAConfig(**LLAMA_STANDARD_CONFIGS[model_paths['default_config_name']], **model_config_override) 195 | else: 196 | config = LLaMAConfig(**model_config_override) 197 | 198 | if config_is_temp: 199 | os.remove(model_paths['config']) 200 | 201 | model = FlaxLLaMAForCausalLM(config, dtype=activation_dtype, _do_init=False, param_dtype=param_dtype, input_shape=(bsize, 1024)) 202 | # TODO: embedding dim is hardcoded to 1024, it's 203 | 204 | tokenizer_is_temp = False 205 | if model_paths['tokenizer'].startswith('gcs://'): 206 | temp_file = tempfile.NamedTemporaryFile('wb', delete=False) 207 | with open_with_bucket(model_paths['tokenizer'], 'rb') as f: 208 | temp_file.write(f.read()) 209 | temp_file.close() 210 | model_paths['tokenizer'] = temp_file.name 211 | tokenizer_is_temp = True 212 | 213 | tokenizer_kwargs = dict( 214 | truncation_side='right', 215 | padding_side='right', 216 | ) 217 | tokenizer_kwargs.update(outputs_tokenizer_override) 218 | tokenizer = AutoTokenizer.from_pretrained(model_paths['tokenizer'], **tokenizer_kwargs) 219 | tokenizer.add_special_tokens({'pad_token': tokenizer.convert_ids_to_tokens(config.pad_token_id)}) 220 | 221 | if tokenizer_is_temp: 222 | os.remove(model_paths['tokenizer']) 223 | 224 | sharding_rules = TreePathShardingRule(*config.get_partition_rules()) 225 | 226 | @partial( 227 | mesh.sjit, 228 | in_shardings=(sharding_rules,), 229 | out_shardings=sharding_rules, 230 | ) 231 | def create_train_state_from_params(params): 232 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 233 | 234 | @partial( 235 | mesh.sjit, 236 | in_shardings=(PS(),), 237 | out_shardings=sharding_rules, 238 | ) 239 | def init_fn(rng): 240 | params = model.init_weights(rng, (bsize, 1024)) 241 | return create_train_state_from_params(params) 242 | 243 | train_state_shape = jax.eval_shape(lambda: init_fn(jax.random.PRNGKey(0))) 244 | shard_train_state_fns, gather_train_state_fns = mesh.make_shard_and_gather_fns(train_state_shape, sharding_rules) 245 | 246 | if 'params' in model_paths: 247 | train_state = create_train_state_from_params(load_checkpoint( 248 | model_paths['params'], 249 | shard_fns=shard_train_state_fns.params, 250 | remove_dict_prefix=model_paths['remove_dict_prefix'], 251 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape.params), 252 | )) 253 | elif 'train_state' in model_paths: 254 | train_state = load_checkpoint( 255 | model_paths['train_state'], 256 | shard_fns=shard_train_state_fns, 257 | remove_dict_prefix=model_paths['remove_dict_prefix'], 258 | convert_to_dtypes=jax.tree_util.tree_map(lambda x: x.dtype, train_state_shape), 259 | ) 260 | else: 261 | train_state = init_fn(jax.random.PRNGKey(0)) 262 | 263 | print(model) 264 | print('model loaded.') 265 | 266 | @partial( 267 | mesh.sjit, 268 | in_shardings=(sharding_rules, PS(),PS()), 269 | out_shardings=(sharding_rules, PS()), 270 | args_sharding_constraint=(sharding_rules, None, PS(('dp', 'fsdp'))), 271 | donate_argnums=(0,), 272 | ) 273 | def train_step(train_state, rng, batch): 274 | def loss_and_accuracy(params): 275 | logits = model( 276 | input_ids=batch['input_ids'], 277 | attention_mask=batch['attention_mask'], 278 | position_ids=batch['position_ids'], 279 | params=params, 280 | dropout_rng=rng, 281 | train=True, 282 | ).logits 283 | return cross_entropy_loss_and_accuracy( 284 | logits, batch['target_ids'], batch['loss_masks'], 285 | ) 286 | # print("start training...") 287 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) 288 | (loss, accuracy), grads = grad_fn(train_state.params) 289 | # print(f"loss: {loss}, accuracy: {accuracy}") 290 | train_state = train_state.apply_gradients(grads=grads) 291 | # print("gradients applied.") 292 | metrics = dict( 293 | loss=loss, 294 | accuracy=accuracy, 295 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), 296 | gradient_norm=global_norm(grads), 297 | param_norm=global_norm(train_state.params), 298 | ) 299 | return train_state, metrics 300 | 301 | @partial( 302 | mesh.sjit, 303 | in_shardings=(sharding_rules, PS()), 304 | out_shardings=PS(), 305 | args_sharding_constraint=(sharding_rules, PS(('dp', 'fsdp'))), 306 | ) 307 | def eval_step(params, batch): 308 | logits = model( 309 | input_ids=batch['input_ids'], 310 | attention_mask=batch['attention_mask'], 311 | position_ids=batch['position_ids'], 312 | params=params, 313 | train=False, 314 | ).logits 315 | loss, accuracy = cross_entropy_loss_and_accuracy( 316 | logits, batch['target_ids'], batch['loss_masks'], 317 | ) 318 | metrics = dict( 319 | eval_loss=loss, 320 | eval_accuracy=accuracy, 321 | ) 322 | return metrics 323 | 324 | print('loading data ...') 325 | if train_data_path.endswith('jsonl'): 326 | train_examples = [] 327 | with open_with_bucket(train_data_path, 'r') as f: 328 | for line in f: 329 | train_examples.append(json.loads(line.strip())) 330 | else: 331 | with open_with_bucket(train_data_path, 'r') as f: 332 | train_examples = json.load(f) 333 | if eval_data_path.endswith('jsonl'): 334 | eval_examples = [] 335 | with open_with_bucket(eval_data_path, 'r') as f: 336 | for line in f: 337 | eval_examples.append(json.loads(line.strip())) 338 | else: 339 | with open_with_bucket(eval_data_path, 'r') as f: 340 | eval_examples = json.load(f) 341 | print('done.') 342 | 343 | def data_iterable(data_items, rng, bsize, shuffle=True, loop=True): 344 | while True: 345 | with jax.default_device(jax.devices('cpu')[0]): 346 | idxs = [] 347 | for _ in range((bsize + (len(data_items) - 1)) // len(data_items)): 348 | if shuffle: 349 | rng, subrng = jax.random.split(rng) 350 | curr_idxs = jax.random.permutation(subrng, np.arange(len(data_items))) 351 | idxs.extend(curr_idxs.tolist()) 352 | else: 353 | curr_idxs = np.arange(len(data_items)) 354 | idxs.extend(curr_idxs.tolist()) 355 | idxs = np.asarray(idxs) 356 | for batch_idx in range(len(idxs) // bsize): 357 | batch_idxs = idxs[batch_idx*bsize:(batch_idx+1)*bsize] 358 | batch_examples = [data_items[idx] for idx in batch_idxs] 359 | processed_batch_examples = [] 360 | for example in batch_examples: 361 | processed_batch_examples.append(process_pretrain_example( 362 | example, 363 | max_length, 364 | tokenizer, 365 | )) 366 | batch = dict() 367 | for key in processed_batch_examples[0]: 368 | batch[key] = np.concatenate([example[key] for example in processed_batch_examples], axis=0) 369 | yield batch 370 | if not loop: 371 | break 372 | 373 | if 'enable' not in logger_config: 374 | logger_config['enable'] = (jax.process_index() == 0) 375 | if 'config_to_log' in logger_config: 376 | logger_config['config_to_log'].update(args_dict) 377 | else: 378 | logger_config['config_to_log'] = args_dict 379 | logger = WandbLogger( 380 | wandb_project, 381 | output_dir=output_dir, 382 | **logger_config, 383 | ) 384 | print('wandb logger initialized.') 385 | 386 | checkpoint_queue = collections.deque() 387 | 388 | def _save_checkpoint( 389 | train_state, 390 | step, 391 | ): 392 | old_step = None 393 | if (max_checkpoints is not None) and (len(checkpoint_queue) >= max_checkpoints): 394 | old_step = checkpoint_queue.popleft() 395 | if logger.can_save(): 396 | print(f'saving checkpoint at step {step} ...') 397 | # delete old checkpoint if max checkpoints is reached 398 | if old_step is not None: 399 | old_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{old_step}') 400 | delete_with_bucket(old_path, recursive=True) 401 | 402 | metadata = dict( 403 | step=step, 404 | args_dict=args_dict, 405 | ) 406 | 407 | checkpointer( 408 | path=os.path.join(logger.output_dir, 'checkpoints', f'step_{step}'), 409 | train_state=train_state, 410 | config=config.to_dict(), 411 | gather_fns=gather_train_state_fns, 412 | metadata=metadata, 413 | active=logger.can_save(), 414 | **checkpointer_config, 415 | ) 416 | 417 | checkpoint_queue.append(step) 418 | 419 | if logger.can_save(): 420 | print('saved.') 421 | 422 | if save_initial_checkpoint: 423 | _save_checkpoint(train_state, 0) 424 | 425 | rng = jax.random.PRNGKey(0) 426 | 427 | rng, eval_iterable_rng = jax.random.split(rng) 428 | rng, subrng = jax.random.split(rng) 429 | train_iterable = data_iterable(train_examples, subrng, bsize, shuffle=shuffle_train_data, loop=True) 430 | for step, train_batch in tqdm(itertools.islice(enumerate(train_iterable), num_train_steps), total=num_train_steps): 431 | rng, subrng = jax.random.split(rng) 432 | train_state, metrics = train_step(train_state, subrng, train_batch) 433 | # print(f"step {step} metrics: {metrics}") 434 | if log_freq > 0 and ((step+1) % log_freq == 0 or (log_initial_step and step == 0)): 435 | if num_eval_steps > 0: 436 | eval_metric_list = [] 437 | eval_iterable = data_iterable(eval_examples, eval_iterable_rng, eval_bsize, shuffle=True, loop=False) 438 | for eval_batch in itertools.islice(eval_iterable, num_eval_steps): 439 | eval_metric_list.append(eval_step(train_state.params, eval_batch)) 440 | metrics.update(average_metrics(jax.device_get(eval_metric_list))) 441 | log_metrics = {"step": step+1} 442 | log_metrics.update(metrics) 443 | log_metrics = jax.device_get(log_metrics) 444 | logger.log(log_metrics) 445 | print(log_metrics) 446 | 447 | if save_model_freq > 0 and (step+1) % save_model_freq == 0: 448 | _save_checkpoint(train_state, step+1) 449 | 450 | if save_model_freq > 0 and (num_train_steps not in checkpoint_queue): 451 | _save_checkpoint(train_state, num_train_steps) 452 | 453 | jax_distributed_barrier() 454 | logger.finish() 455 | jax_distributed_barrier() 456 | 457 | # Only have the first worker push to hub to avoid conflicts 458 | if jax.process_index() == 0 and logger.can_save(): 459 | import shutil 460 | print("First worker copying final checkpoint to hub...") 461 | 462 | # Create temp directory for checkpoint 463 | temp_dir = tempfile.mkdtemp() 464 | final_ckpt_path = os.path.join(logger.output_dir, 'checkpoints', f'step_{num_train_steps}') 465 | 466 | # Copy checkpoint files to temp dir 467 | if final_ckpt_path.startswith('gcs://'): 468 | with open_with_bucket(os.path.join(final_ckpt_path, 'params.msgpack'), 'rb') as f: 469 | with open(os.path.join(temp_dir, 'params.msgpack'), 'wb') as f_out: 470 | f_out.write(f.read()) 471 | with open_with_bucket(os.path.join(final_ckpt_path, 'config.json'), 'rb') as f: 472 | with open(os.path.join(temp_dir, 'config.json'), 'wb') as f_out: 473 | f_out.write(f.read()) 474 | else: 475 | shutil.copy2(os.path.join(final_ckpt_path, 'params.msgpack'), temp_dir) 476 | shutil.copy2(os.path.join(final_ckpt_path, 'config.json'), temp_dir) 477 | 478 | # Push to hub 479 | try: 480 | from huggingface_hub import HfApi 481 | api = HfApi() 482 | repo_type = "model" 483 | repo_name = hf_repo_id 484 | 485 | api.create_repo(repo_name, repo_type=repo_type, private=False, exist_ok=True) 486 | api.upload_folder( 487 | folder_path=temp_dir, 488 | repo_id=repo_name, 489 | repo_type=repo_type 490 | ) 491 | print("Successfully pushed checkpoint to hub") 492 | except Exception as e: 493 | print(f"Error pushing to hub: {e}") 494 | finally: 495 | # Cleanup temp directory 496 | shutil.rmtree(temp_dir) 497 | if __name__ == "__main__": 498 | tyro.cli(main) 499 | -------------------------------------------------------------------------------- /supervised-jax/llama3_train/requirements.txt: -------------------------------------------------------------------------------- 1 | gcsfs==2023.10.0 2 | jax==0.4.31 3 | transformers==4.47.1 4 | flax==0.10.2 5 | sentencepiece==0.1.99 6 | wget==3.2 7 | jaxtyping==0.2.23 8 | scalax @ git+https://github.com/young-geng/scalax.git@main 9 | tyro==0.8.11 10 | tqdm==4.66.1 11 | wandb==0.16.1 12 | einops==0.8.0 13 | numpy<2.0.0 14 | redis==4.3.4 15 | Flask==3.0.3 16 | flask-cors==5.0.0 -------------------------------------------------------------------------------- /supervised-jax/llama3_train/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages, setup 3 | 4 | 5 | def read_requirements_file(filename): 6 | req_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 7 | filename) 8 | with open(req_file_path) as f: 9 | return [line.strip() for line in f if line.strip() != ''] 10 | 11 | 12 | setup( 13 | name='llama3_train', 14 | version='1.0.0', 15 | description='LLaMA-3 Train.', 16 | url='https://github.com/Sea-Snell/llama3_train', 17 | author='Charlie Snell', 18 | packages=find_packages(), 19 | install_requires=read_requirements_file('requirements.txt'), 20 | license='LICENCE', 21 | ) -------------------------------------------------------------------------------- /supervised-jax/scripts/hs-v3.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-200m-hs-v2" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 25 | \"default_config_name\": \"200m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/v3/train_hs.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/v3/val_hs.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hs/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 1, 67 | \"eos_token_id\": 2, 68 | \"pad_token_id\": 0, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-hsp-v3" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/hsp-v3.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-200m-hsp-v2" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 25 | \"default_config_name\": \"200m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/v3/train_hsp.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/v3/val_hsp.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hsp/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 1, 67 | \"eos_token_id\": 2, 68 | \"pad_token_id\": 0, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-hsp-v3" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/sos-v3.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-200m-sos-v2" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 25 | \"default_config_name\": \"200m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/v3/train_sos.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/v3/val_sos.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 1, 67 | \"eos_token_id\": 2, 68 | \"pad_token_id\": 0, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-sos-v3" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_hs-v2_llama.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-200m-hs-v2" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=8000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 25 | \"default_config_name\": \"200m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/hs-v2/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/hs-v2/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 1, 67 | \"eos_token_id\": 2, 68 | \"pad_token_id\": 0, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-hs-v2-8k-step" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_hsp-v2_llama.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-200m-hsp-v2" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=8000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 25 | \"default_config_name\": \"200m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/hsp-v2/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/hsp-v2/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 1, 67 | \"eos_token_id\": 2, 68 | \"pad_token_id\": 0, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-hsp-v2-8k-step" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_sos_llama.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-300m-standard-500k-bs256-valid" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 25 | \"default_config_name\": \"300m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 1, 67 | \"eos_token_id\": 2, 68 | \"pad_token_id\": 0, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-300m-standard-valid-5e-5-bs256" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_sos_llama_1ep.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-300m-standard-500k-1ep-bs128-valid" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | pip install wget 19 | # pip install optax wandb 20 | source ~/miniconda3/bin/activate llama3_train 21 | 22 | TRAIN_STEPS=3900 23 | python llama_train_script.py \ 24 | --load_model="paths:{ 25 | \"tokenizer\": \"meta-llama/Llama-2-7b-hf\", 26 | \"default_config_name\": \"300m\" 27 | }" \ 28 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 29 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 30 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 31 | --sharding="-1,1,1" \ 32 | --num_train_steps=$TRAIN_STEPS \ 33 | --max_length=4096 \ 34 | --bsize=128 \ 35 | --log_freq=100 \ 36 | --num_eval_steps=500 \ 37 | --save_model_freq=100000000 \ 38 | --wandb_project="sos" \ 39 | --param_dtype="fp32" \ 40 | --activation_dtype="fp32" \ 41 | --optim_config="adamw:{ 42 | \"init_lr\": 5e-6, 43 | \"end_lr\": 5e-7, 44 | \"lr\": 5e-5, 45 | \"lr_warmup_steps\": 1, 46 | \"lr_decay_steps\": $TRAIN_STEPS, 47 | \"b1\": 0.9, 48 | \"b2\": 0.999, 49 | \"clip_gradient\": 1.0, 50 | \"weight_decay\": 0.01, 51 | \"bf16_momentum\": false, 52 | \"multiply_by_parameter_scale\": false, 53 | \"weight_decay_exclusions\": [], 54 | \"schedule\": \"cos\", 55 | \"grad_accum_steps\": 1 56 | }" \ 57 | --logger_config="{ 58 | \"online\": true, 59 | \"prefix\": \"$RUN_NAME\", 60 | \"prefix_to_id\": true 61 | }" \ 62 | --checkpointer_config="{ 63 | \"save_optimizer_state\": false, 64 | \"save_float_dtype\": \"bf16\" 65 | }" \ 66 | --model_config_override="{ 67 | \"bos_token_id\": 1, 68 | \"eos_token_id\": 2, 69 | \"pad_token_id\": 0, 70 | \"remat_block\": \"nothing_saveable\" 71 | }" \ 72 | --eval_bsize=512 \ 73 | --no_shuffle_train_data \ 74 | --hf_repo_id="LM-Parallel/llama-300m-standard-valid-5e-5-bs128-1ep" 75 | ) 76 | 77 | 78 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_sos_llama_gpt2tok.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-300m-gpt2tok-standard-500k-bs256-valid" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python llama_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"gpt2\", 25 | \"default_config_name\": \"300m\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 50256, 67 | \"eos_token_id\": 50256, 68 | \"pad_token_id\": 50256, 69 | \"remat_block\": \"nothing_saveable\" 70 | }" \ 71 | --eval_bsize=512 \ 72 | --no_shuffle_train_data \ 73 | --hf_repo_id="LM-Parallel/llama-300m-gpt2tok-standard-valid-5e-5" 74 | ) 75 | 76 | 77 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_sos_llama_gpt2tok_1ep.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="llama-300m-gpt2tok-standard-500k-1ep-bs128-valid" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="53a3e8edb945646eb837622d6422755f5a3131b2" 17 | cd ~/llama3_train 18 | pip install wget 19 | # pip install optax wandb 20 | source ~/miniconda3/bin/activate llama3_train 21 | 22 | TRAIN_STEPS=3900 23 | python llama_train_script.py \ 24 | --load_model="paths:{ 25 | \"tokenizer\": \"gpt2\", 26 | \"default_config_name\": \"300m\" 27 | }" \ 28 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 29 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 30 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 31 | --sharding="-1,1,1" \ 32 | --num_train_steps=$TRAIN_STEPS \ 33 | --max_length=4096 \ 34 | --bsize=128 \ 35 | --log_freq=100 \ 36 | --num_eval_steps=500 \ 37 | --save_model_freq=100000000 \ 38 | --wandb_project="sos" \ 39 | --param_dtype="fp32" \ 40 | --activation_dtype="fp32" \ 41 | --optim_config="adamw:{ 42 | \"init_lr\": 5e-6, 43 | \"end_lr\": 5e-7, 44 | \"lr\": 5e-5, 45 | \"lr_warmup_steps\": 1, 46 | \"lr_decay_steps\": $TRAIN_STEPS, 47 | \"b1\": 0.9, 48 | \"b2\": 0.999, 49 | \"clip_gradient\": 1.0, 50 | \"weight_decay\": 0.01, 51 | \"bf16_momentum\": false, 52 | \"multiply_by_parameter_scale\": false, 53 | \"weight_decay_exclusions\": [], 54 | \"schedule\": \"cos\", 55 | \"grad_accum_steps\": 1 56 | }" \ 57 | --logger_config="{ 58 | \"online\": true, 59 | \"prefix\": \"$RUN_NAME\", 60 | \"prefix_to_id\": true 61 | }" \ 62 | --checkpointer_config="{ 63 | \"save_optimizer_state\": false, 64 | \"save_float_dtype\": \"bf16\" 65 | }" \ 66 | --model_config_override="{ 67 | \"bos_token_id\": 50256, 68 | \"eos_token_id\": 50256, 69 | \"pad_token_id\": 50256, 70 | \"remat_block\": \"nothing_saveable\" 71 | }" \ 72 | --eval_bsize=512 \ 73 | --no_shuffle_train_data \ 74 | --hf_repo_id="LM-Parallel/llama-300m-gpt2tok-standard-valid-5e-5-bs128-1ep" 75 | ) 76 | 77 | 78 | -------------------------------------------------------------------------------- /supervised-jax/scripts/training_sos_xiuyu.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="gpt2-standard-500k-bs128-valid" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=39000 22 | python gpt2_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"gpt2\", 25 | \"config\": \"LM-Parallel/jax-reference-gp2-s\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=128 \ 34 | --log_freq=100 \ 35 | --num_eval_steps=500 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 1e-5, 42 | \"end_lr\": 0, 43 | \"lr\": 1e-5, 44 | \"lr_warmup_steps\": 0, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 50256, 67 | \"eos_token_id\": 50256, 68 | \"pad_token_id\": 50256, 69 | \"remat_block\": \"nothing_saveable\", 70 | \"n_positions\": 4096 71 | }" \ 72 | --eval_bsize=512 \ 73 | --no_shuffle_train_data \ 74 | --hf_repo_id="LM-Parallel/gpt2-standard-valid-1e-5" 75 | ) 76 | 77 | 78 | -------------------------------------------------------------------------------- /supervised-jax/to_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json, os\n", 10 | "from google.cloud import storage\n", 11 | "import os\n", 12 | "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"CREDENTIALS_PATH\" # TODO: replace with your own credentials path\n", 13 | "# Upload to GCS\n", 14 | "bucket_name = \"BUCKET_NAME\" # TODO: replace with your own bucket name\n", 15 | "local_dir = \"LOCAL_DIR\" # TODO: replace with your own local directory\n", 16 | "prefix = \"PREFIX\" # TODO: replace with your own prefix" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 5, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "def process_to_gcp(file_path, name):\n", 26 | " # load \n", 27 | " with open(file_path, 'r') as f:\n", 28 | " data = json.load(f)\n", 29 | "\n", 30 | " # get all seqs\n", 31 | " if \"hsp\" not in name:\n", 32 | " seqs = [d['search_path'] for d in data]\n", 33 | " else:\n", 34 | " def add_all_calls(trace_dict):\n", 35 | " all_calls = []\n", 36 | " all_calls += trace_dict['main_calls']\n", 37 | " for sub in trace_dict['sub_calls']:\n", 38 | " for sub_trace in sub:\n", 39 | " all_calls += add_all_calls(sub_trace)\n", 40 | " return all_calls\n", 41 | " seqs = []\n", 42 | " for dp in data:\n", 43 | " seqs += add_all_calls(dp['trace_dict'])\n", 44 | " print(f\"name: {len(seqs)}\")\n", 45 | "\n", 46 | " # local save\n", 47 | " os.makedirs(local_dir, exist_ok=True)\n", 48 | " local_path = os.path.join(local_dir, f\"{name}.json\")\n", 49 | " with open(local_path, \"w\") as f:\n", 50 | " json.dump(seqs, f)\n", 51 | "\n", 52 | " # send to gcp\n", 53 | " storage_client = storage.Client()\n", 54 | " bucket = storage_client.bucket(bucket_name)\n", 55 | " # # Upload train file\n", 56 | " train_blob = bucket.blob(prefix+f\"{name}.json\")\n", 57 | " train_blob.upload_from_filename(local_path)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 6, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "todos = [\n", 67 | " [\"./train_apr.json\", \"train_apr\"],\n", 68 | " # TODO: add your own data here\n", 69 | "]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "from tqdm import tqdm\n", 79 | "for file_path, name in tqdm(todos):\n", 80 | " process_to_gcp(file_path, name)" 81 | ] 82 | } 83 | ], 84 | "metadata": { 85 | "kernelspec": { 86 | "display_name": "GPML", 87 | "language": "python", 88 | "name": "python3" 89 | }, 90 | "language_info": { 91 | "codemirror_mode": { 92 | "name": "ipython", 93 | "version": 3 94 | }, 95 | "file_extension": ".py", 96 | "mimetype": "text/x-python", 97 | "name": "python", 98 | "nbconvert_exporter": "python", 99 | "pygments_lexer": "ipython3", 100 | "version": "3.10.14" 101 | } 102 | }, 103 | "nbformat": 4, 104 | "nbformat_minor": 2 105 | } 106 | -------------------------------------------------------------------------------- /supervised-jax/training_hsp.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 13 | export RUN_NAME="hsp_v0_500k" 14 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 15 | export GCLOUD_PROJECT="civic-boulder-204700" 16 | export HF_TOKEN="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" 17 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6" 18 | cd ~/llama3_train 19 | # pip install optax wandb 20 | source ~/miniconda3/bin/activate llama3_train 21 | 22 | TRAIN_STEPS=19000 23 | python gpt2_train_script.py \ 24 | --load_model="paths:{ 25 | \"tokenizer\": \"gpt2\", 26 | \"config\": \"LM-Parallel/jax-reference-gp2-s\" 27 | }" \ 28 | --train_data_path="gcs://jiayi-eu/data/hsp_v0_500k/train.json" \ 29 | --eval_data_path="gcs://jiayi-eu/data/hsp_v0_500k/test.json" \ 30 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hsp/" \ 31 | --sharding="-1,1,1" \ 32 | --num_train_steps=$TRAIN_STEPS \ 33 | --max_length=4096 \ 34 | --bsize=256 \ 35 | --log_freq=512 \ 36 | --num_eval_steps=512 \ 37 | --save_model_freq=100000000 \ 38 | --wandb_project="sos" \ 39 | --param_dtype="fp32" \ 40 | --activation_dtype="fp32" \ 41 | --optim_config="adamw:{ 42 | \"init_lr\": 5e-6, 43 | \"end_lr\": 5e-7, 44 | \"lr\": 5e-5, 45 | \"lr_warmup_steps\": 1, 46 | \"lr_decay_steps\": $TRAIN_STEPS, 47 | \"b1\": 0.9, 48 | \"b2\": 0.999, 49 | \"clip_gradient\": 100.0, 50 | \"weight_decay\": 0.01, 51 | \"bf16_momentum\": false, 52 | \"multiply_by_parameter_scale\": false, 53 | \"weight_decay_exclusions\": [], 54 | \"schedule\": \"cos\", 55 | \"grad_accum_steps\": 1 56 | }" \ 57 | --logger_config="{ 58 | \"online\": true, 59 | \"prefix\": \"$RUN_NAME\", 60 | \"prefix_to_id\": true 61 | }" \ 62 | --checkpointer_config="{ 63 | \"save_optimizer_state\": false, 64 | \"save_float_dtype\": \"bf16\" 65 | }" \ 66 | --model_config_override="{ 67 | \"bos_token_id\": 50256, 68 | \"eos_token_id\": 50256, 69 | \"pad_token_id\": 50256, 70 | \"remat_block\": \"nothing_saveable\", 71 | \"resid_pdrop\": 0.00, 72 | \"embd_pdrop\": 0.00, 73 | \"attn_pdrop\": 0.00, 74 | \"n_positions\": 4096 75 | }" \ 76 | --eval_bsize=512 \ 77 | --no_shuffle_train_data \ 78 | --hf_repo_id="LM-Parallel/hsp-v0" 79 | ) 80 | 81 | 82 | -------------------------------------------------------------------------------- /supervised-jax/training_hsp_2x.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="hsp_v0_500k" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python gpt2_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"gpt2\", 25 | \"config\": \"LM-Parallel/jax-reference-gpt2-2x-s\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/hsp_v0_500k/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/hsp_v0_500k/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-hsp/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=512 \ 35 | --num_eval_steps=512 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 100.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 50256, 67 | \"eos_token_id\": 50256, 68 | \"pad_token_id\": 50256, 69 | \"remat_block\": \"nothing_saveable\", 70 | \"resid_pdrop\": 0.00, 71 | \"embd_pdrop\": 0.00, 72 | \"attn_pdrop\": 0.00, 73 | \"n_positions\": 4096 74 | }" \ 75 | --eval_bsize=512 \ 76 | --no_shuffle_train_data \ 77 | --hf_repo_id="LM-Parallel/hsp-v0-2x" 78 | ) 79 | 80 | 81 | -------------------------------------------------------------------------------- /supervised-jax/training_run.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="gpt2_sos_jan12" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=19000 22 | python gpt2_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"gpt2\", 25 | \"config\": \"LM-Parallel/jax-reference-gp2-s\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=256 \ 34 | --log_freq=512 \ 35 | --num_eval_steps=512 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 5e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.999, 48 | \"clip_gradient\": 100.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 50256, 67 | \"eos_token_id\": 50256, 68 | \"pad_token_id\": 50256, 69 | \"remat_block\": \"nothing_saveable\", 70 | \"resid_pdrop\": 0.1, 71 | \"embd_pdrop\": 0.00, 72 | \"attn_pdrop\": 0.00, 73 | \"n_positions\": 4096 74 | }" \ 75 | --eval_bsize=512 \ 76 | --no_shuffle_train_data \ 77 | --hf_repo_id="LM-Parallel/standard-ref-5e-5-rdrop01" 78 | ) 79 | 80 | 81 | -------------------------------------------------------------------------------- /supervised-jax/training_run_test.sh: -------------------------------------------------------------------------------- 1 | # 12/9/24 2 | 3 | # test gpt2 training, this will train a randomly initialized model 4 | 5 | # charlie-pod 6 | # \"tokenizer\": \"gpt2\", 7 | # \"config\": \"gpt2\" 8 | # gs:// 9 | 10 | ( 11 | source ~/miniconda3/bin/activate llama3_train 12 | export RUN_NAME="gpt2_sos_jan15" 13 | export GCLOUD_TOKEN_PATH="$HOME/.config/gcloud/civic-boulder-204700-3052e43e8c80.json" 14 | export GCLOUD_PROJECT="civic-boulder-204700" 15 | export HF_TOKEN="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" 16 | export WANDB_API_KEY="0929e692448f1bc929d71d7e3ece80073c3041e6" 17 | cd ~/llama3_train 18 | # pip install optax wandb 19 | source ~/miniconda3/bin/activate llama3_train 20 | 21 | TRAIN_STEPS=5000 22 | python gpt2_train_script.py \ 23 | --load_model="paths:{ 24 | \"tokenizer\": \"gpt2\", 25 | \"config\": \"LM-Parallel/jax-reference-gp2-s\" 26 | }" \ 27 | --train_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/train.json" \ 28 | --eval_data_path="gcs://jiayi-eu/data/sos-jan12-xiuyu/test.json" \ 29 | --output_dir="gcs://jiayi-eu/lm-parallel-exp/exp-sos/" \ 30 | --sharding="-1,1,1" \ 31 | --num_train_steps=$TRAIN_STEPS \ 32 | --max_length=4096 \ 33 | --bsize=128 \ 34 | --log_freq=512 \ 35 | --num_eval_steps=512 \ 36 | --save_model_freq=100000000 \ 37 | --wandb_project="sos" \ 38 | --param_dtype="fp32" \ 39 | --activation_dtype="fp32" \ 40 | --optim_config="adamw:{ 41 | \"init_lr\": 5e-6, 42 | \"end_lr\": 5e-7, 43 | \"lr\": 1e-5, 44 | \"lr_warmup_steps\": 1, 45 | \"lr_decay_steps\": $TRAIN_STEPS, 46 | \"b1\": 0.9, 47 | \"b2\": 0.95, 48 | \"clip_gradient\": 1.0, 49 | \"weight_decay\": 0.01, 50 | \"bf16_momentum\": false, 51 | \"multiply_by_parameter_scale\": false, 52 | \"weight_decay_exclusions\": [], 53 | \"schedule\": \"cos\", 54 | \"grad_accum_steps\": 1 55 | }" \ 56 | --logger_config="{ 57 | \"online\": true, 58 | \"prefix\": \"$RUN_NAME\", 59 | \"prefix_to_id\": true 60 | }" \ 61 | --checkpointer_config="{ 62 | \"save_optimizer_state\": false, 63 | \"save_float_dtype\": \"bf16\" 64 | }" \ 65 | --model_config_override="{ 66 | \"bos_token_id\": 50256, 67 | \"eos_token_id\": 50256, 68 | \"pad_token_id\": 50256, 69 | \"remat_block\": \"nothing_saveable\", 70 | \"resid_pdrop\": 0.00, 71 | \"embd_pdrop\": 0.00, 72 | \"attn_pdrop\": 0.00, 73 | \"n_positions\": 4096 74 | }" \ 75 | --eval_bsize=512 \ 76 | --no_shuffle_train_data \ 77 | --hf_repo_id="LM-Parallel/standard-ref-test" 78 | ) 79 | 80 | 81 | -------------------------------------------------------------------------------- /tinyrl/README.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | ``` 3 | conda create -n grpo python=3.10 4 | # install sgl 5 | pip install --upgrade pip 6 | pip install sgl-kernel --force-reinstall --no-deps 7 | pip install "sglang[all]>=0.4.3.post1" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python 8 | 9 | # other utils 10 | pip install hydra-core omegaconf wandb 11 | ``` 12 | 13 | ## For APR, SGLang needs to be patched 14 | Remove this check in `python/sglang/srt/managers/tokenizer_manager.py` in your local SGLang repo: 15 | ``` 16 | # if ( 17 | # obj.sampling_params.get("max_new_tokens") is not None 18 | # and obj.sampling_params.get("max_new_tokens") + input_token_num 19 | # >= self.context_len 20 | # ): 21 | # raise ValueError( 22 | # f"Requested token count exceeds the model's maximum context length " 23 | # f"of {self.context_len} tokens. You requested a total of " 24 | # f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " 25 | # f"tokens: {input_token_num} tokens from the input messages and " 26 | # f"{obj.sampling_params.get('max_new_tokens')} tokens for the " 27 | # f"completion. Please reduce the number of tokens in the input " 28 | # f"messages or the completion to fit within the limit." 29 | # ) 30 | ``` 31 | 32 | This file is at https://github.com/sgl-project/sglang/blob/45205d88a08606d5875476fbbbc76815a5107edd/python/sglang/srt/managers/tokenizer_manager.py#L350 33 | 34 | # Data Preparation 35 | Please put `sosp_train_prefix.json`, `sosp_val_prefix.json`, `apr_train_beam10_subbeam15_prefix.json`, and `apr_val_prefix.json` in the `data` folder. 36 | 37 | You can download them from [https://huggingface.co/datasets/Parallel-Reasoning/apr_rl_data](https://huggingface.co/datasets/Parallel-Reasoning/apr_rl_data). 38 | 39 | # Run 40 | Each run requires two GPUs: one for model training and one for serving with SGLang. 41 | 42 | ## RL on APR without subcall condition 43 | ``` 44 | export CUDA_VISIBLE_DEVICES=0,1 45 | python trainer.py --config-name apr 46 | ``` 47 | 48 | Reference checkpoint: [https://huggingface.co/Parallel-Reasoning/apr_grpo](https://huggingface.co/Parallel-Reasoning/apr_grpo) 49 | 50 | ## RL on APR with subcall condition (condition set to 10) 51 | ``` 52 | export CUDA_VISIBLE_DEVICES=0,1 53 | python trainer.py --config-name apr_cond10 54 | ``` 55 | 56 | Reference checkpoint: [https://huggingface.co/Parallel-Reasoning/apr_cond10_grpo](https://huggingface.co/Parallel-Reasoning/apr_cond10_grpo) 57 | 58 | ## RL on SOS+ 59 | ``` 60 | export CUDA_VISIBLE_DEVICES=0,1 61 | python trainer.py --config-name sosp 62 | ``` 63 | 64 | Reference checkpoint: [https://huggingface.co/Parallel-Reasoning/sosp_grpo](https://huggingface.co/Parallel-Reasoning/sosp_grpo) 65 | 66 | You can set your own config files and specify different configs with `--config-name `. 67 | -------------------------------------------------------------------------------- /tinyrl/configs/apr.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/hydra_logging: disabled 4 | - override hydra/job_logging: disabled 5 | 6 | # Data configuration 7 | data: 8 | train_data_path: data/apr_train_beam10_subbeam15_prefix.json 9 | val_data_path: data/apr_val_prefix.json 10 | train_batch_size: 64 11 | eval_batch_size: 128 12 | 13 | # Model configuration 14 | model: 15 | name: "Parallel-Reasoning/llama-apr-beam10_subbeam15" 16 | server_url: "http://localhost:0" 17 | master_port: 0 18 | mem_fraction_static: 0.6 19 | gradient_checkpointing: true # Enable gradient checkpointing to save memory 20 | 21 | # Rollout configuration 22 | rollout: 23 | mode: "apr" 24 | sample_temperature: 1.0 25 | eval_temperature: 0.0 26 | group_size: 5 27 | condition_prefix: null 28 | 29 | # Training configuration 30 | training: 31 | learning_rate: 1e-5 32 | ppo_clip_ratio: 0.2 33 | num_steps: 150 34 | inner_steps: 2 35 | kl_beta: 0.001 36 | grad_clip: 1.0 37 | # This grad_accum_chunk_size is used for gradient accumulation 38 | # It is used to avoid OOM when the batch size is too large 39 | grad_accum_chunk_size: 4 40 | log_probs_chunk_size: 16 41 | validate_per_steps: 50 42 | # Number of validation samples to use (null for full validation set), set to 16 for testing (skipping validation), set to a larger number for actual validation 43 | val_samples: null 44 | output_dir: checkpoints/apr 45 | save_steps: 10 46 | 47 | # Logging configuration 48 | logging: 49 | verbose: false 50 | wandb_project: "tinyrl-apr" 51 | wandb_name: "apr_grpo" 52 | log_interval: 1 53 | use_current_time: true 54 | 55 | hydra: 56 | run: 57 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 58 | sweep: 59 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 60 | subdir: ${hydra.job.num} 61 | -------------------------------------------------------------------------------- /tinyrl/configs/apr_cond10.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/hydra_logging: disabled 4 | - override hydra/job_logging: disabled 5 | 6 | # Data configuration 7 | data: 8 | train_data_path: data/apr_train_beam10_subbeam15_prefix.json 9 | val_data_path: data/apr_val_prefix.json 10 | train_batch_size: 64 11 | eval_batch_size: 128 12 | 13 | # Model configuration 14 | model: 15 | name: "Parallel-Reasoning/llama-apr-beam10_subbeam15_num_subcall_cond" 16 | server_url: "http://localhost:0" 17 | master_port: 0 18 | mem_fraction_static: 0.6 19 | gradient_checkpointing: true # Enable gradient checkpointing to save memory 20 | 21 | # Rollout configuration 22 | rollout: 23 | mode: "apr" 24 | sample_temperature: 1.0 25 | eval_temperature: 1.0 26 | group_size: 5 27 | # condition_prefix: null 28 | condition_prefix: "Sub Call Budget: 10 " 29 | 30 | # Training configuration 31 | training: 32 | learning_rate: 1e-5 33 | ppo_clip_ratio: 0.2 34 | num_steps: 150 35 | inner_steps: 2 36 | kl_beta: 0.001 37 | grad_clip: 1.0 38 | # This grad_accum_chunk_size is used for gradient accumulation 39 | # It is used to avoid OOM when the batch size is too large 40 | grad_accum_chunk_size: 4 41 | log_probs_chunk_size: 16 42 | validate_per_steps: 50 43 | # Number of validation samples to use (null for full validation set), set to 16 for testing, set to a larger number for actual validation 44 | val_samples: 16 45 | output_dir: checkpoints/apr 46 | save_steps: 10 47 | 48 | # Logging configuration 49 | logging: 50 | verbose: false 51 | wandb_project: "tinyrl-apr" 52 | wandb_name: "apr_cond10_grpo" 53 | log_interval: 1 54 | use_current_time: true 55 | 56 | hydra: 57 | run: 58 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 59 | sweep: 60 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 61 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /tinyrl/configs/sosp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/hydra_logging: disabled 4 | - override hydra/job_logging: disabled 5 | 6 | # Data configuration 7 | data: 8 | train_data_path: data/sosp_train_prefix.json # To be specified by user 9 | val_data_path: data/sosp_val_prefix.json # To be specified by user 10 | train_batch_size: 64 11 | eval_batch_size: 128 12 | 13 | # Model configuration 14 | model: 15 | name: "Parallel-Reasoning/llama-sosp" 16 | server_url: "http://localhost:29215" 17 | master_port: 45516 18 | mem_fraction_static: 0.9 19 | gradient_checkpointing: true # Enable gradient checkpointing to save memory 20 | 21 | # Rollout configuration 22 | rollout: 23 | mode: "sos" 24 | sample_temperature: 1.0 25 | eval_temperature: 0.5 26 | group_size: 5 27 | 28 | # Training configuration 29 | training: 30 | learning_rate: 1e-5 31 | ppo_clip_ratio: 0.2 32 | num_steps: 150 33 | inner_steps: 2 34 | kl_beta: 0.01 35 | grad_clip: 1.0 36 | # This grad_accum_chunk_size is used for gradient accumulation 37 | # It is used to avoid OOM when the batch size is too large 38 | grad_accum_chunk_size: 4 39 | log_probs_chunk_size: 16 40 | validate_per_steps: 25 41 | # Number of validation samples to use (null for full validation set), set to 16 for testing, set to a larger number for actual validation 42 | val_samples: null 43 | output_dir: checkpoints/sosp 44 | save_steps: 50 45 | 46 | # Logging configuration 47 | logging: 48 | verbose: false 49 | wandb_project: "tinyrl" 50 | wandb_name: "sosp_grpo" 51 | log_interval: 1 52 | use_current_time: true 53 | 54 | hydra: 55 | run: 56 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 57 | sweep: 58 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 59 | subdir: ${hydra.job.num} 60 | -------------------------------------------------------------------------------- /tinyrl/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Parallel-Reasoning/APR/8936e8db46bf938242bf5e0a6ebe79ff48ba267a/tinyrl/data/.gitkeep -------------------------------------------------------------------------------- /tinyrl/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | sglang 3 | wandb 4 | hydra-core 5 | termcolor 6 | -------------------------------------------------------------------------------- /tinyrl/rollout/apr_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | export CUDA_VISIBLE_DEVICES=7 3 | python -m sglang.launch_server --model-path LM-Parallel/llama-apr-v3 --host localhost --served-model-name model 4 | """ 5 | import re 6 | from transformers import AutoTokenizer 7 | from litellm import text_completion 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | import sglang 10 | 11 | # Initialize tokenizer 12 | MODEL_NAME = "LM-Parallel/llama-apr-v3" 13 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 14 | print(f"Tokenizer loaded for {MODEL_NAME}") 15 | 16 | VERBOSE = False 17 | 18 | def check_solution(prefix: str, solution: str) -> bool: 19 | """ 20 | Parses the prefix and solution to verify if the solution actually 21 | solves the puzzle of reaching `target` from the given list of numbers. 22 | 23 | :param prefix: A line like: 24 | "Moving to Node #1\nCurrent State: 62:[5, 50, 79, 27], Operations: []" 25 | :param solution: The multiline string describing the step-by-step solution. 26 | :return: True if the solution's final result matches the target and 27 | all stated operations are valid. False otherwise. 28 | """ 29 | # ----------------------------------------------------------------- 30 | # 1. Parse the prefix to extract target and initial numbers 31 | # ----------------------------------------------------------------- 32 | # Example prefix line to parse: 33 | # Current State: 62:[5, 50, 79, 27], Operations: [] 34 | # 35 | # We'll look for something matching: 36 | # Current State: :[], ... 37 | prefix_pattern = r"Current State:\s*(\d+):\[(.*?)\]" 38 | match = re.search(prefix_pattern, prefix) 39 | if not match: 40 | if VERBOSE: 41 | print("ERROR: Could not parse the prefix for target and numbers.") 42 | return False 43 | 44 | target_str, numbers_str = match.groups() 45 | target = int(target_str.strip()) 46 | # Now parse something like "5, 50, 79, 27" into a list of integers 47 | if numbers_str.strip(): 48 | initial_numbers = [int(x.strip()) for x in numbers_str.split(",")] 49 | else: 50 | initial_numbers = [] 51 | 52 | # We'll keep track of our current working list of numbers 53 | current_numbers = initial_numbers 54 | 55 | # ----------------------------------------------------------------- 56 | # 2. Parse solution to extract lines with "Exploring Operation:" 57 | # ----------------------------------------------------------------- 58 | # Example lines: 59 | # Exploring Operation: 79-27=52, Resulting Numbers: [5, 50, 52] 60 | # We want to parse out: operand1=79, operator='-', operand2=27, result=52 61 | # Then parse the new list: [5, 50, 52] 62 | 63 | operation_pattern = r"Exploring Operation:\s*([\d]+)([\+\-\*/])([\d]+)=(\d+),\s*Resulting Numbers:\s*\[(.*?)\]" 64 | 65 | # We'll process the solution line-by-line 66 | # so that we can also capture the final "Goal Reached" line. 67 | lines = solution.splitlines() 68 | 69 | for line in lines: 70 | line = line.strip() 71 | 72 | # Check for "Exploring Operation" 73 | op_match = re.search(operation_pattern, line) 74 | if op_match: 75 | # Parse out the operation parts 76 | x_str, op, y_str, z_str, new_nums_str = op_match.groups() 77 | x_val = int(x_str) 78 | y_val = int(y_str) 79 | z_val = int(z_str) 80 | 81 | # Parse the new list of numbers from something like "5, 50, 52" 82 | new_numbers = [] 83 | if new_nums_str.strip(): 84 | new_numbers = [int(n.strip()) for n in new_nums_str.split(",")] 85 | 86 | # ------------------------------------------------------------- 87 | # Verify that applying X op Y => Z to current_numbers is valid 88 | # ------------------------------------------------------------- 89 | # 1. X and Y must both be present in current_numbers 90 | # 2. Remove X and Y from current_numbers 91 | # 3. Add Z 92 | # 4. The new list must match exactly the "Resulting Numbers" 93 | # 5. Also verify the arithmetic was correct (if you want to be strict) 94 | # 95 | # NOTE: we do not handle repeating values carefully here if X or Y 96 | # appear multiple times, but you can adapt as needed (e.g. remove once). 97 | 98 | temp_list = current_numbers[:] 99 | 100 | # Try removing X, Y once each 101 | try: 102 | temp_list.remove(x_val) 103 | temp_list.remove(y_val) 104 | except ValueError: 105 | if VERBOSE: 106 | print(f"ERROR: {x_val} or {y_val} not found in current_numbers {current_numbers}.") 107 | return False 108 | 109 | # Check that the stated Z matches the arithmetic operation 110 | # (If you want to skip verifying the math, remove these lines.) 111 | computed_result = None 112 | if op == '+': 113 | computed_result = x_val + y_val 114 | elif op == '-': 115 | computed_result = x_val - y_val 116 | elif op == '*': 117 | computed_result = x_val * y_val 118 | elif op == '/': 119 | # watch for zero division or non-integer division if you want 120 | # to require exact integer results 121 | if y_val == 0: 122 | if VERBOSE: 123 | print("ERROR: Division by zero encountered.") 124 | return False 125 | # For a typical "24 game" style puzzle, we allow float or integer check 126 | computed_result = x_val / y_val 127 | # If the puzzle requires integer arithmetic only, check remainder: 128 | # if x_val % y_val != 0: 129 | # print("ERROR: Non-integer division result.") 130 | # return False 131 | 132 | # Compare the stated z_val to the computed result 133 | # (if it's integer-based arithmetic, we might check int(...) or round) 134 | if computed_result is None: 135 | if VERBOSE: 136 | print("ERROR: Unknown operation encountered.") 137 | return False 138 | 139 | # If we want exact integer match (for e.g. 50/5=10): 140 | # If float is possible, we might do a small epsilon check: 141 | # e.g. if abs(computed_result - z_val) > 1e-9 142 | if computed_result != z_val: 143 | if VERBOSE: 144 | print(f"ERROR: Operation {x_val}{op}{y_val} does not equal {z_val}. Got {computed_result} instead.") 145 | return False 146 | 147 | # Now add the result to temp_list 148 | temp_list.append(z_val) 149 | # Sort if you do not care about order, or keep order if you do 150 | # and compare to new_numbers 151 | # We'll assume exact order is not critical, so let's do a sorted comparison: 152 | if sorted(temp_list) != sorted(new_numbers): 153 | if VERBOSE: 154 | print(f"ERROR: After applying {x_val}{op}{y_val}={z_val} to {current_numbers}, " 155 | f"got {sorted(temp_list)} but solution says {sorted(new_numbers)}.") 156 | return False 157 | 158 | # If we got here, it means the operation is consistent 159 | current_numbers = new_numbers 160 | 161 | # --------------------------------------------------------- 162 | # 3. Check for "Goal Reached" line 163 | # --------------------------------------------------------- 164 | # Something like: "62,62 equal: Goal Reached" 165 | # We'll check if the final single number is indeed `target`. 166 | if "Goal Reached" in line: 167 | # For a simple check, if "Goal Reached" is present, 168 | # confirm that current_numbers is [target]. 169 | if len(current_numbers) == 1 and current_numbers[0] == target: 170 | return True 171 | else: 172 | if VERBOSE: 173 | print("ERROR: 'Goal Reached' declared but final numbers don't match the target.") 174 | return False 175 | 176 | # If we never saw "Goal Reached," then it's incomplete 177 | # or didn't declare success. Return False by default 178 | if VERBOSE: 179 | print("ERROR: Did not find 'Goal Reached' in solution.") 180 | return False 181 | 182 | def get_search_result(search_trace): 183 | # Given a search trace, return the result of the search 184 | # If the search is successful, return the result optimal path 185 | # If the search is unsuccessful, return None 186 | if search_trace.count("Goal Reached") >= 2: 187 | # Find all occurrences of "Goal Reached" 188 | goal_indices = [i for i in range(len(search_trace)) if search_trace.startswith("Goal Reached", i)] 189 | # Get the second to last index, this is where we begin generate 190 | # the optimal path 191 | goal_idx = goal_indices[-2] 192 | return search_trace[goal_idx:].strip()[13:] 193 | else: 194 | return None 195 | 196 | def get_subsearch_info(search_trace): 197 | try: 198 | return _get_subsearch_info(search_trace) 199 | except Exception as e: 200 | print(f"Error at get_subsearch_info: {e}") 201 | print(search_trace) 202 | raise e 203 | 204 | def _get_subsearch_info(search_trace): 205 | # Given a search trace, return the information of the 206 | # subsearch that it wants to invoke 207 | # sub_search= {"node": "#1,1,2", "target": 39, 'nums':[2, 11], "operations": ["51-49=2", "36-25=11"]} 208 | last_line = search_trace.split("\n")[-1] 209 | assert "" in last_line, "This is not a valid subsearch trace" 210 | 211 | # --- Parse the search trace to get the generated nodes --- 212 | generated_nodes = {} 213 | # First find any "Moving to Node" lines followed by "Current State" lines 214 | lines = search_trace.split("\n") 215 | for i in range(len(lines)-1): 216 | # Moving to Node #1,1 217 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2'] 218 | if "Moving to Node #" in lines[i] and "Current State:" in lines[i+1]: 219 | # Extract node id from first line like: 220 | # Moving to Node #1,1 221 | node_id = lines[i].split("Moving to Node #")[1].strip() 222 | 223 | # Extract state from second line like: 224 | # Current State: 39:[25, 36, 2], Operations: ['51-49=2'] 225 | state_line = lines[i+1] 226 | state_part = state_line.split("Current State:")[1].split("],")[0].strip() 227 | operations_part = state_line.split("Operations:")[1].strip() 228 | 229 | # Parse state like "39:[25, 36, 2]" 230 | target = int(state_part.split(":")[0]) 231 | # nums = eval(state_part.split(":")[1].strip()) 232 | nums = eval(state_part.split(":")[1].strip() + "]") 233 | operations = eval(operations_part) 234 | 235 | # Parse operations list 236 | 237 | generated_nodes[node_id] = { 238 | "node": f"#{node_id}", 239 | "target": target, 240 | "nums": nums, 241 | "operations": operations 242 | } 243 | for line in search_trace.split("\n"): 244 | if "Generated Node" in line: 245 | # Extract node id and info from line like: 246 | # Generated Node #1,1,2: 39:[2, 11] Operation: 36-25=11 247 | node_id = line.split(":")[0].split("#")[1] 248 | if node_id in generated_nodes: 249 | continue 250 | rest = line.split(":", 1)[1].strip() 251 | state = rest.split("Operation:")[0].strip() 252 | operation = rest.split("Operation:")[1].strip() 253 | 254 | # Parse state like "39:[2, 11]" into target and nums 255 | target = int(state.split(":")[0]) 256 | nums = eval(state.split(":")[1].strip()) 257 | 258 | parent_node_id = ",".join(node_id.split(",")[:-1]) 259 | parent_node = generated_nodes[parent_node_id] 260 | new_operations = parent_node["operations"] + [operation] 261 | 262 | generated_nodes[node_id] = { 263 | "node": f"#{node_id}", 264 | "target": target, 265 | "nums": nums, 266 | "operations": new_operations 267 | } 268 | # then we construct the sub_searches 269 | sub_search_nodes = [] 270 | # Split on and take the last chunk 271 | last_chunk = search_trace.split("\n")[-1] 272 | # Split that chunk on and take first part 273 | sub_search_section = last_chunk.split("\n")[0] 274 | 275 | for line in sub_search_section.split("\n"): 276 | if " Moving to Node #1,1,2 279 | node_id = line.split("Moving to Node #")[1].strip() 280 | sub_search_nodes.append(generated_nodes[node_id]) 281 | 282 | def construct_sub_search_prefix(node): 283 | # exmaple 284 | # "Moving to Node #1,1,0\nCurrent State: 39:[36, 50], Operations: ['51-49=2', '25*2=50'] 285 | return f"Moving to Node {node['node']}\nCurrent State: {node['target']}:[{', '.join(map(str, node['nums']))}], Operations: {node['operations']}" 286 | sub_search_prefix_list = [construct_sub_search_prefix(node) for node in sub_search_nodes] 287 | return sub_search_prefix_list, sub_search_nodes 288 | 289 | def get_main_trace_after_sub_search(main_trace, sub_search_nodes, sub_search_result_list): 290 | last_line = main_trace.split("\n")[-1] 291 | assert "" in last_line, "This is not a valid subsearch trace" 292 | sub_searches = [] 293 | # Split on and take the last chunk 294 | last_chunk = main_trace.split("\n")[-1] 295 | # Split that chunk on and take first part 296 | sub_search_section = last_chunk.split("\n")[0] 297 | main_trace_after_sub_search = "\n".join(main_trace.split("\n")[:-1]) 298 | assert main_trace_after_sub_search in main_trace 299 | main_trace_after_sub_search += "\n" 300 | for i, (this_node, this_result) in enumerate(zip(sub_search_nodes, sub_search_result_list)): 301 | if this_result is None: 302 | # 303 | main_trace_after_sub_search += f"\n" 304 | else: 305 | # \nMoving to Node #1,2,0\nCurrent State: 39:[51, 12], Operations: ['49-36=13', '25-13=12']\nExploring Operation: 51-12=39, Resulting Numbers: [39]\n39,39 equal: Goal Reached\n 306 | main_trace_after_sub_search += f"\n" 307 | main_trace_after_sub_search += this_result + "\n" 308 | main_trace_after_sub_search += "\n" 309 | return main_trace_after_sub_search 310 | 311 | 312 | 313 | def add_angle_brackets(text): 314 | lines = text.split('\n') 315 | result_lines = [] 316 | for line in lines: 317 | if '>' in line and '<' not in line: 318 | line = '<' + line 319 | result_lines.append(line) 320 | return '\n'.join(result_lines) 321 | 322 | def generate(prefix, tokenizer, api_base_url, temperature=0.5, stop=[]): 323 | """Generate text using the model API""" 324 | bos_token = tokenizer.bos_token 325 | prefix = bos_token + prefix 326 | 327 | result = text_completion( 328 | model="openai/model", 329 | prompt=prefix, 330 | api_base=api_base_url, 331 | api_key="api_key", 332 | temperature=temperature, 333 | max_tokens=4096, 334 | stop=stop, 335 | ) 336 | 337 | text = result['choices'][0]['text'] 338 | complete_text = prefix + text 339 | complete_text = complete_text.replace(bos_token, ' ') 340 | if complete_text[0] == ' ': 341 | complete_text = complete_text[1:] 342 | return complete_text, result 343 | 344 | def decode_trace(prefix, tokenizer, api_base_url, temperature=0.5): 345 | """Decode a single trace""" 346 | while True: 347 | trace = generate(prefix, tokenizer, api_base_url, temperature=temperature, stop=[""]) 348 | llm_call_info = { 349 | "prefix": prefix, # Store the original prefix 350 | "output": trace[1]["choices"][0]["text"] 351 | } 352 | # Store the prefix and output in a dictionary 353 | prefix = trace[0] 354 | if trace[1].choices[0].matched_stop == "": 355 | prefix += "" 356 | else: 357 | break 358 | prefix = trace[0] 359 | if prefix.split('\n')[-1] == "": 360 | prefix = prefix[:-1] 361 | return prefix, llm_call_info 362 | 363 | def batch_decode_trace(prefix_list, tokenizer, api_base_url, temperature=0.5, max_workers=16): 364 | """Decode multiple traces in parallel""" 365 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 366 | future_to_prefix = { 367 | executor.submit(decode_trace, prefix, tokenizer, api_base_url, temperature): prefix 368 | for prefix in prefix_list 369 | } 370 | 371 | results = [None] * len(prefix_list) 372 | llm_calls = [None] * len(prefix_list) 373 | 374 | for future in as_completed(future_to_prefix): 375 | prefix = future_to_prefix[future] 376 | try: 377 | result, llm_call_info = future.result() 378 | original_idx = prefix_list.index(prefix) 379 | results[original_idx] = result 380 | llm_calls[original_idx] = llm_call_info 381 | except Exception as e: 382 | print(f"Error processing prefix: {e}") 383 | original_idx = prefix_list.index(prefix) 384 | results[original_idx] = None 385 | llm_calls[original_idx] = None 386 | 387 | return results, llm_calls 388 | 389 | def is_calling_subsearch(trace): 390 | return "" in trace.split('\n')[-1] 391 | 392 | def call_search(prefix, api_base_url, temperature=0.5): 393 | """ 394 | Main search function that processes a given prefix 395 | 396 | Args: 397 | prefix (str): The initial state and goal 398 | api_base_url (str): The URL for the model API 399 | temperature (float): Temperature parameter for text generation 400 | 401 | Returns: 402 | tuple: (solution, trace_dict, success, true_success) 403 | """ 404 | condition_prefix = prefix.split("\n")[0].split("Moving to Node")[0] 405 | try: 406 | trace_dict = {"main_calls": [], "sub_calls": [], "llm_calls": []} 407 | trace, llm_call_info = decode_trace(prefix, tokenizer, api_base_url, temperature=temperature) 408 | trace_dict["main_calls"].append(trace) 409 | trace_dict["llm_calls"].append(llm_call_info) 410 | 411 | while is_calling_subsearch(trace): 412 | sub_search_prefix_list, sub_search_nodes = get_subsearch_info(trace) 413 | for idx, sub_search_prefix in enumerate(sub_search_prefix_list): 414 | if condition_prefix.startswith("Sub Call Budget: "): 415 | # we don't need to add the condition prefix to the sub search prefix if the condition prefix starts with "Sub Call Budget: " 416 | sub_search_prefix_list[idx] = sub_search_prefix 417 | elif condition_prefix.startswith("Token Budget: "): 418 | # we need to add the condition prefix to the sub search prefix if the condition prefix starts with "Token Budget: " 419 | sub_search_prefix_list[idx] = condition_prefix + sub_search_prefix 420 | elif condition_prefix: 421 | raise ValueError(f"Unknown condition prefix: {condition_prefix}") 422 | sub_search_traces, sub_search_llm_calls = batch_decode_trace(sub_search_prefix_list, tokenizer, api_base_url, temperature=temperature) 423 | 424 | trace_dict["sub_calls"].append([]) 425 | for sub_search_trace, sub_search_llm_call in zip(sub_search_traces, sub_search_llm_calls): 426 | trace_dict["sub_calls"][-1].append({ 427 | "main_calls": [sub_search_trace], 428 | "llm_calls": [sub_search_llm_call] 429 | }) 430 | 431 | sub_search_results = [get_search_result(trace) for trace in sub_search_traces] 432 | new_prefix = get_main_trace_after_sub_search(trace, sub_search_nodes, sub_search_results) 433 | trace, llm_call_info = decode_trace(new_prefix, tokenizer, api_base_url, temperature=temperature) 434 | trace_dict["main_calls"].append(trace) 435 | trace_dict["llm_calls"].append(llm_call_info) 436 | 437 | solution = get_search_result(trace) 438 | success = solution is not None 439 | true_success = check_solution(prefix, solution) if success else False 440 | 441 | return solution, trace_dict, success, true_success 442 | 443 | except Exception as e: 444 | print(f"Error in call_search: {e}") 445 | return None, trace_dict, False, False 446 | 447 | def process_single_prefix(prefix, server_url, bos_token, temperature=0.5): 448 | """Helper function to process a single prefix for parallel execution""" 449 | solution, trace_dict, success, true_success = call_search(prefix, server_url, temperature) 450 | seqs = [] 451 | for lm_call in trace_dict['llm_calls']: 452 | seqs.append(lm_call) 453 | for sub_calls in trace_dict['sub_calls']: 454 | for sub_call in sub_calls: 455 | for lm_call in sub_call['llm_calls']: 456 | seqs.append(lm_call) 457 | return { 458 | "seqs": seqs, 459 | "is_correct": true_success 460 | } 461 | 462 | def rollout_apr(server_url, prefix_list, bos_token, temperature=0.5, max_workers=32, condition_prefix="" 463 | ): 464 | """ 465 | Parallel implementation of rollout function using ThreadPoolExecutor 466 | """ 467 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 468 | # Create futures with index tracking 469 | futures = [] 470 | for idx, prefix in enumerate(prefix_list): 471 | future = executor.submit( 472 | process_single_prefix, 473 | condition_prefix + prefix, 474 | server_url, 475 | bos_token, 476 | temperature 477 | ) 478 | futures.append((idx, future)) 479 | 480 | # Initialize results list with correct size 481 | results = [None] * len(prefix_list) 482 | 483 | # Collect results as they complete 484 | for idx, future in futures: 485 | try: 486 | result = future.result() 487 | results[idx] = result 488 | except Exception as e: 489 | print(f"Error processing prefix at index {idx}: {e}") 490 | results[idx] = { 491 | "seqs": [], 492 | "is_correct": False 493 | } 494 | # if the result is None, something's wrong in the decoding process 495 | # we use a dummy result to avoid breaking the loop 496 | for idx in range(len(results)): 497 | if not results[idx]: 498 | print(f"Error processing prefix {prefix_list[idx]} at index {idx}") 499 | results[idx] = { 500 | "seqs": [], 501 | "is_correct": False 502 | } 503 | for result_idx in range(len(results)): 504 | for seq_id in reversed(range(len(results[result_idx]['seqs']))): 505 | if not results[result_idx]['seqs'][seq_id]: 506 | # if one seq is None, we remove that seq 507 | results[result_idx]['seqs'].pop(seq_id) 508 | print(f"Removing seq {seq_id} from result {result_idx} since it's None") 509 | return results -------------------------------------------------------------------------------- /tinyrl/rollout/sos_utils.py: -------------------------------------------------------------------------------- 1 | from litellm import text_completion 2 | import re 3 | def _parse_trajectory(search_path): 4 | # Find the first occurrence of "Current State" and trim everything before it 5 | start_idx = search_path.find("Current State") 6 | if start_idx == -1: 7 | return "Invalid input: Cannot find the initial state." 8 | search_path = search_path[start_idx:] 9 | 10 | # Extracting the target and initial numbers from the first line 11 | first_line = search_path.strip().split('\n')[0] 12 | 13 | # if mode == "dt": 14 | # first_line = first_line.split("->")[1] 15 | target_nums_match = re.match(r"Current State: (\d+):\[(.*?)\]", first_line) 16 | if not target_nums_match: 17 | return "Invalid input: Cannot find the initial state in the first line." 18 | 19 | target, nums = int(target_nums_match.group(1)), [int(n) for n in target_nums_match.group(2).split(", ")] 20 | 21 | # Extract the operations from the line that claims the goal is reached. 22 | goal_lines = re.finditer(r"\d+,\d+ equal: Goal Reached", search_path) 23 | goal_lines = list(goal_lines) 24 | if not goal_lines: 25 | return "No goal reached statement found." 26 | 27 | goal_line = goal_lines[-1] 28 | # get the last operation line before the goal reached statement 29 | operations = re.findall(r"Exploring Operation: (.*?=\d+), Resulting Numbers: \[(.*?)\]", 30 | search_path[:goal_line.start()]) 31 | if not operations: 32 | return "No operations found leading to the goal." 33 | 34 | final_operation = operations[-1][0] 35 | try: 36 | predicted_result = int(final_operation.split('=')[1]) 37 | except: 38 | print("couldnt parse last op", final_operation) 39 | return "Couldnt parse last op" 40 | if predicted_result != target: 41 | return "Invalid path: Final operation does not result in target." 42 | 43 | # get the last current state, operations before the goal reached statement, and extract the operations 44 | try: 45 | core_path = search_path[:goal_line.start()].split("Goal Reached\n")[1] 46 | except: 47 | print("invalid, no summarized answer") 48 | return "Invalid path: no summarized answer." 49 | operation_list = re.findall(r"Current State: \d+:\[.*?\], Operations: \[(.*?)\]", core_path)[ 50 | -1].split(', ') 51 | operation_list = [op.replace("'", "") for op in operation_list] 52 | operation_list += [final_operation] 53 | 54 | # Verify each operation and keep track of the numbers involved 55 | available_numbers = nums 56 | for operation in operation_list: 57 | # Verify the operation 58 | try: 59 | left, right = operation.split('=') 60 | except: 61 | return f"Could not split operation into lhs, rhs" 62 | try: 63 | if eval(left) != int(right): 64 | return f"Invalid operation: {operation}" 65 | except Exception as e: 66 | return f"Error in evaluating operation {operation}: {e}" 67 | # get the numbers involved 68 | used_numbers = re.findall(r"\d+", left) 69 | for n in used_numbers: 70 | if int(n) not in available_numbers: 71 | return f"Invalid operation: {operation}, number {n} not available in {available_numbers}" 72 | 73 | available_numbers = [n for n in available_numbers if n not in used_numbers] 74 | available_numbers.append(int(right)) 75 | 76 | return "Valid path." 77 | 78 | def is_correct(search_path): 79 | try: 80 | return _parse_trajectory(search_path) == "Valid path." 81 | except Exception as e: 82 | print(f"Error in is_correct: {e}, treating as incorrect") 83 | return False 84 | 85 | 86 | def rollout_sos(server_url, prefix_list, bos_token, temperature=0.5, condition_prefix=None): 87 | if condition_prefix is not None: 88 | assert not prefix_list[0].startswith("Token Budget: "), f"Condition prefix already in the prefix: {prefix_list[0]}, please use data without condition prefix" 89 | all_prefixes = [bos_token + condition_prefix + prefix for prefix in prefix_list] 90 | else: 91 | all_prefixes = [bos_token + prefix for prefix in prefix_list] 92 | 93 | outputs = text_completion( 94 | model="openai/model", 95 | prompt=all_prefixes, 96 | api_base=server_url, 97 | api_key="api_key", 98 | temperature=temperature, 99 | max_tokens=4000, 100 | ) 101 | return_dict = [] 102 | for output, prefix in zip(outputs["choices"], all_prefixes): 103 | whole_text = prefix + output["text"] 104 | return_dict.append({ 105 | 'seqs': [ 106 | { 107 | "prefix": prefix, 108 | "output": output["text"], 109 | } 110 | ], 111 | "is_correct": is_correct(whole_text) 112 | }) 113 | return return_dict 114 | -------------------------------------------------------------------------------- /tinyrl/utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | import requests 4 | from typing import Optional 5 | from sglang.srt.utils import kill_process_tree 6 | 7 | def popen_launch_server( 8 | model: str, 9 | base_url: str, 10 | timeout: float, 11 | model_name: str = "model", 12 | api_key: Optional[str] = None, 13 | other_args: list[str] = (), 14 | env: Optional[dict] = None, 15 | return_stdout_stderr: Optional[tuple] = None, 16 | ): 17 | _, host, port = base_url.split(":") 18 | host = host[2:] 19 | 20 | command = [ 21 | "python3", 22 | "-m", 23 | "sglang.launch_server", 24 | "--model-path", 25 | model, 26 | "--host", 27 | host, 28 | "--port", 29 | port, 30 | "--served-model-name", 31 | model_name, 32 | *other_args, 33 | ] 34 | 35 | if api_key: 36 | command += ["--api-key", api_key] 37 | 38 | if return_stdout_stderr: 39 | process = subprocess.Popen( 40 | command, 41 | stdout=return_stdout_stderr[0], 42 | stderr=return_stdout_stderr[1], 43 | env=env, 44 | text=True, 45 | ) 46 | else: 47 | process = subprocess.Popen( 48 | command, 49 | stdout=subprocess.DEVNULL, 50 | stderr=subprocess.DEVNULL, 51 | env=env 52 | ) 53 | 54 | start_time = time.time() 55 | with requests.Session() as session: 56 | while time.time() - start_time < timeout: 57 | try: 58 | headers = { 59 | "Content-Type": "application/json; charset=utf-8", 60 | "Authorization": f"Bearer {api_key}", 61 | } 62 | response = session.get( 63 | f"{base_url}/health_generate", 64 | headers=headers, 65 | ) 66 | if response.status_code == 200: 67 | return process 68 | except requests.RequestException: 69 | pass 70 | time.sleep(10) 71 | raise TimeoutError("Server failed to start within the timeout period.") 72 | 73 | def terminate_process(process): 74 | kill_process_tree(process.pid) 75 | --------------------------------------------------------------------------------