├── .gitignore
├── LICENSE
├── README.md
├── create_finetuning_data_from_refinements.py
├── environment.yml
├── eval_mbpp.py
├── finetune.py
├── finetune_refinement_model.py
├── generate_code_for_mbpp.py
├── generate_refinements_codegen_finetuned.py
├── ilf_for_code_gen.pdf
├── ilf_pipeline.sh
├── preprocess_feedback_spreadsheet.py
└── surge_annotations.jsonl
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ML² AT CILVR
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving Code Generation by Training with Natural Language Feedback
2 | Authors: Angelica Chen, Jérémy Scheurer, Tomasz Korbak, Jon Ander Campos, Jun Shern Chan, Samuel R. Bowman, Kyunghyun Cho, Ethan Perez
3 |
4 | This repository contains the code and data (human-written feedback and refinements) for running the Imitation learning from Language Feedback (ILF) algorithm
5 | for code generation from "Improving Code Generation by Training with Natural Language Feedback" by [Chen et al. (2023)](https://arxiv.org/abs/2303.16749). This paper has since been superceded by our TMLR publication, ["Learning from Natural Language Feedback"](https://openreview.net/forum?id=xo3hI5MwvU).
6 |
7 |
8 |
9 |
10 |
11 | ## Installation
12 |
13 | Our code relies upon the [`jaxformer` repository](https://github.com/salesforce/jaxformer) and open-source [CodeGen-Mono checkpoints](https://github.com/salesforce/CodeGen).
14 |
15 | To install all dependencies and download the necessary model checkpoints:
16 | ```{bash}
17 | git clone git@github.com:nyu-mll/ILF-for-code-generation.git
18 | cd ILF-for-code-generation
19 | conda env create -f environment.yml
20 |
21 | # Install codegen repo and reset to old commit
22 | git clone git@github.com:salesforce/CodeGen.git
23 | cd CodeGen
24 | git reset --hard 9cc1f971c83ad606cce5da292d3c58523dd920a2
25 | git clean -df
26 | pip3 install -r requirements.txt
27 | cd ..
28 |
29 | # To download codegen-6B-mono
30 | wget -P checkpoints https://storage.googleapis.com/sfr-codegen-research/checkpoints/codegen-6B-mono.tar.gz && tar -xvf checkpoints/codegen-6B-mono.tar.gz -C checkpoints/
31 |
32 | ```
33 |
34 | In our paper we use the Codegen-Mono 6B checkpoint, but you can easily replace the above `wget` command with the download links for the [other CodeGen models](https://github.com/salesforce/CodeGen#sampling-with-repository).
35 |
36 | ## To run the ILF pipeline
37 | To run the ILF pipeline using our dataset, run (from this directory):
38 | ```{bash}
39 | source ilf_pipeline.sh -d $(pwd) -n
40 | ```
41 | with `` replaced with the name of the subdirectory that you wish to store results in.
42 |
43 | ## Citation
44 | ```
45 | @article{
46 | chen2024learning,
47 | title={Learning from Natural Language Feedback},
48 | author={Angelica Chen and J{\'e}r{\'e}my Scheurer and Jon Ander Campos and Tomasz Korbak and Jun Shern Chan and Samuel R. Bowman and Kyunghyun Cho and Ethan Perez},
49 | journal={Transactions on Machine Learning Research},
50 | issn={2835-8856},
51 | year={2024},
52 | url={https://openreview.net/forum?id=xo3hI5MwvU},
53 | note={}
54 | }
55 | ```
56 |
--------------------------------------------------------------------------------
/create_finetuning_data_from_refinements.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import re
4 |
5 | from datasets import Dataset, load_dataset, concatenate_datasets
6 |
7 |
8 | def format_prompt(mbpp, task_id):
9 | idx = mbpp["task_id"].index(task_id)
10 | text = mbpp["text"][idx]
11 | tests = mbpp["test_list"][idx]
12 | sample_code = mbpp["code"][idx]
13 |
14 | # Create prompt from scratch
15 | prompt = f'"""\n{text}\n\n'
16 | # Add the first unit test as an input-output example
17 | example = tests[0].split("assert ")[-1].replace("==", "=")
18 | prompt += f">>> Example: {example}\n"
19 |
20 | # Add code prefix
21 | fn_name = tests[0].split("assert ")[-1].split("(")[0]
22 | fn_search = re.search(f"def {fn_name}\(.*\):", sample_code)
23 | if fn_search is None:
24 | raise ValueError(
25 | f"Could not find 'def {fn_name}\(.*\):' in code for task {task_id}."
26 | )
27 | code_prefix = sample_code[: fn_search.end()]
28 | prompt = f'{prompt}"""\n\n{code_prefix}\n'
29 | return prompt
30 |
31 |
32 | def load_scored_data(feedback_path):
33 | d = load_dataset("json", data_files={"train": feedback_path})["train"].map(
34 | lambda _, idx: {"row_id": idx},
35 | with_indices=True,
36 | )
37 | print(f"Initial length of d: {len(d)}")
38 | d = d.filter(lambda example: example["passed"])
39 | print(f"Length of d after filtering for passed: {len(d)}")
40 | return d
41 |
42 |
43 | def dedupe_dataset(dataset):
44 | cols = dataset.column_names
45 | row_set = set()
46 | for ex in dataset:
47 | ex_tuple = tuple(ex[col] for col in cols)
48 | row_set.add(ex_tuple)
49 | deduped = {k: [row[i] for row in row_set] for i, k in enumerate(cols)}
50 | return Dataset.from_dict(deduped)
51 |
52 |
53 | def remove_prefix_and_func_sig(code, func_sig):
54 | if f"{func_sig}\r\n" in code:
55 | return code[code.rfind(f"{func_sig}\r\n") + len(f"{func_sig}\r\n") :]
56 | elif f"{func_sig} \r\n" in code:
57 | return code[code.rfind(f"{func_sig} \r\n") + len(f"{func_sig} \r\n") :]
58 | elif f"{func_sig}\n" in code:
59 | return code[code.rfind(f"{func_sig}\n") + len(f"{func_sig}\n") :]
60 | elif f"{func_sig}" in code:
61 | return code[code.rfind(f"{func_sig}") + len(f"{func_sig}") :]
62 | else:
63 | return code
64 |
65 |
66 | def get_completion(prompt, completion):
67 | """If 'REFINEMENT:' is in the completion, remove it. Also remove prompt prefix if present."""
68 | ref_str = "REFINEMENT:"
69 | if ref_str in completion:
70 | idx = completion.rfind(ref_str)
71 | completion = completion[idx + len(ref_str) :]
72 | if prompt in completion:
73 | idx = completion.rfind(prompt)
74 | completion = completion[idx + len(prompt) :]
75 | return completion
76 |
77 |
78 | def create_prompts(args):
79 | mbpp = load_dataset("mbpp")
80 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()])
81 | ref_data = load_scored_data(args.refinement_file)
82 | print(f"Length of scored data: {len(ref_data)}")
83 |
84 | # Get unique pairs of (task ID, prompt) from the scored refinements.
85 | tasks = set([(example["task_id"], example["prompt"]) for example in ref_data])
86 |
87 | if not args.no_output_gold_data:
88 | mbpp_ft_data = {
89 | "finetuning_prompt": [],
90 | "finetuning_completion": [],
91 | "task_id": [],
92 | }
93 | task_id_to_func_sig = {}
94 | for task_id, prompt in tasks:
95 | mbpp_idx = mbpp["task_id"].index(task_id)
96 |
97 | # Get the original reformatted MBPP prompt
98 | orig_prompt = format_prompt(mbpp, task_id)
99 |
100 | # Remove method signature prefix
101 | gold_code = mbpp["code"][mbpp_idx]
102 | sig_idx = prompt.rfind("def ")
103 | colon_idx = prompt.rfind(":")
104 | func_sig = prompt[sig_idx : colon_idx + 1]
105 | task_id_to_func_sig[task_id] = func_sig
106 | gold_code = remove_prefix_and_func_sig(gold_code, func_sig)
107 | if gold_code is None:
108 | logging.warning(
109 | f"Could not find function signature {func_sig} in gold code.\nGold code:\n{gold_code}"
110 | )
111 | continue
112 | mbpp_ft_data["finetuning_prompt"].append(orig_prompt)
113 | mbpp_ft_data["finetuning_completion"].append(gold_code)
114 | mbpp_ft_data["task_id"].append(task_id)
115 | mbpp_ft_data = Dataset.from_dict(mbpp_ft_data)
116 |
117 | if args.sample_size is not None:
118 | n = min(len(mbpp_ft_data), args.sample_size)
119 | mbpp_ft_data = mbpp_ft_data.shuffle().select(range(n))
120 | mbpp_ft_data.to_json(
121 | f"{args.output_dir}/finetuning_prompts_mbpp_gold_{args.output_file_suffix}.jsonl"
122 | )
123 |
124 | refs_ft_data = ref_data.map(
125 | lambda ex: {
126 | "finetuning_prompt": format_prompt(mbpp, ex["task_id"]),
127 | }
128 | ).map(
129 | lambda ex: {
130 | "finetuning_completion": get_completion(
131 | ex["finetuning_prompt"], ex["completion"]
132 | )
133 | }
134 | )
135 | cols_to_remove = list(
136 | set(refs_ft_data.column_names)
137 | - set(["task_id", "finetuning_prompt", "finetuning_completion"])
138 | )
139 | refs_ft_data = refs_ft_data.remove_columns(cols_to_remove)
140 | refs_ft_data = dedupe_dataset(refs_ft_data)
141 | if args.one_per_task:
142 | df = refs_ft_data.shuffle().to_pandas()
143 | df = df.groupby("task_id").first()
144 | refs_ft_data = Dataset.from_pandas(df)
145 |
146 | if args.sample_size is not None:
147 | n = min(len(refs_ft_data), args.sample_size)
148 | refs_ft_data = refs_ft_data.shuffle().select(range(n))
149 | refs_ft_data.to_json(
150 | f"{args.output_dir}/finetuning_prompts_mbpp_refinements_{args.output_file_suffix}.jsonl"
151 | )
152 |
153 |
154 | def parse_args(input_args):
155 | parser = argparse.ArgumentParser(
156 | description="Generate fine-tuning prompts from model-generated refinements. Also generate FT prompts for those same task IDs from the original MBPP dataset using gold code."
157 | )
158 | parser.add_argument(
159 | "--refinement-file",
160 | type=str,
161 | help="Path to file containing evaluated refinements. Needs to have the following columns: passed, task_id, prompt, completion.",
162 | )
163 | parser.add_argument(
164 | "--output-dir", type=str, help="Directory to output data files in."
165 | )
166 | parser.add_argument(
167 | "--no-output-gold-data",
168 | action="store_true",
169 | help="If set, will not output finetuning files for gold completions.",
170 | )
171 | parser.add_argument("--output-file-suffix", type=str, default="")
172 | parser.add_argument(
173 | "-n",
174 | "--sample-size",
175 | default=None,
176 | type=int,
177 | help="If set, will limit the number of outputs to this value.",
178 | )
179 | parser.add_argument(
180 | "--one-per-task",
181 | action="store_true",
182 | help="If set, will randomly select one correct refinement per task.",
183 | )
184 | args = parser.parse_args()
185 | return args
186 |
187 |
188 | def main():
189 | args = parse_args(None)
190 | create_prompts(args)
191 |
192 |
193 | if __name__ == "__main__":
194 | main()
195 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: ilf
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.6.0=hecad31d_10
8 | - pip=22.1.2
9 | - python=3.7.13=h12debd9_0
10 | - pytorch=1.12.1=py3.7_cuda11.6_cudnn8.3.2_0
11 | - pytorch-mutex=1.0=cuda
12 | - readline=8.1.2=h7f8727e_1
13 | - setuptools=63.4.1
14 | - pip:
15 | - argparse==1.4.0
16 | - datasets==2.7.1
17 | - evaluate==0.3.0
18 | - huggingface-hub==0.9.1
19 | - matplotlib==3.5.3
20 | - nltk==3.7
21 | - numpy==1.21.6
22 | - openai==0.23.0
23 | - pytest==7.2.2
24 | - python-dateutil==2.8.2
25 | - pytz==2022.2.1
26 | - regex==2022.9.13
27 | - sacremoses==0.0.53
28 | - scikit-learn==1.0.2
29 | - scipy==1.7.3
30 | - six==1.16.0
31 | - sklearn==0.0
32 | - timeout-decorator==0.5.0
33 | - tokenizers==0.10.3
34 | - tqdm==4.64.1
35 | - transformers==4.12.5
36 |
--------------------------------------------------------------------------------
/eval_mbpp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import io
4 | import itertools
5 | import json
6 | import pprint
7 | import numpy as np
8 | import re
9 | import sys
10 | import timeout_decorator
11 | import traceback
12 |
13 |
14 | from collections import defaultdict
15 | from datasets import concatenate_datasets, load_dataset
16 | from multiprocessing import Process, Queue
17 | from tqdm import tqdm
18 | from typing import Dict, List, Union
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(
23 | description="Evaluate model completions on the MBPP benchmark."
24 | )
25 | parser.add_argument(
26 | "--input-file",
27 | type=str,
28 | help="File containing columns , 'completion', and 'task_id'.",
29 | )
30 | parser.add_argument("--k", default="1,10")
31 | parser.add_argument("--file-suffix", default="results")
32 | parser.add_argument(
33 | "--prompt-column-name", default="prompt", help="Name of prompt column."
34 | )
35 | args = parser.parse_args()
36 | return args
37 |
38 |
39 | def estimate_pass_at_k(
40 | num_samples: Union[int, List[int], np.ndarray],
41 | num_correct: Union[List[int], np.ndarray],
42 | k: int,
43 | ) -> np.ndarray:
44 | """
45 | Estimates pass@k of each problem and returns them in an array.
46 | Taken from https://github.com/openai/human-eval/blob/master/human_eval/evaluation.py#L13.
47 | """
48 |
49 | def estimator(n: int, c: int, k: int) -> float:
50 | """
51 | Calculates 1 - comb(n - c, k) / comb(n, k).
52 | """
53 | if n - c < k:
54 | return 1.0
55 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
56 |
57 | if isinstance(num_samples, int):
58 | num_samples_it = itertools.repeat(num_samples, len(num_correct))
59 | else:
60 | assert len(num_samples) == len(num_correct)
61 | num_samples_it = iter(num_samples)
62 |
63 | return np.array(
64 | [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
65 | )
66 |
67 |
68 | def compute_results(eval_results):
69 | results = defaultdict(list)
70 | for row in eval_results:
71 | ti = row["task_id"]
72 | passed = row["passed"]
73 | results[ti].append(passed)
74 | outputs = {
75 | ti: {"num_correct": np.sum(r), "num_total": len(r)} for ti, r in results.items()
76 | }
77 | return outputs
78 |
79 |
80 | def compute_at_least_one_pass_per_task(results):
81 | total = 0
82 | task_ids = []
83 | for task_id, results_dict in results.items():
84 | if results_dict["num_correct"] > 0:
85 | total += 1
86 | task_ids.append(task_id)
87 | return total, task_ids
88 |
89 |
90 | def compute_pass_at_ks(results, ks):
91 | output = {
92 | k: estimate_pass_at_k(
93 | [x["num_total"] for _, x in results.items()],
94 | [x["num_correct"] for _, x in results.items()],
95 | k,
96 | ).mean()
97 | for k in ks
98 | }
99 | return output
100 |
101 |
102 | @timeout_decorator.timeout(3)
103 | def eval_code(q, src, test, entry_point):
104 | all_src = f"{src}\n{test}\ncheck({entry_point})\n"
105 | try:
106 | exec(all_src, {})
107 | except Exception:
108 | with io.StringIO() as f:
109 | traceback.print_exception(*sys.exc_info(), file=f)
110 | q.put((False, f.getvalue()))
111 | return
112 | q.put((True, None))
113 |
114 |
115 | def eval_code_wrapper(src, test, entry_point):
116 | queue = Queue()
117 | p = Process(target=eval_code, args=(queue, src, test, entry_point))
118 | p.start()
119 | p.join(3)
120 | if p.is_alive():
121 | p.kill()
122 | if not queue.empty():
123 | return queue.get()
124 | else:
125 | return False, f"Exit code: {p.exitcode}"
126 |
127 |
128 | def is_float(element: str) -> bool:
129 | try:
130 | float(element)
131 | return True
132 | except ValueError:
133 | return False
134 |
135 |
136 | def format_test(mbpp, entrypoint, task_id):
137 | idx = mbpp["task_id"].index(task_id)
138 | test_list = mbpp["test_list"][idx]
139 |
140 | test_str = "def check(candidate):\n"
141 |
142 | # use pytest.approx() for float results
143 | if is_float(test_list[0].split("==")[-1]):
144 | test_str = "from pytest import approx\n\n" + test_str
145 | for i in range(len(test_list)):
146 | split = test_list[i].split("==")
147 | split[-1] = f"approx({split[-1]})"
148 | test_list[i] = "==".join(split)
149 |
150 | for test in test_list:
151 | test_str += f"\t{test}\n"
152 | test_str += "\n"
153 |
154 | if entrypoint != "check":
155 | test_str = test_str.replace(entrypoint, "candidate")
156 | else:
157 | test_str = test_str.replace(f"assert {entrypoint}", "assert candidate")
158 | return test_str
159 |
160 |
161 | def get_entry_point(mbpp, task_id):
162 | idx = mbpp["task_id"].index(task_id)
163 | assert_statement = mbpp["test_list"][idx][0]
164 | assert_statement = assert_statement[len("assert ") :]
165 | lparen_idx = assert_statement.index("(")
166 | entrypoint = assert_statement[:lparen_idx]
167 | return entrypoint
168 |
169 |
170 | def get_dict_list(filename: str) -> List[Dict]:
171 | output_list = []
172 | if filename.endswith(".gz"):
173 | with open(filename, "rb") as gzfp:
174 | with gzip.open(gzfp, "rt") as fp:
175 | for line in fp:
176 | if any(not x.isspace() for x in line):
177 | output_list.append(json.loads(line))
178 | elif filename.endswith(".jsonl"):
179 | with open(filename, "r") as fp:
180 | for line in fp:
181 | if any(not x.isspace() for x in line):
182 | output_list.append(json.loads(line))
183 | elif filename.endswith(".csv"):
184 | d = load_dataset("csv", data_files={"train": filename})["train"]
185 | for i in range(len(d[d.column_names[0]])):
186 | output_list.append({col: d[col][i] for col in d.column_names})
187 | else:
188 | raise ValueError(f"Unrecognized file extension type for file {filename}!")
189 | return output_list
190 |
191 |
192 | def truncate_code(completion, prompt):
193 | if isinstance(completion, list):
194 | completion = completion[0]
195 |
196 | # if code is refinement, remove everything else before it.
197 | if "REFINEMENT:" in completion or "Refinement:\n" in completion:
198 | refinement_str = (
199 | "REFINEMENT:" if "REFINEMENT:" in completion else "Refinement:\n"
200 | )
201 | ref_end_idx = completion.rfind(refinement_str) + len(refinement_str)
202 | completion = completion[ref_end_idx:]
203 |
204 | if not completion.startswith(prompt):
205 | # completion doesn't start with exact prompt for some reason, even though it should
206 | # return early
207 | return completion
208 |
209 | # Remove prompt first so that we can fix the indentation of the completion.
210 | code = completion[len(prompt) :]
211 |
212 | # sometimes indentation on the first line is messed up
213 | if not code.startswith(" "):
214 | # find the first line
215 | eo_fl_idx = code.find("\n")
216 | first_line = code[:eo_fl_idx].strip()
217 | first_line = " " + first_line
218 | code = first_line + code[eo_fl_idx:]
219 |
220 | # Find end of function and truncate there
221 | eof_m = re.search(r'\n[A-Za-z#"]+?', code)
222 | if eof_m is not None:
223 | code = code[: eof_m.start() + 1]
224 |
225 | # Now re-add the prompt
226 | code = prompt + code
227 | completion = code
228 | return completion
229 |
230 |
231 | def eval_samples(args):
232 | ks = [int(elem) for elem in args.k.split(",")]
233 | output_file_prefix = args.input_file + f"_{args.file_suffix}"
234 | ext = args.input_file.split(".")[-1]
235 | output_file = f"{output_file_prefix}.{ext}"
236 | output_summ_file = f"{output_file_prefix}_summary.{ext}"
237 |
238 | mbpp = load_dataset("mbpp")
239 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()])
240 | samples = get_dict_list(args.input_file)
241 | for sample_dict in tqdm(samples, desc="Evaluating and scoring..."):
242 | completion = sample_dict["completion"]
243 | prompt = sample_dict[args.prompt_column_name]
244 | completion = truncate_code(completion, prompt)
245 | entrypoint = get_entry_point(mbpp, sample_dict["task_id"])
246 | test_str = format_test(mbpp, entrypoint, sample_dict["task_id"])
247 | try:
248 | p, r = eval_code_wrapper(completion, test_str, entrypoint)
249 | except Exception as e:
250 | with io.StringIO() as f:
251 | traceback.print_exception(*sys.exc_info(), file=f)
252 | r = f.getvalue()
253 | p = False
254 | print(f"Caught exception from eval_code: {e}\n{r}")
255 | sample_dict["passed"] = p
256 | sample_dict["result"] = r
257 | num_corr_results = compute_results(samples)
258 | pass_at_k_results = compute_pass_at_ks(num_corr_results, ks)
259 | at_least_one_correct, _ = compute_at_least_one_pass_per_task(num_corr_results)
260 | pc_one_correct = at_least_one_correct / len(num_corr_results.keys())
261 | pass_at_k_results["% tasks with at least one passed completion"] = pc_one_correct
262 | print(pass_at_k_results)
263 |
264 | with open(output_file, "w") as f:
265 | for d in samples:
266 | f.write(json.dumps(d) + "\n")
267 | with open(output_summ_file, "w") as f:
268 | f.write(json.dumps(pass_at_k_results))
269 |
270 |
271 | def main(args):
272 | argsdict = vars(args)
273 | print(pprint.pformat(argsdict))
274 | eval_samples(args)
275 |
276 |
277 | if __name__ == "__main__":
278 | main(parse_args())
279 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | """
4 | Fine-tuning CodeGen models on the input data.
5 | Adapted from a HuggingFace transformers example for training seq2seq models.
6 |
7 | Assumes that CodeGen model checkpoints are stored in {data_args.codegen_repo}/codegen-[6B|16B]-mono.
8 | """
9 | import os
10 |
11 | import sys
12 |
13 | import logging
14 | import torch
15 | from dataclasses import dataclass, field
16 | from typing import Dict, List, Optional
17 |
18 | import datasets
19 | from datasets import load_dataset, load_metric, DatasetDict
20 |
21 | from jaxformer.hf import sample # from the CodeGen repository
22 | from jaxformer.hf.codegen import modeling_codegen # from the CodeGen repository
23 |
24 | import transformers
25 | from transformers import (
26 | DataCollatorForSeq2Seq,
27 | HfArgumentParser,
28 | Seq2SeqTrainer,
29 | Seq2SeqTrainingArguments,
30 | set_seed,
31 | )
32 | from transformers.trainer_utils import (
33 | get_last_checkpoint,
34 | )
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | @dataclass
40 | class ModelArguments:
41 | """
42 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
43 | """
44 |
45 | model_name_or_path: str = field(
46 | default=None, metadata={"help": "Can be codegen-16B, or codegen-6B."}
47 | )
48 | config_name: Optional[str] = field(
49 | default=None,
50 | metadata={
51 | "help": "Pretrained config name or path if not the same as model_name"
52 | },
53 | )
54 | cache_dir: Optional[str] = field(
55 | default=None,
56 | metadata={
57 | "help": "Path to directory to store the pretrained models downloaded from huggingface.co"
58 | },
59 | )
60 | use_fast_tokenizer: bool = field(
61 | default=True,
62 | metadata={
63 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
64 | },
65 | )
66 | model_revision: str = field(
67 | default="main",
68 | metadata={
69 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."
70 | },
71 | )
72 | use_auth_token: bool = field(
73 | default=False,
74 | metadata={
75 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
76 | "with private models)."
77 | },
78 | )
79 | parallelize: bool = field(
80 | default=False,
81 | )
82 |
83 |
84 | @dataclass
85 | class DataTrainingArguments:
86 | """
87 | Arguments pertaining to what data we are going to input our model for training and eval.
88 | """
89 |
90 | codegen_repo: Optional[str] = field(
91 | default=None,
92 | metadata={"help": "Path to the cloned SalesForce codegen repo."},
93 | )
94 | dataset_name: Optional[str] = field(
95 | default=None,
96 | metadata={"help": "The name of the dataset to use (via the datasets library)."},
97 | )
98 | dataset_config_name: Optional[str] = field(
99 | default=None,
100 | metadata={
101 | "help": "The configuration name of the dataset to use (via the datasets library)."
102 | },
103 | )
104 | prompt_column: Optional[str] = field(
105 | default="finetuning_prompt",
106 | metadata={
107 | "help": "The name of the column in the datasets containing the task prompt."
108 | },
109 | )
110 | completion_column: Optional[str] = field(
111 | default="finetuning_completion",
112 | metadata={
113 | "help": "The name of the column in the datasets containing the refinement of the code."
114 | },
115 | )
116 | train_file: Optional[str] = field(
117 | default=None,
118 | metadata={"help": "The input training data file (a text file)."},
119 | )
120 | validation_file: Optional[str] = field(
121 | default=None,
122 | metadata={
123 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
124 | },
125 | )
126 | test_file: Optional[str] = field(
127 | default=None,
128 | metadata={
129 | "help": "An optional input test data file to evaluate the perplexity on (a text file)."
130 | },
131 | )
132 | overwrite_cache: bool = field(
133 | default=False,
134 | metadata={"help": "Overwrite the cached training and evaluation sets"},
135 | )
136 | preprocessing_num_workers: Optional[int] = field(
137 | default=None,
138 | metadata={"help": "The number of processes to use for the preprocessing."},
139 | )
140 | max_seq_length: int = field(
141 | default=1024,
142 | metadata={
143 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
144 | "than this will be truncated, sequences shorter will be padded."
145 | },
146 | )
147 | max_answer_length: int = field(
148 | default=1024,
149 | metadata={
150 | "help": "The maximum length of an answer that can be generated. This is needed because the start "
151 | "and end predictions are not conditioned on one another."
152 | },
153 | )
154 | val_max_answer_length: Optional[int] = field(
155 | default=None,
156 | metadata={
157 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
158 | "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
159 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
160 | "during ``evaluate`` and ``predict``."
161 | },
162 | )
163 | pad_to_max_length: bool = field(
164 | default=True,
165 | metadata={
166 | "help": "Whether to pad all samples to `max_seq_length`. "
167 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
168 | "be faster on GPU but will be slower on TPU)."
169 | },
170 | )
171 | max_train_samples: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
175 | "value if set."
176 | },
177 | )
178 | max_eval_samples: Optional[int] = field(
179 | default=None,
180 | metadata={
181 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
182 | "value if set."
183 | },
184 | )
185 | max_predict_samples: Optional[int] = field(
186 | default=None,
187 | metadata={
188 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
189 | "value if set."
190 | },
191 | )
192 | version_2_with_negative: bool = field(
193 | default=False,
194 | metadata={"help": "If true, some of the examples do not have an answer."},
195 | )
196 | null_score_diff_threshold: float = field(
197 | default=0.0,
198 | metadata={
199 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
200 | "the score of the null answer minus this threshold, the null answer is selected for this example. "
201 | "Only useful when `version_2_with_negative=True`."
202 | },
203 | )
204 | doc_stride: int = field(
205 | default=128,
206 | metadata={
207 | "help": "When splitting up a long document into chunks, how much stride to take between chunks."
208 | },
209 | )
210 | n_best_size: int = field(
211 | default=20,
212 | metadata={
213 | "help": "The total number of n-best predictions to generate when looking for an answer."
214 | },
215 | )
216 | num_beams: Optional[int] = field(
217 | default=5,
218 | metadata={
219 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
220 | "which is used during ``evaluate`` and ``predict``."
221 | },
222 | )
223 | ignore_pad_token_for_loss: bool = field(
224 | default=True,
225 | metadata={
226 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
227 | },
228 | )
229 |
230 | def __post_init__(self):
231 | if (
232 | self.dataset_name is None
233 | and self.train_file is None
234 | and self.validation_file is None
235 | and self.test_file is None
236 | ):
237 | raise ValueError(
238 | "Need either a dataset name or a training/validation file/test_file."
239 | )
240 | else:
241 | if self.train_file is not None:
242 | extension = self.train_file.split(".")[-1]
243 | assert extension in [
244 | "csv",
245 | "json",
246 | "jsonl",
247 | ], "`train_file` should be a csv or a json file."
248 | if self.validation_file is not None:
249 | extension = self.validation_file.split(".")[-1]
250 | assert extension in [
251 | "csv",
252 | "json",
253 | ], "`validation_file` should be a csv or a json file."
254 | if self.test_file is not None:
255 | extension = self.test_file.split(".")[-1]
256 | assert extension in [
257 | "csv",
258 | "json",
259 | ], "`test_file` should be a csv or a json file."
260 | if self.val_max_answer_length is None:
261 | self.val_max_answer_length = self.max_answer_length
262 |
263 |
264 | question_answering_column_name_mapping = {
265 | "squad_v2": ("question", "context", "answer"),
266 | }
267 |
268 |
269 | def main():
270 | # See all possible arguments in src/transformers/training_args.py
271 | # or by passing the --help flag to this script.
272 | # We now keep distinct sets of args, for a cleaner separation of concerns.
273 |
274 | parser = HfArgumentParser(
275 | (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)
276 | )
277 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
278 | # If we pass only one argument to the script and it's the path to a json file,
279 | # let's parse it to get our arguments.
280 | model_args, data_args, training_args = parser.parse_json_file(
281 | json_file=os.path.abspath(sys.argv[1])
282 | )
283 | else:
284 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
285 |
286 | # Setup logging
287 | logging.basicConfig(
288 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
289 | datefmt="%m/%d/%Y %H:%M:%S",
290 | handlers=[logging.StreamHandler(sys.stdout)],
291 | )
292 |
293 | log_level = training_args.get_process_log_level()
294 | logger.setLevel(log_level)
295 | datasets.utils.logging.set_verbosity(log_level)
296 | transformers.utils.logging.set_verbosity(log_level)
297 | transformers.utils.logging.enable_default_handler()
298 | transformers.utils.logging.enable_explicit_format()
299 |
300 | # Log on each process the small summary:
301 | logger.warning(
302 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
303 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
304 | )
305 | logger.info(f"Training/evaluation parameters {training_args}")
306 |
307 | # Detecting last checkpoint.
308 | last_checkpoint = None
309 | if (
310 | os.path.isdir(training_args.output_dir)
311 | and training_args.do_train
312 | and not training_args.overwrite_output_dir
313 | ):
314 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
315 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
316 | raise ValueError(
317 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
318 | "Use --overwrite_output_dir to overcome."
319 | )
320 | elif (
321 | last_checkpoint is not None and training_args.resume_from_checkpoint is None
322 | ):
323 | logger.info(
324 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
325 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
326 | )
327 |
328 | # Set seed before initializing model.
329 | set_seed(training_args.seed)
330 |
331 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
332 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
333 | # (the dataset will be downloaded automatically from the datasets Hub).
334 | #
335 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
336 | # 'text' is found. You can easily tweak this behavior (see below).
337 | #
338 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
339 | # download the dataset.
340 | if data_args.dataset_name is not None:
341 | # Downloading and loading a dataset from the hub.
342 | raw_datasets = load_dataset(
343 | data_args.dataset_name,
344 | data_args.dataset_config_name,
345 | cache_dir=model_args.cache_dir,
346 | )
347 | else:
348 | data_files = {}
349 | if data_args.train_file is not None:
350 | data_files["train"] = data_args.train_file
351 | extension = data_args.train_file.split(".")[-1]
352 | if extension == "jsonl":
353 | extension = "json"
354 |
355 | if data_args.validation_file is not None:
356 | data_files["validation"] = data_args.validation_file
357 | extension = data_args.validation_file.split(".")[-1]
358 | if data_args.test_file is not None:
359 | data_files["test"] = data_args.test_file
360 | extension = data_args.test_file.split(".")[-1]
361 | # raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
362 | if extension == "json":
363 | raw_datasets = DatasetDict.from_json(data_files)
364 | else:
365 | raw_datasets = DatasetDict.from_csv(data_files)
366 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
367 | # https://huggingface.co/docs/datasets/loading_datasets.html.
368 |
369 | # Load pretrained model and tokenizer
370 | #
371 | # Distributed training:
372 | # The .from_pretrained methods guarantee that only one local process can concurrently
373 | # download model & vocab.
374 |
375 | if model_args.model_name_or_path.startswith("codegen-"):
376 | if last_checkpoint is not None:
377 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
378 | last_checkpoint, low_cpu_mem_usage=True
379 | )
380 | else:
381 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
382 | f"{data_args.codegen_repo}/{model_args.model_name_or_path}-mono",
383 | low_cpu_mem_usage=True,
384 | )
385 | ## IMPORTANT: DO NOT REMOVE
386 | model = model.to(torch.float32)
387 |
388 | tokenizer = sample.create_custom_gpt2_tokenizer()
389 | # tokenizer.padding_side = 'left'
390 | tokenizer.pad_token = 50256
391 | if model_args.parallelize:
392 | model.parallelize()
393 | else:
394 | model = model.cuda()
395 | else:
396 | raise ValueError(
397 | f"{model_args.model_name_or_path} is not a valid model name or path."
398 | )
399 |
400 | model.resize_token_embeddings(len(tokenizer))
401 |
402 | # Preprocessing the datasets.
403 | # We need to generate and tokenize inputs and targets.
404 | if training_args.do_train:
405 | column_names = list(raw_datasets["train"].features.keys())
406 | elif training_args.do_eval:
407 | column_names = list(raw_datasets["validation"].features.keys())
408 | elif training_args.do_predict:
409 | column_names = list(raw_datasets["test"].features.keys())
410 | else:
411 | logger.info(
412 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
413 | )
414 | return
415 |
416 | # Get the column names for input/target.
417 | dataset_columns = question_answering_column_name_mapping.get(
418 | data_args.dataset_name, None
419 | )
420 | if data_args.prompt_column is None:
421 | prompt_column = (
422 | dataset_columns[0] if dataset_columns is not None else column_names[0]
423 | )
424 | else:
425 | prompt_column = data_args.prompt_column
426 | if prompt_column not in column_names:
427 | raise ValueError(
428 | f"--prompt_column' value '{data_args.prompt_column}' needs to be one of: {', '.join(column_names)}"
429 | )
430 | if data_args.completion_column is None:
431 | completion_column = (
432 | dataset_columns[2] if dataset_columns is not None else column_names[2]
433 | )
434 | else:
435 | completion_column = data_args.completion_column
436 | if completion_column not in column_names:
437 | raise ValueError(
438 | f"--completion_column' value '{data_args.completion_column}' needs to be one of: {', '.join(column_names)}"
439 | )
440 |
441 | # Temporarily set max_answer_length for training.
442 | max_answer_length = data_args.max_answer_length
443 | padding = "max_length" if data_args.pad_to_max_length else False
444 |
445 | if training_args.label_smoothing_factor > 0 and not hasattr(
446 | model, "prepare_decoder_input_ids_from_labels"
447 | ):
448 | logger.warning(
449 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
450 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
451 | )
452 |
453 | if data_args.max_seq_length > tokenizer.model_max_length:
454 | logger.warning(
455 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
456 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
457 | )
458 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
459 |
460 | def truncate(ex, tokenizer, max_length):
461 | return tokenizer.decode(
462 | tokenizer(ex, max_length=max_length, truncation=True).input_ids
463 | )
464 |
465 | def preprocess_example(example):
466 | input_str = truncate(example[prompt_column], tokenizer, max_seq_length)
467 | r = example[completion_column]
468 | input_token_ids = tokenizer.encode(input_str, verbose=False)
469 | target_token_ids = tokenizer.encode(r, verbose=False) + [tokenizer.eos_token_id]
470 | input_ids = input_token_ids + target_token_ids
471 | labels_input_ids = ([-100] * len(input_token_ids)) + target_token_ids
472 |
473 | if len(input_ids) > max_seq_length:
474 | input_ids = input_ids[:max_seq_length]
475 | labels_input_ids = labels_input_ids[:max_seq_length]
476 | return {
477 | "input_ids": torch.IntTensor(input_ids).cuda(),
478 | "labels": torch.IntTensor(labels_input_ids).cuda(),
479 | }
480 |
481 | if training_args.do_train:
482 | if "train" not in raw_datasets:
483 | raise ValueError("--do_train requires a train dataset")
484 | train_dataset = raw_datasets["train"]
485 | if data_args.max_train_samples is not None:
486 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
487 | train_dataset = train_dataset.select(range(max_train_samples))
488 | with training_args.main_process_first(desc="train dataset map pre-processing"):
489 | train_dataset = train_dataset.map(
490 | preprocess_example,
491 | remove_columns=column_names,
492 | )
493 | if data_args.max_train_samples is not None:
494 | # Number of samples might increase during Feature Creation, We select only specified max samples
495 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
496 | train_dataset = train_dataset.select(range(max_train_samples))
497 |
498 | # Data collator
499 | label_pad_token_id = (
500 | -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
501 | )
502 | data_collator = DataCollatorForSeq2Seq(
503 | tokenizer,
504 | model=model,
505 | label_pad_token_id=label_pad_token_id,
506 | pad_to_multiple_of=8 if training_args.fp16 else None,
507 | )
508 |
509 | # Initialize our Trainer
510 | trainer = Seq2SeqTrainer(
511 | model=model,
512 | args=training_args,
513 | train_dataset=train_dataset if training_args.do_train else None,
514 | tokenizer=tokenizer,
515 | data_collator=data_collator,
516 | )
517 |
518 | old_collator = trainer.data_collator
519 | trainer.data_collator = lambda data: dict(old_collator(data))
520 |
521 | # Training
522 | if training_args.do_train:
523 | train_result = trainer.train()
524 | trainer.save_model() # Saves the tokenizer too for easy upload
525 |
526 | metrics = train_result.metrics
527 | max_train_samples = (
528 | data_args.max_train_samples
529 | if data_args.max_train_samples is not None
530 | else len(train_dataset)
531 | )
532 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
533 |
534 | trainer.log_metrics("train", metrics)
535 | trainer.save_metrics("train", metrics)
536 | trainer.save_state()
537 |
538 |
539 | def _mp_fn(index):
540 | # For xla_spawn (TPUs)
541 | main()
542 |
543 |
544 | if __name__ == "__main__":
545 | main()
546 |
--------------------------------------------------------------------------------
/finetune_refinement_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Fine-tuning transformers models to generate refinements given old code and NL feedback.
4 | Adapted from a HuggingFace transformers example for training seq2seq models.
5 |
6 | Assumes that CodeGen model checkpoints are stored in {model_args.codegen_model_dir}/codegen-[6B|16B]-mono.
7 | """
8 | import os
9 |
10 |
11 | import sys
12 | import logging
13 | import json
14 | import torch
15 | from dataclasses import dataclass, field
16 | from typing import Dict, List, Optional, Tuple
17 |
18 | import datasets
19 | from datasets import load_dataset, load_metric
20 |
21 | from jaxformer.hf import sample
22 | from jaxformer.hf.codegen import modeling_codegen
23 |
24 | from tqdm import tqdm
25 |
26 | import transformers
27 | from transformers import (
28 | DataCollatorForSeq2Seq,
29 | HfArgumentParser,
30 | Seq2SeqTrainer,
31 | Seq2SeqTrainingArguments,
32 | set_seed,
33 | )
34 | from transformers.trainer_utils import (
35 | get_last_checkpoint,
36 | )
37 | from transformers.utils import check_min_version
38 | from transformers.utils.versions import require_version
39 |
40 | from torch.utils.data import Dataset
41 |
42 | # Will error if the minimal version of Transformers is not installed.
43 | check_min_version("4.12.5")
44 |
45 | logger = logging.getLogger(__name__)
46 |
47 |
48 | @dataclass
49 | class ModelArguments:
50 | """
51 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
52 | """
53 |
54 | codegen_model_dir: Optional[str] = field(
55 | default="checkpoints",
56 | metadata={
57 | "help": "Path to directory containing CodeGen model checkpoints."
58 | "Assumes the model checkpoints are stored in {codegen_model_dir}/."
59 | },
60 | )
61 | model_name_or_path: str = field(
62 | default=None, metadata={"help": "Can be codegen-16B or codegen-6B."}
63 | )
64 | config_name: Optional[str] = field(
65 | default=None,
66 | metadata={
67 | "help": "Pretrained config name or path if not the same as model_name"
68 | },
69 | )
70 | cache_dir: Optional[str] = field(
71 | default=None,
72 | metadata={
73 | "help": "Path to directory to store the pretrained models downloaded from huggingface.co"
74 | },
75 | )
76 | use_fast_tokenizer: bool = field(
77 | default=True,
78 | metadata={
79 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
80 | },
81 | )
82 | model_revision: str = field(
83 | default="main",
84 | metadata={
85 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."
86 | },
87 | )
88 | use_auth_token: bool = field(
89 | default=False,
90 | metadata={
91 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
92 | "with private models)."
93 | },
94 | )
95 | parallelize: bool = field(
96 | default=False,
97 | )
98 |
99 |
100 | @dataclass
101 | class DataTrainingArguments:
102 | """
103 | Arguments pertaining to what data we are going to input our model for training and eval.
104 | """
105 |
106 | dataset_name: Optional[str] = field(
107 | default=None,
108 | metadata={"help": "The name of the dataset to use (via the datasets library)."},
109 | )
110 | dataset_config_name: Optional[str] = field(
111 | default=None,
112 | metadata={
113 | "help": "The configuration name of the dataset to use (via the datasets library)."
114 | },
115 | )
116 | feedback_column: Optional[str] = field(
117 | default="Feedback",
118 | metadata={
119 | "help": "The name of the column in the datasets containing the NL feedback (for code refinement)."
120 | },
121 | )
122 | question_column: Optional[str] = field(
123 | default="completion",
124 | metadata={
125 | "help": "The name of the column in the datasets containing the original task description and code."
126 | },
127 | )
128 | refinement_column: Optional[str] = field(
129 | default="Refinement",
130 | metadata={
131 | "help": "The name of the column in the datasets containing the refinement of the code."
132 | },
133 | )
134 | train_file: Optional[str] = field(
135 | default=None,
136 | metadata={"help": "The input training data file (a text file)."},
137 | )
138 | validation_file: Optional[str] = field(
139 | default=None,
140 | metadata={
141 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
142 | },
143 | )
144 | test_file: Optional[str] = field(
145 | default=None,
146 | metadata={
147 | "help": "An optional input test data file to evaluate the perplexity on (a text file)."
148 | },
149 | )
150 | overwrite_cache: bool = field(
151 | default=False,
152 | metadata={"help": "Overwrite the cached training and evaluation sets"},
153 | )
154 | preprocessing_num_workers: Optional[int] = field(
155 | default=None,
156 | metadata={"help": "The number of processes to use for the preprocessing."},
157 | )
158 | max_seq_length: int = field(
159 | default=1024,
160 | metadata={
161 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
162 | "than this will be truncated, sequences shorter will be padded."
163 | },
164 | )
165 | max_answer_length: int = field(
166 | default=1024,
167 | metadata={
168 | "help": "The maximum length of an answer that can be generated. This is needed because the start "
169 | "and end predictions are not conditioned on one another."
170 | },
171 | )
172 | val_max_answer_length: Optional[int] = field(
173 | default=None,
174 | metadata={
175 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
176 | "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
177 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
178 | "during ``evaluate`` and ``predict``."
179 | },
180 | )
181 | pad_to_max_length: bool = field(
182 | default=True,
183 | metadata={
184 | "help": "Whether to pad all samples to `max_seq_length`. "
185 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
186 | "be faster on GPU but will be slower on TPU)."
187 | },
188 | )
189 | max_train_samples: Optional[int] = field(
190 | default=None,
191 | metadata={
192 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
193 | "value if set."
194 | },
195 | )
196 | max_eval_samples: Optional[int] = field(
197 | default=None,
198 | metadata={
199 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
200 | "value if set."
201 | },
202 | )
203 | max_predict_samples: Optional[int] = field(
204 | default=None,
205 | metadata={
206 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
207 | "value if set."
208 | },
209 | )
210 | version_2_with_negative: bool = field(
211 | default=False,
212 | metadata={"help": "If true, some of the examples do not have an answer."},
213 | )
214 | null_score_diff_threshold: float = field(
215 | default=0.0,
216 | metadata={
217 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
218 | "the score of the null answer minus this threshold, the null answer is selected for this example. "
219 | "Only useful when `version_2_with_negative=True`."
220 | },
221 | )
222 | doc_stride: int = field(
223 | default=128,
224 | metadata={
225 | "help": "When splitting up a long document into chunks, how much stride to take between chunks."
226 | },
227 | )
228 | n_best_size: int = field(
229 | default=20,
230 | metadata={
231 | "help": "The total number of n-best predictions to generate when looking for an answer."
232 | },
233 | )
234 | num_beams: Optional[int] = field(
235 | default=5,
236 | metadata={
237 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
238 | "which is used during ``evaluate`` and ``predict``."
239 | },
240 | )
241 | ignore_pad_token_for_loss: bool = field(
242 | default=True,
243 | metadata={
244 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
245 | },
246 | )
247 |
248 | def __post_init__(self):
249 | if (
250 | self.dataset_name is None
251 | and self.train_file is None
252 | and self.validation_file is None
253 | and self.test_file is None
254 | ):
255 | raise ValueError(
256 | "Need either a dataset name or a training/validation file/test_file."
257 | )
258 | else:
259 | if self.train_file is not None:
260 | extension = self.train_file.split(".")[-1]
261 | assert extension in [
262 | "csv",
263 | "json",
264 | "jsonl",
265 | ], "`train_file` should be a csv or a json file."
266 | if self.validation_file is not None:
267 | extension = self.validation_file.split(".")[-1]
268 | assert extension in [
269 | "csv",
270 | "json",
271 | ], "`validation_file` should be a csv or a json file."
272 | if self.test_file is not None:
273 | extension = self.test_file.split(".")[-1]
274 | assert extension in [
275 | "csv",
276 | "json",
277 | ], "`test_file` should be a csv or a json file."
278 | if self.val_max_answer_length is None:
279 | self.val_max_answer_length = self.max_answer_length
280 |
281 |
282 | def main():
283 | # See all possible arguments by passing the --help flag to this script.
284 | parser = HfArgumentParser(
285 | (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)
286 | )
287 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
288 | # If we pass only one argument to the script and it's the path to a json file,
289 | # let's parse it to get our arguments.
290 | model_args, data_args, training_args = parser.parse_json_file(
291 | json_file=os.path.abspath(sys.argv[1])
292 | )
293 | else:
294 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
295 |
296 | # Setup logging
297 | logging.basicConfig(
298 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
299 | datefmt="%m/%d/%Y %H:%M:%S",
300 | handlers=[logging.StreamHandler(sys.stdout)],
301 | )
302 |
303 | log_level = training_args.get_process_log_level()
304 | logger.setLevel(log_level)
305 | datasets.utils.logging.set_verbosity(log_level)
306 | transformers.utils.logging.set_verbosity(log_level)
307 | transformers.utils.logging.enable_default_handler()
308 | transformers.utils.logging.enable_explicit_format()
309 |
310 | logger.warning(
311 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
312 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
313 | )
314 | logger.info(f"Training/evaluation parameters {training_args}")
315 |
316 | # Detecting last checkpoint.
317 | last_checkpoint = None
318 | if (
319 | os.path.isdir(training_args.output_dir)
320 | and training_args.do_train
321 | and not training_args.overwrite_output_dir
322 | ):
323 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
324 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
325 | raise ValueError(
326 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
327 | "Use --overwrite_output_dir to overcome."
328 | )
329 | elif (
330 | last_checkpoint is not None and training_args.resume_from_checkpoint is None
331 | ):
332 | logger.info(
333 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
334 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
335 | )
336 |
337 | set_seed(training_args.seed)
338 |
339 | if data_args.dataset_name is not None:
340 | # Downloading and loading a dataset from the hub.
341 | raw_datasets = load_dataset(
342 | data_args.dataset_name,
343 | data_args.dataset_config_name,
344 | cache_dir=model_args.cache_dir,
345 | )
346 | else:
347 | data_files = {}
348 | if data_args.train_file is not None:
349 | data_files["train"] = data_args.train_file
350 | extension = data_args.train_file.split(".")[-1]
351 | if extension == "jsonl":
352 | extension = "json"
353 |
354 | if data_args.validation_file is not None:
355 | data_files["validation"] = data_args.validation_file
356 | extension = data_args.validation_file.split(".")[-1]
357 | if data_args.test_file is not None:
358 | data_files["test"] = data_args.test_file
359 | extension = data_args.test_file.split(".")[-1]
360 | raw_datasets = load_dataset(
361 | extension, data_files=data_files, cache_dir=model_args.cache_dir
362 | )
363 |
364 | if model_args.model_name_or_path.startswith("codegen-"):
365 | if last_checkpoint is not None:
366 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
367 | last_checkpoint, low_cpu_mem_usage=True
368 | )
369 | else:
370 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
371 | f"{model_args.codegen_model_dir}/{model_args.model_name_or_path}-mono",
372 | low_cpu_mem_usage=True,
373 | )
374 | ## IMPORTANT: DO NOT REMOVE
375 | model = model.to(torch.float32)
376 |
377 | tokenizer = sample.create_custom_gpt2_tokenizer()
378 | tokenizer.pad_token = 50256
379 | if model_args.parallelize:
380 | model.parallelize()
381 | else:
382 | model = model.cuda()
383 | else:
384 | raise ValueError(
385 | f"{model_args.model_name_or_path} is not a valid model name or path."
386 | )
387 |
388 | model.resize_token_embeddings(len(tokenizer))
389 |
390 | if training_args.do_train:
391 | column_names = raw_datasets["train"].column_names
392 | elif training_args.do_eval:
393 | column_names = raw_datasets["validation"].column_names
394 | elif training_args.do_predict:
395 | column_names = raw_datasets["test"].column_names
396 | else:
397 | logger.info(
398 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
399 | )
400 | return
401 |
402 | # Get the column names for input/target.
403 | if data_args.question_column is None:
404 | question_column = column_names[0]
405 | else:
406 | question_column = data_args.question_column
407 | if question_column not in column_names:
408 | raise ValueError(
409 | f"--question_column' value '{data_args.question_column}' needs to be one of: {', '.join(column_names)}"
410 | )
411 | if data_args.feedback_column is None:
412 | feedback_column = column_names[1]
413 | else:
414 | feedback_column = data_args.feedback_column
415 | if feedback_column not in column_names:
416 | raise ValueError(
417 | f"--feedback_column' value '{data_args.feedback_column}' needs to be one of: {', '.join(column_names)}"
418 | )
419 | if data_args.refinement_column is None:
420 | refinement_column = column_names[2]
421 | else:
422 | refinement_column = data_args.refinement_column
423 | if refinement_column not in column_names:
424 | raise ValueError(
425 | f"--refinement_column' value '{data_args.refinement_column}' needs to be one of: {', '.join(column_names)}"
426 | )
427 |
428 | # Temporarily set max_answer_length for training.
429 | max_answer_length = data_args.max_answer_length
430 | padding = "max_length" if data_args.pad_to_max_length else False
431 |
432 | if training_args.label_smoothing_factor > 0 and not hasattr(
433 | model, "prepare_decoder_input_ids_from_labels"
434 | ):
435 | logger.warning(
436 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
437 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
438 | )
439 |
440 | if data_args.max_seq_length > tokenizer.model_max_length:
441 | logger.warning(
442 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
443 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
444 | )
445 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
446 |
447 | def truncate(ex, tokenizer, max_length):
448 | return tokenizer.decode(
449 | tokenizer(ex, max_length=max_length, truncation=True).input_ids
450 | )
451 |
452 | def preprocess_example(example):
453 | # Encode prompt prefix and suffix
454 | f = example[feedback_column]
455 | input_prefix = "OLD CODE:\n"
456 | prefix_encoded = tokenizer.encode(input_prefix, verbose=False)
457 | input_suffix = f"\n\nFEEDBACK:\n{f}\n\nREFINEMENT:\n"
458 | suffix_encoded = tokenizer.encode(input_suffix, verbose=False)
459 |
460 | # Encode the refinement
461 | r = example[refinement_column]
462 | target_token_ids = tokenizer.encode(r, verbose=False) + [tokenizer.eos_token_id]
463 |
464 | # We only truncate the old code
465 | q_max_length = (
466 | max_seq_length
467 | - len(prefix_encoded)
468 | - len(suffix_encoded)
469 | - len(target_token_ids)
470 | )
471 | q_encoded = tokenizer.encode(example[question_column], verbose=False)[
472 | :q_max_length
473 | ]
474 | input_token_ids = prefix_encoded + q_encoded + suffix_encoded
475 |
476 | # Combine everything
477 | input_ids = input_token_ids + target_token_ids
478 | labels_input_ids = ([-100] * len(input_token_ids)) + target_token_ids
479 |
480 | if len(input_ids) > max_seq_length:
481 | input_ids = input_ids[:max_seq_length]
482 | labels_input_ids = labels_input_ids[:max_seq_length]
483 | return {
484 | "input_ids": torch.IntTensor(input_ids).cuda(),
485 | "labels": torch.IntTensor(labels_input_ids).cuda(),
486 | }
487 |
488 | if training_args.do_train:
489 | if "train" not in raw_datasets:
490 | raise ValueError("--do_train requires a train dataset")
491 | train_dataset = raw_datasets["train"]
492 | if data_args.max_train_samples is not None:
493 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
494 | train_dataset = train_dataset.select(range(max_train_samples))
495 | with training_args.main_process_first(desc="train dataset map pre-processing"):
496 | train_dataset = train_dataset.filter(
497 | lambda e: e["Refinement"] is not None and e["Refinement"]
498 | ).map(
499 | preprocess_example,
500 | num_proc=data_args.preprocessing_num_workers,
501 | remove_columns=column_names,
502 | load_from_cache_file=not data_args.overwrite_cache,
503 | desc="Running tokenizer on train dataset",
504 | )
505 | if data_args.max_train_samples is not None:
506 | # Number of samples might increase during Feature Creation, We select only specified max samples
507 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
508 | train_dataset = train_dataset.select(range(max_train_samples))
509 |
510 | # Data collator
511 | label_pad_token_id = (
512 | -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
513 | )
514 | data_collator = DataCollatorForSeq2Seq(
515 | tokenizer,
516 | model=model,
517 | label_pad_token_id=label_pad_token_id,
518 | pad_to_multiple_of=8 if training_args.fp16 else None,
519 | )
520 |
521 | # Initialize our Trainer
522 | trainer = Seq2SeqTrainer(
523 | model=model,
524 | args=training_args,
525 | train_dataset=train_dataset if training_args.do_train else None,
526 | tokenizer=tokenizer,
527 | data_collator=data_collator,
528 | )
529 |
530 | old_collator = trainer.data_collator
531 | trainer.data_collator = lambda data: dict(old_collator(data))
532 |
533 | # Training
534 | if training_args.do_train:
535 | train_result = trainer.train()
536 | trainer.save_model()
537 |
538 | metrics = train_result.metrics
539 | max_train_samples = (
540 | data_args.max_train_samples
541 | if data_args.max_train_samples is not None
542 | else len(train_dataset)
543 | )
544 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
545 |
546 | trainer.log_metrics("train", metrics)
547 | trainer.save_metrics("train", metrics)
548 | trainer.save_state()
549 |
550 |
551 | def _mp_fn(index):
552 | # For xla_spawn (TPUs)
553 | main()
554 |
555 |
556 | if __name__ == "__main__":
557 | main()
558 |
--------------------------------------------------------------------------------
/generate_code_for_mbpp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import openai
5 | import os
6 | import pprint
7 | import re
8 | import time
9 | import torch
10 |
11 | from jaxformer.hf import sample
12 | from jaxformer.hf.codegen import modeling_codegen
13 | from datasets import load_dataset, concatenate_datasets
14 | from tqdm import tqdm
15 |
16 |
17 | def format_prompt(task_id, text, tests, sample_code, num_prompts):
18 | # Create prompt from scratch
19 | prompt = f'"""\n{text}\n\n'
20 | if num_prompts > 0:
21 | for i in range(num_prompts):
22 | example = tests[i].split("assert ")[-1].replace("==", "=")
23 | prompt += f">>> Example: {example}\n"
24 |
25 | # Add code prefix
26 | fn_name = tests[0].split("assert ")[-1].split("(")[0]
27 | fn_search = re.search(f"def {fn_name}\(.*\):", sample_code)
28 | if fn_search is None:
29 | raise ValueError(
30 | f"Could not find 'def {fn_name}\(.*\):' in code for task {task_id}."
31 | )
32 | code_prefix = sample_code[: fn_search.end()]
33 | prompt = f'{prompt}"""\n\n{code_prefix}\n'
34 | return prompt
35 |
36 |
37 | # GPT-J
38 | def sample_code_from_gpt_models(args, prompt, model, tokenizer):
39 | output_strs = []
40 | num_samples = args.num_samples
41 | temperature = args.temperature
42 | debug = args.debug
43 | try:
44 | with torch.no_grad():
45 | input_ids = (
46 | torch.LongTensor(tokenizer.encode(prompt, verbose=False))
47 | .unsqueeze(0)
48 | .cuda()
49 | )
50 | output_ids = model.generate(
51 | input_ids,
52 | do_sample=True,
53 | temperature=temperature, # 0.2, 0.8
54 | max_length=1024 - len(input_ids),
55 | num_return_sequences=num_samples,
56 | )
57 | output_strs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
58 | if debug:
59 | print(f"Input: {prompt}")
60 | print(f"Outputs: {output_strs}")
61 | except Exception as e:
62 | if (
63 | isinstance(e, UnboundLocalError)
64 | and str(e) == "local variable 'next_tokens' referenced before assignment"
65 | ):
66 | # See https://github.com/huggingface/transformers/issues/5118
67 | if debug:
68 | print("Problem text was > 1024 tokens, so cannot do generation")
69 | print(e)
70 | print(e)
71 | return output_strs
72 |
73 |
74 | def sample_code_from_codegen(args, prompt, model, tokenizer):
75 | device = "cuda:0"
76 | completions = []
77 | input_ids = tokenizer(
78 | prompt, truncation=True, max_length=1024, return_tensors="pt"
79 | ).input_ids.cuda()
80 | if args.temperature == 0.0:
81 | args.num_samples = 1
82 | for i in range(args.num_samples):
83 | try:
84 | # Note: max_length is max length of input IDs, and max_length_sample is max length for completion (not including input IDs)
85 | if args.temperature > 0:
86 | tokens = model.generate(
87 | input_ids,
88 | do_sample=True,
89 | num_return_sequences=1,
90 | max_length=input_ids.shape[1] + 1024,
91 | temperature=args.temperature,
92 | use_cache=True,
93 | )
94 | else:
95 | tokens = model.generate(
96 | input_ids,
97 | num_return_sequences=1,
98 | max_length=input_ids.shape[1] + 1024,
99 | use_cache=True,
100 | )
101 | text = tokenizer.decode(tokens[0])
102 | if "<|endoftext|>" in text:
103 | text = text[: text.find("<|endoftext|>")]
104 | completions.append(text)
105 | except RuntimeError as e:
106 | logging.error(f"Could not sample from model: {e}")
107 | return completions
108 |
109 |
110 | def initialize_openai(args):
111 | api_key = open(f"{args.openai_creds_dir}/openai_api_key.txt").read()
112 | openai.organization = open(
113 | f"{args.openai_creds_dir}/openai_organization_id.txt"
114 | ).read()
115 | openai.api_key = api_key
116 |
117 |
118 | def sample_code_from_openai_model(args, prompt_text):
119 | output_strs = []
120 | start = time.time()
121 |
122 | arch_mapping = {
123 | "codex": "code-davinci-002",
124 | "gpt3": "text-davinci-001",
125 | "davinci-002": "text-davinci-002",
126 | "davinci-003": "text-davinci-003",
127 | "ada": "text-ada-001",
128 | "babbage": "text-babbage-001",
129 | "curie": "text-curie-001",
130 | }
131 | engine_name = arch_mapping[args.arch]
132 |
133 | for i in range(args.num_samples):
134 | while time.time() - start < args.max_request_time:
135 | try:
136 | response = openai.Completion.create(
137 | engine=engine_name,
138 | prompt=prompt_text,
139 | max_tokens=1024,
140 | n=1,
141 | temperature=args.temperature,
142 | )
143 | output_strs += [
144 | prompt_text + choice["text"] for choice in response["choices"]
145 | ]
146 | break
147 | except Exception as e:
148 | print(
149 | f"Unexpected exception in generating solution. Sleeping again: {e}"
150 | )
151 | time.sleep(args.sleep_time)
152 | return output_strs
153 |
154 |
155 | def write_jsonl(data, output_filepath):
156 | with open(output_filepath, "w") as f:
157 | for row in data:
158 | f.write(json.dumps(row) + "\n")
159 |
160 |
161 | def generate_code_for_problems(args):
162 | mbpp = load_dataset("mbpp")
163 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()])
164 |
165 | output = []
166 | if args.arch in ["gpt3", "codex"]:
167 | initialize_openai(args)
168 | generate_code_fn = sample_code_from_openai_model
169 | elif args.arch in ["codegen-6B", "codegen-16B"]:
170 | if args.model_path is None:
171 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
172 | f"{args.codegen_model_dir}/{args.arch}-mono",
173 | revision="float16",
174 | torch_dtype=torch.float16,
175 | low_cpu_mem_usage=True,
176 | ).cuda()
177 | else:
178 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
179 | args.model_path, low_cpu_mem_usage=True, torch_dtype=torch.float32
180 | ).cuda()
181 | tokenizer = sample.create_custom_gpt2_tokenizer(truncation_side="left")
182 | tokenizer.padding_side = "left"
183 | tokenizer.pad_token = 50256
184 | generate_code_fn = lambda args, prompt: sample_code_from_codegen(
185 | args, prompt, model, tokenizer
186 | )
187 |
188 | task_ids_range = set(range(args.start, args.end))
189 | for i in tqdm(range(len(mbpp))):
190 | if mbpp["task_id"][i] not in task_ids_range:
191 | continue
192 | try:
193 | prompt = format_prompt(
194 | mbpp["task_id"][i],
195 | mbpp["text"][i],
196 | mbpp["test_list"][i],
197 | mbpp["code"][i],
198 | args.num_shots,
199 | )
200 | except ValueError as e:
201 | logging.error(e)
202 | continue
203 |
204 | task_id = mbpp["task_id"][i]
205 | for completion in generate_code_fn(args, prompt):
206 | output.append(
207 | {
208 | "task_id": task_id,
209 | "prompt": prompt,
210 | "completion": completion,
211 | }
212 | )
213 | return output
214 |
215 |
216 | def parse_args():
217 | parser = argparse.ArgumentParser(
218 | description="Run a trained model to generate Python code for the MBPP benchmark."
219 | )
220 | parser.add_argument(
221 | "--arch",
222 | default="gptj",
223 | choices=[
224 | "gptj",
225 | "codex",
226 | "gpt3",
227 | "codegen-16B",
228 | "codegen-6B",
229 | "davinci-002",
230 | "davinci-003",
231 | "ada",
232 | "babbage",
233 | "curie",
234 | ],
235 | )
236 | parser.add_argument(
237 | "--codegen-model-dir",
238 | default="checkpoints",
239 | help="Directory where pre-trained CodeGen model checkpoints are saved.",
240 | )
241 | parser.add_argument(
242 | "--model-path",
243 | default=None,
244 | help="Directory to load model checkpoint from. If None, will load a pre-trained "
245 | "CodeGen model using the --arch argument instead.",
246 | )
247 | parser.add_argument("--num-samples", default=1, type=int)
248 | parser.add_argument("-d", "--debug", action="store_true")
249 | parser.add_argument("--output-dir", type=str)
250 | parser.add_argument("--output-file-suffix", type=str, default="")
251 | parser.add_argument("--temperature", default=0.8, type=float)
252 | parser.add_argument(
253 | "--split",
254 | default="test",
255 | type=str,
256 | help="Which MBPP split to use. In datasets v1.16.1, MBPP only has the split 'test'.",
257 | )
258 | parser.add_argument(
259 | "-s", "--start", default=1, type=int, help="Task ID to start with."
260 | )
261 | parser.add_argument(
262 | "-e", "--end", default=975, type=int, help="Task ID to end with (exclusive)."
263 | )
264 | parser.add_argument(
265 | "-n",
266 | "--num-shots",
267 | default=0,
268 | type=int,
269 | help="Number of assert (test examples) to give in the task description.",
270 | )
271 | parser.add_argument(
272 | "--max-request-time",
273 | type=int,
274 | default=80,
275 | help="Max. time to wait for a successful GPT-3 request.",
276 | )
277 | parser.add_argument(
278 | "--sleep-time",
279 | type=int,
280 | default=10,
281 | help="Time to sleep (in seconds) between each GPT-3 call.",
282 | )
283 | parser.add_argument(
284 | "--openai-creds-dir",
285 | type=str,
286 | default=None,
287 | help="Directory where OpenAI API credentials are stored. Assumes the presence of "
288 | "openai_api_key.txt and openai_organization_id.txt files.",
289 | )
290 | args = parser.parse_args()
291 | return args
292 |
293 |
294 | def main(args):
295 | argsdict = vars(args)
296 | print(pprint.pformat(argsdict))
297 | completions = generate_code_for_problems(args)
298 | output_filepath = os.path.join(
299 | args.output_dir,
300 | f"samples_{args.split}_{args.arch}_{args.num_shots}shot_temp{args.temperature}_{args.start}-{args.end}{args.output_file_suffix}.jsonl",
301 | )
302 | write_jsonl(completions, output_filepath)
303 |
304 |
305 | if __name__ == "__main__":
306 | main(parse_args())
307 |
--------------------------------------------------------------------------------
/generate_refinements_codegen_finetuned.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from datasets import Dataset, load_dataset, concatenate_datasets
3 | from jaxformer.hf.codegen import modeling_codegen
4 | from jaxformer.hf import sample
5 | import torch
6 | import pprint
7 | import os
8 | import logging
9 | import json
10 | import csv
11 | import argparse
12 | import re
13 |
14 |
15 | def load_jsonl(filepath):
16 | data = [json.loads(line) for line in open(filepath).readlines()]
17 | fields = data[0].keys()
18 | data_dict = {k: [x[k] for x in data] for k in fields}
19 | ds = Dataset.from_dict(data_dict)
20 | return ds
21 |
22 |
23 | def load_csv(filepath):
24 | data = list(csv.DictReader(open(filepath)))
25 | fields = data[0].keys()
26 | data_dict = {k: [x[k] for x in data] for k in fields}
27 | ds = Dataset.from_dict(data_dict)
28 | return ds
29 |
30 |
31 | def load_feedback(feedback_path):
32 | extension = "csv" if feedback_path.endswith("csv") else "json"
33 | if extension == "json":
34 | d = load_jsonl(feedback_path)
35 | else:
36 | d = load_csv(feedback_path)
37 | d = d.map(
38 | lambda _, idx: {"row_id": idx},
39 | with_indices=True,
40 | )
41 | d = d.filter(
42 | lambda example: example["Refinement"] is not None and example["Refinement"]
43 | )
44 | return d
45 |
46 |
47 | def sample_code_from_codegen(args, prompt, model, tokenizer):
48 | device = "cuda:0"
49 | completions = []
50 | print(f"Tokenizing input: {prompt}")
51 | input_ids = tokenizer(
52 | prompt, truncation=True, max_length=1024, return_tensors="pt"
53 | ).input_ids.cuda()
54 | if args.temperature == 0.0:
55 | args.num_samples = 1
56 | for i in range(args.num_samples):
57 | try:
58 | # Note: max_length is max length of input IDs, and max_length_sample is max length for completion (not including input IDs)
59 | if args.temperature > 0:
60 | tokens = model.generate(
61 | input_ids,
62 | do_sample=True,
63 | num_return_sequences=1,
64 | max_length=input_ids.shape[1] + 1024,
65 | temperature=args.temperature,
66 | use_cache=True,
67 | )
68 | else:
69 | tokens = model.generate(
70 | input_ids,
71 | num_return_sequences=1,
72 | max_length=input_ids.shape[1] + 1024,
73 | use_cache=True,
74 | )
75 | text = tokenizer.decode(tokens[0])
76 | if "<|endoftext|>" in text:
77 | text = text[: text.find("<|endoftext|>")]
78 | completions.append(text)
79 | except RuntimeError as e:
80 | logging.error(f"Could not sample from model: {e}")
81 | return completions
82 |
83 |
84 | def truncate(ex, tokenizer, max_length):
85 | return tokenizer.decode(
86 | tokenizer(ex, max_length=max_length, truncation=True).input_ids
87 | )
88 |
89 |
90 | def format_mbpp_prompt(mbpp, task_id):
91 | idx = mbpp["task_id"].index(task_id)
92 | text = mbpp["text"][idx]
93 | tests = mbpp["test_list"][idx]
94 | sample_code = mbpp["code"][idx]
95 |
96 | # Create prompt from scratch
97 | prompt = f'"""\n{text}\n\n'
98 | # Add the first unit test as an input-output example
99 | example = tests[0].split("assert ")[-1].replace("==", "=")
100 | prompt += f">>> Example: {example}\n"
101 |
102 | # Add code prefix
103 | fn_name = tests[0].split("assert ")[-1].split("(")[0]
104 | fn_search = re.search(f"def {fn_name}\(.*\):", sample_code)
105 | if fn_search is None:
106 | raise ValueError(
107 | f"Could not find 'def {fn_name}\(.*\):' in code for task {task_id}."
108 | )
109 | code_prefix = sample_code[: fn_search.end()]
110 | prompt = f'{prompt}"""\n\n{code_prefix}\n'
111 | return prompt
112 |
113 |
114 | def gen_refinement_prompt(args, example, tokenizer, mbpp):
115 | prompt = (
116 | f"OLD CODE:\n{truncate(example[args.completion_column], tokenizer, 512)}"
117 | f"\n\nFEEDBACK:\n{example['Feedback']}\n\n"
118 | f"REFINEMENT:\n{format_mbpp_prompt(mbpp, example['task_id'])}"
119 | )
120 | return prompt
121 |
122 |
123 | def gen_code(args, data, model, tokenizer):
124 | mbpp = load_dataset("mbpp")
125 | mbpp = concatenate_datasets([mbpp[k] for k in mbpp.keys()])
126 | output = data.map(
127 | lambda ex: {"input_str": gen_refinement_prompt(args, ex, tokenizer, mbpp)}
128 | )
129 | output = output.map(
130 | lambda ex: {
131 | "output_strs": sample_code_from_codegen(
132 | args, ex["input_str"], model, tokenizer
133 | )
134 | },
135 | desc="Sampling code from codegen...",
136 | )
137 | return output
138 |
139 |
140 | def generate_code_for_problems(args):
141 | data = load_feedback(args.feedback_file)
142 |
143 | if args.model_path is None:
144 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
145 | f"{args.codegen_model_dir}/{args.arch}-mono",
146 | revision="float16",
147 | torch_dtype=torch.float16,
148 | low_cpu_mem_usage=True,
149 | ).cuda()
150 | else:
151 | model = modeling_codegen.CodeGenForCausalLM.from_pretrained(
152 | args.model_path, low_cpu_mem_usage=True, torch_dtype=torch.float32
153 | ).cuda()
154 | tokenizer = sample.create_custom_gpt2_tokenizer()
155 | tokenizer.pad_token = 50256
156 | val = gen_code(args, data, model, tokenizer)
157 |
158 | output = []
159 | for row in tqdm(val):
160 | for completion in row["output_strs"]:
161 | output.append(
162 | {
163 | "task_id": row["task_id"],
164 | "prompt": row["input_str"],
165 | "feedback": row["Feedback"],
166 | "old_completion": row[args.completion_column],
167 | "completion": completion,
168 | }
169 | )
170 | return output
171 |
172 |
173 | def write_jsonl(data, output_filepath):
174 | with open(output_filepath, "w") as f:
175 | for row in data:
176 | f.write(json.dumps(row) + "\n")
177 |
178 |
179 | def parse_args():
180 | parser = argparse.ArgumentParser(
181 | description="Run a trained model to generate Python code for the MBPP benchmark."
182 | )
183 | parser.add_argument(
184 | "--arch", default="codegen-6B", choices=["codegen-16B", "codegen-6B"]
185 | )
186 | parser.add_argument(
187 | "--codegen-model-dir",
188 | default="checkpoints",
189 | help="Directory where pre-trained CodeGen model checkpoints are saved.",
190 | )
191 | parser.add_argument(
192 | "--model-path",
193 | default=None,
194 | required=True,
195 | help="Directory to load model checkpoint from. If None, will load a pre-trained "
196 | "CodeGen model using the --arch argument instead.",
197 | )
198 | parser.add_argument("--num-samples", default=1, type=int)
199 | parser.add_argument("-d", "--debug", action="store_true")
200 | parser.add_argument("--output-dir", type=str)
201 | parser.add_argument("--output-file-suffix", type=str, default="")
202 | parser.add_argument("--temperature", default=0.8, type=float)
203 | parser.add_argument(
204 | "--feedback-file",
205 | default=None,
206 | required=True,
207 | help="CSV file containing feedback and past completions.",
208 | )
209 | parser.add_argument("--completion-column", default="completion")
210 | args = parser.parse_args()
211 | return args
212 |
213 |
214 | def main(args):
215 | argsdict = vars(args)
216 | print(pprint.pformat(argsdict))
217 | completions = generate_code_for_problems(args)
218 |
219 | if args.model_path is None:
220 | output_filepath = os.path.join(
221 | args.output_dir,
222 | f"refinements_{args.arch}_temp{args.temperature}_{args.output_file_suffix}.jsonl",
223 | )
224 | else:
225 | output_filepath = os.path.join(
226 | args.model_path,
227 | f"refinements_{args.arch}_temp{args.temperature}_{args.output_file_suffix}.jsonl",
228 | )
229 | write_jsonl(completions, output_filepath)
230 |
231 |
232 | if __name__ == "__main__":
233 | main(parse_args())
234 |
--------------------------------------------------------------------------------
/ilf_for_code_gen.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nyu-mll/ILF-for-code-generation/1bbccca2934b26e2d8745e5afab65eb677cbe92a/ilf_for_code_gen.pdf
--------------------------------------------------------------------------------
/ilf_pipeline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Assumes that the Codegen checkpoints are stored in a directory
4 | # named "checkpoints" that is a subdirectory of the current directory.
5 |
6 | FEEDBACK_COLUMN="Feedback"
7 | REFINEMENTS_COLUMN="Refinement"
8 | INPUT_FILE="surge_annotations.jsonl"
9 | LEARNING_RATE=5e-6
10 | GRADIENT_ACCUMULATION_STEPS=32
11 | NUM_OUTPUT_SAMPLES=30
12 | NUM_EPOCHS=2
13 | while getopts "i:f:r:n:l:g:o:e:d:" option; do
14 | case $option in
15 | i) # File containing all Surge annotations. Feedback is in column named via option -f, and refinements are under "unedited_annotator_completion".
16 | INPUT_FILE=$OPTARG;;
17 | f) # Name of feedback column
18 | FEEDBACK_COLUMN=$OPTARG;;
19 | r) # Name of refinements column that will be created in all intermediate outputs.
20 | REFINEMENTS_COLUMN=$OPTARG;;
21 | n) # Experiment name
22 | EXP_NAME=$OPTARG;;
23 | l) # Learning rate
24 | LEARNING_RATE=$OPTARG;;
25 | g) # Gradient accumulation steps. Determines effective batch size because the per-device train batch size is 1
26 | GRADIENT_ACCUMULATION_STEPS=$OPTARG;;
27 | o) # Number of final MBPP samples to output from the final fine-tuned CodeGen-6B model.
28 | NUM_OUTPUT_SAMPLES=$OPTARG;;
29 | e) # Number of epochs to train for.
30 | NUM_EPOCHS=$OPTARG;;
31 | d) # Parent directory to save results in. Experiment results will be saved in a subdirectory of this directory named ${EXP_NAME}.
32 | PARENT_DIR=$OPTARG;;
33 | \?) # Invalid option
34 | echo "Error: Invalid option ${option}"
35 | exit;;
36 | esac
37 | done
38 |
39 | TRAIN_START_TASK_ID=111
40 | TRAIN_END_TASK_ID=310 # inclusive
41 | TRAIN_N=$(( $TRAIN_END_TASK_ID - $TRAIN_START_TASK_ID + 1 ))
42 | VAL_START_TASK_ID=311
43 | VAL_END_TASK_ID=974 # inclusive
44 | VAL_N=$(( $VAL_END_TASK_ID - $VAL_START_TASK_ID + 1 ))
45 | TEST_START_TASK_ID=11
46 | TEST_END_TASK_ID=111 # (should be exclusive)
47 |
48 | CONDA_ENV="ilf"
49 | EXPERIMENT_DIR="${PARENT_DIR}/${EXP_NAME}"
50 |
51 | echo "Running with arguments -i=${INPUT_FILE}, -f=${FEEDBACK_COLUMN}, -r=${REFINEMENTS_COLUMN}," \
52 | "-n=${EXP_NAME}, -l=${LEARNING_RATE}, -g=${GRADIENT_ACCUMULATION_STEPS}, -o=${NUM_OUTPUT_SAMPLES}," \
53 | "-e=${NUM_EPOCHS}, -d=${PARENT_DIR}."
54 | echo "Outputting experiment results in ${EXPERIMENT_DIR}."
55 |
56 | conda deactivate
57 | conda activate ${CONDA_ENV}
58 |
59 | mkdir -p ${EXPERIMENT_DIR}
60 | python preprocess_feedback_spreadsheet.py --input_file=${INPUT_FILE} \
61 | --model_completion_column=original_model_completion \
62 | --old_refinement_column=unedited_annotator_completion \
63 | --training_n=$TRAIN_N --val_n=$VAL_N \
64 | --feedback_column=${FEEDBACK_COLUMN} --refinement_column=${REFINEMENTS_COLUMN} \
65 | --one_per_task --filter_for_correct --output_dir=${EXPERIMENT_DIR} \
66 | --training_start_id=${TRAIN_START_TASK_ID} --training_end_id=${TRAIN_END_TASK_ID} \
67 | --val_start_id=${VAL_START_TASK_ID} --val_end_id=${VAL_END_TASK_ID} || exit
68 | OUTPUT_FILE_PREFIX=$(python -c "print(''.join('${INPUT_FILE}'.split('.')[:-1]).split('/')[-1])")
69 | OUTPUT_FILE_PREFIX=${EXPERIMENT_DIR}/${OUTPUT_FILE_PREFIX}
70 | REF_TRAINING_FILE="${OUTPUT_FILE_PREFIX}-train.jsonl"
71 | REF_VAL_FILE="${OUTPUT_FILE_PREFIX}-val.jsonl"
72 |
73 | echo "Training data for Pi_Ref: ${REF_TRAINING_FILE}"
74 | echo "Val data for Pi_Ref: ${REF_VAL_FILE}"
75 |
76 | # Fine-tune a model to generate refinements.
77 | # We trained with per-device batch size of 1 due to computational constraints
78 | # (but used gradient accumulation to reach the desired effective batch size).
79 | PI_REF_DIR="${EXPERIMENT_DIR}/mref_lr${LEARNING_RATE}_ga${GRADIENT_ACCUMULATION_STEPS}_${NUM_EPOCHS}epochs"
80 | CHECKPOINTS_DIR="$(pwd)/checkpoints"
81 | python finetune_refinement_model.py \
82 | --do_train \
83 | --codegen_model_dir=${CHECKPOINTS_DIR} \
84 | --model_name_or_path=codegen-6B \
85 | --num_train_epochs=${NUM_EPOCHS} \
86 | --save_strategy=no \
87 | --learning_rate=${LEARNING_RATE} \
88 | --per_device_train_batch_size=1 \
89 | --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
90 | --logging_steps=1 \
91 | --output_dir ${PI_REF_DIR} \
92 | --pad_to_max_length \
93 | --generation_max_length=512 \
94 | --max_seq_length=1024 \
95 | --max_answer_length=512 \
96 | --parallelize \
97 | --overwrite_output_dir \
98 | --save_total_limit=2 \
99 | --feedback_column=${FEEDBACK_COLUMN} \
100 | --refinement_column=${REFINEMENTS_COLUMN} \
101 | --train_file ${REF_TRAINING_FILE} || exit
102 |
103 | # Generate refinements using Pi_Ref
104 | python generate_refinements_codegen_finetuned.py \
105 | --arch=codegen-6B \
106 | --codegen-model-dir=${CHECKPOINTS_DIR} \
107 | --num-samples=${NUM_OUTPUT_SAMPLES} --output-dir=${PI_REF_DIR} \
108 | --temperature=0.8 --feedback-file=${REF_VAL_FILE} \
109 | --output-file-suffix=${EXP_NAME} \
110 | --model-path=${PI_REF_DIR} || exit
111 |
112 | # Evaluate refinements generated for tasks in MBPP_Train, and
113 | # keep only the correct ones for training Pi_Theta
114 | python eval_mbpp.py \
115 | --input-file=${PI_REF_DIR}/refinements_codegen-6B_temp0.8_${EXP_NAME}.jsonl \
116 | --k=1,10 || exit
117 | python create_finetuning_data_from_refinements.py \
118 | --one-per-task \
119 | --refinement-file=${PI_REF_DIR}/refinements_codegen-6B_temp0.8_${EXP_NAME}.jsonl_results.jsonl \
120 | --output-dir=${PI_REF_DIR} \
121 | --output-file-suffix=surge_final || exit
122 |
123 | # Fine-tune two separate models:
124 | # 1) fine-tuned on MBPP gold data,
125 | # 2) fine-tuned on Pi_Refine-generated refinements
126 | TRAINING_FILE="${PI_REF_DIR}/finetuning_prompts_mbpp_refinements_surge_final.jsonl"
127 | GOLD_TRAINING_FILE="${PI_REF_DIR}/finetuning_prompts_mbpp_gold_surge_final.jsonl"
128 | # Fine-tune (1)
129 | FINAL_GOLD_FINETUNE_DIR=${EXPERIMENT_DIR}/final_gold_finetune_lr${LEARNING_RATE}_ga${GRADIENT_ACCUMULATION_STEPS}_${NUM_EPOCHS}epochs
130 | python finetune.py \
131 | --codegen_repo=${CHECKPOINTS_DIR} \
132 | --do_train \
133 | --model_name_or_path=codegen-6B \
134 | --save_strategy=no \
135 | --num_train_epochs=${NUM_EPOCHS} \
136 | --learning_rate=${LEARNING_RATE} \
137 | --per_device_train_batch_size=1 \
138 | --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
139 | --logging_steps=1 \
140 | --output_dir ${FINAL_GOLD_FINETUNE_DIR} \
141 | --parallelize \
142 | --pad_to_max_length \
143 | --generation_max_length=512 \
144 | --max_seq_length=1024 \
145 | --max_answer_length=512 \
146 | --save_total_limit=2 \
147 | --parallelize \
148 | --prompt_column=finetuning_prompt \
149 | --completion_column=finetuning_completion \
150 | --overwrite_output_dir \
151 | --train_file ${GOLD_TRAINING_FILE} || exit
152 | # Fine-tune (2)
153 | FINAL_FINETUNE_DIR=${EXPERIMENT_DIR}/final_finetune_lr${LEARNING_RATE}_ga${GRADIENT_ACCUMULATION_STEPS}_${NUM_EPOCHS}epochs
154 | python finetune.py \
155 | --codegen_repo=${CHECKPOINTS_DIR} \
156 | --do_train \
157 | --model_name_or_path=codegen-6B \
158 | --save_strategy=no \
159 | --num_train_epochs=${NUM_EPOCHS} \
160 | --learning_rate=${LEARNING_RATE} \
161 | --per_device_train_batch_size=1 \
162 | --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
163 | --logging_steps=1 \
164 | --output_dir ${FINAL_FINETUNE_DIR} \
165 | --parallelize \
166 | --pad_to_max_length \
167 | --generation_max_length=512 \
168 | --max_seq_length=1024 \
169 | --max_answer_length=512 \
170 | --save_total_limit=2 \
171 | --parallelize \
172 | --prompt_column=finetuning_prompt \
173 | --completion_column=finetuning_completion \
174 | --overwrite_output_dir \
175 | --train_file ${TRAINING_FILE} || exit
176 |
177 | # Evaluate models (1) and (2) on MBPP_Test
178 | ## First generate programs for MBPP_Test
179 | python generate_code_for_mbpp.py \
180 | --codegen-model-dir=${CHECKPOINTS_DIR} \
181 | --num-samples=${NUM_OUTPUT_SAMPLES} \
182 | --output-dir=${FINAL_GOLD_FINETUNE_DIR} \
183 | --arch=codegen-6B \
184 | -n=1 \
185 | --temperature=0.8 \
186 | --debug -s ${TEST_START_TASK_ID} -e ${TEST_END_TASK_ID} \
187 | --model-path=${FINAL_GOLD_FINETUNE_DIR} || exit
188 | python generate_code_for_mbpp.py \
189 | --codegen-model-dir=${CHECKPOINTS_DIR} \
190 | --num-samples=${NUM_OUTPUT_SAMPLES} \
191 | --output-dir=${FINAL_FINETUNE_DIR} \
192 | --arch=codegen-6B \
193 | -n=1 \
194 | --temperature=0.8 \
195 | --debug -s ${TEST_START_TASK_ID} -e ${TEST_END_TASK_ID} \
196 | --model-path=${FINAL_FINETUNE_DIR} || exit
197 | ## Now evaluate final generations
198 | python eval_mbpp.py \
199 | --input-file=${FINAL_GOLD_FINETUNE_DIR}/samples_test_codegen-6B_1shot_temp0.8_${TEST_START_TASK_ID}-${TEST_END_TASK_ID}.jsonl \
200 | --k=1,10 || exit
201 | python eval_mbpp.py \
202 | --input-file=${FINAL_FINETUNE_DIR}/samples_test_codegen-6B_1shot_temp0.8_${TEST_START_TASK_ID}-${TEST_END_TASK_ID}.jsonl \
203 | --k=1,10 || exit
--------------------------------------------------------------------------------
/preprocess_feedback_spreadsheet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import Dataset, load_dataset
3 |
4 |
5 | def group_by_and_select_one(ds, group_by_col):
6 | df = ds.shuffle().to_pandas()
7 | df = df.groupby(group_by_col).first()
8 | ds = Dataset.from_pandas(df)
9 | return ds
10 |
11 |
12 | def truncate_completion(src):
13 | ref_str = "Refinement:\n"
14 | if ref_str in src:
15 | src = src[src.rfind(ref_str) + len(ref_str) :]
16 | return src
17 |
18 |
19 | def preprocess_data(args):
20 | orig_ext = args.input_file.split(".")[-1]
21 | if orig_ext not in ["csv", "json", "jsonl"]:
22 | raise ValueError(f"{ext} is not a supported file extension.")
23 | if orig_ext == "jsonl":
24 | ext = "json"
25 | else:
26 | ext = orig_ext
27 | d = load_dataset(ext, data_files={"train": args.input_file})["train"].filter(
28 | lambda ex: ex[args.feedback_column] is not None and ex[args.feedback_column]
29 | )
30 |
31 | if args.old_refinement_column is not None:
32 | d = d.map(
33 | lambda ex: {args.refinement_column: ex[args.old_refinement_column]},
34 | remove_columns=[args.old_refinement_column],
35 | )
36 |
37 | d = d.map(
38 | lambda ex: {"completion": ex[args.model_completion_column]},
39 | )
40 |
41 | d = d.filter(
42 | lambda ex: ex[args.refinement_column] is not None and ex[args.refinement_column]
43 | ).map(
44 | lambda ex: {
45 | args.refinement_column: truncate_completion(ex[args.refinement_column])
46 | }
47 | )
48 |
49 | if args.filter_for_correct and "passed" in d.column_names:
50 | # Filter for correct ones only, if the column exists in the spreadsheet
51 | d = d.filter(lambda ex: ex["passed"])
52 |
53 | if args.one_per_task:
54 | # Filter for just one sample per task ID.
55 | d = group_by_and_select_one(d, args.id_col)
56 |
57 | # Split data and print out filenames
58 | output_file_prefix = ".".join(args.input_file.split(".")[:-1])
59 | if args.output_dir is not None:
60 | fname_prefix = output_file_prefix.split("/")[-1]
61 | output_file_prefix = f"{args.output_dir}/{fname_prefix}"
62 |
63 | df = d.to_pandas().set_index(args.id_col)
64 | train_df = df[
65 | (df.index >= args.training_start_id) & (df.index <= args.training_end_id)
66 | ]
67 | train_n = min(len(train_df), args.training_n)
68 | train_df = train_df.sample(n=train_n)
69 | train_output_filepath = f"{output_file_prefix}-train.jsonl"
70 | train_df.reset_index().to_json(train_output_filepath, orient="records", lines=True)
71 | val_df = df[(df.index >= args.val_start_id) & (df.index <= args.val_end_id)]
72 | val_n = min(len(val_df), args.val_n)
73 | val_df = val_df.sample(n=val_n)
74 | val_output_filepath = f"{output_file_prefix}-val.jsonl"
75 | val_df.reset_index().to_json(val_output_filepath, orient="records", lines=True)
76 | print("\n".join([train_output_filepath, val_output_filepath]))
77 |
78 |
79 | def parse_args():
80 | parser = argparse.ArgumentParser(
81 | description="Filter and pre-process CSV or JSONL input file containing feedback and refinements."
82 | )
83 | parser.add_argument(
84 | "--input_file",
85 | default="",
86 | required=True,
87 | help="Input CSV or JSONL file containing feedback and refinements.",
88 | )
89 | parser.add_argument(
90 | "--feedback_column", default="Feedback", help="Name of feedback column."
91 | )
92 | parser.add_argument(
93 | "--old_refinement_column",
94 | default=None,
95 | help="If set, will change the column with this name to --refinement_column.",
96 | )
97 | parser.add_argument(
98 | "--refinement_column", default="Refinement", help="Name of refinement column."
99 | )
100 | parser.add_argument(
101 | "--model_completion_column", default="original_model_completion"
102 | )
103 | parser.add_argument(
104 | "--training_n",
105 | default=None,
106 | type=int,
107 | help="Number of examples to be used for training data. If None, does not split data into train/val.",
108 | )
109 | parser.add_argument(
110 | "--val_n",
111 | default=None,
112 | type=int,
113 | help="Number of examples to be used for validation data. If None, just uses all non-training examples as validation data.",
114 | )
115 | parser.add_argument(
116 | "--id_col",
117 | type=str,
118 | default="task_id",
119 | help="Which column to index on and to split data by.",
120 | )
121 | parser.add_argument(
122 | "--one_per_task",
123 | action="store_true",
124 | help="If set, then will filter only one sample per task.",
125 | )
126 | parser.add_argument(
127 | "--filter_for_correct",
128 | action="store_true",
129 | help="Filter for only the rows for which passed=True. "
130 | + "(May want to keep off for feedback spreadsheets where the 'passed' column corresponds to the original model completion instead of the Refinement.)",
131 | )
132 | parser.add_argument(
133 | "--training_start_id",
134 | type=int,
135 | default=601,
136 | )
137 | parser.add_argument("--training_end_id", type=int, default=974)
138 | parser.add_argument("--val_start_id", type=int, default=511)
139 | parser.add_argument("--val_end_id", type=int, default=600)
140 | parser.add_argument(
141 | "--output_dir",
142 | type=str,
143 | default=None,
144 | help="Output directory. If None, outputs to the same directory that the input file is already in.",
145 | )
146 | args = parser.parse_args()
147 |
148 | # if training_n is set, then val_n must also be set.
149 | assert (args.training_n is None) or (
150 | args.val_n is not None
151 | ), "Error: if --training_n is set, then --val_n must also be set."
152 | return args
153 |
154 |
155 | def main(args):
156 | argsdict = vars(args)
157 | preprocess_data(args)
158 |
159 |
160 | if __name__ == "__main__":
161 | main(parse_args())
162 |
--------------------------------------------------------------------------------