├── .gitignore ├── LICENSE ├── README.md ├── create_finetuning_data_from_refinements.py ├── environment.yml ├── eval_mbpp.py ├── finetune.py ├── finetune_refinement_model.py ├── generate_code_for_mbpp.py ├── generate_refinements_codegen_finetuned.py ├── ilf_for_code_gen.pdf ├── ilf_pipeline.sh ├── preprocess_feedback_spreadsheet.py └── surge_annotations.jsonl /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ML² AT CILVR 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Code Generation by Training with Natural Language Feedback 2 | Authors: Angelica Chen, Jérémy Scheurer, Tomasz Korbak, Jon Ander Campos, Jun Shern Chan, Samuel R. Bowman, Kyunghyun Cho, Ethan Perez 3 | 4 | This repository contains the code and data (human-written feedback and refinements) for running the Imitation learning from Language Feedback (ILF) algorithm 5 | for code generation from "Improving Code Generation by Training with Natural Language Feedback" by [Chen et al. (2023)](https://arxiv.org/abs/2303.16749). This paper has since been superceded by our TMLR publication, ["Learning from Natural Language Feedback"](https://openreview.net/forum?id=xo3hI5MwvU). 6 | 7 |

8 | 9 |

10 | 11 | ## Installation 12 | 13 | Our code relies upon the [`jaxformer` repository](https://github.com/salesforce/jaxformer) and open-source [CodeGen-Mono checkpoints](https://github.com/salesforce/CodeGen). 14 | 15 | To install all dependencies and download the necessary model checkpoints: 16 | ```{bash} 17 | git clone git@github.com:nyu-mll/ILF-for-code-generation.git 18 | cd ILF-for-code-generation 19 | conda env create -f environment.yml 20 | 21 | # Install codegen repo and reset to old commit 22 | git clone git@github.com:salesforce/CodeGen.git 23 | cd CodeGen 24 | git reset --hard 9cc1f971c83ad606cce5da292d3c58523dd920a2 25 | git clean -df 26 | pip3 install -r requirements.txt 27 | cd .. 28 | 29 | # To download codegen-6B-mono 30 | wget -P checkpoints https://storage.googleapis.com/sfr-codegen-research/checkpoints/codegen-6B-mono.tar.gz && tar -xvf checkpoints/codegen-6B-mono.tar.gz -C checkpoints/ 31 | 32 | ``` 33 | 34 | In our paper we use the Codegen-Mono 6B checkpoint, but you can easily replace the above `wget` command with the download links for the [other CodeGen models](https://github.com/salesforce/CodeGen#sampling-with-repository). 35 | 36 | ## To run the ILF pipeline 37 | To run the ILF pipeline using our dataset, run (from this directory): 38 | ```{bash} 39 | source ilf_pipeline.sh -d $(pwd) -n 40 | ``` 41 | with `` replaced with the name of the subdirectory that you wish to store results in. 42 | 43 | ## Citation 44 | ``` 45 | @article{ 46 | chen2024learning, 47 | title={Learning from Natural Language Feedback}, 48 | author={Angelica Chen and J{\'e}r{\'e}my Scheurer and Jon Ander Campos and Tomasz Korbak and Jun Shern Chan and Samuel R. Bowman and Kyunghyun Cho and Ethan Perez}, 49 | journal={Transactions on Machine Learning Research}, 50 | issn={2835-8856}, 51 | year={2024}, 52 | url={https://openreview.net/forum?id=xo3hI5MwvU}, 53 | note={} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /create_finetuning_data_from_refinements.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import re 4 | 5 | from datasets import Dataset, load_dataset, concatenate_datasets 6 | 7 | 8 | def format_prompt(mbpp, task_id): 9 | idx = mbpp["task_id"].index(task_id) 10 | text = mbpp["text"][idx] 11 | tests = mbpp["test_list"][idx] 12 | sample_code = mbpp["code"][idx] 13 | 14 | # Create prompt from scratch 15 | prompt = f'"""\n{text}\n\n' 16 | # Add the first unit test as an input-output example 17 | example = tests[0].split("assert ")[-1].replace("==", "=") 18 | prompt += f">>> Example: {example}\n" 19 | 20 | # Add code prefix 21 | fn_name = tests[0].split("assert ")[-1].split("(")[0] 22 | fn_search = re.search(f"def {fn_name}\(.*\):", sample_code) 23 | if fn_search is None: 24 | raise ValueError( 25 | f"Could not find 'def {fn_name}\(.*\):' in code for task {task_id}." 26 | ) 27 | code_prefix = sample_code[: fn_search.end()] 28 | prompt = f'{prompt}"""\n\n{code_prefix}\n' 29 | return prompt 30 | 31 | 32 | def load_scored_data(feedback_path): 33 | d = load_dataset("json", data_files={"train": feedback_path})["train"].map( 34 | lambda _, idx: {"row_id": idx}, 35 | with_indices=True, 36 | ) 37 | print(f"Initial length of d: {len(d)}") 38 | d = d.filter(lambda example: example["passed"]) 39 | print(f"Length of d after filtering for passed: {len(d)}") 40 | return d 41 | 42 | 43 | def dedupe_dataset(dataset): 44 | cols = dataset.column_names 45 | row_set = set() 46 | for ex in dataset: 47 | ex_tuple = tuple(ex[col] for col in cols) 48 | row_set.add(ex_tuple) 49 | deduped = {k: [row[i] for row in row_set] for i, k in enumerate(cols)} 50 | return Dataset.from_dict(deduped) 51 | 52 | 53 | def remove_prefix_and_func_sig(code, func_sig): 54 | if f"{func_sig}\r\n" in code: 55 | return code[code.rfind(f"{func_sig}\r\n") + len(f"{func_sig}\r\n") :] 56 | elif f"{func_sig} \r\n" in code: 57 | return code[code.rfind(f"{func_sig} \r\n") + len(f"{func_sig} \r\n") :] 58 | elif f"{func_sig}\n" in code: 59 | return code[code.rfind(f"{func_sig}\n") + len(f"{func_sig}\n") :] 60 | elif f"{func_sig}" in code: 61 | return code[code.rfind(f"{func_sig}") + len(f"{func_sig}") :] 62 | else: 63 | return code 64 | 65 | 66 | def get_completion(prompt, completion): 67 | """If 'REFINEMENT:' is in the completion, remove it. Also remove prompt prefix if present.""" 68 | ref_str = "REFINEMENT:" 69 | if ref_str in completion: 70 | idx = completion.rfind(ref_str) 71 | completion = completion[idx + len(ref_str) :] 72 | if prompt in completion: 73 | idx = completion.rfind(prompt) 74 | completion = completion[idx + len(prompt) :] 75 | return completion 76 | 77 | 78 | def create_prompts(args): 79 | mbpp = load_dataset("mbpp") 80 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()]) 81 | ref_data = load_scored_data(args.refinement_file) 82 | print(f"Length of scored data: {len(ref_data)}") 83 | 84 | # Get unique pairs of (task ID, prompt) from the scored refinements. 85 | tasks = set([(example["task_id"], example["prompt"]) for example in ref_data]) 86 | 87 | if not args.no_output_gold_data: 88 | mbpp_ft_data = { 89 | "finetuning_prompt": [], 90 | "finetuning_completion": [], 91 | "task_id": [], 92 | } 93 | task_id_to_func_sig = {} 94 | for task_id, prompt in tasks: 95 | mbpp_idx = mbpp["task_id"].index(task_id) 96 | 97 | # Get the original reformatted MBPP prompt 98 | orig_prompt = format_prompt(mbpp, task_id) 99 | 100 | # Remove method signature prefix 101 | gold_code = mbpp["code"][mbpp_idx] 102 | sig_idx = prompt.rfind("def ") 103 | colon_idx = prompt.rfind(":") 104 | func_sig = prompt[sig_idx : colon_idx + 1] 105 | task_id_to_func_sig[task_id] = func_sig 106 | gold_code = remove_prefix_and_func_sig(gold_code, func_sig) 107 | if gold_code is None: 108 | logging.warning( 109 | f"Could not find function signature {func_sig} in gold code.\nGold code:\n{gold_code}" 110 | ) 111 | continue 112 | mbpp_ft_data["finetuning_prompt"].append(orig_prompt) 113 | mbpp_ft_data["finetuning_completion"].append(gold_code) 114 | mbpp_ft_data["task_id"].append(task_id) 115 | mbpp_ft_data = Dataset.from_dict(mbpp_ft_data) 116 | 117 | if args.sample_size is not None: 118 | n = min(len(mbpp_ft_data), args.sample_size) 119 | mbpp_ft_data = mbpp_ft_data.shuffle().select(range(n)) 120 | mbpp_ft_data.to_json( 121 | f"{args.output_dir}/finetuning_prompts_mbpp_gold_{args.output_file_suffix}.jsonl" 122 | ) 123 | 124 | refs_ft_data = ref_data.map( 125 | lambda ex: { 126 | "finetuning_prompt": format_prompt(mbpp, ex["task_id"]), 127 | } 128 | ).map( 129 | lambda ex: { 130 | "finetuning_completion": get_completion( 131 | ex["finetuning_prompt"], ex["completion"] 132 | ) 133 | } 134 | ) 135 | cols_to_remove = list( 136 | set(refs_ft_data.column_names) 137 | - set(["task_id", "finetuning_prompt", "finetuning_completion"]) 138 | ) 139 | refs_ft_data = refs_ft_data.remove_columns(cols_to_remove) 140 | refs_ft_data = dedupe_dataset(refs_ft_data) 141 | if args.one_per_task: 142 | df = refs_ft_data.shuffle().to_pandas() 143 | df = df.groupby("task_id").first() 144 | refs_ft_data = Dataset.from_pandas(df) 145 | 146 | if args.sample_size is not None: 147 | n = min(len(refs_ft_data), args.sample_size) 148 | refs_ft_data = refs_ft_data.shuffle().select(range(n)) 149 | refs_ft_data.to_json( 150 | f"{args.output_dir}/finetuning_prompts_mbpp_refinements_{args.output_file_suffix}.jsonl" 151 | ) 152 | 153 | 154 | def parse_args(input_args): 155 | parser = argparse.ArgumentParser( 156 | description="Generate fine-tuning prompts from model-generated refinements. Also generate FT prompts for those same task IDs from the original MBPP dataset using gold code." 157 | ) 158 | parser.add_argument( 159 | "--refinement-file", 160 | type=str, 161 | help="Path to file containing evaluated refinements. Needs to have the following columns: passed, task_id, prompt, completion.", 162 | ) 163 | parser.add_argument( 164 | "--output-dir", type=str, help="Directory to output data files in." 165 | ) 166 | parser.add_argument( 167 | "--no-output-gold-data", 168 | action="store_true", 169 | help="If set, will not output finetuning files for gold completions.", 170 | ) 171 | parser.add_argument("--output-file-suffix", type=str, default="") 172 | parser.add_argument( 173 | "-n", 174 | "--sample-size", 175 | default=None, 176 | type=int, 177 | help="If set, will limit the number of outputs to this value.", 178 | ) 179 | parser.add_argument( 180 | "--one-per-task", 181 | action="store_true", 182 | help="If set, will randomly select one correct refinement per task.", 183 | ) 184 | args = parser.parse_args() 185 | return args 186 | 187 | 188 | def main(): 189 | args = parse_args(None) 190 | create_prompts(args) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ilf 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6.0=hecad31d_10 8 | - pip=22.1.2 9 | - python=3.7.13=h12debd9_0 10 | - pytorch=1.12.1=py3.7_cuda11.6_cudnn8.3.2_0 11 | - pytorch-mutex=1.0=cuda 12 | - readline=8.1.2=h7f8727e_1 13 | - setuptools=63.4.1 14 | - pip: 15 | - argparse==1.4.0 16 | - datasets==2.7.1 17 | - evaluate==0.3.0 18 | - huggingface-hub==0.9.1 19 | - matplotlib==3.5.3 20 | - nltk==3.7 21 | - numpy==1.21.6 22 | - openai==0.23.0 23 | - pytest==7.2.2 24 | - python-dateutil==2.8.2 25 | - pytz==2022.2.1 26 | - regex==2022.9.13 27 | - sacremoses==0.0.53 28 | - scikit-learn==1.0.2 29 | - scipy==1.7.3 30 | - six==1.16.0 31 | - sklearn==0.0 32 | - timeout-decorator==0.5.0 33 | - tokenizers==0.10.3 34 | - tqdm==4.64.1 35 | - transformers==4.12.5 36 | -------------------------------------------------------------------------------- /eval_mbpp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import io 4 | import itertools 5 | import json 6 | import pprint 7 | import numpy as np 8 | import re 9 | import sys 10 | import timeout_decorator 11 | import traceback 12 | 13 | 14 | from collections import defaultdict 15 | from datasets import concatenate_datasets, load_dataset 16 | from multiprocessing import Process, Queue 17 | from tqdm import tqdm 18 | from typing import Dict, List, Union 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser( 23 | description="Evaluate model completions on the MBPP benchmark." 24 | ) 25 | parser.add_argument( 26 | "--input-file", 27 | type=str, 28 | help="File containing columns , 'completion', and 'task_id'.", 29 | ) 30 | parser.add_argument("--k", default="1,10") 31 | parser.add_argument("--file-suffix", default="results") 32 | parser.add_argument( 33 | "--prompt-column-name", default="prompt", help="Name of prompt column." 34 | ) 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def estimate_pass_at_k( 40 | num_samples: Union[int, List[int], np.ndarray], 41 | num_correct: Union[List[int], np.ndarray], 42 | k: int, 43 | ) -> np.ndarray: 44 | """ 45 | Estimates pass@k of each problem and returns them in an array. 46 | Taken from https://github.com/openai/human-eval/blob/master/human_eval/evaluation.py#L13. 47 | """ 48 | 49 | def estimator(n: int, c: int, k: int) -> float: 50 | """ 51 | Calculates 1 - comb(n - c, k) / comb(n, k). 52 | """ 53 | if n - c < k: 54 | return 1.0 55 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 56 | 57 | if isinstance(num_samples, int): 58 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 59 | else: 60 | assert len(num_samples) == len(num_correct) 61 | num_samples_it = iter(num_samples) 62 | 63 | return np.array( 64 | [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] 65 | ) 66 | 67 | 68 | def compute_results(eval_results): 69 | results = defaultdict(list) 70 | for row in eval_results: 71 | ti = row["task_id"] 72 | passed = row["passed"] 73 | results[ti].append(passed) 74 | outputs = { 75 | ti: {"num_correct": np.sum(r), "num_total": len(r)} for ti, r in results.items() 76 | } 77 | return outputs 78 | 79 | 80 | def compute_at_least_one_pass_per_task(results): 81 | total = 0 82 | task_ids = [] 83 | for task_id, results_dict in results.items(): 84 | if results_dict["num_correct"] > 0: 85 | total += 1 86 | task_ids.append(task_id) 87 | return total, task_ids 88 | 89 | 90 | def compute_pass_at_ks(results, ks): 91 | output = { 92 | k: estimate_pass_at_k( 93 | [x["num_total"] for _, x in results.items()], 94 | [x["num_correct"] for _, x in results.items()], 95 | k, 96 | ).mean() 97 | for k in ks 98 | } 99 | return output 100 | 101 | 102 | @timeout_decorator.timeout(3) 103 | def eval_code(q, src, test, entry_point): 104 | all_src = f"{src}\n{test}\ncheck({entry_point})\n" 105 | try: 106 | exec(all_src, {}) 107 | except Exception: 108 | with io.StringIO() as f: 109 | traceback.print_exception(*sys.exc_info(), file=f) 110 | q.put((False, f.getvalue())) 111 | return 112 | q.put((True, None)) 113 | 114 | 115 | def eval_code_wrapper(src, test, entry_point): 116 | queue = Queue() 117 | p = Process(target=eval_code, args=(queue, src, test, entry_point)) 118 | p.start() 119 | p.join(3) 120 | if p.is_alive(): 121 | p.kill() 122 | if not queue.empty(): 123 | return queue.get() 124 | else: 125 | return False, f"Exit code: {p.exitcode}" 126 | 127 | 128 | def is_float(element: str) -> bool: 129 | try: 130 | float(element) 131 | return True 132 | except ValueError: 133 | return False 134 | 135 | 136 | def format_test(mbpp, entrypoint, task_id): 137 | idx = mbpp["task_id"].index(task_id) 138 | test_list = mbpp["test_list"][idx] 139 | 140 | test_str = "def check(candidate):\n" 141 | 142 | # use pytest.approx() for float results 143 | if is_float(test_list[0].split("==")[-1]): 144 | test_str = "from pytest import approx\n\n" + test_str 145 | for i in range(len(test_list)): 146 | split = test_list[i].split("==") 147 | split[-1] = f"approx({split[-1]})" 148 | test_list[i] = "==".join(split) 149 | 150 | for test in test_list: 151 | test_str += f"\t{test}\n" 152 | test_str += "\n" 153 | 154 | if entrypoint != "check": 155 | test_str = test_str.replace(entrypoint, "candidate") 156 | else: 157 | test_str = test_str.replace(f"assert {entrypoint}", "assert candidate") 158 | return test_str 159 | 160 | 161 | def get_entry_point(mbpp, task_id): 162 | idx = mbpp["task_id"].index(task_id) 163 | assert_statement = mbpp["test_list"][idx][0] 164 | assert_statement = assert_statement[len("assert ") :] 165 | lparen_idx = assert_statement.index("(") 166 | entrypoint = assert_statement[:lparen_idx] 167 | return entrypoint 168 | 169 | 170 | def get_dict_list(filename: str) -> List[Dict]: 171 | output_list = [] 172 | if filename.endswith(".gz"): 173 | with open(filename, "rb") as gzfp: 174 | with gzip.open(gzfp, "rt") as fp: 175 | for line in fp: 176 | if any(not x.isspace() for x in line): 177 | output_list.append(json.loads(line)) 178 | elif filename.endswith(".jsonl"): 179 | with open(filename, "r") as fp: 180 | for line in fp: 181 | if any(not x.isspace() for x in line): 182 | output_list.append(json.loads(line)) 183 | elif filename.endswith(".csv"): 184 | d = load_dataset("csv", data_files={"train": filename})["train"] 185 | for i in range(len(d[d.column_names[0]])): 186 | output_list.append({col: d[col][i] for col in d.column_names}) 187 | else: 188 | raise ValueError(f"Unrecognized file extension type for file {filename}!") 189 | return output_list 190 | 191 | 192 | def truncate_code(completion, prompt): 193 | if isinstance(completion, list): 194 | completion = completion[0] 195 | 196 | # if code is refinement, remove everything else before it. 197 | if "REFINEMENT:" in completion or "Refinement:\n" in completion: 198 | refinement_str = ( 199 | "REFINEMENT:" if "REFINEMENT:" in completion else "Refinement:\n" 200 | ) 201 | ref_end_idx = completion.rfind(refinement_str) + len(refinement_str) 202 | completion = completion[ref_end_idx:] 203 | 204 | if not completion.startswith(prompt): 205 | # completion doesn't start with exact prompt for some reason, even though it should 206 | # return early 207 | return completion 208 | 209 | # Remove prompt first so that we can fix the indentation of the completion. 210 | code = completion[len(prompt) :] 211 | 212 | # sometimes indentation on the first line is messed up 213 | if not code.startswith(" "): 214 | # find the first line 215 | eo_fl_idx = code.find("\n") 216 | first_line = code[:eo_fl_idx].strip() 217 | first_line = " " + first_line 218 | code = first_line + code[eo_fl_idx:] 219 | 220 | # Find end of function and truncate there 221 | eof_m = re.search(r'\n[A-Za-z#"]+?', code) 222 | if eof_m is not None: 223 | code = code[: eof_m.start() + 1] 224 | 225 | # Now re-add the prompt 226 | code = prompt + code 227 | completion = code 228 | return completion 229 | 230 | 231 | def eval_samples(args): 232 | ks = [int(elem) for elem in args.k.split(",")] 233 | output_file_prefix = args.input_file + f"_{args.file_suffix}" 234 | ext = args.input_file.split(".")[-1] 235 | output_file = f"{output_file_prefix}.{ext}" 236 | output_summ_file = f"{output_file_prefix}_summary.{ext}" 237 | 238 | mbpp = load_dataset("mbpp") 239 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()]) 240 | samples = get_dict_list(args.input_file) 241 | for sample_dict in tqdm(samples, desc="Evaluating and scoring..."): 242 | completion = sample_dict["completion"] 243 | prompt = sample_dict[args.prompt_column_name] 244 | completion = truncate_code(completion, prompt) 245 | entrypoint = get_entry_point(mbpp, sample_dict["task_id"]) 246 | test_str = format_test(mbpp, entrypoint, sample_dict["task_id"]) 247 | try: 248 | p, r = eval_code_wrapper(completion, test_str, entrypoint) 249 | except Exception as e: 250 | with io.StringIO() as f: 251 | traceback.print_exception(*sys.exc_info(), file=f) 252 | r = f.getvalue() 253 | p = False 254 | print(f"Caught exception from eval_code: {e}\n{r}") 255 | sample_dict["passed"] = p 256 | sample_dict["result"] = r 257 | num_corr_results = compute_results(samples) 258 | pass_at_k_results = compute_pass_at_ks(num_corr_results, ks) 259 | at_least_one_correct, _ = compute_at_least_one_pass_per_task(num_corr_results) 260 | pc_one_correct = at_least_one_correct / len(num_corr_results.keys()) 261 | pass_at_k_results["% tasks with at least one passed completion"] = pc_one_correct 262 | print(pass_at_k_results) 263 | 264 | with open(output_file, "w") as f: 265 | for d in samples: 266 | f.write(json.dumps(d) + "\n") 267 | with open(output_summ_file, "w") as f: 268 | f.write(json.dumps(pass_at_k_results)) 269 | 270 | 271 | def main(args): 272 | argsdict = vars(args) 273 | print(pprint.pformat(argsdict)) 274 | eval_samples(args) 275 | 276 | 277 | if __name__ == "__main__": 278 | main(parse_args()) 279 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | Fine-tuning CodeGen models on the input data. 5 | Adapted from a HuggingFace transformers example for training seq2seq models. 6 | 7 | Assumes that CodeGen model checkpoints are stored in {data_args.codegen_repo}/codegen-[6B|16B]-mono. 8 | """ 9 | import os 10 | 11 | import sys 12 | 13 | import logging 14 | import torch 15 | from dataclasses import dataclass, field 16 | from typing import Dict, List, Optional 17 | 18 | import datasets 19 | from datasets import load_dataset, load_metric, DatasetDict 20 | 21 | from jaxformer.hf import sample # from the CodeGen repository 22 | from jaxformer.hf.codegen import modeling_codegen # from the CodeGen repository 23 | 24 | import transformers 25 | from transformers import ( 26 | DataCollatorForSeq2Seq, 27 | HfArgumentParser, 28 | Seq2SeqTrainer, 29 | Seq2SeqTrainingArguments, 30 | set_seed, 31 | ) 32 | from transformers.trainer_utils import ( 33 | get_last_checkpoint, 34 | ) 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | @dataclass 40 | class ModelArguments: 41 | """ 42 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 43 | """ 44 | 45 | model_name_or_path: str = field( 46 | default=None, metadata={"help": "Can be codegen-16B, or codegen-6B."} 47 | ) 48 | config_name: Optional[str] = field( 49 | default=None, 50 | metadata={ 51 | "help": "Pretrained config name or path if not the same as model_name" 52 | }, 53 | ) 54 | cache_dir: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "Path to directory to store the pretrained models downloaded from huggingface.co" 58 | }, 59 | ) 60 | use_fast_tokenizer: bool = field( 61 | default=True, 62 | metadata={ 63 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 64 | }, 65 | ) 66 | model_revision: str = field( 67 | default="main", 68 | metadata={ 69 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 70 | }, 71 | ) 72 | use_auth_token: bool = field( 73 | default=False, 74 | metadata={ 75 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 76 | "with private models)." 77 | }, 78 | ) 79 | parallelize: bool = field( 80 | default=False, 81 | ) 82 | 83 | 84 | @dataclass 85 | class DataTrainingArguments: 86 | """ 87 | Arguments pertaining to what data we are going to input our model for training and eval. 88 | """ 89 | 90 | codegen_repo: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "Path to the cloned SalesForce codegen repo."}, 93 | ) 94 | dataset_name: Optional[str] = field( 95 | default=None, 96 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 97 | ) 98 | dataset_config_name: Optional[str] = field( 99 | default=None, 100 | metadata={ 101 | "help": "The configuration name of the dataset to use (via the datasets library)." 102 | }, 103 | ) 104 | prompt_column: Optional[str] = field( 105 | default="finetuning_prompt", 106 | metadata={ 107 | "help": "The name of the column in the datasets containing the task prompt." 108 | }, 109 | ) 110 | completion_column: Optional[str] = field( 111 | default="finetuning_completion", 112 | metadata={ 113 | "help": "The name of the column in the datasets containing the refinement of the code." 114 | }, 115 | ) 116 | train_file: Optional[str] = field( 117 | default=None, 118 | metadata={"help": "The input training data file (a text file)."}, 119 | ) 120 | validation_file: Optional[str] = field( 121 | default=None, 122 | metadata={ 123 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." 124 | }, 125 | ) 126 | test_file: Optional[str] = field( 127 | default=None, 128 | metadata={ 129 | "help": "An optional input test data file to evaluate the perplexity on (a text file)." 130 | }, 131 | ) 132 | overwrite_cache: bool = field( 133 | default=False, 134 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 135 | ) 136 | preprocessing_num_workers: Optional[int] = field( 137 | default=None, 138 | metadata={"help": "The number of processes to use for the preprocessing."}, 139 | ) 140 | max_seq_length: int = field( 141 | default=1024, 142 | metadata={ 143 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 144 | "than this will be truncated, sequences shorter will be padded." 145 | }, 146 | ) 147 | max_answer_length: int = field( 148 | default=1024, 149 | metadata={ 150 | "help": "The maximum length of an answer that can be generated. This is needed because the start " 151 | "and end predictions are not conditioned on one another." 152 | }, 153 | ) 154 | val_max_answer_length: Optional[int] = field( 155 | default=None, 156 | metadata={ 157 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 158 | "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`." 159 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 160 | "during ``evaluate`` and ``predict``." 161 | }, 162 | ) 163 | pad_to_max_length: bool = field( 164 | default=True, 165 | metadata={ 166 | "help": "Whether to pad all samples to `max_seq_length`. " 167 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " 168 | "be faster on GPU but will be slower on TPU)." 169 | }, 170 | ) 171 | max_train_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 175 | "value if set." 176 | }, 177 | ) 178 | max_eval_samples: Optional[int] = field( 179 | default=None, 180 | metadata={ 181 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 182 | "value if set." 183 | }, 184 | ) 185 | max_predict_samples: Optional[int] = field( 186 | default=None, 187 | metadata={ 188 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 189 | "value if set." 190 | }, 191 | ) 192 | version_2_with_negative: bool = field( 193 | default=False, 194 | metadata={"help": "If true, some of the examples do not have an answer."}, 195 | ) 196 | null_score_diff_threshold: float = field( 197 | default=0.0, 198 | metadata={ 199 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than " 200 | "the score of the null answer minus this threshold, the null answer is selected for this example. " 201 | "Only useful when `version_2_with_negative=True`." 202 | }, 203 | ) 204 | doc_stride: int = field( 205 | default=128, 206 | metadata={ 207 | "help": "When splitting up a long document into chunks, how much stride to take between chunks." 208 | }, 209 | ) 210 | n_best_size: int = field( 211 | default=20, 212 | metadata={ 213 | "help": "The total number of n-best predictions to generate when looking for an answer." 214 | }, 215 | ) 216 | num_beams: Optional[int] = field( 217 | default=5, 218 | metadata={ 219 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 220 | "which is used during ``evaluate`` and ``predict``." 221 | }, 222 | ) 223 | ignore_pad_token_for_loss: bool = field( 224 | default=True, 225 | metadata={ 226 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 227 | }, 228 | ) 229 | 230 | def __post_init__(self): 231 | if ( 232 | self.dataset_name is None 233 | and self.train_file is None 234 | and self.validation_file is None 235 | and self.test_file is None 236 | ): 237 | raise ValueError( 238 | "Need either a dataset name or a training/validation file/test_file." 239 | ) 240 | else: 241 | if self.train_file is not None: 242 | extension = self.train_file.split(".")[-1] 243 | assert extension in [ 244 | "csv", 245 | "json", 246 | "jsonl", 247 | ], "`train_file` should be a csv or a json file." 248 | if self.validation_file is not None: 249 | extension = self.validation_file.split(".")[-1] 250 | assert extension in [ 251 | "csv", 252 | "json", 253 | ], "`validation_file` should be a csv or a json file." 254 | if self.test_file is not None: 255 | extension = self.test_file.split(".")[-1] 256 | assert extension in [ 257 | "csv", 258 | "json", 259 | ], "`test_file` should be a csv or a json file." 260 | if self.val_max_answer_length is None: 261 | self.val_max_answer_length = self.max_answer_length 262 | 263 | 264 | question_answering_column_name_mapping = { 265 | "squad_v2": ("question", "context", "answer"), 266 | } 267 | 268 | 269 | def main(): 270 | # See all possible arguments in src/transformers/training_args.py 271 | # or by passing the --help flag to this script. 272 | # We now keep distinct sets of args, for a cleaner separation of concerns. 273 | 274 | parser = HfArgumentParser( 275 | (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments) 276 | ) 277 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 278 | # If we pass only one argument to the script and it's the path to a json file, 279 | # let's parse it to get our arguments. 280 | model_args, data_args, training_args = parser.parse_json_file( 281 | json_file=os.path.abspath(sys.argv[1]) 282 | ) 283 | else: 284 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 285 | 286 | # Setup logging 287 | logging.basicConfig( 288 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 289 | datefmt="%m/%d/%Y %H:%M:%S", 290 | handlers=[logging.StreamHandler(sys.stdout)], 291 | ) 292 | 293 | log_level = training_args.get_process_log_level() 294 | logger.setLevel(log_level) 295 | datasets.utils.logging.set_verbosity(log_level) 296 | transformers.utils.logging.set_verbosity(log_level) 297 | transformers.utils.logging.enable_default_handler() 298 | transformers.utils.logging.enable_explicit_format() 299 | 300 | # Log on each process the small summary: 301 | logger.warning( 302 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 303 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 304 | ) 305 | logger.info(f"Training/evaluation parameters {training_args}") 306 | 307 | # Detecting last checkpoint. 308 | last_checkpoint = None 309 | if ( 310 | os.path.isdir(training_args.output_dir) 311 | and training_args.do_train 312 | and not training_args.overwrite_output_dir 313 | ): 314 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 315 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 316 | raise ValueError( 317 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 318 | "Use --overwrite_output_dir to overcome." 319 | ) 320 | elif ( 321 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 322 | ): 323 | logger.info( 324 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 325 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 326 | ) 327 | 328 | # Set seed before initializing model. 329 | set_seed(training_args.seed) 330 | 331 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 332 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 333 | # (the dataset will be downloaded automatically from the datasets Hub). 334 | # 335 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 336 | # 'text' is found. You can easily tweak this behavior (see below). 337 | # 338 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 339 | # download the dataset. 340 | if data_args.dataset_name is not None: 341 | # Downloading and loading a dataset from the hub. 342 | raw_datasets = load_dataset( 343 | data_args.dataset_name, 344 | data_args.dataset_config_name, 345 | cache_dir=model_args.cache_dir, 346 | ) 347 | else: 348 | data_files = {} 349 | if data_args.train_file is not None: 350 | data_files["train"] = data_args.train_file 351 | extension = data_args.train_file.split(".")[-1] 352 | if extension == "jsonl": 353 | extension = "json" 354 | 355 | if data_args.validation_file is not None: 356 | data_files["validation"] = data_args.validation_file 357 | extension = data_args.validation_file.split(".")[-1] 358 | if data_args.test_file is not None: 359 | data_files["test"] = data_args.test_file 360 | extension = data_args.test_file.split(".")[-1] 361 | # raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 362 | if extension == "json": 363 | raw_datasets = DatasetDict.from_json(data_files) 364 | else: 365 | raw_datasets = DatasetDict.from_csv(data_files) 366 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 367 | # https://huggingface.co/docs/datasets/loading_datasets.html. 368 | 369 | # Load pretrained model and tokenizer 370 | # 371 | # Distributed training: 372 | # The .from_pretrained methods guarantee that only one local process can concurrently 373 | # download model & vocab. 374 | 375 | if model_args.model_name_or_path.startswith("codegen-"): 376 | if last_checkpoint is not None: 377 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 378 | last_checkpoint, low_cpu_mem_usage=True 379 | ) 380 | else: 381 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 382 | f"{data_args.codegen_repo}/{model_args.model_name_or_path}-mono", 383 | low_cpu_mem_usage=True, 384 | ) 385 | ## IMPORTANT: DO NOT REMOVE 386 | model = model.to(torch.float32) 387 | 388 | tokenizer = sample.create_custom_gpt2_tokenizer() 389 | # tokenizer.padding_side = 'left' 390 | tokenizer.pad_token = 50256 391 | if model_args.parallelize: 392 | model.parallelize() 393 | else: 394 | model = model.cuda() 395 | else: 396 | raise ValueError( 397 | f"{model_args.model_name_or_path} is not a valid model name or path." 398 | ) 399 | 400 | model.resize_token_embeddings(len(tokenizer)) 401 | 402 | # Preprocessing the datasets. 403 | # We need to generate and tokenize inputs and targets. 404 | if training_args.do_train: 405 | column_names = list(raw_datasets["train"].features.keys()) 406 | elif training_args.do_eval: 407 | column_names = list(raw_datasets["validation"].features.keys()) 408 | elif training_args.do_predict: 409 | column_names = list(raw_datasets["test"].features.keys()) 410 | else: 411 | logger.info( 412 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." 413 | ) 414 | return 415 | 416 | # Get the column names for input/target. 417 | dataset_columns = question_answering_column_name_mapping.get( 418 | data_args.dataset_name, None 419 | ) 420 | if data_args.prompt_column is None: 421 | prompt_column = ( 422 | dataset_columns[0] if dataset_columns is not None else column_names[0] 423 | ) 424 | else: 425 | prompt_column = data_args.prompt_column 426 | if prompt_column not in column_names: 427 | raise ValueError( 428 | f"--prompt_column' value '{data_args.prompt_column}' needs to be one of: {', '.join(column_names)}" 429 | ) 430 | if data_args.completion_column is None: 431 | completion_column = ( 432 | dataset_columns[2] if dataset_columns is not None else column_names[2] 433 | ) 434 | else: 435 | completion_column = data_args.completion_column 436 | if completion_column not in column_names: 437 | raise ValueError( 438 | f"--completion_column' value '{data_args.completion_column}' needs to be one of: {', '.join(column_names)}" 439 | ) 440 | 441 | # Temporarily set max_answer_length for training. 442 | max_answer_length = data_args.max_answer_length 443 | padding = "max_length" if data_args.pad_to_max_length else False 444 | 445 | if training_args.label_smoothing_factor > 0 and not hasattr( 446 | model, "prepare_decoder_input_ids_from_labels" 447 | ): 448 | logger.warning( 449 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 450 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 451 | ) 452 | 453 | if data_args.max_seq_length > tokenizer.model_max_length: 454 | logger.warning( 455 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 456 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 457 | ) 458 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 459 | 460 | def truncate(ex, tokenizer, max_length): 461 | return tokenizer.decode( 462 | tokenizer(ex, max_length=max_length, truncation=True).input_ids 463 | ) 464 | 465 | def preprocess_example(example): 466 | input_str = truncate(example[prompt_column], tokenizer, max_seq_length) 467 | r = example[completion_column] 468 | input_token_ids = tokenizer.encode(input_str, verbose=False) 469 | target_token_ids = tokenizer.encode(r, verbose=False) + [tokenizer.eos_token_id] 470 | input_ids = input_token_ids + target_token_ids 471 | labels_input_ids = ([-100] * len(input_token_ids)) + target_token_ids 472 | 473 | if len(input_ids) > max_seq_length: 474 | input_ids = input_ids[:max_seq_length] 475 | labels_input_ids = labels_input_ids[:max_seq_length] 476 | return { 477 | "input_ids": torch.IntTensor(input_ids).cuda(), 478 | "labels": torch.IntTensor(labels_input_ids).cuda(), 479 | } 480 | 481 | if training_args.do_train: 482 | if "train" not in raw_datasets: 483 | raise ValueError("--do_train requires a train dataset") 484 | train_dataset = raw_datasets["train"] 485 | if data_args.max_train_samples is not None: 486 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 487 | train_dataset = train_dataset.select(range(max_train_samples)) 488 | with training_args.main_process_first(desc="train dataset map pre-processing"): 489 | train_dataset = train_dataset.map( 490 | preprocess_example, 491 | remove_columns=column_names, 492 | ) 493 | if data_args.max_train_samples is not None: 494 | # Number of samples might increase during Feature Creation, We select only specified max samples 495 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 496 | train_dataset = train_dataset.select(range(max_train_samples)) 497 | 498 | # Data collator 499 | label_pad_token_id = ( 500 | -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 501 | ) 502 | data_collator = DataCollatorForSeq2Seq( 503 | tokenizer, 504 | model=model, 505 | label_pad_token_id=label_pad_token_id, 506 | pad_to_multiple_of=8 if training_args.fp16 else None, 507 | ) 508 | 509 | # Initialize our Trainer 510 | trainer = Seq2SeqTrainer( 511 | model=model, 512 | args=training_args, 513 | train_dataset=train_dataset if training_args.do_train else None, 514 | tokenizer=tokenizer, 515 | data_collator=data_collator, 516 | ) 517 | 518 | old_collator = trainer.data_collator 519 | trainer.data_collator = lambda data: dict(old_collator(data)) 520 | 521 | # Training 522 | if training_args.do_train: 523 | train_result = trainer.train() 524 | trainer.save_model() # Saves the tokenizer too for easy upload 525 | 526 | metrics = train_result.metrics 527 | max_train_samples = ( 528 | data_args.max_train_samples 529 | if data_args.max_train_samples is not None 530 | else len(train_dataset) 531 | ) 532 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 533 | 534 | trainer.log_metrics("train", metrics) 535 | trainer.save_metrics("train", metrics) 536 | trainer.save_state() 537 | 538 | 539 | def _mp_fn(index): 540 | # For xla_spawn (TPUs) 541 | main() 542 | 543 | 544 | if __name__ == "__main__": 545 | main() 546 | -------------------------------------------------------------------------------- /finetune_refinement_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Fine-tuning transformers models to generate refinements given old code and NL feedback. 4 | Adapted from a HuggingFace transformers example for training seq2seq models. 5 | 6 | Assumes that CodeGen model checkpoints are stored in {model_args.codegen_model_dir}/codegen-[6B|16B]-mono. 7 | """ 8 | import os 9 | 10 | 11 | import sys 12 | import logging 13 | import json 14 | import torch 15 | from dataclasses import dataclass, field 16 | from typing import Dict, List, Optional, Tuple 17 | 18 | import datasets 19 | from datasets import load_dataset, load_metric 20 | 21 | from jaxformer.hf import sample 22 | from jaxformer.hf.codegen import modeling_codegen 23 | 24 | from tqdm import tqdm 25 | 26 | import transformers 27 | from transformers import ( 28 | DataCollatorForSeq2Seq, 29 | HfArgumentParser, 30 | Seq2SeqTrainer, 31 | Seq2SeqTrainingArguments, 32 | set_seed, 33 | ) 34 | from transformers.trainer_utils import ( 35 | get_last_checkpoint, 36 | ) 37 | from transformers.utils import check_min_version 38 | from transformers.utils.versions import require_version 39 | 40 | from torch.utils.data import Dataset 41 | 42 | # Will error if the minimal version of Transformers is not installed. 43 | check_min_version("4.12.5") 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | @dataclass 49 | class ModelArguments: 50 | """ 51 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 52 | """ 53 | 54 | codegen_model_dir: Optional[str] = field( 55 | default="checkpoints", 56 | metadata={ 57 | "help": "Path to directory containing CodeGen model checkpoints." 58 | "Assumes the model checkpoints are stored in {codegen_model_dir}/." 59 | }, 60 | ) 61 | model_name_or_path: str = field( 62 | default=None, metadata={"help": "Can be codegen-16B or codegen-6B."} 63 | ) 64 | config_name: Optional[str] = field( 65 | default=None, 66 | metadata={ 67 | "help": "Pretrained config name or path if not the same as model_name" 68 | }, 69 | ) 70 | cache_dir: Optional[str] = field( 71 | default=None, 72 | metadata={ 73 | "help": "Path to directory to store the pretrained models downloaded from huggingface.co" 74 | }, 75 | ) 76 | use_fast_tokenizer: bool = field( 77 | default=True, 78 | metadata={ 79 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 80 | }, 81 | ) 82 | model_revision: str = field( 83 | default="main", 84 | metadata={ 85 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 86 | }, 87 | ) 88 | use_auth_token: bool = field( 89 | default=False, 90 | metadata={ 91 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 92 | "with private models)." 93 | }, 94 | ) 95 | parallelize: bool = field( 96 | default=False, 97 | ) 98 | 99 | 100 | @dataclass 101 | class DataTrainingArguments: 102 | """ 103 | Arguments pertaining to what data we are going to input our model for training and eval. 104 | """ 105 | 106 | dataset_name: Optional[str] = field( 107 | default=None, 108 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 109 | ) 110 | dataset_config_name: Optional[str] = field( 111 | default=None, 112 | metadata={ 113 | "help": "The configuration name of the dataset to use (via the datasets library)." 114 | }, 115 | ) 116 | feedback_column: Optional[str] = field( 117 | default="Feedback", 118 | metadata={ 119 | "help": "The name of the column in the datasets containing the NL feedback (for code refinement)." 120 | }, 121 | ) 122 | question_column: Optional[str] = field( 123 | default="completion", 124 | metadata={ 125 | "help": "The name of the column in the datasets containing the original task description and code." 126 | }, 127 | ) 128 | refinement_column: Optional[str] = field( 129 | default="Refinement", 130 | metadata={ 131 | "help": "The name of the column in the datasets containing the refinement of the code." 132 | }, 133 | ) 134 | train_file: Optional[str] = field( 135 | default=None, 136 | metadata={"help": "The input training data file (a text file)."}, 137 | ) 138 | validation_file: Optional[str] = field( 139 | default=None, 140 | metadata={ 141 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." 142 | }, 143 | ) 144 | test_file: Optional[str] = field( 145 | default=None, 146 | metadata={ 147 | "help": "An optional input test data file to evaluate the perplexity on (a text file)." 148 | }, 149 | ) 150 | overwrite_cache: bool = field( 151 | default=False, 152 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 153 | ) 154 | preprocessing_num_workers: Optional[int] = field( 155 | default=None, 156 | metadata={"help": "The number of processes to use for the preprocessing."}, 157 | ) 158 | max_seq_length: int = field( 159 | default=1024, 160 | metadata={ 161 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 162 | "than this will be truncated, sequences shorter will be padded." 163 | }, 164 | ) 165 | max_answer_length: int = field( 166 | default=1024, 167 | metadata={ 168 | "help": "The maximum length of an answer that can be generated. This is needed because the start " 169 | "and end predictions are not conditioned on one another." 170 | }, 171 | ) 172 | val_max_answer_length: Optional[int] = field( 173 | default=None, 174 | metadata={ 175 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 176 | "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`." 177 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 178 | "during ``evaluate`` and ``predict``." 179 | }, 180 | ) 181 | pad_to_max_length: bool = field( 182 | default=True, 183 | metadata={ 184 | "help": "Whether to pad all samples to `max_seq_length`. " 185 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " 186 | "be faster on GPU but will be slower on TPU)." 187 | }, 188 | ) 189 | max_train_samples: Optional[int] = field( 190 | default=None, 191 | metadata={ 192 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 193 | "value if set." 194 | }, 195 | ) 196 | max_eval_samples: Optional[int] = field( 197 | default=None, 198 | metadata={ 199 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 200 | "value if set." 201 | }, 202 | ) 203 | max_predict_samples: Optional[int] = field( 204 | default=None, 205 | metadata={ 206 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 207 | "value if set." 208 | }, 209 | ) 210 | version_2_with_negative: bool = field( 211 | default=False, 212 | metadata={"help": "If true, some of the examples do not have an answer."}, 213 | ) 214 | null_score_diff_threshold: float = field( 215 | default=0.0, 216 | metadata={ 217 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than " 218 | "the score of the null answer minus this threshold, the null answer is selected for this example. " 219 | "Only useful when `version_2_with_negative=True`." 220 | }, 221 | ) 222 | doc_stride: int = field( 223 | default=128, 224 | metadata={ 225 | "help": "When splitting up a long document into chunks, how much stride to take between chunks." 226 | }, 227 | ) 228 | n_best_size: int = field( 229 | default=20, 230 | metadata={ 231 | "help": "The total number of n-best predictions to generate when looking for an answer." 232 | }, 233 | ) 234 | num_beams: Optional[int] = field( 235 | default=5, 236 | metadata={ 237 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 238 | "which is used during ``evaluate`` and ``predict``." 239 | }, 240 | ) 241 | ignore_pad_token_for_loss: bool = field( 242 | default=True, 243 | metadata={ 244 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 245 | }, 246 | ) 247 | 248 | def __post_init__(self): 249 | if ( 250 | self.dataset_name is None 251 | and self.train_file is None 252 | and self.validation_file is None 253 | and self.test_file is None 254 | ): 255 | raise ValueError( 256 | "Need either a dataset name or a training/validation file/test_file." 257 | ) 258 | else: 259 | if self.train_file is not None: 260 | extension = self.train_file.split(".")[-1] 261 | assert extension in [ 262 | "csv", 263 | "json", 264 | "jsonl", 265 | ], "`train_file` should be a csv or a json file." 266 | if self.validation_file is not None: 267 | extension = self.validation_file.split(".")[-1] 268 | assert extension in [ 269 | "csv", 270 | "json", 271 | ], "`validation_file` should be a csv or a json file." 272 | if self.test_file is not None: 273 | extension = self.test_file.split(".")[-1] 274 | assert extension in [ 275 | "csv", 276 | "json", 277 | ], "`test_file` should be a csv or a json file." 278 | if self.val_max_answer_length is None: 279 | self.val_max_answer_length = self.max_answer_length 280 | 281 | 282 | def main(): 283 | # See all possible arguments by passing the --help flag to this script. 284 | parser = HfArgumentParser( 285 | (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments) 286 | ) 287 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 288 | # If we pass only one argument to the script and it's the path to a json file, 289 | # let's parse it to get our arguments. 290 | model_args, data_args, training_args = parser.parse_json_file( 291 | json_file=os.path.abspath(sys.argv[1]) 292 | ) 293 | else: 294 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 295 | 296 | # Setup logging 297 | logging.basicConfig( 298 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 299 | datefmt="%m/%d/%Y %H:%M:%S", 300 | handlers=[logging.StreamHandler(sys.stdout)], 301 | ) 302 | 303 | log_level = training_args.get_process_log_level() 304 | logger.setLevel(log_level) 305 | datasets.utils.logging.set_verbosity(log_level) 306 | transformers.utils.logging.set_verbosity(log_level) 307 | transformers.utils.logging.enable_default_handler() 308 | transformers.utils.logging.enable_explicit_format() 309 | 310 | logger.warning( 311 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 312 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 313 | ) 314 | logger.info(f"Training/evaluation parameters {training_args}") 315 | 316 | # Detecting last checkpoint. 317 | last_checkpoint = None 318 | if ( 319 | os.path.isdir(training_args.output_dir) 320 | and training_args.do_train 321 | and not training_args.overwrite_output_dir 322 | ): 323 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 324 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 325 | raise ValueError( 326 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 327 | "Use --overwrite_output_dir to overcome." 328 | ) 329 | elif ( 330 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 331 | ): 332 | logger.info( 333 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 334 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 335 | ) 336 | 337 | set_seed(training_args.seed) 338 | 339 | if data_args.dataset_name is not None: 340 | # Downloading and loading a dataset from the hub. 341 | raw_datasets = load_dataset( 342 | data_args.dataset_name, 343 | data_args.dataset_config_name, 344 | cache_dir=model_args.cache_dir, 345 | ) 346 | else: 347 | data_files = {} 348 | if data_args.train_file is not None: 349 | data_files["train"] = data_args.train_file 350 | extension = data_args.train_file.split(".")[-1] 351 | if extension == "jsonl": 352 | extension = "json" 353 | 354 | if data_args.validation_file is not None: 355 | data_files["validation"] = data_args.validation_file 356 | extension = data_args.validation_file.split(".")[-1] 357 | if data_args.test_file is not None: 358 | data_files["test"] = data_args.test_file 359 | extension = data_args.test_file.split(".")[-1] 360 | raw_datasets = load_dataset( 361 | extension, data_files=data_files, cache_dir=model_args.cache_dir 362 | ) 363 | 364 | if model_args.model_name_or_path.startswith("codegen-"): 365 | if last_checkpoint is not None: 366 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 367 | last_checkpoint, low_cpu_mem_usage=True 368 | ) 369 | else: 370 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 371 | f"{model_args.codegen_model_dir}/{model_args.model_name_or_path}-mono", 372 | low_cpu_mem_usage=True, 373 | ) 374 | ## IMPORTANT: DO NOT REMOVE 375 | model = model.to(torch.float32) 376 | 377 | tokenizer = sample.create_custom_gpt2_tokenizer() 378 | tokenizer.pad_token = 50256 379 | if model_args.parallelize: 380 | model.parallelize() 381 | else: 382 | model = model.cuda() 383 | else: 384 | raise ValueError( 385 | f"{model_args.model_name_or_path} is not a valid model name or path." 386 | ) 387 | 388 | model.resize_token_embeddings(len(tokenizer)) 389 | 390 | if training_args.do_train: 391 | column_names = raw_datasets["train"].column_names 392 | elif training_args.do_eval: 393 | column_names = raw_datasets["validation"].column_names 394 | elif training_args.do_predict: 395 | column_names = raw_datasets["test"].column_names 396 | else: 397 | logger.info( 398 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." 399 | ) 400 | return 401 | 402 | # Get the column names for input/target. 403 | if data_args.question_column is None: 404 | question_column = column_names[0] 405 | else: 406 | question_column = data_args.question_column 407 | if question_column not in column_names: 408 | raise ValueError( 409 | f"--question_column' value '{data_args.question_column}' needs to be one of: {', '.join(column_names)}" 410 | ) 411 | if data_args.feedback_column is None: 412 | feedback_column = column_names[1] 413 | else: 414 | feedback_column = data_args.feedback_column 415 | if feedback_column not in column_names: 416 | raise ValueError( 417 | f"--feedback_column' value '{data_args.feedback_column}' needs to be one of: {', '.join(column_names)}" 418 | ) 419 | if data_args.refinement_column is None: 420 | refinement_column = column_names[2] 421 | else: 422 | refinement_column = data_args.refinement_column 423 | if refinement_column not in column_names: 424 | raise ValueError( 425 | f"--refinement_column' value '{data_args.refinement_column}' needs to be one of: {', '.join(column_names)}" 426 | ) 427 | 428 | # Temporarily set max_answer_length for training. 429 | max_answer_length = data_args.max_answer_length 430 | padding = "max_length" if data_args.pad_to_max_length else False 431 | 432 | if training_args.label_smoothing_factor > 0 and not hasattr( 433 | model, "prepare_decoder_input_ids_from_labels" 434 | ): 435 | logger.warning( 436 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 437 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 438 | ) 439 | 440 | if data_args.max_seq_length > tokenizer.model_max_length: 441 | logger.warning( 442 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 443 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 444 | ) 445 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 446 | 447 | def truncate(ex, tokenizer, max_length): 448 | return tokenizer.decode( 449 | tokenizer(ex, max_length=max_length, truncation=True).input_ids 450 | ) 451 | 452 | def preprocess_example(example): 453 | # Encode prompt prefix and suffix 454 | f = example[feedback_column] 455 | input_prefix = "OLD CODE:\n" 456 | prefix_encoded = tokenizer.encode(input_prefix, verbose=False) 457 | input_suffix = f"\n\nFEEDBACK:\n{f}\n\nREFINEMENT:\n" 458 | suffix_encoded = tokenizer.encode(input_suffix, verbose=False) 459 | 460 | # Encode the refinement 461 | r = example[refinement_column] 462 | target_token_ids = tokenizer.encode(r, verbose=False) + [tokenizer.eos_token_id] 463 | 464 | # We only truncate the old code 465 | q_max_length = ( 466 | max_seq_length 467 | - len(prefix_encoded) 468 | - len(suffix_encoded) 469 | - len(target_token_ids) 470 | ) 471 | q_encoded = tokenizer.encode(example[question_column], verbose=False)[ 472 | :q_max_length 473 | ] 474 | input_token_ids = prefix_encoded + q_encoded + suffix_encoded 475 | 476 | # Combine everything 477 | input_ids = input_token_ids + target_token_ids 478 | labels_input_ids = ([-100] * len(input_token_ids)) + target_token_ids 479 | 480 | if len(input_ids) > max_seq_length: 481 | input_ids = input_ids[:max_seq_length] 482 | labels_input_ids = labels_input_ids[:max_seq_length] 483 | return { 484 | "input_ids": torch.IntTensor(input_ids).cuda(), 485 | "labels": torch.IntTensor(labels_input_ids).cuda(), 486 | } 487 | 488 | if training_args.do_train: 489 | if "train" not in raw_datasets: 490 | raise ValueError("--do_train requires a train dataset") 491 | train_dataset = raw_datasets["train"] 492 | if data_args.max_train_samples is not None: 493 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 494 | train_dataset = train_dataset.select(range(max_train_samples)) 495 | with training_args.main_process_first(desc="train dataset map pre-processing"): 496 | train_dataset = train_dataset.filter( 497 | lambda e: e["Refinement"] is not None and e["Refinement"] 498 | ).map( 499 | preprocess_example, 500 | num_proc=data_args.preprocessing_num_workers, 501 | remove_columns=column_names, 502 | load_from_cache_file=not data_args.overwrite_cache, 503 | desc="Running tokenizer on train dataset", 504 | ) 505 | if data_args.max_train_samples is not None: 506 | # Number of samples might increase during Feature Creation, We select only specified max samples 507 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 508 | train_dataset = train_dataset.select(range(max_train_samples)) 509 | 510 | # Data collator 511 | label_pad_token_id = ( 512 | -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 513 | ) 514 | data_collator = DataCollatorForSeq2Seq( 515 | tokenizer, 516 | model=model, 517 | label_pad_token_id=label_pad_token_id, 518 | pad_to_multiple_of=8 if training_args.fp16 else None, 519 | ) 520 | 521 | # Initialize our Trainer 522 | trainer = Seq2SeqTrainer( 523 | model=model, 524 | args=training_args, 525 | train_dataset=train_dataset if training_args.do_train else None, 526 | tokenizer=tokenizer, 527 | data_collator=data_collator, 528 | ) 529 | 530 | old_collator = trainer.data_collator 531 | trainer.data_collator = lambda data: dict(old_collator(data)) 532 | 533 | # Training 534 | if training_args.do_train: 535 | train_result = trainer.train() 536 | trainer.save_model() 537 | 538 | metrics = train_result.metrics 539 | max_train_samples = ( 540 | data_args.max_train_samples 541 | if data_args.max_train_samples is not None 542 | else len(train_dataset) 543 | ) 544 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 545 | 546 | trainer.log_metrics("train", metrics) 547 | trainer.save_metrics("train", metrics) 548 | trainer.save_state() 549 | 550 | 551 | def _mp_fn(index): 552 | # For xla_spawn (TPUs) 553 | main() 554 | 555 | 556 | if __name__ == "__main__": 557 | main() 558 | -------------------------------------------------------------------------------- /generate_code_for_mbpp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import openai 5 | import os 6 | import pprint 7 | import re 8 | import time 9 | import torch 10 | 11 | from jaxformer.hf import sample 12 | from jaxformer.hf.codegen import modeling_codegen 13 | from datasets import load_dataset, concatenate_datasets 14 | from tqdm import tqdm 15 | 16 | 17 | def format_prompt(task_id, text, tests, sample_code, num_prompts): 18 | # Create prompt from scratch 19 | prompt = f'"""\n{text}\n\n' 20 | if num_prompts > 0: 21 | for i in range(num_prompts): 22 | example = tests[i].split("assert ")[-1].replace("==", "=") 23 | prompt += f">>> Example: {example}\n" 24 | 25 | # Add code prefix 26 | fn_name = tests[0].split("assert ")[-1].split("(")[0] 27 | fn_search = re.search(f"def {fn_name}\(.*\):", sample_code) 28 | if fn_search is None: 29 | raise ValueError( 30 | f"Could not find 'def {fn_name}\(.*\):' in code for task {task_id}." 31 | ) 32 | code_prefix = sample_code[: fn_search.end()] 33 | prompt = f'{prompt}"""\n\n{code_prefix}\n' 34 | return prompt 35 | 36 | 37 | # GPT-J 38 | def sample_code_from_gpt_models(args, prompt, model, tokenizer): 39 | output_strs = [] 40 | num_samples = args.num_samples 41 | temperature = args.temperature 42 | debug = args.debug 43 | try: 44 | with torch.no_grad(): 45 | input_ids = ( 46 | torch.LongTensor(tokenizer.encode(prompt, verbose=False)) 47 | .unsqueeze(0) 48 | .cuda() 49 | ) 50 | output_ids = model.generate( 51 | input_ids, 52 | do_sample=True, 53 | temperature=temperature, # 0.2, 0.8 54 | max_length=1024 - len(input_ids), 55 | num_return_sequences=num_samples, 56 | ) 57 | output_strs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 58 | if debug: 59 | print(f"Input: {prompt}") 60 | print(f"Outputs: {output_strs}") 61 | except Exception as e: 62 | if ( 63 | isinstance(e, UnboundLocalError) 64 | and str(e) == "local variable 'next_tokens' referenced before assignment" 65 | ): 66 | # See https://github.com/huggingface/transformers/issues/5118 67 | if debug: 68 | print("Problem text was > 1024 tokens, so cannot do generation") 69 | print(e) 70 | print(e) 71 | return output_strs 72 | 73 | 74 | def sample_code_from_codegen(args, prompt, model, tokenizer): 75 | device = "cuda:0" 76 | completions = [] 77 | input_ids = tokenizer( 78 | prompt, truncation=True, max_length=1024, return_tensors="pt" 79 | ).input_ids.cuda() 80 | if args.temperature == 0.0: 81 | args.num_samples = 1 82 | for i in range(args.num_samples): 83 | try: 84 | # Note: max_length is max length of input IDs, and max_length_sample is max length for completion (not including input IDs) 85 | if args.temperature > 0: 86 | tokens = model.generate( 87 | input_ids, 88 | do_sample=True, 89 | num_return_sequences=1, 90 | max_length=input_ids.shape[1] + 1024, 91 | temperature=args.temperature, 92 | use_cache=True, 93 | ) 94 | else: 95 | tokens = model.generate( 96 | input_ids, 97 | num_return_sequences=1, 98 | max_length=input_ids.shape[1] + 1024, 99 | use_cache=True, 100 | ) 101 | text = tokenizer.decode(tokens[0]) 102 | if "<|endoftext|>" in text: 103 | text = text[: text.find("<|endoftext|>")] 104 | completions.append(text) 105 | except RuntimeError as e: 106 | logging.error(f"Could not sample from model: {e}") 107 | return completions 108 | 109 | 110 | def initialize_openai(args): 111 | api_key = open(f"{args.openai_creds_dir}/openai_api_key.txt").read() 112 | openai.organization = open( 113 | f"{args.openai_creds_dir}/openai_organization_id.txt" 114 | ).read() 115 | openai.api_key = api_key 116 | 117 | 118 | def sample_code_from_openai_model(args, prompt_text): 119 | output_strs = [] 120 | start = time.time() 121 | 122 | arch_mapping = { 123 | "codex": "code-davinci-002", 124 | "gpt3": "text-davinci-001", 125 | "davinci-002": "text-davinci-002", 126 | "davinci-003": "text-davinci-003", 127 | "ada": "text-ada-001", 128 | "babbage": "text-babbage-001", 129 | "curie": "text-curie-001", 130 | } 131 | engine_name = arch_mapping[args.arch] 132 | 133 | for i in range(args.num_samples): 134 | while time.time() - start < args.max_request_time: 135 | try: 136 | response = openai.Completion.create( 137 | engine=engine_name, 138 | prompt=prompt_text, 139 | max_tokens=1024, 140 | n=1, 141 | temperature=args.temperature, 142 | ) 143 | output_strs += [ 144 | prompt_text + choice["text"] for choice in response["choices"] 145 | ] 146 | break 147 | except Exception as e: 148 | print( 149 | f"Unexpected exception in generating solution. Sleeping again: {e}" 150 | ) 151 | time.sleep(args.sleep_time) 152 | return output_strs 153 | 154 | 155 | def write_jsonl(data, output_filepath): 156 | with open(output_filepath, "w") as f: 157 | for row in data: 158 | f.write(json.dumps(row) + "\n") 159 | 160 | 161 | def generate_code_for_problems(args): 162 | mbpp = load_dataset("mbpp") 163 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()]) 164 | 165 | output = [] 166 | if args.arch in ["gpt3", "codex"]: 167 | initialize_openai(args) 168 | generate_code_fn = sample_code_from_openai_model 169 | elif args.arch in ["codegen-6B", "codegen-16B"]: 170 | if args.model_path is None: 171 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 172 | f"{args.codegen_model_dir}/{args.arch}-mono", 173 | revision="float16", 174 | torch_dtype=torch.float16, 175 | low_cpu_mem_usage=True, 176 | ).cuda() 177 | else: 178 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 179 | args.model_path, low_cpu_mem_usage=True, torch_dtype=torch.float32 180 | ).cuda() 181 | tokenizer = sample.create_custom_gpt2_tokenizer(truncation_side="left") 182 | tokenizer.padding_side = "left" 183 | tokenizer.pad_token = 50256 184 | generate_code_fn = lambda args, prompt: sample_code_from_codegen( 185 | args, prompt, model, tokenizer 186 | ) 187 | 188 | task_ids_range = set(range(args.start, args.end)) 189 | for i in tqdm(range(len(mbpp))): 190 | if mbpp["task_id"][i] not in task_ids_range: 191 | continue 192 | try: 193 | prompt = format_prompt( 194 | mbpp["task_id"][i], 195 | mbpp["text"][i], 196 | mbpp["test_list"][i], 197 | mbpp["code"][i], 198 | args.num_shots, 199 | ) 200 | except ValueError as e: 201 | logging.error(e) 202 | continue 203 | 204 | task_id = mbpp["task_id"][i] 205 | for completion in generate_code_fn(args, prompt): 206 | output.append( 207 | { 208 | "task_id": task_id, 209 | "prompt": prompt, 210 | "completion": completion, 211 | } 212 | ) 213 | return output 214 | 215 | 216 | def parse_args(): 217 | parser = argparse.ArgumentParser( 218 | description="Run a trained model to generate Python code for the MBPP benchmark." 219 | ) 220 | parser.add_argument( 221 | "--arch", 222 | default="gptj", 223 | choices=[ 224 | "gptj", 225 | "codex", 226 | "gpt3", 227 | "codegen-16B", 228 | "codegen-6B", 229 | "davinci-002", 230 | "davinci-003", 231 | "ada", 232 | "babbage", 233 | "curie", 234 | ], 235 | ) 236 | parser.add_argument( 237 | "--codegen-model-dir", 238 | default="checkpoints", 239 | help="Directory where pre-trained CodeGen model checkpoints are saved.", 240 | ) 241 | parser.add_argument( 242 | "--model-path", 243 | default=None, 244 | help="Directory to load model checkpoint from. If None, will load a pre-trained " 245 | "CodeGen model using the --arch argument instead.", 246 | ) 247 | parser.add_argument("--num-samples", default=1, type=int) 248 | parser.add_argument("-d", "--debug", action="store_true") 249 | parser.add_argument("--output-dir", type=str) 250 | parser.add_argument("--output-file-suffix", type=str, default="") 251 | parser.add_argument("--temperature", default=0.8, type=float) 252 | parser.add_argument( 253 | "--split", 254 | default="test", 255 | type=str, 256 | help="Which MBPP split to use. In datasets v1.16.1, MBPP only has the split 'test'.", 257 | ) 258 | parser.add_argument( 259 | "-s", "--start", default=1, type=int, help="Task ID to start with." 260 | ) 261 | parser.add_argument( 262 | "-e", "--end", default=975, type=int, help="Task ID to end with (exclusive)." 263 | ) 264 | parser.add_argument( 265 | "-n", 266 | "--num-shots", 267 | default=0, 268 | type=int, 269 | help="Number of assert (test examples) to give in the task description.", 270 | ) 271 | parser.add_argument( 272 | "--max-request-time", 273 | type=int, 274 | default=80, 275 | help="Max. time to wait for a successful GPT-3 request.", 276 | ) 277 | parser.add_argument( 278 | "--sleep-time", 279 | type=int, 280 | default=10, 281 | help="Time to sleep (in seconds) between each GPT-3 call.", 282 | ) 283 | parser.add_argument( 284 | "--openai-creds-dir", 285 | type=str, 286 | default=None, 287 | help="Directory where OpenAI API credentials are stored. Assumes the presence of " 288 | "openai_api_key.txt and openai_organization_id.txt files.", 289 | ) 290 | args = parser.parse_args() 291 | return args 292 | 293 | 294 | def main(args): 295 | argsdict = vars(args) 296 | print(pprint.pformat(argsdict)) 297 | completions = generate_code_for_problems(args) 298 | output_filepath = os.path.join( 299 | args.output_dir, 300 | f"samples_{args.split}_{args.arch}_{args.num_shots}shot_temp{args.temperature}_{args.start}-{args.end}{args.output_file_suffix}.jsonl", 301 | ) 302 | write_jsonl(completions, output_filepath) 303 | 304 | 305 | if __name__ == "__main__": 306 | main(parse_args()) 307 | -------------------------------------------------------------------------------- /generate_refinements_codegen_finetuned.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from datasets import Dataset, load_dataset, concatenate_datasets 3 | from jaxformer.hf.codegen import modeling_codegen 4 | from jaxformer.hf import sample 5 | import torch 6 | import pprint 7 | import os 8 | import logging 9 | import json 10 | import csv 11 | import argparse 12 | import re 13 | 14 | 15 | def load_jsonl(filepath): 16 | data = [json.loads(line) for line in open(filepath).readlines()] 17 | fields = data[0].keys() 18 | data_dict = {k: [x[k] for x in data] for k in fields} 19 | ds = Dataset.from_dict(data_dict) 20 | return ds 21 | 22 | 23 | def load_csv(filepath): 24 | data = list(csv.DictReader(open(filepath))) 25 | fields = data[0].keys() 26 | data_dict = {k: [x[k] for x in data] for k in fields} 27 | ds = Dataset.from_dict(data_dict) 28 | return ds 29 | 30 | 31 | def load_feedback(feedback_path): 32 | extension = "csv" if feedback_path.endswith("csv") else "json" 33 | if extension == "json": 34 | d = load_jsonl(feedback_path) 35 | else: 36 | d = load_csv(feedback_path) 37 | d = d.map( 38 | lambda _, idx: {"row_id": idx}, 39 | with_indices=True, 40 | ) 41 | d = d.filter( 42 | lambda example: example["Refinement"] is not None and example["Refinement"] 43 | ) 44 | return d 45 | 46 | 47 | def sample_code_from_codegen(args, prompt, model, tokenizer): 48 | device = "cuda:0" 49 | completions = [] 50 | print(f"Tokenizing input: {prompt}") 51 | input_ids = tokenizer( 52 | prompt, truncation=True, max_length=1024, return_tensors="pt" 53 | ).input_ids.cuda() 54 | if args.temperature == 0.0: 55 | args.num_samples = 1 56 | for i in range(args.num_samples): 57 | try: 58 | # Note: max_length is max length of input IDs, and max_length_sample is max length for completion (not including input IDs) 59 | if args.temperature > 0: 60 | tokens = model.generate( 61 | input_ids, 62 | do_sample=True, 63 | num_return_sequences=1, 64 | max_length=input_ids.shape[1] + 1024, 65 | temperature=args.temperature, 66 | use_cache=True, 67 | ) 68 | else: 69 | tokens = model.generate( 70 | input_ids, 71 | num_return_sequences=1, 72 | max_length=input_ids.shape[1] + 1024, 73 | use_cache=True, 74 | ) 75 | text = tokenizer.decode(tokens[0]) 76 | if "<|endoftext|>" in text: 77 | text = text[: text.find("<|endoftext|>")] 78 | completions.append(text) 79 | except RuntimeError as e: 80 | logging.error(f"Could not sample from model: {e}") 81 | return completions 82 | 83 | 84 | def truncate(ex, tokenizer, max_length): 85 | return tokenizer.decode( 86 | tokenizer(ex, max_length=max_length, truncation=True).input_ids 87 | ) 88 | 89 | 90 | def format_mbpp_prompt(mbpp, task_id): 91 | idx = mbpp["task_id"].index(task_id) 92 | text = mbpp["text"][idx] 93 | tests = mbpp["test_list"][idx] 94 | sample_code = mbpp["code"][idx] 95 | 96 | # Create prompt from scratch 97 | prompt = f'"""\n{text}\n\n' 98 | # Add the first unit test as an input-output example 99 | example = tests[0].split("assert ")[-1].replace("==", "=") 100 | prompt += f">>> Example: {example}\n" 101 | 102 | # Add code prefix 103 | fn_name = tests[0].split("assert ")[-1].split("(")[0] 104 | fn_search = re.search(f"def {fn_name}\(.*\):", sample_code) 105 | if fn_search is None: 106 | raise ValueError( 107 | f"Could not find 'def {fn_name}\(.*\):' in code for task {task_id}." 108 | ) 109 | code_prefix = sample_code[: fn_search.end()] 110 | prompt = f'{prompt}"""\n\n{code_prefix}\n' 111 | return prompt 112 | 113 | 114 | def gen_refinement_prompt(args, example, tokenizer, mbpp): 115 | prompt = ( 116 | f"OLD CODE:\n{truncate(example[args.completion_column], tokenizer, 512)}" 117 | f"\n\nFEEDBACK:\n{example['Feedback']}\n\n" 118 | f"REFINEMENT:\n{format_mbpp_prompt(mbpp, example['task_id'])}" 119 | ) 120 | return prompt 121 | 122 | 123 | def gen_code(args, data, model, tokenizer): 124 | mbpp = load_dataset("mbpp") 125 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()]) 126 | output = data.map( 127 | lambda ex: {"input_str": gen_refinement_prompt(args, ex, tokenizer, mbpp)} 128 | ) 129 | output = output.map( 130 | lambda ex: { 131 | "output_strs": sample_code_from_codegen( 132 | args, ex["input_str"], model, tokenizer 133 | ) 134 | }, 135 | desc="Sampling code from codegen...", 136 | ) 137 | return output 138 | 139 | 140 | def generate_code_for_problems(args): 141 | data = load_feedback(args.feedback_file) 142 | 143 | if args.model_path is None: 144 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 145 | f"{args.codegen_model_dir}/{args.arch}-mono", 146 | revision="float16", 147 | torch_dtype=torch.float16, 148 | low_cpu_mem_usage=True, 149 | ).cuda() 150 | else: 151 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained( 152 | args.model_path, low_cpu_mem_usage=True, torch_dtype=torch.float32 153 | ).cuda() 154 | tokenizer = sample.create_custom_gpt2_tokenizer() 155 | tokenizer.pad_token = 50256 156 | val = gen_code(args, data, model, tokenizer) 157 | 158 | output = [] 159 | for row in tqdm(val): 160 | for completion in row["output_strs"]: 161 | output.append( 162 | { 163 | "task_id": row["task_id"], 164 | "prompt": row["input_str"], 165 | "feedback": row["Feedback"], 166 | "old_completion": row[args.completion_column], 167 | "completion": completion, 168 | } 169 | ) 170 | return output 171 | 172 | 173 | def write_jsonl(data, output_filepath): 174 | with open(output_filepath, "w") as f: 175 | for row in data: 176 | f.write(json.dumps(row) + "\n") 177 | 178 | 179 | def parse_args(): 180 | parser = argparse.ArgumentParser( 181 | description="Run a trained model to generate Python code for the MBPP benchmark." 182 | ) 183 | parser.add_argument( 184 | "--arch", default="codegen-6B", choices=["codegen-16B", "codegen-6B"] 185 | ) 186 | parser.add_argument( 187 | "--codegen-model-dir", 188 | default="checkpoints", 189 | help="Directory where pre-trained CodeGen model checkpoints are saved.", 190 | ) 191 | parser.add_argument( 192 | "--model-path", 193 | default=None, 194 | required=True, 195 | help="Directory to load model checkpoint from. If None, will load a pre-trained " 196 | "CodeGen model using the --arch argument instead.", 197 | ) 198 | parser.add_argument("--num-samples", default=1, type=int) 199 | parser.add_argument("-d", "--debug", action="store_true") 200 | parser.add_argument("--output-dir", type=str) 201 | parser.add_argument("--output-file-suffix", type=str, default="") 202 | parser.add_argument("--temperature", default=0.8, type=float) 203 | parser.add_argument( 204 | "--feedback-file", 205 | default=None, 206 | required=True, 207 | help="CSV file containing feedback and past completions.", 208 | ) 209 | parser.add_argument("--completion-column", default="completion") 210 | args = parser.parse_args() 211 | return args 212 | 213 | 214 | def main(args): 215 | argsdict = vars(args) 216 | print(pprint.pformat(argsdict)) 217 | completions = generate_code_for_problems(args) 218 | 219 | if args.model_path is None: 220 | output_filepath = os.path.join( 221 | args.output_dir, 222 | f"refinements_{args.arch}_temp{args.temperature}_{args.output_file_suffix}.jsonl", 223 | ) 224 | else: 225 | output_filepath = os.path.join( 226 | args.model_path, 227 | f"refinements_{args.arch}_temp{args.temperature}_{args.output_file_suffix}.jsonl", 228 | ) 229 | write_jsonl(completions, output_filepath) 230 | 231 | 232 | if __name__ == "__main__": 233 | main(parse_args()) 234 | -------------------------------------------------------------------------------- /ilf_for_code_gen.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/ILF-for-code-generation/1bbccca2934b26e2d8745e5afab65eb677cbe92a/ilf_for_code_gen.pdf -------------------------------------------------------------------------------- /ilf_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Assumes that the Codegen checkpoints are stored in a directory 4 | # named "checkpoints" that is a subdirectory of the current directory. 5 | 6 | FEEDBACK_COLUMN="Feedback" 7 | REFINEMENTS_COLUMN="Refinement" 8 | INPUT_FILE="surge_annotations.jsonl" 9 | LEARNING_RATE=5e-6 10 | GRADIENT_ACCUMULATION_STEPS=32 11 | NUM_OUTPUT_SAMPLES=30 12 | NUM_EPOCHS=2 13 | while getopts "i:f:r:n:l:g:o:e:d:" option; do 14 | case $option in 15 | i) # File containing all Surge annotations. Feedback is in column named via option -f, and refinements are under "unedited_annotator_completion". 16 | INPUT_FILE=$OPTARG;; 17 | f) # Name of feedback column 18 | FEEDBACK_COLUMN=$OPTARG;; 19 | r) # Name of refinements column that will be created in all intermediate outputs. 20 | REFINEMENTS_COLUMN=$OPTARG;; 21 | n) # Experiment name 22 | EXP_NAME=$OPTARG;; 23 | l) # Learning rate 24 | LEARNING_RATE=$OPTARG;; 25 | g) # Gradient accumulation steps. Determines effective batch size because the per-device train batch size is 1 26 | GRADIENT_ACCUMULATION_STEPS=$OPTARG;; 27 | o) # Number of final MBPP samples to output from the final fine-tuned CodeGen-6B model. 28 | NUM_OUTPUT_SAMPLES=$OPTARG;; 29 | e) # Number of epochs to train for. 30 | NUM_EPOCHS=$OPTARG;; 31 | d) # Parent directory to save results in. Experiment results will be saved in a subdirectory of this directory named ${EXP_NAME}. 32 | PARENT_DIR=$OPTARG;; 33 | \?) # Invalid option 34 | echo "Error: Invalid option ${option}" 35 | exit;; 36 | esac 37 | done 38 | 39 | TRAIN_START_TASK_ID=111 40 | TRAIN_END_TASK_ID=310 # inclusive 41 | TRAIN_N=$(( $TRAIN_END_TASK_ID - $TRAIN_START_TASK_ID + 1 )) 42 | VAL_START_TASK_ID=311 43 | VAL_END_TASK_ID=974 # inclusive 44 | VAL_N=$(( $VAL_END_TASK_ID - $VAL_START_TASK_ID + 1 )) 45 | TEST_START_TASK_ID=11 46 | TEST_END_TASK_ID=111 # (should be exclusive) 47 | 48 | CONDA_ENV="ilf" 49 | EXPERIMENT_DIR="${PARENT_DIR}/${EXP_NAME}" 50 | 51 | echo "Running with arguments -i=${INPUT_FILE}, -f=${FEEDBACK_COLUMN}, -r=${REFINEMENTS_COLUMN}," \ 52 | "-n=${EXP_NAME}, -l=${LEARNING_RATE}, -g=${GRADIENT_ACCUMULATION_STEPS}, -o=${NUM_OUTPUT_SAMPLES}," \ 53 | "-e=${NUM_EPOCHS}, -d=${PARENT_DIR}." 54 | echo "Outputting experiment results in ${EXPERIMENT_DIR}." 55 | 56 | conda deactivate 57 | conda activate ${CONDA_ENV} 58 | 59 | mkdir -p ${EXPERIMENT_DIR} 60 | python preprocess_feedback_spreadsheet.py --input_file=${INPUT_FILE} \ 61 | --model_completion_column=original_model_completion \ 62 | --old_refinement_column=unedited_annotator_completion \ 63 | --training_n=$TRAIN_N --val_n=$VAL_N \ 64 | --feedback_column=${FEEDBACK_COLUMN} --refinement_column=${REFINEMENTS_COLUMN} \ 65 | --one_per_task --filter_for_correct --output_dir=${EXPERIMENT_DIR} \ 66 | --training_start_id=${TRAIN_START_TASK_ID} --training_end_id=${TRAIN_END_TASK_ID} \ 67 | --val_start_id=${VAL_START_TASK_ID} --val_end_id=${VAL_END_TASK_ID} || exit 68 | OUTPUT_FILE_PREFIX=$(python -c "print(''.join('${INPUT_FILE}'.split('.')[:-1]).split('/')[-1])") 69 | OUTPUT_FILE_PREFIX=${EXPERIMENT_DIR}/${OUTPUT_FILE_PREFIX} 70 | REF_TRAINING_FILE="${OUTPUT_FILE_PREFIX}-train.jsonl" 71 | REF_VAL_FILE="${OUTPUT_FILE_PREFIX}-val.jsonl" 72 | 73 | echo "Training data for Pi_Ref: ${REF_TRAINING_FILE}" 74 | echo "Val data for Pi_Ref: ${REF_VAL_FILE}" 75 | 76 | # Fine-tune a model to generate refinements. 77 | # We trained with per-device batch size of 1 due to computational constraints 78 | # (but used gradient accumulation to reach the desired effective batch size). 79 | PI_REF_DIR="${EXPERIMENT_DIR}/mref_lr${LEARNING_RATE}_ga${GRADIENT_ACCUMULATION_STEPS}_${NUM_EPOCHS}epochs" 80 | CHECKPOINTS_DIR="$(pwd)/checkpoints" 81 | python finetune_refinement_model.py \ 82 | --do_train \ 83 | --codegen_model_dir=${CHECKPOINTS_DIR} \ 84 | --model_name_or_path=codegen-6B \ 85 | --num_train_epochs=${NUM_EPOCHS} \ 86 | --save_strategy=no \ 87 | --learning_rate=${LEARNING_RATE} \ 88 | --per_device_train_batch_size=1 \ 89 | --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 90 | --logging_steps=1 \ 91 | --output_dir ${PI_REF_DIR} \ 92 | --pad_to_max_length \ 93 | --generation_max_length=512 \ 94 | --max_seq_length=1024 \ 95 | --max_answer_length=512 \ 96 | --parallelize \ 97 | --overwrite_output_dir \ 98 | --save_total_limit=2 \ 99 | --feedback_column=${FEEDBACK_COLUMN} \ 100 | --refinement_column=${REFINEMENTS_COLUMN} \ 101 | --train_file ${REF_TRAINING_FILE} || exit 102 | 103 | # Generate refinements using Pi_Ref 104 | python generate_refinements_codegen_finetuned.py \ 105 | --arch=codegen-6B \ 106 | --codegen-model-dir=${CHECKPOINTS_DIR} \ 107 | --num-samples=${NUM_OUTPUT_SAMPLES} --output-dir=${PI_REF_DIR} \ 108 | --temperature=0.8 --feedback-file=${REF_VAL_FILE} \ 109 | --output-file-suffix=${EXP_NAME} \ 110 | --model-path=${PI_REF_DIR} || exit 111 | 112 | # Evaluate refinements generated for tasks in MBPP_Train, and 113 | # keep only the correct ones for training Pi_Theta 114 | python eval_mbpp.py \ 115 | --input-file=${PI_REF_DIR}/refinements_codegen-6B_temp0.8_${EXP_NAME}.jsonl \ 116 | --k=1,10 || exit 117 | python create_finetuning_data_from_refinements.py \ 118 | --one-per-task \ 119 | --refinement-file=${PI_REF_DIR}/refinements_codegen-6B_temp0.8_${EXP_NAME}.jsonl_results.jsonl \ 120 | --output-dir=${PI_REF_DIR} \ 121 | --output-file-suffix=surge_final || exit 122 | 123 | # Fine-tune two separate models: 124 | # 1) fine-tuned on MBPP gold data, 125 | # 2) fine-tuned on Pi_Refine-generated refinements 126 | TRAINING_FILE="${PI_REF_DIR}/finetuning_prompts_mbpp_refinements_surge_final.jsonl" 127 | GOLD_TRAINING_FILE="${PI_REF_DIR}/finetuning_prompts_mbpp_gold_surge_final.jsonl" 128 | # Fine-tune (1) 129 | FINAL_GOLD_FINETUNE_DIR=${EXPERIMENT_DIR}/final_gold_finetune_lr${LEARNING_RATE}_ga${GRADIENT_ACCUMULATION_STEPS}_${NUM_EPOCHS}epochs 130 | python finetune.py \ 131 | --codegen_repo=${CHECKPOINTS_DIR} \ 132 | --do_train \ 133 | --model_name_or_path=codegen-6B \ 134 | --save_strategy=no \ 135 | --num_train_epochs=${NUM_EPOCHS} \ 136 | --learning_rate=${LEARNING_RATE} \ 137 | --per_device_train_batch_size=1 \ 138 | --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 139 | --logging_steps=1 \ 140 | --output_dir ${FINAL_GOLD_FINETUNE_DIR} \ 141 | --parallelize \ 142 | --pad_to_max_length \ 143 | --generation_max_length=512 \ 144 | --max_seq_length=1024 \ 145 | --max_answer_length=512 \ 146 | --save_total_limit=2 \ 147 | --parallelize \ 148 | --prompt_column=finetuning_prompt \ 149 | --completion_column=finetuning_completion \ 150 | --overwrite_output_dir \ 151 | --train_file ${GOLD_TRAINING_FILE} || exit 152 | # Fine-tune (2) 153 | FINAL_FINETUNE_DIR=${EXPERIMENT_DIR}/final_finetune_lr${LEARNING_RATE}_ga${GRADIENT_ACCUMULATION_STEPS}_${NUM_EPOCHS}epochs 154 | python finetune.py \ 155 | --codegen_repo=${CHECKPOINTS_DIR} \ 156 | --do_train \ 157 | --model_name_or_path=codegen-6B \ 158 | --save_strategy=no \ 159 | --num_train_epochs=${NUM_EPOCHS} \ 160 | --learning_rate=${LEARNING_RATE} \ 161 | --per_device_train_batch_size=1 \ 162 | --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 163 | --logging_steps=1 \ 164 | --output_dir ${FINAL_FINETUNE_DIR} \ 165 | --parallelize \ 166 | --pad_to_max_length \ 167 | --generation_max_length=512 \ 168 | --max_seq_length=1024 \ 169 | --max_answer_length=512 \ 170 | --save_total_limit=2 \ 171 | --parallelize \ 172 | --prompt_column=finetuning_prompt \ 173 | --completion_column=finetuning_completion \ 174 | --overwrite_output_dir \ 175 | --train_file ${TRAINING_FILE} || exit 176 | 177 | # Evaluate models (1) and (2) on MBPP_Test 178 | ## First generate programs for MBPP_Test 179 | python generate_code_for_mbpp.py \ 180 | --codegen-model-dir=${CHECKPOINTS_DIR} \ 181 | --num-samples=${NUM_OUTPUT_SAMPLES} \ 182 | --output-dir=${FINAL_GOLD_FINETUNE_DIR} \ 183 | --arch=codegen-6B \ 184 | -n=1 \ 185 | --temperature=0.8 \ 186 | --debug -s ${TEST_START_TASK_ID} -e ${TEST_END_TASK_ID} \ 187 | --model-path=${FINAL_GOLD_FINETUNE_DIR} || exit 188 | python generate_code_for_mbpp.py \ 189 | --codegen-model-dir=${CHECKPOINTS_DIR} \ 190 | --num-samples=${NUM_OUTPUT_SAMPLES} \ 191 | --output-dir=${FINAL_FINETUNE_DIR} \ 192 | --arch=codegen-6B \ 193 | -n=1 \ 194 | --temperature=0.8 \ 195 | --debug -s ${TEST_START_TASK_ID} -e ${TEST_END_TASK_ID} \ 196 | --model-path=${FINAL_FINETUNE_DIR} || exit 197 | ## Now evaluate final generations 198 | python eval_mbpp.py \ 199 | --input-file=${FINAL_GOLD_FINETUNE_DIR}/samples_test_codegen-6B_1shot_temp0.8_${TEST_START_TASK_ID}-${TEST_END_TASK_ID}.jsonl \ 200 | --k=1,10 || exit 201 | python eval_mbpp.py \ 202 | --input-file=${FINAL_FINETUNE_DIR}/samples_test_codegen-6B_1shot_temp0.8_${TEST_START_TASK_ID}-${TEST_END_TASK_ID}.jsonl \ 203 | --k=1,10 || exit -------------------------------------------------------------------------------- /preprocess_feedback_spreadsheet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import Dataset, load_dataset 3 | 4 | 5 | def group_by_and_select_one(ds, group_by_col): 6 | df = ds.shuffle().to_pandas() 7 | df = df.groupby(group_by_col).first() 8 | ds = Dataset.from_pandas(df) 9 | return ds 10 | 11 | 12 | def truncate_completion(src): 13 | ref_str = "Refinement:\n" 14 | if ref_str in src: 15 | src = src[src.rfind(ref_str) + len(ref_str) :] 16 | return src 17 | 18 | 19 | def preprocess_data(args): 20 | orig_ext = args.input_file.split(".")[-1] 21 | if orig_ext not in ["csv", "json", "jsonl"]: 22 | raise ValueError(f"{ext} is not a supported file extension.") 23 | if orig_ext == "jsonl": 24 | ext = "json" 25 | else: 26 | ext = orig_ext 27 | d = load_dataset(ext, data_files={"train": args.input_file})["train"].filter( 28 | lambda ex: ex[args.feedback_column] is not None and ex[args.feedback_column] 29 | ) 30 | 31 | if args.old_refinement_column is not None: 32 | d = d.map( 33 | lambda ex: {args.refinement_column: ex[args.old_refinement_column]}, 34 | remove_columns=[args.old_refinement_column], 35 | ) 36 | 37 | d = d.map( 38 | lambda ex: {"completion": ex[args.model_completion_column]}, 39 | ) 40 | 41 | d = d.filter( 42 | lambda ex: ex[args.refinement_column] is not None and ex[args.refinement_column] 43 | ).map( 44 | lambda ex: { 45 | args.refinement_column: truncate_completion(ex[args.refinement_column]) 46 | } 47 | ) 48 | 49 | if args.filter_for_correct and "passed" in d.column_names: 50 | # Filter for correct ones only, if the column exists in the spreadsheet 51 | d = d.filter(lambda ex: ex["passed"]) 52 | 53 | if args.one_per_task: 54 | # Filter for just one sample per task ID. 55 | d = group_by_and_select_one(d, args.id_col) 56 | 57 | # Split data and print out filenames 58 | output_file_prefix = ".".join(args.input_file.split(".")[:-1]) 59 | if args.output_dir is not None: 60 | fname_prefix = output_file_prefix.split("/")[-1] 61 | output_file_prefix = f"{args.output_dir}/{fname_prefix}" 62 | 63 | df = d.to_pandas().set_index(args.id_col) 64 | train_df = df[ 65 | (df.index >= args.training_start_id) & (df.index <= args.training_end_id) 66 | ] 67 | train_n = min(len(train_df), args.training_n) 68 | train_df = train_df.sample(n=train_n) 69 | train_output_filepath = f"{output_file_prefix}-train.jsonl" 70 | train_df.reset_index().to_json(train_output_filepath, orient="records", lines=True) 71 | val_df = df[(df.index >= args.val_start_id) & (df.index <= args.val_end_id)] 72 | val_n = min(len(val_df), args.val_n) 73 | val_df = val_df.sample(n=val_n) 74 | val_output_filepath = f"{output_file_prefix}-val.jsonl" 75 | val_df.reset_index().to_json(val_output_filepath, orient="records", lines=True) 76 | print("\n".join([train_output_filepath, val_output_filepath])) 77 | 78 | 79 | def parse_args(): 80 | parser = argparse.ArgumentParser( 81 | description="Filter and pre-process CSV or JSONL input file containing feedback and refinements." 82 | ) 83 | parser.add_argument( 84 | "--input_file", 85 | default="", 86 | required=True, 87 | help="Input CSV or JSONL file containing feedback and refinements.", 88 | ) 89 | parser.add_argument( 90 | "--feedback_column", default="Feedback", help="Name of feedback column." 91 | ) 92 | parser.add_argument( 93 | "--old_refinement_column", 94 | default=None, 95 | help="If set, will change the column with this name to --refinement_column.", 96 | ) 97 | parser.add_argument( 98 | "--refinement_column", default="Refinement", help="Name of refinement column." 99 | ) 100 | parser.add_argument( 101 | "--model_completion_column", default="original_model_completion" 102 | ) 103 | parser.add_argument( 104 | "--training_n", 105 | default=None, 106 | type=int, 107 | help="Number of examples to be used for training data. If None, does not split data into train/val.", 108 | ) 109 | parser.add_argument( 110 | "--val_n", 111 | default=None, 112 | type=int, 113 | help="Number of examples to be used for validation data. If None, just uses all non-training examples as validation data.", 114 | ) 115 | parser.add_argument( 116 | "--id_col", 117 | type=str, 118 | default="task_id", 119 | help="Which column to index on and to split data by.", 120 | ) 121 | parser.add_argument( 122 | "--one_per_task", 123 | action="store_true", 124 | help="If set, then will filter only one sample per task.", 125 | ) 126 | parser.add_argument( 127 | "--filter_for_correct", 128 | action="store_true", 129 | help="Filter for only the rows for which passed=True. " 130 | + "(May want to keep off for feedback spreadsheets where the 'passed' column corresponds to the original model completion instead of the Refinement.)", 131 | ) 132 | parser.add_argument( 133 | "--training_start_id", 134 | type=int, 135 | default=601, 136 | ) 137 | parser.add_argument("--training_end_id", type=int, default=974) 138 | parser.add_argument("--val_start_id", type=int, default=511) 139 | parser.add_argument("--val_end_id", type=int, default=600) 140 | parser.add_argument( 141 | "--output_dir", 142 | type=str, 143 | default=None, 144 | help="Output directory. If None, outputs to the same directory that the input file is already in.", 145 | ) 146 | args = parser.parse_args() 147 | 148 | # if training_n is set, then val_n must also be set. 149 | assert (args.training_n is None) or ( 150 | args.val_n is not None 151 | ), "Error: if --training_n is set, then --val_n must also be set." 152 | return args 153 | 154 | 155 | def main(args): 156 | argsdict = vars(args) 157 | preprocess_data(args) 158 | 159 | 160 | if __name__ == "__main__": 161 | main(parse_args()) 162 | --------------------------------------------------------------------------------