├── .gitattributes ├── fig └── demo.png ├── requirements.txt ├── edit_eval ├── templates │ ├── plain_completion.json │ ├── alpaca_short.json │ ├── alpaca.json │ ├── alpaca_legacy.json │ ├── alpaca_infer.json │ └── README.md ├── evaluate_functional_correctness.py ├── utils.py ├── data.py ├── prompter.py ├── evaluation.py ├── validate_dataset.py └── execution.py ├── LICENSE ├── src ├── template │ └── prompt_template.py ├── postprocess.py ├── generate_instance.py ├── generate_scenarios.py ├── request_helper.py └── bootstrap_instructions.py └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | *.json filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /fig/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qishenghu/InstructCoder/HEAD/fig/demo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasketch==1.5.9 2 | numpy==1.23.2 3 | openai==0.27.0 4 | pandas==1.5.3 5 | rouge_score==0.1.2 6 | tenacity==8.2.2 7 | tiktoken==0.3.3 8 | tqdm==4.64.1 9 | -------------------------------------------------------------------------------- /edit_eval/templates/plain_completion.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Plain completion template", 3 | "prompt_input": "{instruction}{input}", 4 | "prompt_no_input": "{instruction}", 5 | "response_split": "{instruction}" 6 | } 7 | -------------------------------------------------------------------------------- /edit_eval/templates/alpaca_short.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "A shorter template to experiment with.", 3 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /edit_eval/templates/alpaca.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /edit_eval/templates/alpaca_legacy.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Legacy template, used by Original Alpaca repository.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /edit_eval/templates/alpaca_infer.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an commit message from GitHub. Please infer the instruction for code edit from the message. It paired with an input that provides the code before commit as further context. Write a response that appropriately completes the request. If you are not sure how to modify the code, just output it directly.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /edit_eval/evaluate_functional_correctness.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import sys 3 | 4 | from evaluation import evaluate_functional_correctness 5 | 6 | 7 | def entry_point( 8 | sample_file: str, 9 | k: str = "1,10,100", 10 | n_workers: int = 4, 11 | timeout: float = 3.0, 12 | problem_file: str = "edit_eval.jsonl", 13 | ): 14 | """ 15 | Evaluates the functional correctness of generated samples, and writes 16 | results to f"{sample_file}_results.jsonl.gz" 17 | """ 18 | k = list(map(int, k)) 19 | results = evaluate_functional_correctness(sample_file, k, n_workers, timeout, problem_file) 20 | print(results) 21 | 22 | 23 | def main(): 24 | fire.Fire(entry_point) 25 | 26 | 27 | sys.exit(main()) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 qshu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /edit_eval/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import colorama 3 | from colorama import Fore, Style 4 | 5 | 6 | 7 | class ColoredLoggingFormatter(logging.Formatter): 8 | FORMATS = { 9 | logging.DEBUG: Fore.LIGHTBLACK_EX + "{format}" + Style.RESET_ALL, 10 | logging.INFO: Fore.LIGHTBLACK_EX + "{format}" + Style.RESET_ALL, 11 | logging.WARNING: Fore.YELLOW + "{format}" + Style.RESET_ALL, 12 | logging.ERROR: Fore.RED + "{format}" + Style.RESET_ALL, 13 | logging.CRITICAL: Fore.LIGHTRED_EX + "{format}" + Style.RESET_ALL, 14 | } 15 | 16 | def __init__(self, 17 | fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s" # (%(filename)s:%(lineno)d) 18 | ): 19 | # Initialize Colorama to work on Windows as well 20 | colorama.init() 21 | self.formatters = dict() 22 | for levelno in self.FORMATS.keys(): 23 | log_fmt = self.FORMATS.get(levelno).format(format=fmt) 24 | self.formatters[levelno] = logging.Formatter(log_fmt) 25 | 26 | 27 | def format(self, record): 28 | return self.formatters.get(record.levelno).format(record) 29 | 30 | 31 | # The convention is to call this at the top of submodule py files. 32 | def init_logger(name): 33 | # Create a custom logger 34 | logger = logging.getLogger(name) 35 | logger.setLevel(logging.DEBUG) 36 | 37 | # Create a console handler and set the custom formatter 38 | console_handler = logging.StreamHandler() 39 | console_handler.setFormatter(ColoredLoggingFormatter()) 40 | 41 | # Add the console handler to the custom logger 42 | logger.addHandler(console_handler) 43 | 44 | return logger -------------------------------------------------------------------------------- /edit_eval/data.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Dict 2 | import gzip 3 | import json 4 | import os 5 | 6 | 7 | ROOT = os.path.dirname(os.path.abspath(__file__)) 8 | HUMAN_EVAL = os.path.join(ROOT, "..", "data", "HumanEval.jsonl.gz") 9 | 10 | 11 | def read_problems(evalset_file: str = HUMAN_EVAL) -> Dict[str, Dict]: 12 | return {task["task_id"]: task for task in stream_jsonl(evalset_file)} 13 | 14 | 15 | def stream_jsonl(filename: str) -> Iterable[Dict]: 16 | """ 17 | Parses each jsonl line and yields it as a dictionary 18 | """ 19 | if filename.endswith(".gz"): 20 | with open(filename, "rb") as gzfp: 21 | with gzip.open(gzfp, 'rt') as fp: 22 | for line in fp: 23 | if any(not x.isspace() for x in line): 24 | yield json.loads(line) 25 | else: 26 | with open(filename, "r") as fp: 27 | for line in fp: 28 | if any(not x.isspace() for x in line): 29 | yield json.loads(line) 30 | 31 | 32 | def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False): 33 | """ 34 | Writes an iterable of dictionaries to jsonl 35 | """ 36 | if append: 37 | mode = 'ab' 38 | else: 39 | mode = 'wb' 40 | filename = os.path.expanduser(filename) 41 | if filename.endswith(".gz"): 42 | with open(filename, mode) as fp: 43 | with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp: 44 | for x in data: 45 | gzfp.write((json.dumps(x) + "\n").encode('utf-8')) 46 | else: 47 | with open(filename, mode) as fp: 48 | for x in data: 49 | fp.write((json.dumps(x) + "\n").encode('utf-8')) 50 | -------------------------------------------------------------------------------- /edit_eval/prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | import os 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "alpaca" 18 | 19 | root_directory = os.path.dirname(os.path.abspath(__file__)) 20 | file_name = osp.join(root_directory, "templates", f"{template_name}.json") 21 | if not osp.exists(file_name): 22 | raise ValueError(f"Can't read {file_name}") 23 | with open(file_name) as fp: 24 | self.template = json.load(fp) 25 | if self._verbose: 26 | print( 27 | f"Using prompt template {template_name}: {self.template['description']}" 28 | ) 29 | 30 | def generate_prompt( 31 | self, 32 | instruction: str, 33 | input: Union[None, str] = None, 34 | label: Union[None, str] = None, 35 | ) -> str: 36 | # returns the full prompt from instruction and optional input 37 | # if a label (=response, =output) is provided, it's also appended. 38 | if input: 39 | res = self.template["prompt_input"].format( 40 | instruction=instruction, input=input 41 | ) 42 | else: 43 | res = self.template["prompt_no_input"].format( 44 | instruction=instruction 45 | ) 46 | if label: 47 | res = f"{res}{label}" 48 | if self._verbose: 49 | print(res) 50 | return res 51 | 52 | def get_response(self, output: str) -> str: 53 | return output.split(self.template["response_split"])[1].strip() 54 | -------------------------------------------------------------------------------- /edit_eval/templates/README.md: -------------------------------------------------------------------------------- 1 | # Prompt templates 2 | 3 | This directory contains template styles for the prompts used to finetune LoRA models. 4 | 5 | ## Format 6 | 7 | A template is described via a JSON file with the following keys: 8 | 9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. 10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. 11 | - `description`: A short description of the template, with possible use cases. 12 | - `response_split`: The text to use as separator when cutting real response from the model output. 13 | 14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. 15 | 16 | ## Example template 17 | 18 | The default template, used unless otherwise specified, is `alpaca.json` 19 | 20 | ```json 21 | { 22 | "description": "Template used by Alpaca-LoRA.", 23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 25 | "response_split": "### Response:" 26 | } 27 | 28 | ``` 29 | 30 | ## Current templates 31 | 32 | ### alpaca 33 | 34 | Default template used for generic LoRA fine tunes so far. 35 | 36 | ### alpaca_legacy 37 | 38 | Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments. 39 | 40 | ### alpaca_short 41 | 42 | A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome. 43 | 44 | ### vigogne 45 | 46 | The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning. 47 | -------------------------------------------------------------------------------- /src/template/prompt_template.py: -------------------------------------------------------------------------------- 1 | 2 | # Prompt template for the generating instruction 3 | INSTRUCTION_PROMPT = "Given the existing instructions, please generate a list of diverse python code editing instructions. The new instructions should address diverse editing task. Please ensure that the instructions are clear and diverse. Include any relevant variable name in the instruction." 4 | INSTRUCTION_PROMPT_WITH_TYPE = "Given the existing instructions, please generate a list of diverse python code editing instructions. The new instructions should address diverse editing tasks related to `{edit_task}`. Please ensure that the instructions are clear and diverse. Include any relevant variable name in the instruction." 5 | EDIT_TYPES = [ 6 | "Completing incomplete code", 7 | "Optimizing code time complexity or space(memory) complexity", 8 | "Fixing bugs", 9 | "Adding comments or docstring", 10 | "Create unit tests", 11 | "Refactoring code", 12 | "Improving code readability", 13 | "Implementing error handling", 14 | "Enhancing performance", 15 | "Updating library dependencies", 16 | "Adapting code for accessibility", 17 | "Adhering to coding standards", 18 | "Ensuring code security", 19 | "Modularizing code", 20 | "Migrating to a new framework", 21 | "Adding logging and monitoring", 22 | "Implementing code linting", 23 | "Internationalizing code", 24 | "Removing dead code", 25 | "Implementing caching mechanisms", 26 | "Applying design patterns", 27 | "Addressing memory leaks", 28 | "Implementing multithreading", 29 | "Reducing code duplication", 30 | "Integrating APIs and services", 31 | "Adding new features/functionality", 32 | ] 33 | 34 | # Prompt template for the generating scenario 35 | SCENARIO_PROMPT = """Given a python code editing task, please come up with 10 diverse scenarios concise description where this python code editing task could be performed or come from. 36 | 37 | {SHOT1} 38 | 39 | 40 | {SHOT2} 41 | """ 42 | 43 | 44 | # Prompt template for the generating input/output code pair 45 | INSTANCE_PROMPT = "Given python code editing task instructions and their scenarios where the task instruction could be used, you need to come up with examples for the following code editing tasks. You need to generate input and output code pair and make sure your variable names are suitable for the scenario. The input code is related to the task instruction, but must NOT meet the task requirements. The output code fulfills the task requirements based on input code.\n\n" -------------------------------------------------------------------------------- /src/postprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import string 6 | 7 | from datasketch import MinHash, MinHashLSH 8 | 9 | 10 | def format_instance(res): 11 | input, output = '', '' 12 | if 'Input:' in res and 'Output:' in res: 13 | input = res.split('Output:')[0].split('Input:')[1].strip('\n') 14 | output = res.split('Output:')[1].strip('\n') 15 | if input.startswith(' \n'): 16 | input = input.lstrip(' \n') 17 | if output.startswith(' \n'): 18 | output = output.lstrip(' \n') 19 | return input, output 20 | 21 | 22 | def format_filter(input, output): 23 | failure_indicators = ['sorry', 'please provide', 'cannot', 'not able', 'already'] 24 | if input == "" or output == "": 25 | return False 26 | if input == output: 27 | return False 28 | if input.strip().endswith(":") or output.strip().endswith(":"): 29 | return False 30 | if any([phrase in input.lower() for phrase in failure_indicators]): 31 | return False 32 | if any([phrase in output.lower() for phrase in failure_indicators]): 33 | return False 34 | return True 35 | 36 | 37 | def load_file(path): 38 | files = [] 39 | if not os.path.exists(path): 40 | return files 41 | if path.endswith('.jsonl'): 42 | files = [json.loads(line) for line in open(path).readlines()] 43 | elif path.endswith('json'): 44 | files = json.loads(open(path).read()) 45 | else: 46 | raise NotImplementedError 47 | return files 48 | 49 | 50 | def filter_duplicate_instance(instances, threshold=0.75): 51 | unique_instances = [] 52 | lsh = MinHashLSH(threshold=threshold) 53 | for idx, instance in enumerate(instances): 54 | instance_string = instance['input'] 55 | instance_string = instance_string.translate(str.maketrans('', '', string.punctuation)) 56 | instance_string = instance_string.strip().split() 57 | mh = MinHash() 58 | for d in instance_string: 59 | mh.update(d.encode('utf8')) 60 | if not lsh.is_empty(): 61 | res = lsh.query(mh) 62 | if len(res) != 0: 63 | continue 64 | lsh.insert(str(idx), mh) 65 | unique_instances.append(instance) 66 | del lsh 67 | return unique_instances 68 | 69 | 70 | def parse_args(): 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument( 73 | "--input_file", 74 | type=str, 75 | default="temp/machine_generated_instances.jsonl", 76 | ) 77 | parser.add_argument( 78 | "--output_file", 79 | type=str, 80 | default="temp/instances.jsonl", 81 | ) 82 | 83 | parser.add_argument 84 | return parser.parse_args() 85 | 86 | 87 | if __name__ == "__main__": 88 | args = parse_args() 89 | machine_generated_instances = load_file(args.input_file) 90 | print(f"Total #machine_generated_instances: {len(machine_generated_instances)}") 91 | formatted_instances = [] 92 | for instance in machine_generated_instances: 93 | instance['input'], instance['output'] = format_instance(instance['result']) 94 | if format_filter(instance['input'], instance['output']): 95 | formatted_instances.append(instance) 96 | print(f"Total #formatted_instances: {len(formatted_instances)}") 97 | 98 | unique_instances = filter_duplicate_instance(formatted_instances) 99 | print(f"Total #unique_instances: {len(unique_instances)}") 100 | 101 | with open(args.output_file, 'w') as f: 102 | for instance in unique_instances: 103 | f.write(json.dumps({ 104 | 'instruction': instance['instruction'], 105 | 'input': instance['input'], 106 | 'output': instance['output'], 107 | 'scenario': instance['scenario'], 108 | 'scenario_list': instance['scenario_list'], 109 | }) + '\n') -------------------------------------------------------------------------------- /edit_eval/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | from collections import defaultdict, Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import List, Union, Iterable, Dict 7 | import itertools 8 | 9 | import numpy as np 10 | import tqdm 11 | 12 | from data import HUMAN_EVAL, read_problems, stream_jsonl, write_jsonl 13 | from execution import check_correctness 14 | 15 | CODE_MARKER = r"{{Code}}" 16 | 17 | def build_program_and_tests(problem, code): 18 | if "context" in problem.keys() and CODE_MARKER in problem["context"]: 19 | code = problem["context"].replace(CODE_MARKER, code) 20 | 21 | return ( 22 | code + "\n\n" + 23 | problem["test"] + "\n\n" + 24 | # f"check({problem['entry_point']})" 25 | f"check()" 26 | ).strip() 27 | 28 | def estimate_pass_at_k( 29 | num_samples: Union[int, List[int], np.ndarray], 30 | num_correct: Union[List[int], np.ndarray], 31 | k: int 32 | ) -> np.ndarray: 33 | """ 34 | Estimates pass@k of each problem and returns them in an array. 35 | """ 36 | 37 | def estimator(n: int, c: int, k: int) -> float: 38 | """ 39 | Calculates 1 - comb(n - c, k) / comb(n, k). 40 | """ 41 | if n - c < k: 42 | return 1.0 43 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 44 | 45 | if isinstance(num_samples, int): 46 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 47 | else: 48 | assert len(num_samples) == len(num_correct) 49 | num_samples_it = iter(num_samples) 50 | 51 | return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) 52 | 53 | def evaluate_functional_correctness( 54 | sample_file: str, 55 | k: List[int] = [1, 10, 100], 56 | n_workers: int = 4, 57 | timeout: float = 3.0, 58 | problem_file: str = HUMAN_EVAL, 59 | ): 60 | """ 61 | Evaluates the functional correctness of generated samples, and writes 62 | results to f"{sample_file}_results.jsonl.gz" 63 | """ 64 | 65 | problems = read_problems(problem_file) 66 | 67 | # Check the generated samples against test suites. 68 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 69 | 70 | futures = [] 71 | code_id = Counter() 72 | n_samples = 0 73 | results = defaultdict(list) 74 | 75 | print("Reading samples...") 76 | for sample in tqdm.tqdm(stream_jsonl(sample_file)): 77 | task_id = sample["task_id"] 78 | edited_code = sample["output"] 79 | check_program = build_program_and_tests(sample, edited_code) 80 | args = (problems[task_id], check_program, timeout, code_id[task_id]) 81 | future = executor.submit(check_correctness, *args) 82 | futures.append(future) 83 | code_id[task_id] += 1 84 | n_samples += 1 85 | assert len(code_id) == len(problems), "Some problems are not attempted." 86 | 87 | print("Running test suites...") 88 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 89 | result = future.result() 90 | results[result["task_id"]].append((result["run_id"], result)) 91 | 92 | # Calculate pass@k. 93 | total, correct = [], [] 94 | for result in results.values(): 95 | result.sort() 96 | passed = [r[1]["passed"] for r in result] 97 | total.append(len(passed)) 98 | correct.append(sum(passed)) 99 | total = np.array(total) 100 | correct = np.array(correct) 101 | 102 | ks = k 103 | pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() 104 | for k in ks if (total >= k).all()} 105 | 106 | # Finally, save the results in one file: 107 | def combine_results(): 108 | for sample in stream_jsonl(sample_file): 109 | task_id = sample["task_id"] 110 | result = results[task_id].pop(0) 111 | sample["result"] = result[1]["result"] 112 | sample["passed"] = result[1]["passed"] 113 | yield sample 114 | 115 | out_file = sample_file + "_results.jsonl" 116 | print(f"Writing results to {out_file}...") 117 | write_jsonl(out_file, tqdm.tqdm(combine_results(), total=n_samples)) 118 | 119 | return pass_at_k 120 | 121 | -------------------------------------------------------------------------------- /edit_eval/validate_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | import argparse 6 | 7 | import numpy as np 8 | import tqdm 9 | 10 | from data import HUMAN_EVAL, read_problems, stream_jsonl, write_jsonl 11 | from execution import check_correctness 12 | 13 | CODE_MARKER = r"{{Code}}" 14 | 15 | def build_program_and_tests(problem, code): 16 | if "context" in problem.keys() and CODE_MARKER in problem["context"]: 17 | code = problem["context"].replace(CODE_MARKER, code) 18 | 19 | return ( 20 | code + "\n\n" + 21 | problem["test"] + "\n\n" + 22 | # f"check({problem['entry_point']})" 23 | f"check()" 24 | ) 25 | 26 | ###################################### 27 | # Methods for validating the dataset # 28 | ###################################### 29 | 30 | def assert_inputs_fail_tests( 31 | problem_file: str = None, 32 | n_workers: int = 4, 33 | timeout: float = 10.0, 34 | ): 35 | """ 36 | Checks if all input programs (before edit) fails all the tests. 37 | """ 38 | 39 | problems = read_problems(problem_file) 40 | 41 | # Check the generated samples against test suites. 42 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 43 | 44 | futures = [] 45 | results = [] 46 | 47 | print("Reading samples...") 48 | for sample in tqdm.tqdm(stream_jsonl(problem_file)): 49 | task_id = sample["task_id"] 50 | code = sample["input"] 51 | check_program = build_program_and_tests(sample, code) 52 | args = (problems[task_id], check_program, timeout, task_id) 53 | future = executor.submit(check_correctness, *args) 54 | futures.append(future) 55 | 56 | print("Running test suites...") 57 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 58 | result = future.result() 59 | results.append((result["run_id"], result)) 60 | 61 | all_correct = True 62 | for result in results: 63 | run_id, result = result 64 | passed = result["passed"] # Only one code per task 65 | if passed: 66 | print(f"The input code of {run_id} should not pass the tests.") 67 | all_correct = False 68 | 69 | if all_correct: 70 | print("All input code (before edit) in the dataset failed as expected.") 71 | 72 | 73 | 74 | def assert_target_pass_tests( 75 | problem_file: str = None, 76 | n_workers: int = 4, 77 | timeout: float = 10.0, 78 | ): 79 | """ 80 | Checks if all input programs (before edit) fails all the tests. 81 | """ 82 | 83 | problems = read_problems(problem_file) 84 | 85 | # Check the generated samples against test suites. 86 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 87 | 88 | futures = [] 89 | results = [] 90 | 91 | print("Reading samples...") 92 | for sample in tqdm.tqdm(stream_jsonl(problem_file)): 93 | task_id = sample["task_id"] 94 | code = sample["output"] 95 | check_program = build_program_and_tests(sample, code) 96 | args = (problems[task_id], check_program, timeout, task_id) 97 | future = executor.submit(check_correctness, *args) 98 | futures.append(future) 99 | 100 | print("Running test suites...") 101 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 102 | result = future.result() 103 | results.append((result["run_id"], result)) 104 | 105 | all_correct = True 106 | for result in results: 107 | run_id, result = result 108 | passed = result["passed"] # Only one code per task 109 | if not passed: 110 | print(f"The target code of {run_id} should pass the tests.") 111 | 112 | if all_correct: 113 | print("All target code (after edit) in the dataset passed as expected.") 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description="Validates the specified JSONL dataset.") 118 | parser.add_argument("file_path", type=str, help="Path to the JSONL file") 119 | 120 | args = parser.parse_args() 121 | 122 | assert_inputs_fail_tests(args.file_path) 123 | print("_" * 100) 124 | print("Testing target code") 125 | assert_target_pass_tests(args.file_path) -------------------------------------------------------------------------------- /src/generate_instance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import re 6 | from collections import OrderedDict 7 | 8 | import openai 9 | import pandas as pd 10 | import tiktoken 11 | import tqdm 12 | from tenacity import retry, stop_after_attempt, wait_random_exponential 13 | 14 | from request_helper import ParallelRunner 15 | from template.prompt_template import INSTANCE_PROMPT 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--input_file", 22 | type=str, 23 | default="temp/machine_generated_scenarios.jsonl", 24 | ) 25 | parser.add_argument( 26 | "--output_file", 27 | type=str, 28 | default="temp/machine_generated_instances.jsonl", 29 | ) 30 | parser.add_argument( 31 | "--engine", 32 | type=str, 33 | default="gpt-3.5-turbo", 34 | help="The engine to use." 35 | ) 36 | parser.add_argument( 37 | "--api_key", 38 | type=str, 39 | help="openai api key" 40 | ) 41 | parser.add_argument( 42 | "--seed_tasks_path", 43 | type=str, 44 | default='data/additional_seed.json', 45 | help="The path to the seed human-evaluated data.", 46 | ) 47 | parser.add_argument( 48 | "--proc_num", 49 | type=int, 50 | default=2, 51 | help="The number of concurrent process." 52 | ) 53 | 54 | parser.add_argument 55 | return parser.parse_args() 56 | 57 | 58 | @retry(wait=wait_random_exponential(min=8, max=16), stop=stop_after_attempt(8)) 59 | def query_for_instance(item, key=None): 60 | if key is not None: 61 | openai.api_key = key 62 | 63 | task = item['instruction'] 64 | scenario = item['scenario'] 65 | scenario_list = item['scenario_list'] 66 | prompt_str = item['prompt'] 67 | 68 | response = openai.ChatCompletion.create( 69 | model = args.engine, 70 | messages = [ 71 | {"role": "user", "content": prompt_str}, 72 | ], 73 | temperature = 0., 74 | max_tokens = 1024, 75 | presence_penalty = 1.5, 76 | ) 77 | 78 | output = response['choices'][0]['message']['content'] 79 | query_result = { 80 | 'instruction': task, 81 | 'scenario': scenario, 82 | 'scenario_list': scenario_list, 83 | 'result': output, 84 | 'response': response 85 | } 86 | return query_result 87 | 88 | 89 | def load_file(path): 90 | files = [] 91 | if not os.path.exists(path): 92 | return files 93 | if path.endswith('.jsonl'): 94 | files = [json.loads(line) for line in open(path).readlines()] 95 | elif path.endswith('json'): 96 | files = json.loads(open(path).read()) 97 | else: 98 | raise NotImplementedError 99 | return files 100 | 101 | 102 | def encode_prompt(item, tasks): 103 | start_prompt = INSTANCE_PROMPT 104 | shots = random.sample(tasks, 4) 105 | for n_shot in range(4, 0, -1): 106 | prompt = start_prompt 107 | for shot in shots[:n_shot]: 108 | prompt += f"Scenario: {shot['scenario']}\nTask: {shot['instruction']}\n" 109 | prompt += f"Input: \n{shot['input']}\n\nOutput: \n{shot['output']}\n\n\n" 110 | prompt += f"Scenario: {item['scenario']}\nTask: {item['instruction']}\n" 111 | if len(tokenizer.encode(prompt)) < 3072: 112 | return prompt 113 | return start_prompt 114 | 115 | 116 | if __name__ == "__main__": 117 | args = parse_args() 118 | tokenizer = tiktoken.encoding_for_model("gpt2") 119 | seed_tasks = load_file(args.seed_tasks_path) 120 | print(f"Total #seed tasks: {len(seed_tasks)}") 121 | 122 | tasks = load_file(args.input_file) 123 | print(f"Total #tasks: {len(tasks)}") 124 | 125 | exist_requests = set([item['instruction'] for item in load_file(args.output_file)]) 126 | 127 | unresolved_tasks, resolved_tasks = [], [] 128 | for task in tasks: 129 | hash = task['instruction'] 130 | if hash in exist_requests: 131 | resolved_tasks.append(task) 132 | else: 133 | task['prompt'] = encode_prompt(task, seed_tasks) 134 | unresolved_tasks.append(task) 135 | 136 | with open(args.output_file, "a") as fout: 137 | runner = ParallelRunner(key=args.api_key, num_workers=args.proc_num, verbose=True) 138 | results = runner.start(data=unresolved_tasks, 139 | query_func=query_for_instance, 140 | output_filename=args.output_file) -------------------------------------------------------------------------------- /src/generate_scenarios.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import re 6 | from collections import OrderedDict 7 | 8 | import openai 9 | import pandas as pd 10 | import tqdm 11 | from tenacity import retry, stop_after_attempt, wait_random_exponential 12 | 13 | from request_helper import ParallelRunner 14 | from template.prompt_template import SCENARIO_PROMPT 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--input_file", 21 | type=str, 22 | default="temp/machine_generated_instructions.jsonl", 23 | ) 24 | parser.add_argument( 25 | "--output_file", 26 | type=str, 27 | default="temp/machine_generated_scenarios.jsonl", 28 | ) 29 | parser.add_argument( 30 | "--engine", 31 | type=str, 32 | default="gpt-3.5-turbo", 33 | help="The engine to use." 34 | ) 35 | parser.add_argument( 36 | "--api_key", 37 | type=str, 38 | help="openai api key" 39 | ) 40 | parser.add_argument( 41 | "--seed_tasks_path", 42 | type=str, 43 | default='data/additional_seed.json', 44 | help="The path to the seed human-evaluated data.", 45 | ) 46 | parser.add_argument( 47 | "--proc_num", 48 | type=int, 49 | default=2, 50 | help="The number of concurrent process." 51 | ) 52 | 53 | parser.add_argument 54 | return parser.parse_args() 55 | 56 | 57 | @retry(wait=wait_random_exponential(min=8, max=16), stop=stop_after_attempt(8)) 58 | def query_for_scenarios(item, key=None): 59 | if key is not None: 60 | openai.api_key = key 61 | 62 | task = item['instruction'] 63 | prompt_str = item['prompt'] 64 | 65 | response = openai.ChatCompletion.create( 66 | model = args.engine, 67 | messages = [ 68 | {"role": "user", "content": prompt_str}, 69 | {"role": "user", "content": f"Task: {task}"}, 70 | ], 71 | temperature = 0.7, 72 | max_tokens = 1300, 73 | presence_penalty = 1.5, 74 | ) 75 | 76 | scenario_list = response['choices'][0]['message']['content'] 77 | scenario = "" 78 | if re.search("Scenario \d+:.+", scenario_list): 79 | scenarios = re.findall("Scenario \d+:.+", scenario_list) 80 | scenarios = [s.split(':')[1].strip() for s in scenarios] 81 | scenario = random.choice(scenarios) 82 | 83 | 84 | query_result = { 85 | 'instruction': task, 86 | 'scenario_list': scenario_list, 87 | 'scenario': scenario, 88 | 'response': response, 89 | } 90 | return query_result 91 | 92 | 93 | def encode_prompt(seed_tasks): 94 | (shot1, shot2) = random.sample(seed_tasks, 2) 95 | # shot1 96 | shot1_scenarios = shot1['scenario_list'].strip() 97 | shot1_prompt = f"Task: {shot1['instruction']}\n{shot1_scenarios}" 98 | 99 | # shot2 100 | shot2_scenarios = shot2['scenario_list'].strip() 101 | shot2_prompt = f"Task: {shot2['instruction']}\n{shot2_scenarios}" 102 | 103 | few_shot_prompt = SCENARIO_PROMPT.format(SHOT1=shot1_prompt, SHOT2=shot2_prompt) 104 | return few_shot_prompt 105 | 106 | 107 | def load_file(path): 108 | files = [] 109 | if not os.path.exists(path): 110 | return files 111 | if path.endswith('.jsonl'): 112 | files = [json.loads(line) for line in open(path).readlines()] 113 | elif path.endswith('json'): 114 | files = json.loads(open(path).read()) 115 | else: 116 | raise NotImplementedError 117 | return files 118 | 119 | 120 | if __name__ == "__main__": 121 | args = parse_args() 122 | seed_tasks = load_file(args.seed_tasks_path) 123 | print(f"Total #seed tasks: {len(seed_tasks)}") 124 | 125 | tasks = load_file(args.input_file) 126 | print(f"Total #tasks: {len(tasks)}") 127 | 128 | exist_requests = set([item['instruction'] for item in load_file(args.output_file)]) 129 | 130 | unresolved_tasks, resolved_tasks = [], [] 131 | for task in tasks: 132 | hash = task['instruction'] 133 | if hash in exist_requests: 134 | resolved_tasks.append(task) 135 | else: 136 | task['prompt'] = encode_prompt(seed_tasks) 137 | unresolved_tasks.append(task) 138 | 139 | with open(args.output_file, "a") as fout: 140 | runner = ParallelRunner(key=args.api_key, num_workers=args.proc_num, verbose=True) 141 | results = runner.start(data=unresolved_tasks, 142 | query_func=query_for_scenarios, 143 | output_filename=args.output_file) -------------------------------------------------------------------------------- /src/request_helper.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | import json 4 | import logging 5 | import multiprocessing 6 | from multiprocessing.dummy import Pool as ThreadPool 7 | import os 8 | import random 9 | 10 | from tqdm import tqdm 11 | 12 | logger = logging.getLogger('ParallelRunner') 13 | 14 | 15 | def chunk(li: list, chunk_size: int) -> list[list]: 16 | return [li[i:i+chunk_size] for i in range(0, len(li), chunk_size)] 17 | 18 | 19 | class ParallelRunner: 20 | def __init__(self, 21 | key: str, 22 | num_workers=2, 23 | chunk_size=1, 24 | verbose=True) -> None: 25 | self.key = key 26 | self.num_workers = num_workers 27 | self.chunk_size = chunk_size 28 | self.query_results = [] 29 | self._query_func = None 30 | self.verbose = verbose 31 | self._file = None 32 | 33 | 34 | @staticmethod 35 | def query_once(item, query_func: callable, key: str, debug=False): 36 | result = None 37 | if debug: 38 | result = query_func(item, key) 39 | else: 40 | try: 41 | result = query_func(item, key) 42 | except Exception as e: 43 | logger.error(str(e)) 44 | return None 45 | return result 46 | 47 | 48 | @staticmethod 49 | def query(data, query_func, key): 50 | """The entrance function executed by worker process(es).""" 51 | results = [] 52 | for item in data: 53 | while True: 54 | result = ParallelRunner.query_once(item, query_func, key) 55 | # Got result 56 | if result is not None: 57 | results.append(result) 58 | break 59 | return results 60 | 61 | 62 | @staticmethod 63 | def query_batch(data, query_func, key): 64 | """The entrance function executed by worker process(es).""" 65 | results = None 66 | while True: 67 | results = ParallelRunner.query_once(data, query_func, key) 68 | if results is not None and isinstance(results, list): 69 | break 70 | return results 71 | 72 | 73 | def handle_result(self, results): 74 | if results: 75 | for item in results: 76 | self._file.write(json.dumps(item)) 77 | self._file.write('\n') 78 | self._file.flush() 79 | self.query_results.extend(results) 80 | if self.num_workers > 0 and self.verbose: 81 | self.progress_bar.update(self.chunk_size) 82 | 83 | 84 | def handle_error(self, error): 85 | print(error, flush=True) 86 | 87 | 88 | def filter_unfinished_jobs(self, data, output_filename): 89 | with open(output_filename, 'r', encoding='utf-8') as f: 90 | saved_jobs = (json.loads(js_str) for js_str in f.read().strip().split('\n')) 91 | saved_job_ids = set([job['job_id'] for job in saved_jobs]) 92 | print(f"Found {len(saved_job_ids)} saved jobs.") 93 | return list(filter(lambda item: item['job_id'] not in saved_job_ids, data)) 94 | 95 | 96 | def start(self, data, query_func, output_filename, resume=False, batch=False): 97 | if resume: 98 | data = self.filter_unfinished_jobs(data, output_filename) 99 | 100 | self._file = open(output_filename, 'a', encoding='utf-8') 101 | 102 | if self.verbose: 103 | print("Starting jobs...") 104 | if self.num_workers == 0: 105 | iterator = tqdm(data) if self.verbose else data 106 | for item in iterator: 107 | results = self.query((item,), query_func, self.keys, self.cfg) 108 | self.handle_result(results) 109 | else: 110 | pool = ThreadPool(self.num_workers) 111 | self._query_func = query_func 112 | self.progress_bar = tqdm(total=len(data)) if self.verbose else None 113 | data_chunks = chunk(data, self.chunk_size) 114 | for dt_i, data in enumerate(data_chunks): 115 | if batch: 116 | pool.apply_async(self.query_batch, 117 | (data, query_func, self.key), 118 | callback=self.handle_result, error_callback=self.handle_error) 119 | else: 120 | pool.apply_async(self.query, 121 | (data, query_func, self.key), 122 | callback=self.handle_result, error_callback=self.handle_error) 123 | pool.close() 124 | pool.join() 125 | if self.verbose: 126 | self.progress_bar.close() 127 | if self.verbose: 128 | print("Jobs finished.") 129 | return self.query_results 130 | -------------------------------------------------------------------------------- /edit_eval/execution.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Dict 2 | import ast 3 | import contextlib 4 | import faulthandler 5 | import io 6 | import os 7 | import multiprocessing 8 | import platform 9 | import signal 10 | import tempfile 11 | 12 | 13 | def check_correctness(problem: Dict, code: str, timeout: float, 14 | run_id: Optional[int] = None) -> Dict: 15 | """ 16 | Evaluates the functional correctness of a generated code by running the test 17 | suite provided in the problem. 18 | 19 | :param run_id: an optional run ID so we can match 20 | the results later even if execution finishes asynchronously. 21 | """ 22 | 23 | def unsafe_execute(): 24 | 25 | with create_tempdir(): 26 | 27 | # These system calls are needed when cleaning up tempdir. 28 | import os 29 | import shutil 30 | rmtree = shutil.rmtree 31 | rmdir = os.rmdir 32 | chdir = os.chdir 33 | 34 | # Disable functionalities that can make destructive changes to the test. 35 | reliability_guard() 36 | 37 | # Construct the check program and run it. 38 | check_program = code 39 | # print(check_program) 40 | 41 | try: 42 | exec_globals = {} 43 | with swallow_io(): 44 | with time_limit(timeout): 45 | # WARNING 46 | # This program exists to execute untrusted model-generated code. Although 47 | # it is highly unlikely that model-generated code will do something overtly 48 | # malicious in response to this test suite, model-generated code may act 49 | # destructively due to a lack of model capability or alignment. 50 | # Users are strongly encouraged to sandbox this evaluation suite so that it 51 | # does not perform destructive actions on their host or network. For more 52 | # information on how OpenAI sandboxes its code, see the accompanying paper. 53 | # Once you have read this disclaimer and taken appropriate precautions, 54 | # uncomment the following line and proceed at your own risk: 55 | exec(check_program, exec_globals) 56 | result.append("passed") 57 | except TimeoutException: 58 | result.append("timed out") 59 | except BaseException as e: 60 | # print(e) 61 | result.append(f"failed: {e}") 62 | 63 | # Needed for cleaning up. 64 | shutil.rmtree = rmtree 65 | os.rmdir = rmdir 66 | os.chdir = chdir 67 | 68 | manager = multiprocessing.Manager() 69 | result = manager.list() 70 | 71 | p = multiprocessing.Process(target=unsafe_execute) 72 | p.start() 73 | p.join(timeout=timeout + 1) 74 | if p.is_alive(): 75 | p.kill() 76 | 77 | if not result: 78 | result.append("timed out") 79 | 80 | return dict( 81 | task_id=problem["task_id"], 82 | passed=result[0] == "passed", 83 | result=result[0], 84 | run_id=run_id, 85 | ) 86 | 87 | 88 | @contextlib.contextmanager 89 | def time_limit(seconds: float): 90 | def signal_handler(signum, frame): 91 | raise TimeoutException("Timed out!") 92 | signal.setitimer(signal.ITIMER_REAL, seconds) 93 | signal.signal(signal.SIGALRM, signal_handler) 94 | try: 95 | yield 96 | finally: 97 | signal.setitimer(signal.ITIMER_REAL, 0) 98 | 99 | 100 | @contextlib.contextmanager 101 | def swallow_io(): 102 | stream = WriteOnlyStringIO() 103 | with contextlib.redirect_stdout(stream): 104 | with contextlib.redirect_stderr(stream): 105 | with redirect_stdin(stream): 106 | yield 107 | 108 | 109 | @contextlib.contextmanager 110 | def create_tempdir(): 111 | with tempfile.TemporaryDirectory() as dirname: 112 | with chdir(dirname): 113 | yield dirname 114 | 115 | 116 | class TimeoutException(Exception): 117 | pass 118 | 119 | 120 | class WriteOnlyStringIO(io.StringIO): 121 | """ StringIO that throws an exception when it's read from """ 122 | 123 | def read(self, *args, **kwargs): 124 | raise IOError 125 | 126 | def readline(self, *args, **kwargs): 127 | raise IOError 128 | 129 | def readlines(self, *args, **kwargs): 130 | raise IOError 131 | 132 | def readable(self, *args, **kwargs): 133 | """ Returns True if the IO object can be read. """ 134 | return False 135 | 136 | 137 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 138 | _stream = 'stdin' 139 | 140 | 141 | @contextlib.contextmanager 142 | def chdir(root): 143 | if root == ".": 144 | yield 145 | return 146 | cwd = os.getcwd() 147 | os.chdir(root) 148 | try: 149 | yield 150 | except BaseException as exc: 151 | raise exc 152 | finally: 153 | os.chdir(cwd) 154 | 155 | 156 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 157 | """ 158 | This disables various destructive functions and prevents the generated code 159 | from interfering with the test (e.g. fork bomb, killing other processes, 160 | removing filesystem files, etc.) 161 | 162 | WARNING 163 | This function is NOT a security sandbox. Untrusted code, including, model- 164 | generated code, should not be blindly executed outside of one. See the 165 | Codex paper for more information about OpenAI's code sandbox, and proceed 166 | with caution. 167 | """ 168 | 169 | if maximum_memory_bytes is not None: 170 | import resource 171 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 172 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 173 | if not platform.uname().system == 'Darwin': 174 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 175 | 176 | faulthandler.disable() 177 | 178 | import builtins 179 | builtins.exit = None 180 | builtins.quit = None 181 | 182 | import os 183 | os.environ['OMP_NUM_THREADS'] = '1' 184 | 185 | os.kill = None 186 | os.system = None 187 | os.putenv = None 188 | os.remove = None 189 | os.removedirs = None 190 | os.rmdir = None 191 | os.fchdir = None 192 | os.setuid = None 193 | os.fork = None 194 | os.forkpty = None 195 | os.killpg = None 196 | os.rename = None 197 | os.renames = None 198 | os.truncate = None 199 | os.replace = None 200 | os.unlink = None 201 | os.fchmod = None 202 | os.fchown = None 203 | os.chmod = None 204 | os.chown = None 205 | os.chroot = None 206 | os.fchdir = None 207 | os.lchflags = None 208 | os.lchmod = None 209 | os.lchown = None 210 | os.getcwd = None 211 | os.chdir = None 212 | 213 | import shutil 214 | shutil.rmtree = None 215 | shutil.move = None 216 | shutil.chown = None 217 | 218 | import subprocess 219 | subprocess.Popen = None # type: ignore 220 | 221 | __builtins__['help'] = None 222 | 223 | import sys 224 | sys.modules['ipdb'] = None 225 | sys.modules['joblib'] = None 226 | sys.modules['resource'] = None 227 | sys.modules['psutil'] = None 228 | sys.modules['tkinter'] = None 229 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InstructCoder: Instruction Tuning Large Language Models for Code Editing 2 | 3 | **[Paper](https://arxiv.org/abs/2310.20329) | [Dataset](https://huggingface.co/datasets/happylkx/InstructCoder) | [Blog](https://blog.nus.edu.sg/kaixinli/2023/05/23/codeinstruct/)** 4 | 5 | ![Alt text](./fig/demo.png "Pipeline & Example") 6 | 7 | 8 | 9 | 10 | # Overview 11 | InstructCoder is the first dataset designed to adapt LLMs for general code editing. It consists of over 114,000 instruction-input-output triplets and covers multiple distinct code editing scenarios, generated by ChatGPT. 12 | 13 | 14 | # Updates 15 | [2024-2-25] The test set **EditEval** and code are released. 16 | 17 | [2023-11-2] The training set is available at [Huggingface](https://huggingface.co/datasets/happylkx/InstructCoder). 18 | 19 | [2023-5-22] The code is released. 20 | 21 | # Data Collection 22 | To generate instructional data for code editing, we employed a similar method based on [Self-Instruct](https://github.com/yizhongw/self-instruct). This methodology of generating training data using LLMs requires minimal human-labeled data as seed tasks while still maintaining the quality and relevance of the tasks in the dataset. InstructCoder is systematically expanded through an iterative process that commences with editing data sourced from GitHub commits as seed tasks. Seed and generated tasks are used subsequently bootstrapped to prompt ChatGPT for more task data. 23 | 24 | For each generated instruction, we also prompt ChatGPT to generate a list of practical events as 'real-world' scenarios where the editing instruction could be performed, and randomly select one for subsequent generation. During instance generation, ChatGPT is instructed to generate examples that correspond with the operation in the instruction while ensuring the codebases and variable names are appropriate for the given scenario. We observe that instances generated with the inclusion of a scenario demonstrate higher quality in terms of richer context and code structure compared to those without. 25 | 26 | 27 | # Results and Examples 28 | The largest model fine-tuned, LLaMA-33B, performs on par with ChatGPT, with 79.3% accuracy on the test set. Some qualitative examples generated by our model are shown [here](#examples). 29 | 30 | 31 | # Release 32 | We are planning to release the following assets: 33 | - [x] Full dataset: Over 114,000 code-editing instructional data. 34 | - [x] Train / Validation data: Arround 95% / 5% of the full dataset. 35 | - [x] Test data: 134 github commit data. 36 | - [x] Github seed data: A total of 634 github commit data. 37 | - [x] Additional seed data: 592 unused generated samples. 38 | - [x] Source code for generating data 39 | - [ ] TODO: Checkpoints 40 | 41 | We are currently working on a clean release of our code. We will upload the code as soon as we finish the job. 42 | 43 | We also plan to release our LoRA checkpoints [here](https://huggingface.co/happylkx). 44 | 45 | 46 | # Examples 47 | 48 | 1. Instruction: Transform the code into a class method called 'calculate_distance' that calculates the distance between two points in 3D space. 49 | 50 | Input code: 51 | ```python 52 | import math 53 | 54 | def distance(point1, point2): 55 | x1, y1, z1 = point1 56 | x2, y2, z2 = point2 57 | dx = x2 - x1 58 | dy = y2 - y1 59 | dz = z2 - z1 60 | return math.sqrt(dx**2 + dy**2 + dz**2) 61 | 62 | point1 = (0, 0, 0) 63 | point2 = (1, 1, 1) 64 | print(distance(point1, point2)) 65 | ``` 66 | 67 | Output code: 68 | ```python 69 | import math 70 | 71 | class Point: 72 | def __init__(self, x, y, z): 73 | self.x = x 74 | self.y = y 75 | self.z = z 76 | 77 | def calculate_distance(self, other_point): 78 | dx = other_point.x - self.x 79 | dy = other_point.y - self.y 80 | dz = other_point.z - self.z 81 | return math.sqrt(dx**2 + dy**2 + dz**2) 82 | 83 | point1 = Point(0, 0, 0) 84 | point2 = Point(1, 1, 1) 85 | print(point1.calculate_distance(point2)) 86 | ``` 87 | 88 | 2. Instruction: Integrate error handling into the code that performs batch processing of large datasets, including job failures or incomplete runs. 89 | 90 | Input code: 91 | ```python 92 | import pandas as pd 93 | 94 | def process_orders(orders_file, shipping_file, inventory_file): 95 | orders = pd.read_csv(orders_file) 96 | shipping = pd.read_csv(shipping_file) 97 | inventory = pd.read_csv(inventory_file) 98 | 99 | # Perform data processing and analysis here... 100 | 101 | results = {"success": True, "message": "Batch processing completed successfully."} 102 | return results 103 | 104 | orders_file = "orders.csv" 105 | shipping_file = "shipping.csv" 106 | inventory_file = "inventory.csv" 107 | 108 | results = process_orders(orders_file, shipping_file, inventory_file) 109 | print(results) 110 | ``` 111 | 112 | Output code: 113 | ```python 114 | import pandas as pd 115 | import logging 116 | 117 | logging.basicConfig(filename='batch_processing.log', level=logging.ERROR) 118 | 119 | def process_orders(orders_file, shipping_file, inventory_file): 120 | try: 121 | orders = pd.read_csv(orders_file) 122 | shipping = pd.read_csv(shipping_file) 123 | inventory = pd.read_csv(inventory_file) 124 | 125 | # Perform data processing and analysis here... 126 | 127 | results = {"success": True, "message": "Batch processing completed successfully."} 128 | return results 129 | except Exception as e: 130 | logging.error(f"Error processing batch: {e}") 131 | results = {"success": False, "message": "Batch processing failed. Please check the input files for any issues."} 132 | return results 133 | 134 | orders_file = "orders.csv" 135 | shipping_file = "shipping.csv" 136 | inventory_file = "inventory.csv" 137 | 138 | results = process_orders(orders_file, shipping_file, inventory_file) 139 | print(results) 140 | ``` 141 | 142 | 2. Instruction: Refactor any code that uses 'exec' or 'eval' for security reasons. 143 | 144 | Input code: 145 | ```python 146 | game_data = { 147 | "score": 0, 148 | "level": 1 149 | } 150 | 151 | def execute_script(script): 152 | exec(script) 153 | 154 | script = """ 155 | game_data["score"] += 10 156 | if game_data["score"] >= 100: 157 | game_data["level"] += 1 158 | """ 159 | 160 | execute_script(script) 161 | print(game_data) 162 | ``` 163 | 164 | Output code: 165 | ```python 166 | game_data = { 167 | "score": 0, 168 | "level": 1 169 | } 170 | 171 | def execute_script(script): 172 | compiled_script = compile(script, "", "exec") 173 | exec(compiled_script, game_data) 174 | 175 | script = """ 176 | game_data["score"] += 10 177 | if game_data["score"] >= 100: 178 | game_data["level"] += 1 179 | """ 180 | 181 | execute_script(script) 182 | print(game_data) 183 | ``` 184 | Now the script is not able to access the data in `game_data`. 185 | 186 | 187 | # Citation 188 | Feel free to cite this work if you use the data or code in this repo. 189 | ```bibtex 190 | @inproceedings{li2024instructcoder, 191 | title={Instructcoder: Instruction tuning large language models for code editing}, 192 | author={Li, Kaixin and Hu, Qisheng and Zhao, James and Chen, Hui and Xie, Yuxi and Liu, Tiedong and Shieh, Michael and He, Junxian}, 193 | booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 4: Student Research Workshop)}, 194 | pages={50--70}, 195 | year={2024} 196 | } 197 | ``` 198 | -------------------------------------------------------------------------------- /src/bootstrap_instructions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import re 6 | import string 7 | from functools import partial 8 | from multiprocessing import Pool 9 | 10 | import numpy as np 11 | import openai 12 | import pandas as pd 13 | import tqdm 14 | from rouge_score import rouge_scorer 15 | from tenacity import retry, stop_after_attempt, wait_random_exponential 16 | 17 | from request_helper import ParallelRunner 18 | from template.prompt_template import EDIT_TYPES, INSTRUCTION_PROMPT, INSTRUCTION_PROMPT_WITH_TYPE 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "--save_dir", 25 | type=str, 26 | default="./temp/", 27 | help="The directory where the generated file will be saved", 28 | ) 29 | parser.add_argument( 30 | "--seed_tasks_path", 31 | type=str, 32 | nargs="+", 33 | default=['./data/github_seed.json', './data/additional_seed.json'], 34 | help="The path to the seed human-evaluated data.", 35 | ) 36 | parser.add_argument( 37 | "--num_instructions_to_generate", 38 | type=int, 39 | default=5, 40 | help="th", 41 | ) 42 | parser.add_argument( 43 | "--use_edit_type", 44 | action="store_true", 45 | help="If specified, will use template with edit type for generating instructions.", 46 | ) 47 | parser.add_argument( 48 | "--engine", 49 | type=str, 50 | default="gpt-3.5-turbo", 51 | help="The engine to use." 52 | ) 53 | parser.add_argument( 54 | "--num_prompt_instructions", 55 | type=int, 56 | default=8, 57 | help="The number of instructions to use in the prompt." 58 | ) 59 | parser.add_argument( 60 | "--api_key", 61 | type=str, 62 | help="openai api key" 63 | ) 64 | parser.add_argument( 65 | "--seed", 66 | type=int, 67 | default=42, 68 | help="The number of concurrent process." 69 | ) 70 | parser.add_argument( 71 | "--proc_num", 72 | type=int, 73 | default=2, 74 | help="The number of concurrent process." 75 | ) 76 | 77 | return parser.parse_args() 78 | 79 | 80 | def sample_items(total, n): 81 | """Sample n items from a list of total instructions.""" 82 | samples = [] 83 | if len(total) == 0 or n == 0: 84 | return samples 85 | samples = random.sample(total, min(n, len(total))) 86 | return samples 87 | 88 | 89 | @retry(wait=wait_random_exponential(min=6, max=16), stop=stop_after_attempt(1)) 90 | def query_for_instruction(items, key): 91 | item = items if isinstance(items, dict) else items[0] 92 | user_message_1, user_message_2 = item['user_message_1'], item['user_message_2'] 93 | response = openai.ChatCompletion.create( 94 | api_key = key, 95 | model = args.engine, 96 | messages = [ 97 | {"role": "system", "content": "You are an experienced python developer."}, 98 | {"role": "user", "content": user_message_1}, 99 | {"role": "user", "content": user_message_2}, 100 | ], 101 | temperature = 0.7, 102 | top_p=0.5, 103 | presence_penalty=1.5, 104 | max_tokens = 2048, 105 | ) 106 | output = response['choices'][0]['message']['content'] 107 | pt = re.compile("(?<=\").+(?=\")") 108 | if pt.search(output): 109 | instructions = pt.findall(output) 110 | query_results = [] 111 | for instruction in instructions: 112 | query_result = { 113 | 'instruction': instruction, 114 | } 115 | query_results.append(query_result) 116 | return query_results 117 | else: 118 | instruction = "None" 119 | query_result = { 120 | 'instruction': instruction, 121 | } 122 | return [query_result] 123 | 124 | 125 | def get_rouge_scores(inst, all_instructions): 126 | if inst in all_instructions: 127 | all_instructions.remove(inst) 128 | rouge_scores = [] 129 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 130 | for e_inst in all_instructions: 131 | rouge_scores.append(scorer.score(inst, e_inst)) 132 | rouge_scores = [score["rougeL"].fmeasure for score in rouge_scores] 133 | max_score = max(rouge_scores) 134 | most_similar_instructions = { 135 | all_instructions[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] 136 | } 137 | mean_score = float(np.mean(rouge_scores)) 138 | all_instructions.append(inst) 139 | return max_score, most_similar_instructions, mean_score 140 | 141 | 142 | def encode_conversation(instructions, edit_type=''): 143 | if edit_type: 144 | user_message_1 = INSTRUCTION_PROMPT_WITH_TYPE.format(edit_task=edit_type) 145 | else: 146 | user_message_1 = INSTRUCTION_PROMPT 147 | user_message_2 = "" 148 | for i in range(len(instructions)): 149 | user_message_2 += f"{i+1}. \"{instructions[i]}\"\n" 150 | return user_message_1, user_message_2 151 | 152 | 153 | def load_file(path): 154 | files = [] 155 | if not os.path.exists(path): 156 | return files 157 | if path.endswith('.jsonl'): 158 | files = [json.loads(line) for line in open(path).readlines()] 159 | elif path.endswith('json'): 160 | files = json.loads(open(path).read()) 161 | else: 162 | raise NotImplementedError 163 | return files 164 | 165 | 166 | 167 | if __name__ == "__main__": 168 | args = parse_args() 169 | random.seed(args.seed) 170 | # Seed task loading 171 | seed_tasks = [] 172 | print(args.seed_tasks_path) 173 | for path in args.seed_tasks_path: 174 | seed_tasks.extend(load_file(path)) 175 | seed_instructions = [t["instruction"] for t in seed_tasks] 176 | print("Loaded {} seed tasks.".format(len(seed_instructions))) 177 | os.makedirs(args.save_dir, exist_ok=True) 178 | batch_size = args.proc_num 179 | 180 | # LM-generated task loading 181 | machine_tasks_path = os.path.join(args.save_dir, "machine_generated_instructions.jsonl") 182 | machine_tasks = load_file(machine_tasks_path) 183 | machine_instructions = [row['instruction'] for row in machine_tasks] 184 | print(f"Found {len(machine_instructions)} machine-generated instructions.") 185 | 186 | # Intermediate output path 187 | inter_inst_path = os.path.join(args.save_dir, "inter_instructions.jsonl") 188 | 189 | exist_num = len(machine_tasks) 190 | print(f"Generate {args.num_instructions_to_generate} instructions in total. \ 191 | Found {exist_num} existing instructions. \ 192 | Still need to generate {args.num_instructions_to_generate - exist_num} instructions.") 193 | 194 | progress_bar = tqdm.tqdm(total=args.num_instructions_to_generate) 195 | if machine_instructions: 196 | progress_bar.update(exist_num) 197 | 198 | edit_type_idx = 0 199 | while len(machine_instructions) < args.num_instructions_to_generate: 200 | batch_inputs = [] 201 | for idx in range(min(batch_size, args.num_instructions_to_generate - len(machine_instructions))): 202 | # Select an edit_type 203 | edit_type = None 204 | if args.use_edit_type: 205 | edit_type = EDIT_TYPES[edit_type_idx % len(EDIT_TYPES)] 206 | edit_type_idx += 1 207 | 208 | # Sample batch instructions from the pool 209 | prompt_tasks = [] 210 | if machine_instructions: 211 | prompt_tasks = sample_items(machine_instructions, n=1) 212 | prompt_tasks += sample_items(seed_instructions, args.num_prompt_instructions - len(prompt_tasks)) 213 | random.shuffle(prompt_tasks) 214 | 215 | # Encode the conversation 216 | user_message_1, user_message_2 = encode_conversation(prompt_tasks, edit_type=edit_type) 217 | batch_inputs.append({'user_message_1': user_message_1, 'user_message_2': user_message_2}) 218 | 219 | # Query to generate instruction 220 | runner = ParallelRunner(key=args.api_key, num_workers=args.proc_num, verbose=False) 221 | results = runner.start(data=batch_inputs, 222 | query_func=query_for_instruction, 223 | output_filename=inter_inst_path, 224 | batch=True) 225 | 226 | # Filter and write results 227 | instructions = [result['instruction'] for result in results] 228 | all_instructions = seed_instructions + machine_instructions + instructions 229 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 230 | 231 | for inst in instructions: 232 | with Pool(4) as p: 233 | rouge_scores = p.map(partial(scorer.score, inst), seed_instructions + machine_instructions) 234 | rouge_scores = [score["rougeL"].fmeasure for score in rouge_scores] 235 | if max(rouge_scores) > 0.7: 236 | continue 237 | failure_indicators = [ 238 | 'no new task', 239 | 'no task instruction', 240 | 'no change', 241 | 'no new instruction' 242 | 'no new task', 243 | 'sorry', 244 | ] 245 | if any([phrase in inst.lower() for phrase in failure_indicators]): 246 | continue 247 | 248 | machine_instructions.append(inst) 249 | with open(machine_tasks_path, 'a') as f: 250 | f.write(json.dumps({'instruction': inst}) + '\n') 251 | progress_bar.update(1) 252 | 253 | if len(machine_instructions) >= args.num_instructions_to_generate: 254 | break 255 | 256 | 257 | print(f"Saved {len(machine_instructions)} instructions.") --------------------------------------------------------------------------------