├── .gitignore ├── LICENSE ├── README.md ├── eval ├── README.md ├── eval_bleu.py ├── generate_gpt_codes.py ├── merge_codes.py ├── reindent.py ├── sbatch │ ├── start_slurm_gen.sbatch │ ├── submit_all_jobs.sh │ └── test_all_sols.sbatch ├── test_one_solution.py └── testing_util.py ├── requirements.txt └── train ├── CustomTensorboardCallback.py ├── README.md ├── apps_create_split.py ├── dataset_apps └── APPSBaseDataset.py ├── dataset_lm ├── base_lm_dataset.py ├── reindent.py └── util.py ├── deepspeed_config.json └── tune_apps_gpt.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific ignores 2 | APPS.tar.gz 3 | apps_dataset/* 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dan Hendrycks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Measuring Coding Challenge Competence With APPS 2 | This is the repository for [Measuring Coding Challenge Competence With APPS](https://arxiv.org/pdf/2105.09938) by 3 | [Dan Hendrycks\*](https://danhendrycks.com/), [Steven Basart\*](https://stevenbas.art), [Saurav Kadavath](http://www.sauravkadavath.com), Mantas Mazeika, [Akul Arora](https://github.com/akulaarora), Ethan Guo, [Collin Burns](http://collinpburns.com), Samir Puranik, [Horace He](http://horace.io), [Dawn Song](https://people.eecs.berkeley.edu/~dawnsong/), and [Jacob Steinhardt](https://www.stat.berkeley.edu/~jsteinhardt/). 4 | 5 | Download the [**APPS dataset here**](https://people.eecs.berkeley.edu/~hendrycks/APPS.tar.gz). (~1.3GB) 6 | 7 | This repository contains both training and evaluation code. 8 | 9 | Fine-tuned GPT-2 1.5B and GPT-Neo 2.7B weights are available [here](https://drive.google.com/file/d/1XW1Od9L-5l9zXl1HUCyER5pS9zQTbIvU/view?usp=sharing). 10 | 11 | For other benchmarks of enormous Transformers, see a dataset which tests ability in [competition math](https://github.com/hendrycks/math), a dataset which tests knowledge of [ethics](https://github.com/hendrycks/ethics), and [a dataset spanning 50+ academic subjects](https://github.com/hendrycks/test). 12 | 13 | ## How to Use 14 | 15 | The training instructions are specified in [train/README](train/README.md) and similarly the evaluation instructions are specified in [eval/README](eval/README.md). 16 | 17 | ### Hugging Face 18 | 19 | The dataset is also available in [Hugging Face datasets](https://huggingface.co/datasets/codeparrot/apps) under apps. 20 | 21 | ## Citation 22 | 23 | If you find this useful in your research, please consider citing 24 | 25 | @article{hendrycksapps2021, 26 | title={Measuring Coding Challenge Competence With APPS}, 27 | author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt}, 28 | journal={NeurIPS}, 29 | year={2021} 30 | } 31 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | Note we updated the code so that it pulls the data from hugging face. It makes the usability a bit better. 4 | 5 | ## Single Threaded Generation and evaluation 6 | 7 | See slurm instructions below for how we parallelized the generation and evaluation. 8 | 9 | ### Prerequisites 10 | 11 | pip install -r requirements.txt 12 | 13 | ### First generate the code outputs 14 | 15 | python3 generate_gpt_codes.py --save /path/to/save_dir 16 | 17 | ### Second evaluate the accuracy of the outputted code 18 | 19 | python3 test_one_solution.py --save /path/to/save_dir 20 | # because the above may fail on account of poorly generated python programs 21 | # we suggest to run a for loop for each problem index against the "all_codes.json" 22 | for i in {0..#Num_Problems#} ; do 23 | python3 test_one_solution.py --save /path/to/save_dir -i $i ; 24 | done 25 | 26 | The above will output the accuracy but to run it again once the evaluations have completed execute the line below: 27 | 28 | python3 test_one_solution.py --save /path/to/save_dir --print_results 29 | 30 | ### Third evaluate the bleu scores of the outputted code 31 | 32 | python3 eval_bleu.py --save /path/to/save_dir 33 | 34 | Note: Third step does not depend on the second step. 35 | 36 | ## Parallelized Slurm Evaluation 37 | 38 | Need to modify the path to apps in submit_all_jobs.sh to point to the evaluation folder and any other paths in that file as necessary. Install the requirements.txt file if you haven't already. 39 | 40 | ### Parrallel Generation and evaluation 41 | 42 | cd sbatch 43 | bash submit_all_jobs.sh 44 | 45 | ### Viewing the results 46 | 47 | Once completed we provide a utility function to combine all of the smaller files into one larger file for ease of processing. 48 | 49 | python3 merge_codes.py --root /path/to/slurm/save_files.json 50 | 51 | After the smaller files are combined we can view our accuracy with the following: 52 | 53 | python3 test_one_solution.py --save /path/to/save_dir --print_results 54 | -------------------------------------------------------------------------------- /eval/eval_bleu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Eval outpus via BLEU measure. 3 | """ 4 | 5 | import json 6 | import logging 7 | import math 8 | import numpy as np 9 | import os 10 | import pprint 11 | import random 12 | import sys 13 | import time 14 | 15 | # for timing debugging 16 | from datetime import datetime, date 17 | from tqdm import tqdm 18 | 19 | from typing import List 20 | 21 | #bleu imports 22 | import sacrebleu 23 | from sacremoses import MosesDetokenizer 24 | md = MosesDetokenizer(lang='en') 25 | 26 | random.seed(12345678987654321) 27 | 28 | def calc_bleu(output:List[str], targets:List[List[str]]): 29 | max_bleu = 0 30 | bleu = sacrebleu.corpus_bleu(output, targets) 31 | for item in targets[0]: 32 | tmp_bleu = sacrebleu.corpus_bleu(output, [[item]]) 33 | if tmp_bleu.score > max_bleu: 34 | max_bleu = tmp_bleu.score 35 | return bleu.score, max_bleu 36 | 37 | def eval_and_save_bleu_scores(args): 38 | with open(args.test_loc, "r") as f: 39 | problems = json.load(f) 40 | 41 | gpt_codes = {} 42 | gpt_bleu = {} 43 | codes_loc = os.path.join(args.save, f"all_codes.json") 44 | if not os.path.exists(codes_loc): 45 | codes_loc = os.path.join(args.save, f"{args.start}-{args.end}_codes.json") 46 | 47 | if os.path.exists(codes_loc): 48 | with open(codes_loc, "r") as f: 49 | gpt_codes = json.load(f) 50 | 51 | if args.index: 52 | problems = [problems[args.index]] 53 | else: 54 | if args.start > len(problems) or args.start < 0: 55 | print(f"start index {args.start} > number of problems {len(problems)}") 56 | return 57 | start = args.start 58 | if args.end is None or args.end > len(problems): 59 | end = len(problems) 60 | else: 61 | end = args.end 62 | problems = problems[start:end] 63 | 64 | # main eval loop 65 | for index, problem in enumerate(tqdm(problems)): 66 | prob_path = os.path.join(args.root, problem) 67 | if args.debug: 68 | print(f"problem path = {problem}") 69 | try: 70 | output_strs = gpt_codes[str(index+args.start)] 71 | except: 72 | continue 73 | 74 | if args.debug: 75 | print(output_str) 76 | 77 | with open(os.path.join(prob_path, "solutions.json"), "r") as f: 78 | sols = json.load(f) 79 | 80 | random.shuffle(sols) 81 | if args.debug: 82 | sols = sols[:100] 83 | 84 | tmp = [] 85 | for sol in sols: 86 | tmp.append([sol]) 87 | 88 | sols = tmp 89 | 90 | # this is if we generated multiple outputs per problem 91 | if isinstance(output_strs, list): 92 | gpt_bleu[index+args.start] = [] 93 | for output_str in output_strs: 94 | gpt_bleu[index+args.start].extend(calc_bleu([output_str], sols)) 95 | # one output per problem 96 | else: 97 | output_str = output_strs 98 | gpt_bleu[index+args.start] = calc_bleu([output_str], sols) 99 | 100 | if not os.path.exists(args.save): 101 | os.makedirs(args.save) 102 | 103 | if args.end is None and args.index is None: 104 | bleu_loc = os.path.join(args.save, f"all_bleu_results.json") 105 | elif args.index: 106 | bleu_loc = os.path.join(args.save, f"{args.index}_bleu_results.json") 107 | else: 108 | bleu_loc = os.path.join(args.save, f"{args.start}-{args.end}_bleu_results.json") 109 | 110 | with open(bleu_loc, "w") as f: 111 | json.dump(gpt_bleu, f) 112 | 113 | return gpt_bleu 114 | 115 | def print_results(results): 116 | bleu_scores = [] 117 | max_bleu_scores = [] 118 | for res in results: 119 | bleu_scores.append(results[res][0]) 120 | max_bleu_scores.append(results[res][1]) 121 | 122 | print(f"Average BLEU Score = {np.mean(bleu_scores)}") 123 | print(f"Average of Max BLEU Score = {np.mean(max_bleu_scores)}") 124 | 125 | 126 | def main(args): 127 | 128 | argsdict = vars(args) 129 | print(pprint.pformat(argsdict)) 130 | 131 | if args.print_results: 132 | bleu_loc = os.path.join(args.save, f"all_bleu_results.json") 133 | if os.path.exists(bleu_loc): 134 | with open(bleu_loc, "r") as f: 135 | results = json.load(f) 136 | else: 137 | print(f"Error file does not exist in this path {bleu_loc}. Exiting.") 138 | return 139 | else: 140 | results = eval_and_save_bleu_scores(args) 141 | 142 | print_results(results) 143 | 144 | 145 | if __name__ == "__main__": 146 | import argparse 147 | 148 | parser = argparse.ArgumentParser(description="BLEU Evaluation") 149 | parser.add_argument("-t","--test_loc", default="../data_split/test.json", type=str) 150 | parser.add_argument("-r","--root", default="../", type=str, help="where the data is stored.") 151 | parser.add_argument("-s","--start", default=0, type=int) 152 | parser.add_argument("-e","--end", default=None, type=int) 153 | parser.add_argument("-i", "--index", default=None, type=int) 154 | parser.add_argument("-d", "--debug", action="store_true") 155 | parser.add_argument("-p", "--print_results", action="store_true", help="If you have already evaluated the results and only want to print them.") 156 | parser.add_argument("--save", type=str, default="./results") 157 | 158 | args = parser.parse_args() 159 | 160 | main(args) 161 | -------------------------------------------------------------------------------- /eval/generate_gpt_codes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run a tranined model to generate Python code. 3 | """ 4 | 5 | import io 6 | import json 7 | import logging 8 | import math 9 | import random 10 | import numpy as np 11 | import os 12 | import pprint 13 | import sys 14 | import time 15 | import transformers 16 | import torch 17 | 18 | from datasets import load_dataset 19 | from reindent import run as run_reindent 20 | from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM 21 | 22 | # for timing and debugging 23 | from datetime import datetime, date 24 | from tqdm import tqdm 25 | 26 | 27 | def reindent_code(codestr): 28 | """ 29 | Given code string, reindent it in the same way that the 30 | Github dataset was indented 31 | """ 32 | codestr = io.StringIO(codestr) 33 | ret = io.StringIO() 34 | 35 | run_reindent( 36 | codestr, 37 | ret, 38 | config = { 39 | "dry-run": False, 40 | "help": False, 41 | "to": 10, 42 | "from": -1, 43 | "tabs": True, 44 | "encoding": "utf-8", 45 | "is-tabs": False, 46 | "tabsize": 10, 47 | "all-tabs": False 48 | } 49 | ) 50 | 51 | return ret.getvalue() 52 | 53 | def generate_prompt(args, test_case, prompt, solutions, tokenizer, starter_code=None): 54 | _input = "\nQUESTION:\n" 55 | data = prompt 56 | _input += data 57 | if starter_code != None: 58 | data = starter_code 59 | data = "\n" + data #+ "\n" 60 | _input += data 61 | else: 62 | #_input += "\n\n" 63 | pass 64 | 65 | data = test_case 66 | if not data.get("fn_name"): 67 | _input += "\nUse Standard Input format"#\n" 68 | else: 69 | _input += "\nUse Call-Based format"#\n" 70 | 71 | _input += "\nANSWER:\n" 72 | 73 | if args.peeking > 0.0: 74 | # Need to do some peeking. 75 | 76 | # Read one example solution 77 | sols = solutions 78 | 79 | # Choose the shortest solution for the model to use. 80 | # This is so we can conserve tokens (1024 max) 81 | # sample_sol = min(sols, key=len) 82 | 83 | # # Add args.peeking% of that solution to the prompt 84 | # sample_sol_token_ids = tokenizer.encode(sample_sol, verbose=False) 85 | # num_to_keep = int(len(sample_sol_token_ids) * args.peeking) 86 | # sample_sol_token_ids = sample_sol_token_ids[:num_to_keep] 87 | # _input += tokenizer.decode(sample_sol_token_ids) 88 | 89 | # Alternatively take a random solution 90 | sample_sol = random.choice(sols) 91 | rand_sol = reindent_code(sample_sol) 92 | rand_sol = tokenizer.encode(rand_sol, verbose=False) 93 | tokens_taken = int(args.peek_frac * len(rand_sol)) 94 | rand_sol = rand_sol[:tokens_taken] 95 | _input += tokenizer.decode(rand_sol) 96 | else: 97 | sample_sol = None 98 | 99 | return _input, sample_sol 100 | 101 | 102 | def main(args): 103 | 104 | argsdict = vars(args) 105 | print(pprint.pformat(argsdict)) 106 | 107 | problems = load_dataset("codeparrot/apps", split=f"{args.split}") 108 | 109 | gpt_codes = {} 110 | if not os.path.exists(args.save): 111 | os.makedirs(args.save, exist_ok=True) 112 | if not args.end: 113 | codes_loc = os.path.join(args.save, f"all_codes.json") 114 | else: 115 | codes_loc = os.path.join(args.save, f"{args.start}-{args.end}_codes.json") 116 | 117 | # Only do the problems that are specified. 118 | if args.index: 119 | problems = load_dataset("codeparrot/apps", split=f"{args.split}[{args.index}]") 120 | else: 121 | if args.start > len(problems) or args.start < 0: 122 | print(f"start index {args.start} > number of problems {len(problems)}") 123 | return 124 | start = args.start 125 | if args.end is None or args.end > len(problems): 126 | end = len(problems) 127 | else: 128 | end = args.end 129 | problems = load_dataset("codeparrot/apps", split=f"{args.split}[{start}:{end}]") 130 | 131 | 132 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 133 | 134 | if args.load: 135 | # Set up model 136 | # Tokenizer 137 | tokenizer = transformers.GPT2Tokenizer.from_pretrained(args.arch) 138 | print("Loading model...") 139 | model = transformers.GPT2LMHeadModel.from_pretrained(args.load) 140 | model.to(device) 141 | print(f"Loaded {args.load}.") 142 | else: 143 | tokenizer = AutoTokenizer.from_pretrained(args.arch) 144 | model = AutoModelForCausalLM.from_pretrained(args.arch, device_map="auto").eval() 145 | 146 | # main eval loop 147 | for index, problem in enumerate(tqdm(problems)): 148 | problem["solutions"] = json.loads(problem["solutions"]) 149 | problem["input_output"] = json.loads(problem["input_output"]) 150 | test_case = problem["input_output"] 151 | prompt = problem["question"] 152 | starter_code = problem["starter_code"] 153 | solutions = problem["solutions"] 154 | if not starter_code: 155 | starter_code = None 156 | 157 | # Read the question in 158 | prompt_text, sample_sol = generate_prompt(args, test_case, prompt, solutions, tokenizer, starter_code) 159 | if args.debug: 160 | print("PROMPT_TEXT:") 161 | print(prompt_text) 162 | 163 | # Feed this into the model. 164 | start = time.time() 165 | try: 166 | with torch.no_grad(): 167 | input_ids = torch.LongTensor(tokenizer.encode(prompt_text, verbose=False)).unsqueeze(0).to(device) 168 | output_ids = model.generate( 169 | input_ids, 170 | num_beams=args.num_beams, 171 | early_stopping=True, 172 | max_length=1024 - len(input_ids) 173 | ) 174 | output_str = tokenizer.decode(output_ids[0]) 175 | except Exception as e: 176 | if isinstance(e, UnboundLocalError) and str(e) == "local variable 'next_tokens' referenced before assignment": 177 | # See https://github.com/huggingface/transformers/issues/5118 178 | if args.debug: 179 | print("Problem text was > 1024 tokens, so cannot do generation") 180 | print(e) 181 | else: 182 | print("Unexpected exception in generating solution") 183 | print(e) 184 | # Default to empty string on errors 185 | output_str = "" 186 | end = time.time() 187 | 188 | if args.peeking == 1.0: 189 | output_str = sample_sol 190 | elif len(output_str): 191 | output_str = output_str.split("ANSWER:\n")[1].replace("<|endoftext|>", "") 192 | 193 | # Save the generated sol 194 | gpt_codes[index+args.start] = output_str 195 | 196 | if args.debug: 197 | print(f"Generation time: {end - start}") 198 | print(f"Generated output string:") 199 | print(output_str) 200 | print("------------------------------------------------------------") 201 | 202 | with open(codes_loc, "w") as f: 203 | json.dump(gpt_codes, f) 204 | 205 | 206 | if __name__ == "__main__": 207 | import argparse 208 | 209 | parser = argparse.ArgumentParser(description="Run a tranined model to generate Python code.") 210 | parser.add_argument("--arch", default="gpt2") 211 | parser.add_argument("-t","--test_loc", default="~/apps/data_split/test.json", type=str, help="path to the test folder.") 212 | parser.add_argument("-r","--root", default="../", type=str, help="where the data is stored.") 213 | parser.add_argument("-l","--load", default="", type=str) 214 | parser.add_argument("--peeking", default=0.0, type=float) 215 | parser.add_argument("--num-beams", default=5, type=int) 216 | parser.add_argument("-s","--start", default=0, type=int) 217 | parser.add_argument("-e","--end", default=None, type=int) 218 | parser.add_argument("-i", "--index", default=None, type=int) 219 | parser.add_argument("-d", "--debug", action="store_true") 220 | parser.add_argument("--split", type=str, default="test", help="What split to use.") 221 | parser.add_argument("--save", type=str, default="./results") 222 | 223 | args = parser.parse_args() 224 | 225 | main(args) 226 | -------------------------------------------------------------------------------- /eval/merge_codes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def combine_codes(args): 7 | result_files = os.listdir(args.root) 8 | tmp_codes = {} 9 | 10 | # load the results and combine them 11 | for r_file in result_files: 12 | path = os.path.join(args.root, r_file) 13 | if args.debug: 14 | print(path) 15 | elif "bleu" in path: 16 | continue 17 | elif "results.json" in path: 18 | continue 19 | elif "codes" in path and args.save not in path: 20 | with open(path, "r") as f: 21 | results = json.load(f) 22 | for res in results: 23 | tmp_codes[res] = results[res] 24 | continue 25 | with open(os.path.join(args.root, args.save), 'w') as f: 26 | json.dump(tmp_codes, f) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--debug", help="print debugging statements", 32 | action="store_true") 33 | parser.add_argument("--root", default="./results", type=str, help="which folder to merge the results") 34 | parser.add_argument("-s","--save", default="all_codes.json", type=str, help="Large final save file name. Note other files use the default value.") 35 | args = parser.parse_args() 36 | 37 | combine_codes(args) 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /eval/reindent.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Reindent files. 4 | """ 5 | 6 | from __future__ import print_function 7 | import sys 8 | import getopt 9 | import codecs 10 | import tempfile 11 | import shutil 12 | import os 13 | 14 | 15 | def _find_indentation(line, config): 16 | if len(line) and line[0] in (" ", "\t") and not line.isspace(): 17 | if line[0] == "\t": 18 | config['is-tabs'] = True 19 | # Find indentation 20 | i = 0 21 | for char in list(line): 22 | if char not in (" ", "\t"): 23 | break 24 | i += 1 25 | config["from"] = i 26 | 27 | 28 | def find_indentation(line, config): 29 | # Find indentation level used in file 30 | if config['from'] < 0: 31 | _find_indentation(line, config) 32 | 33 | if config['from'] >= 0: 34 | # Set old indent 35 | indent = " " if not config['is-tabs'] else "\t" 36 | indent = indent * config['from'] 37 | 38 | # Set new indent 39 | newindent = " " if not config['tabs'] else "\t" 40 | if not config['tabs']: 41 | newindent = newindent * config['to'] 42 | 43 | return indent, newindent 44 | 45 | # Continue to the next line, indentation not found 46 | return False 47 | 48 | 49 | def replace_inline_tabs(content, config): 50 | newcontent = "" 51 | imagined_i = 0 52 | for i in range(0, len(content)): 53 | char = content[i] 54 | if char == '\t': 55 | spaces = config['tabsize']-(imagined_i % config['tabsize']) 56 | newcontent += " " * spaces 57 | imagined_i += spaces 58 | else: 59 | newcontent += char 60 | imagined_i += 1 61 | return newcontent 62 | 63 | 64 | def run(fd_in, fd_out, config): 65 | while True: 66 | line = fd_in.readline() 67 | if not line: 68 | break 69 | line = line.rstrip('\r\n') 70 | 71 | # Find indentation style used in file if not set 72 | if config['from'] < 0: 73 | indent = find_indentation(line, config) 74 | if not indent: 75 | print(line, file=fd_out) 76 | continue 77 | indent, newindent = indent 78 | 79 | # Find current indentation level 80 | level = 0 81 | while True: 82 | whitespace = line[:len(indent) * (level + 1)] 83 | if whitespace == indent * (level + 1): 84 | level += 1 85 | else: 86 | break 87 | 88 | content = line[len(indent) * level:] 89 | if config['all-tabs']: 90 | content = replace_inline_tabs(content, config) 91 | 92 | line = (newindent * level) + content 93 | print(line, file=fd_out) 94 | 95 | 96 | def run_files(filenames, config): 97 | for filename in filenames: 98 | with codecs.open(filename, encoding=config['encoding']) as fd_in: 99 | if config['dry-run']: 100 | print("Filename: %s" % filename) 101 | fd_out = sys.stdout 102 | else: 103 | fd_out = tempfile.NamedTemporaryFile(mode='wb', delete=False) 104 | fd_out.close() 105 | fd_out = codecs.open(fd_out.name, "wb", encoding=config['encoding']) 106 | 107 | run(fd_in, fd_out, config) 108 | 109 | if not config["dry-run"]: 110 | fd_out.close() 111 | shutil.copy(fd_out.name, filename) 112 | os.remove(fd_out.name) 113 | 114 | 115 | def main(args): 116 | config = { 117 | "dry-run": False, 118 | "help": False, 119 | "to": 4, 120 | "from": -1, 121 | "tabs": False, 122 | "encoding": "utf-8", 123 | "is-tabs": False, 124 | "tabsize": 4, 125 | "all-tabs": False 126 | } 127 | possible_args = { 128 | "d": "dry-run", 129 | "h": "help", 130 | "t:": "to=", 131 | "f:": "from=", 132 | "n": "tabs", 133 | "e:": "encoding=", 134 | "s:": "tabsize=", 135 | "a": "all-tabs", 136 | } 137 | optlist, filenames = getopt.getopt( 138 | args[1:], 139 | "".join(possible_args.keys()), 140 | possible_args.values() 141 | ) 142 | 143 | shortargs, longargs = [], [] 144 | for shortarg in possible_args: 145 | shortargs.append(shortarg.rstrip(":")) 146 | longargs.append(possible_args[shortarg].rstrip("=")) 147 | 148 | for opt, val in optlist: 149 | opt = opt.lstrip("-") 150 | if opt in shortargs: 151 | opt = longargs[shortargs.index(opt)] 152 | if isinstance(config[opt], bool): 153 | config[opt] = True 154 | elif isinstance(config[opt], int): 155 | config[opt] = int(val) 156 | else: 157 | config[opt] = val 158 | 159 | if config['help']: 160 | help = """ 161 | Usage: %s [options] filename(s) 162 | Options: 163 | -h, --help Show this message 164 | -d, --dry-run Don't save anything, just print 165 | the result 166 | -t , --to Convert to this number of spaces 167 | (default: 4) 168 | -f , --from Convert from this number of spaces 169 | (default: auto-detect, will also 170 | detect tabs) 171 | -n, --tabs Don't convert indentation to spaces, 172 | convert to tabs instead. -t and 173 | --to will have no effect. 174 | -a, --all-tabs Also convert tabs used for alignment 175 | in the code (Warning: will replace 176 | all tabs in the file, even if inside 177 | a string) 178 | -s , --tabsize Set how many spaces one tab is 179 | (only has an effect on -a, default: 4) 180 | -e , --encoding Open files with specified encoding 181 | (default: utf-8) 182 | """ % args[0] 183 | 184 | # Also removes 8 leading spaces to remove our indentation 185 | print("\n".join([x[8:] for x in help[1:].split("\n")])) 186 | sys.exit(0) 187 | 188 | if filenames: 189 | run_files(filenames, config) 190 | else: 191 | run(sys.stdin, sys.stdout, config) 192 | 193 | if __name__ == "__main__": 194 | main(sys.argv) 195 | -------------------------------------------------------------------------------- /eval/sbatch/start_slurm_gen.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu_jsteinhardt 3 | #SBATCH -w balrog-gpu 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --output=R-%j.out 6 | 7 | echo $1 $2 $3 $4 $5 $6 8 | pushd . 9 | 10 | cd $1 11 | echo "python generate_gpt_codes.py -s $2 -e $3 --save $4 --load $5 --test_loc $6" 12 | python generate_gpt_codes.py -s $2 -e $3 --save $4 --load $5 --test_loc $6 --debug 13 | 14 | popd 15 | mkdir slurm-output 16 | mv R-${SLURM_JOB_ID}.out slurm-output 17 | -------------------------------------------------------------------------------- /eval/sbatch/submit_all_jobs.sh: -------------------------------------------------------------------------------- 1 | APPS_EVAL_DIR="~/apps-beta/min_eval/" 2 | SKIP_AMT=20 3 | SAVE_LOC="~/apps-beta/min_eval/results" 4 | MODEL_LOC='~/apps-beta/modelling/checkpoints/final_checkpoint/' 5 | TEST_LOC="~/apps-beta/data_split/test.json" 6 | TOTAL_PROBLEMS=10640 7 | 8 | for (( i=0; i <= $TOTAL_PROBLEMS ; i+=$SKIP_AMT)) ; 9 | do 10 | echo "$frac $i" 11 | jid1=$(sbatch --parsable start_slurm_gen.sbatch $APPS_EVAL_DIR $i $(($i+$SKIP_AMT)) $SAVE_LOC $MODEL_LOC $TEST_LOC ) 12 | jid2=$(sbatch --dependency=afterany:$jid1 test_all_sols.sbatch $APPS_EVAL_DIR $i $(($i+$SKIP_AMT)) $SAVE_LOC $TEST_LOC ) 13 | done 14 | 15 | -------------------------------------------------------------------------------- /eval/sbatch/test_all_sols.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p high_pre 3 | #SBATCH -w balrog-cpu 4 | #SBATCH -c 2 5 | #SBATCH --output=R-%j.out 6 | 7 | echo $1 $2 $3 $4 $5 8 | pushd . 9 | 10 | cd $1 11 | echo "python eval_bleu.py -s $2 -e $3 --save $4 --test_loc $5" 12 | python eval_bleu.py -s $2 -e $3 --save $4 --test_loc $5 #--debug 13 | echo "python test_one_solution.py -s $2 -e $3 --save $4 --test_loc $5" 14 | python test_one_solution.py -s $2 -e $3 --save $4 --test_loc $5 #--debug 15 | 16 | popd 17 | mkdir slurm-output 18 | mv R-${SLURM_JOB_ID}.out slurm-output 19 | 20 | -------------------------------------------------------------------------------- /eval/test_one_solution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run solutions from one problem. 3 | """ 4 | import argparse 5 | import json 6 | import numpy as np 7 | import os 8 | import pprint 9 | import multiprocessing 10 | import time 11 | import testing_util as test_util 12 | 13 | # for timing debugging 14 | from datetime import datetime, date 15 | from tqdm import tqdm 16 | 17 | from datasets import load_dataset 18 | from types import SimpleNamespace 19 | from typing import Dict 20 | 21 | 22 | EXAMPLE_RESULTS = {"0": [[-2]],"1": [[False,False,False]],"2": [[True,True]],"3": [[False,True,False,True,False,False,False,True,False,True,False,True,True,True,False,True]],"4": [[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1]]} 23 | EXAMPLE_ARGS = SimpleNamespace(debug=True) 24 | TIMEOUT = 10 25 | 26 | def print_results(results: Dict, args:argparse.Namespace=None): 27 | """ 28 | Given the results evaluated against the testcases we output some statistics. 29 | 30 | >>> print_results(EXAMPLE_RESULTS, EXAMPLE_ARGS) 31 | number of compile errors = 1 avg = 0.2 32 | number of runtime errors = 1 avg = 0.2 33 | number of test cases run = 5 34 | Test Case Average (average accuracy over problems) = 0.3 35 | Strict Accuracy (all test cases passed / total problems) = 0.2 36 | """ 37 | res = [] 38 | per_prob_res = [] 39 | all_correct = [] 40 | for index in results: 41 | problem_results = np.asarray(results[index]) 42 | res.extend(problem_results) 43 | per_prob_res.append(np.mean(problem_results > 0)) 44 | all_correct.append(np.all(problem_results > 0)) 45 | 46 | # We count both compile errors and runtime errors for multiple tests as one error. 47 | compile_errors = len([e for e in res if -2 in e]) 48 | runtime_errors = len([e for e in res if -1 in e]) 49 | total_testcases = len(res) 50 | if args and args.debug: 51 | print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases }") 52 | print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}") 53 | print(f"number of test cases run = {total_testcases}") 54 | 55 | print(f"Test Case Average (average accuracy over problems) = {np.mean(per_prob_res)}") 56 | print(f"Strict Accuracy (all test cases passed / total problems) = {np.mean(all_correct)}") 57 | 58 | # Dummy `test_util.run_test` function for debugging multiprocessing. 59 | def run_test(problem, test, debug): 60 | time.sleep(1) # Simulate some work 61 | return [1] # Dummy test result 62 | 63 | def check_correctness(problem, generation, timeout, debug): 64 | """Check correctness of code generation with a global timeout. 65 | The global timeout is to catch some extreme/rare cases not handled by the timeouts 66 | inside `run_test`""" 67 | def _temp_run(problem, generation, debug, result): 68 | try: 69 | if debug: 70 | print(f"Running test for problem: {problem}") 71 | result.append(test_util.run_test(problem=problem, test=generation, debug=debug)) 72 | # Useful for debugging the multiprocessing. 73 | # result.append(run_test(problem=problem, test=generation, debug=debug)) 74 | if debug: 75 | print(f"Test completed with result: {result}") 76 | except Exception as e: 77 | if debug: 78 | print(f"Error in _temp_run: {e}") 79 | 80 | manager = multiprocessing.Manager() 81 | result = manager.list() 82 | p = multiprocessing.Process(target=_temp_run, args=(problem, generation, debug, result)) 83 | p.start() 84 | p.join(timeout=timeout + 1) 85 | if p.is_alive(): 86 | if debug: 87 | print(f"Process is still alive. Killing the process.") 88 | p.kill() 89 | if not result: 90 | # Remark: ideally we would consider that all tests failed but we can't access number of tests here easily 91 | # so we use 21=the average number of tests for a smaple in the test split instead 92 | avg_number_tests = 21 93 | result = [[-1] * avg_number_tests] 94 | if debug: 95 | print(f"Global timeout occurred, returning default result.") 96 | if debug: 97 | print(f"Final result: {result}") 98 | return result[0] 99 | 100 | 101 | def eval_and_save_problems(args): 102 | problems = load_dataset("codeparrot/apps", split=f"{args.split}") 103 | 104 | codes = {} 105 | gpt_bleu = {} 106 | gpt_codebleu = {} 107 | results = {} 108 | codes_loc = os.path.join(args.save, f"all_codes.json") 109 | if not os.path.exists(codes_loc): 110 | codes_loc = os.path.join(args.save, f"{args.start}-{args.end}_codes.json") 111 | 112 | if os.path.exists(codes_loc): 113 | results_loc = os.path.join(args.save, f"all_results.json") 114 | else: 115 | results_loc = os.path.join(args.save, f"{args.start}-{args.end}_results.json") 116 | # print(codes_loc, results_loc) 117 | 118 | with open(codes_loc, "r") as f: 119 | codes = json.load(f) 120 | 121 | # Only do the problems that are specified. 122 | if args.index: 123 | problems = load_dataset("codeparrot/apps", split=f"{args.split}[{args.index}]") 124 | else: 125 | if args.start > len(problems) or args.start < 0: 126 | print(f"start index {args.start} > number of problems {len(problems)}") 127 | return 128 | start = args.start 129 | if args.end is None or args.end > len(problems): 130 | end = len(problems) 131 | else: 132 | end = args.end 133 | problems = load_dataset("codeparrot/apps", split=f"{args.split}[{start}:{end}]") 134 | 135 | if args.stop_early: 136 | problems = load_dataset("codeparrot/apps", split=f"{args.split}[{start}:{args.stop_early}]") 137 | 138 | # main eval loop 139 | for index, problem in enumerate(tqdm(problems)): 140 | try: 141 | if isinstance(codes, dict): 142 | output_strings = codes[str(index+args.start)] 143 | else: 144 | output_strings = codes[index+args.start] 145 | except: 146 | # print("CANNOT FIND OUTPUT_STR FOR", problem) 147 | continue 148 | 149 | problem["solutions"] = json.loads(problem["solutions"]) 150 | problem["input_output"] = json.loads(problem["input_output"]) 151 | sols = problem["solutions"] 152 | 153 | if not os.path.exists(args.save): 154 | os.makedirs(args.save) 155 | 156 | res = [] 157 | if isinstance(output_strings, str): 158 | output_strings = [output_strings] 159 | for generation_idx, generation in enumerate(output_strings): 160 | if args.debug: 161 | print(f"\nTesting solution {generation_idx}, {generation=}") 162 | curr_res = [-2] 163 | try: 164 | curr_res = check_correctness(problem, generation=generation, timeout=TIMEOUT, debug=args.debug) 165 | fixed = [] 166 | for e in curr_res: 167 | if isinstance(e, np.ndarray): 168 | e = e.item(0) 169 | if isinstance(e, np.bool_): 170 | e = bool(e) 171 | fixed.append(e) 172 | curr_res = fixed 173 | if not np.all(curr_res): 174 | print(f"Results were not all True: {curr_res}") 175 | except Exception as e: 176 | print(f"test framework exception = {repr(e)}{e}\n") 177 | break 178 | finally: 179 | assert isinstance(curr_res, list) 180 | res.append(curr_res) 181 | 182 | if args.debug: 183 | print(f"\nHow to read results [-2] = compile error, [-1] = runtime error, [False] = failed test case, [True] = passed test case") 184 | #print(f"results = {res}") 185 | 186 | results[index+args.start+args.index] = res 187 | 188 | with open(results_loc, "w") as f: 189 | try: 190 | f.write(json.dumps(results)) 191 | except Exception as e: 192 | import pdb; pdb.set_trace() 193 | print("didn't save problem due to {e}") 194 | 195 | return results 196 | 197 | 198 | def main(args): 199 | 200 | argsdict = vars(args) 201 | print(pprint.pformat(argsdict)) 202 | 203 | if args.print_results: 204 | results = {} 205 | results_loc = os.path.join(args.save, f"all_results.json") 206 | if os.path.exists(results_loc): 207 | results_loc = os.path.join(args.save, f"all_results.json") 208 | elif os.path.exists(f"{args.start}-{args.end}_results.json"): 209 | results_loc = os.path.join(args.save, f"{args.start}-{args.end}_results.json") 210 | else: 211 | print("No results to print exiting.") 212 | exit() 213 | 214 | with open(results_loc, "r") as f: 215 | results = json.load(f) 216 | print_results(results, args) 217 | exit() 218 | 219 | if not args.skip_evals: 220 | results = eval_and_save_problems(args) 221 | 222 | print_results(results, args) 223 | 224 | 225 | if __name__ == "__main__": 226 | # import doctest 227 | # doctest.testmod() 228 | 229 | parser = argparse.ArgumentParser(description="Testing a Language Model on Python Code") 230 | parser.add_argument("-t","--test_loc", default="../data_split/test.json", type=str, help="path to the json containing problem paths to be evaluated.") 231 | parser.add_argument("-r","--root", default="../", type=str, help="where the data is stored.") 232 | parser.add_argument("-s","--start", default=0, type=int) 233 | parser.add_argument("-e","--end", default=None, type=int, help="If you want to evaluate a subset of problems specify start and ending index. File with start and ending prefix must exist typically used with batch evaluation.") 234 | parser.add_argument("-i", "--index", default=0, type=int) 235 | parser.add_argument("-p", "--print_results", action="store_true", help="If you have already evaluated the results and only want to print them.") 236 | parser.add_argument("--skip_evals", action="store_true", help="If you want to skip the evals similar to print results.") 237 | parser.add_argument("-d", "--debug", action="store_true") 238 | parser.add_argument("--save", type=str, default="./results", help="Where the evaluated data is loaded from and results saved to.") 239 | parser.add_argument("--split", type=str, default="test", help="What split to use.") 240 | parser.add_argument("--stop-early", default=None, type=int) 241 | 242 | args = parser.parse_args() 243 | 244 | main(args) 245 | -------------------------------------------------------------------------------- /eval/testing_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import io 6 | import faulthandler 7 | import platform 8 | 9 | # used for debugging to time steps 10 | from datetime import datetime 11 | 12 | # to run the solution files we're using a timing based approach 13 | import signal 14 | 15 | import numpy as np 16 | # for capturing the stdout 17 | from io import StringIO 18 | from typing import get_type_hints 19 | from typing import List, Tuple 20 | # used for testing the code that reads from input 21 | from unittest.mock import patch, mock_open 22 | 23 | from pyext import RuntimeModule 24 | 25 | from enum import Enum 26 | class CODE_TYPE(Enum): 27 | call_based = 0 28 | standard_input = 1 29 | 30 | # stuff for setting up signal timer 31 | class TimeoutException(Exception): 32 | pass 33 | def timeout_handler(signum, frame): 34 | print("alarm went off") 35 | #return 36 | raise TimeoutException 37 | signal.signal(signal.SIGALRM, timeout_handler) 38 | timeout = 4 # seconds 39 | 40 | # used to capture stdout as a list 41 | # from https://stackoverflow.com/a/16571630/6416660 42 | # alternative use redirect_stdout() from contextlib 43 | class Capturing(list): 44 | def __enter__(self): 45 | self._stdout = sys.stdout 46 | sys.stdout = self._stringio = StringIO() 47 | # Make closing the StringIO a no-op 48 | self._stringio.close = lambda x: 1 49 | return self 50 | def __exit__(self, *args): 51 | self.extend(self._stringio.getvalue().splitlines()) 52 | del self._stringio # free up some memory 53 | sys.stdout = self._stdout 54 | 55 | 56 | def parse_args(): 57 | parser = argparse.ArgumentParser(description="Utility for testing code generation.") 58 | parser.add_argument("-v", "--verbosity-level", action="store", type=int, 59 | help="") 60 | parser.add_argument("-s", "--source", type=str, default="leetcode", 61 | choices=["leetcode", "atcoder", "codewars",], 62 | help="which data source to gather from.") 63 | parser.add_argument("-d", "--data", type=str, default="question", 64 | choices=["question", "q", "solutions", "sol", "s", "starter", "tests", "t"], 65 | help="which type of data to receive.") 66 | parser.add_argument("-n", "--number", type=int, default=0, 67 | help="which problem to query.") 68 | 69 | args = parser.parse_args() 70 | return args 71 | 72 | 73 | def get_valid_problems(data_dir="leetcode"): 74 | # these are unnecessary atm 75 | if data_dir == "leetcode": 76 | root = os.path.join(args.source, "data") 77 | elif data_dir == "atcoder": 78 | pass 79 | 80 | root = os.path.join(data_dir, "data") 81 | if os.path.exists(os.path.join(data_dir, "valid_problems.json")): 82 | with open(os.path.join(data_dir, "valid_problems.json"), "r") as f: 83 | return json.load(f) 84 | 85 | # after we compute it once let's save it and load that instead 86 | # TODO determine if might be better to reload each time 87 | tmp = os.listdir(root) 88 | valid_probs = [] 89 | for folder in tmp: 90 | prob_path = os.path.join(root, folder) 91 | files = os.listdir(prob_path) 92 | #TODO add more validity checks 93 | if "input_output.json" in files or "sols.json" in files: 94 | valid_probs.append(prob_path) 95 | valid_probs = sorted(valid_probs) 96 | #with open(os.path.join(args.source,"valid_problems.json"), "w") as f: 97 | # json.dump(valid_probs, f) 98 | return valid_probs 99 | 100 | 101 | def get_question(problem_list, prob_index): 102 | root = problem_list[prob_index] 103 | #print("get q", root) 104 | if os.path.exists(os.path.join(root, "question.txt")): 105 | with open(os.path.join(root, "question.txt")) as f: 106 | question = f.readlines() 107 | else: 108 | print("question prompt not found") 109 | question = "" 110 | question = "".join(question) 111 | return question 112 | 113 | 114 | def get_solutions(problem_list, prob_index): 115 | root = problem_list[prob_index] 116 | if os.path.exists(os.path.join(root, "solutions.json")): 117 | with open(os.path.join(root, "solutions.json")) as f: 118 | sols = json.load(f) 119 | return sols 120 | 121 | 122 | def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, 123 | test:str=None, debug:bool=False): 124 | """ 125 | if test is not None it'll try to run the code. 126 | otherwise it'll just return an input and output pair. 127 | """ 128 | 129 | if debug: 130 | print(f"start = {datetime.now().time()}") 131 | 132 | if problem_list is not None: 133 | root = problem_list[prob_index] 134 | 135 | 136 | in_outs = problem["input_output"] 137 | if debug: 138 | print(f"test cases json = {in_outs['inputs']} {in_outs['outputs']}") 139 | 140 | if in_outs.get("fn_name") is None: 141 | which_type = CODE_TYPE.standard_input # Standard input 142 | method_name = None 143 | else: 144 | which_type = CODE_TYPE.call_based # Call-based 145 | method_name = in_outs["fn_name"] 146 | if debug: 147 | print(f"loaded json = {datetime.now().time()}") 148 | 149 | if test is None: 150 | return in_outs 151 | elif test is not None: 152 | # Disable functionalities that can make destructive changes to the test. 153 | reliability_guard() 154 | 155 | results = [] 156 | sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" 157 | if debug: 158 | print(f"loading test code = {datetime.now().time()}") 159 | 160 | if which_type == CODE_TYPE.call_based: 161 | sol += test 162 | if debug: # or True: 163 | print(f"sol = {sol}") 164 | signal.alarm(timeout) 165 | try: 166 | tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) 167 | if "class Solution" not in test: 168 | tmp = tmp_sol 169 | else: 170 | tmp = tmp_sol.Solution() 171 | signal.alarm(0) 172 | except Exception as e: 173 | signal.alarm(0) 174 | print(f"type 0 compilation error = {e}") 175 | results.append(-2) 176 | return results 177 | signal.alarm(0) 178 | 179 | elif which_type == CODE_TYPE.standard_input: 180 | # sol 181 | tmp_test = test.split("\n") 182 | 183 | new_test = [] 184 | for x in tmp_test: 185 | if (not x.startswith("from ")) and (not x.startswith("import ")): 186 | new_test.append("\t" + x + "\n") 187 | else: 188 | new_test.append(x + "\n") 189 | tmp_test = new_test 190 | 191 | new_test = "" 192 | started = False 193 | for i in tmp_test: 194 | if i.startswith("\t") and not started: 195 | new_test += "stdin = sys.stdin\nstdout = sys.stdout\n" 196 | new_test += "def code():\n" 197 | new_test += i 198 | started = True 199 | elif started and ((i.startswith("from ")) or (i.startswith("import "))): 200 | new_test += "\t" + i 201 | else: 202 | new_test += i 203 | tmp_test = new_test 204 | 205 | sol += tmp_test 206 | if debug: 207 | print(f"sol = {sol}") 208 | # print(f"{o}") 209 | method_name = "code" 210 | signal.alarm(timeout) 211 | try: 212 | tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) 213 | tmp = tmp_sol 214 | signal.alarm(0) 215 | except Exception as e: 216 | signal.alarm(0) 217 | print(f"type 1 compilation error = {e}") 218 | results.append(-2) 219 | return results 220 | signal.alarm(0) 221 | if debug: 222 | print(f"get method = {datetime.now().time()}") 223 | 224 | try: 225 | method = getattr(tmp, method_name) # get_attr second arg must be str 226 | except: 227 | signal.alarm(0) 228 | e = sys.exc_info() 229 | print(f"unable to get function error = {e}") 230 | results.append(-2) 231 | return results 232 | 233 | for index, inputs in enumerate(in_outs["inputs"]): 234 | # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) 235 | try: 236 | if isinstance(inputs[0], dict): 237 | inputs = [{int(k): v for k,v in inputs[0].items()}] 238 | except: 239 | True 240 | try: 241 | if isinstance(in_outs["outputs"][index], dict): 242 | in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}] 243 | except: 244 | True 245 | try: 246 | if isinstance(in_outs["outputs"][index][0], dict): 247 | in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}] 248 | except: 249 | True 250 | 251 | if debug: 252 | print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}") 253 | if which_type == CODE_TYPE.call_based: # Call-based 254 | signal.alarm(timeout) 255 | faulthandler.enable() 256 | try: 257 | # print("------------") 258 | # print(inputs) 259 | output = method(*inputs) 260 | 261 | # ground truth sequences are not tuples 262 | if isinstance(output, tuple): 263 | output = list(output) 264 | 265 | tmp_result = output == in_outs["outputs"][index] 266 | if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: 267 | tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) 268 | 269 | # ground truth sequences are not tuples 270 | try: 271 | if isinstance(output[0], tuple): 272 | tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0]) 273 | except: 274 | True 275 | results.append(tmp_result) 276 | 277 | # reset the alarm 278 | signal.alarm(0) 279 | except Exception as e: 280 | signal.alarm(0) 281 | faulthandler.disable() 282 | print(f"Standard input runtime error or time limit exceeded error = {e}") 283 | results.append(-1) 284 | continue 285 | faulthandler.disable() 286 | signal.alarm(0) 287 | if debug: 288 | print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 289 | elif which_type == CODE_TYPE.standard_input: # Standard input 290 | faulthandler.enable() 291 | signal.alarm(timeout) 292 | passed = False 293 | 294 | if isinstance(inputs, list): 295 | inputs = "\n".join(inputs) 296 | if isinstance(in_outs['outputs'][index], list): 297 | in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index]) 298 | 299 | with Capturing() as output: 300 | try: 301 | call_method(method, inputs) 302 | # reset the alarm 303 | signal.alarm(0) 304 | passed = True 305 | except Exception as e: 306 | # runtime error or took too long 307 | signal.alarm(0) 308 | print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") 309 | results.append(-1) 310 | signal.alarm(0) 311 | 312 | if not passed: 313 | if debug: 314 | nl = "\n" 315 | if not isinstance(inputs, list): 316 | print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 317 | else: 318 | print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 319 | continue 320 | 321 | if passed and debug: 322 | print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}") 323 | 324 | if custom_compare_(output, in_outs['outputs'][index]): 325 | tmp_result = True 326 | results.append(tmp_result) 327 | continue 328 | 329 | # ground truth sequences are expressed as lists not tuples 330 | if isinstance(output, tuple): 331 | output = list(output) 332 | 333 | tmp_result = False 334 | try: 335 | tmp_result = (output == [in_outs["outputs"][index]]) 336 | if isinstance(in_outs["outputs"][index], list): 337 | tmp_result = tmp_result or (output == in_outs["outputs"][index]) 338 | if isinstance(output[0], str): 339 | tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index]) 340 | except Exception as e: 341 | print(f"Failed check1 exception = {e}") 342 | pass 343 | 344 | if tmp_result == True: 345 | results.append(tmp_result) 346 | continue 347 | 348 | # try one more time without \n 349 | if isinstance(in_outs["outputs"][index], list): 350 | for tmp_index, i in enumerate(in_outs["outputs"][index]): 351 | in_outs["outputs"][index][tmp_index] = i.split("\n") 352 | in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x] 353 | else: 354 | in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") 355 | in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) 356 | in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index])) 357 | 358 | try: 359 | tmp_result = (output == [in_outs["outputs"][index]]) 360 | if isinstance(in_outs["outputs"][index], list): 361 | tmp_result = tmp_result or (output == in_outs["outputs"][index]) 362 | except Exception as e: 363 | print(f"Failed check2 exception = {e}") 364 | pass 365 | 366 | if tmp_result == True: 367 | results.append(tmp_result) 368 | continue 369 | 370 | # try by converting the output into a split up list too 371 | if isinstance(output, list): 372 | output = list(filter(len, output)) 373 | 374 | if debug: 375 | nl = "\n" 376 | if not isinstance(inputs, list): 377 | print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 378 | else: 379 | print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 380 | 381 | if tmp_result == True: 382 | results.append(tmp_result) 383 | continue 384 | 385 | try: 386 | tmp_result = (output == [in_outs["outputs"][index]]) 387 | if isinstance(in_outs["outputs"][index], list): 388 | tmp_result = tmp_result or (output == in_outs["outputs"][index]) 389 | except Exception as e: 390 | print(f"Failed check3 exception = {e}") 391 | pass 392 | 393 | try: 394 | output_float = [float(e) for e in output] 395 | gt_float = [float(e) for e in in_outs['outputs'][index]] 396 | tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) 397 | except Exception as e: 398 | pass 399 | try: 400 | if isinstance(output[0], list): 401 | output_float = [float(e) for e in output[0]] 402 | gt_float = [float(e) for e in in_outs['outputs'][index][0]] 403 | tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) 404 | except Exception as e: 405 | pass 406 | 407 | if tmp_result == True: 408 | results.append(tmp_result) 409 | continue 410 | 411 | # try by converting the stuff into split up list 412 | if isinstance(in_outs["outputs"][index], list): 413 | for tmp_index, i in enumerate(in_outs["outputs"][index]): 414 | in_outs["outputs"][index][tmp_index] = set(i.split()) 415 | else: 416 | in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) 417 | 418 | try: 419 | tmp_result = (output == in_outs["outputs"][index]) 420 | except Exception as e: 421 | print(f"Failed check4 exception = {e}") 422 | continue 423 | 424 | if tmp_result == True: 425 | results.append(tmp_result) 426 | continue 427 | 428 | # try by converting the output into a split up list too 429 | if isinstance(output, list): 430 | for tmp_index, i in enumerate(output): 431 | output[tmp_index] = i.split() 432 | output = list(filter(len, output)) 433 | for tmp_index, i in enumerate(output): 434 | output[tmp_index] = set(i) 435 | else: 436 | output = output.split() 437 | output = list(filter(len, output)) 438 | output = set(output) 439 | 440 | try: 441 | tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index])) 442 | except Exception as e: 443 | print(f"Failed check5 exception = {e}") 444 | 445 | 446 | # if they are all numbers, round so that similar numbers are treated as identical 447 | try: 448 | tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\ 449 | set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index])) 450 | except Exception as e: 451 | print(f"Failed check6 exception = {e}") 452 | 453 | if tmp_result == True and debug: 454 | print("PASSED") 455 | 456 | results.append(tmp_result) 457 | 458 | if debug: 459 | nl = "\n" 460 | if not isinstance(inputs, list): 461 | print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 462 | else: 463 | print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") 464 | 465 | 466 | return results 467 | 468 | def custom_compare_(output, ground_truth): 469 | 470 | if isinstance(output, list): 471 | output_1 = "\n".join(output) 472 | if stripped_string_compare(output_1, ground_truth): 473 | return True 474 | 475 | if isinstance(output, list): 476 | output_2 = [o.lstrip().rstrip() for o in output] 477 | output_2 = "\n".join(output_2) 478 | if stripped_string_compare(output_2, ground_truth): 479 | return True 480 | 481 | return False 482 | 483 | def stripped_string_compare(s1, s2): 484 | s1 = s1.lstrip().rstrip() 485 | s2 = s2.lstrip().rstrip() 486 | return s1 == s2 487 | 488 | def call_method(method, inputs): 489 | 490 | if isinstance(inputs, list): 491 | inputs = "\n".join(inputs) 492 | 493 | inputs_line_iterator = iter(inputs.split("\n")) 494 | 495 | # sys.setrecursionlimit(10000) 496 | 497 | # @patch('builtins.input', side_effect=inputs.split("\n")) 498 | @patch('builtins.open', mock_open(read_data=inputs)) 499 | @patch('sys.stdin', StringIO(inputs)) 500 | @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) 501 | @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) 502 | @patch('sys.stdin.read', lambda *args: inputs) 503 | # @patch('sys.stdout.write', print) 504 | def _inner_call_method(_method): 505 | try: 506 | return _method() 507 | except SystemExit as e: 508 | pass 509 | finally: 510 | pass 511 | return _inner_call_method(method) 512 | 513 | def reliability_guard(maximum_memory_bytes=None): 514 | """ 515 | source: https://github.com/openai/human-eval 516 | This disables various destructive functions and prevents the generated code 517 | from interfering with the test (e.g. fork bomb, killing other processes, 518 | removing filesystem files, etc.) 519 | WARNING 520 | This function is NOT a security sandbox. Untrusted code, including, model- 521 | generated code, should not be blindly executed outside of one. See the 522 | Codex paper for more information about OpenAI's code sandbox, and proceed 523 | with caution. 524 | """ 525 | 526 | if maximum_memory_bytes is not None: 527 | import resource 528 | 529 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 530 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 531 | if not platform.uname().system == "Darwin": 532 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 533 | 534 | faulthandler.disable() 535 | 536 | import builtins 537 | 538 | builtins.exit = None 539 | builtins.quit = None 540 | 541 | import os 542 | 543 | os.environ["OMP_NUM_THREADS"] = "1" 544 | 545 | os.kill = None 546 | os.system = None 547 | os.putenv = None 548 | os.remove = None 549 | os.removedirs = None 550 | os.rmdir = None 551 | os.fchdir = None 552 | os.setuid = None 553 | os.fork = None 554 | os.forkpty = None 555 | os.killpg = None 556 | os.rename = None 557 | os.renames = None 558 | os.truncate = None 559 | os.replace = None 560 | os.unlink = None 561 | os.fchmod = None 562 | os.fchown = None 563 | os.chmod = None 564 | os.chown = None 565 | os.chroot = None 566 | os.fchdir = None 567 | os.lchflags = None 568 | os.lchmod = None 569 | os.lchown = None 570 | os.getcwd = None 571 | os.chdir = None 572 | 573 | import shutil 574 | 575 | shutil.rmtree = None 576 | shutil.move = None 577 | shutil.chown = None 578 | 579 | import subprocess 580 | 581 | subprocess.Popen = None # type: ignore 582 | 583 | __builtins__["help"] = None 584 | 585 | import sys 586 | 587 | sys.modules["ipdb"] = None 588 | sys.modules["joblib"] = None 589 | sys.modules["resource"] = None 590 | sys.modules["psutil"] = None 591 | sys.modules["tkinter"] = None 592 | 593 | def main(args): 594 | print(args) 595 | problem_list = sorted(get_valid_problems(args.source)) 596 | prob_index = args.number 597 | 598 | # This checks it correctly loaded. remove this later 599 | assert prob_index < len(problem_list) 600 | 601 | if args.data == "q" or args.data == "question": 602 | tmp = get_question(problem_list, prob_index) 603 | print("q", tmp) 604 | elif args.data in ["solutions", "sol", "s",]: 605 | tmp = get_solutions(problem_list, prob_index) 606 | print("sol", tmp) 607 | elif args.data == "starter": 608 | tmp = get_starter(problem_list, prob_index) 609 | print("starter", tmp) 610 | elif args.data in ["test", "t"]: 611 | # test it with sols 612 | sols = get_solutions(problem_list, prob_index) 613 | tmp = run_test(problem_list, prob_index, test=sols[0]) 614 | 615 | print("results = ", tmp) 616 | print("-2 = compile error, -1 is runtime error, False failed test, True passed test") 617 | 618 | if __name__ == "__main__": 619 | args = parse_args() 620 | main(args) 621 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Used for evaluation and some common for train too 2 | numpy 3 | pyext>=0.7 4 | sacrebleu # we used 1.5.1 5 | sacremoses # we used 0.0.45 6 | torch>=1.7 7 | transformers>=4 8 | psutil # we used 5.7.0 9 | 10 | # Used during training 11 | deepspeed>=0.4.0 12 | tensorboardX>=2.2 13 | -------------------------------------------------------------------------------- /train/CustomTensorboardCallback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import psutil 3 | import transformers 4 | from tensorboardX import SummaryWriter 5 | 6 | def get_system_info(): 7 | this = psutil.Process(os.getpid()) 8 | mem_usage_bytes = this.memory_info().rss 9 | mem_usage_gb = mem_usage_bytes / (1024 ** 3) 10 | 11 | total_mem_usage_gb = psutil.virtual_memory().used / (1024 ** 3) 12 | total_mem_usage_percent = psutil.virtual_memory().percent 13 | 14 | return { 15 | 'system/proc_mem_usage_gb' : mem_usage_gb, 16 | 'system/total_mem_usage_gb' : total_mem_usage_gb, 17 | 'system/total_mem_usage_percent' : total_mem_usage_percent 18 | } 19 | 20 | 21 | logger = transformers.utils.logging.get_logger(__name__) 22 | 23 | def rewrite_logs(d): 24 | new_d = {} 25 | eval_prefix = "eval_" 26 | eval_prefix_len = len(eval_prefix) 27 | for k, v in d.items(): 28 | if k.startswith(eval_prefix): 29 | new_d["eval/" + k[eval_prefix_len:]] = v 30 | else: 31 | new_d["train/" + k] = v 32 | return new_d 33 | 34 | 35 | class CustomTensorBoardCallback(transformers.trainer_callback.TrainerCallback): 36 | """ 37 | A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard 38 | `__. 39 | 40 | Args: 41 | tb_writer (:obj:`SummaryWriter`, `optional`): 42 | The writer to use. Will instantiate one if not set. 43 | """ 44 | 45 | def __init__(self, tb_writer=None): 46 | self.tb_writer = tb_writer 47 | 48 | def _init_summary_writer(self, args, log_dir=None): 49 | log_dir = log_dir or args.logging_dir 50 | self.tb_writer = SummaryWriter(log_dir=log_dir) 51 | 52 | def on_train_begin(self, args, state, control, **kwargs): 53 | if not state.is_world_process_zero: 54 | return 55 | 56 | log_dir = None 57 | 58 | if state.is_hyper_param_search: 59 | trial_name = state.trial_name 60 | if trial_name is not None: 61 | log_dir = os.path.join(args.logging_dir, trial_name) 62 | 63 | self._init_summary_writer(args, log_dir) 64 | 65 | if self.tb_writer is not None: 66 | self.tb_writer.add_text("args", args.to_json_string()) 67 | if "model" in kwargs: 68 | model = kwargs["model"] 69 | if hasattr(model, "config") and model.config is not None: 70 | model_config_json = model.config.to_json_string() 71 | self.tb_writer.add_text("model_config", model_config_json) 72 | # Version of TensorBoard coming from tensorboardX does not have this method. 73 | if hasattr(self.tb_writer, "add_hparams"): 74 | self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) 75 | 76 | def on_log(self, args, state, control, logs=None, **kwargs): 77 | 78 | logs = rewrite_logs(logs) 79 | logs.update(get_system_info()) 80 | 81 | if state.is_world_process_zero: 82 | if self.tb_writer is None: 83 | self._init_summary_writer(args) 84 | 85 | if self.tb_writer: 86 | for k, v in logs.items(): 87 | if isinstance(v, (int, float)): 88 | self.tb_writer.add_scalar(k, v, state.global_step) 89 | else: 90 | logger.warning( 91 | "Trainer is attempting to log a value of " 92 | '"%s" of type %s for key "%s" as a scalar. ' 93 | "This invocation of Tensorboard's writer.add_scalar() " 94 | "is incorrect so we dropped this attribute.", 95 | v, 96 | type(v), 97 | k, 98 | ) 99 | self.tb_writer.flush() 100 | 101 | def on_train_end(self, args, state, control, **kwargs): 102 | if self.tb_writer: 103 | self.tb_writer.close() 104 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | ## How to train 4 | 5 | First use `apps_create_split.py` to create the `train.json` and `test.json`. Note the paths specified in `apps_create_split.py` should point to relative paths from training directory or absolute paths. 6 | 7 | We use the following command to run and train. Note the configuration file is called deepspeed_config.json. 8 | 9 | USE_TF=NO deepspeed tune_apps_gpt.py \ 10 | --save-dir=/path/to/save_dir \ 11 | --load=/path/to/model \ # Can be used to restart from checkpoint 12 | --apps-train-files ~/apps/train \ 13 | --apps-dataroot ~/apps/train/ \ 14 | --grad-acc-steps=8 \ 15 | --epochs=10 \ 16 | --fp16 \ 17 | --deepspeed deepspeed_config.json \ 18 | --batch-size-per-replica=2 19 | -------------------------------------------------------------------------------- /train/apps_create_split.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | 5 | def create_split(split="train", name="train"): 6 | paths = [] 7 | roots = sorted(os.listdir(split)) 8 | for folder in roots: 9 | root_path = os.path.join(split, folder) 10 | paths.append(root_path) 11 | 12 | 13 | with open(name+".json", "w") as f: 14 | json.dump(paths, f) 15 | 16 | return paths 17 | 18 | # insert path to train and test 19 | # path should be relative to root directory or absolute paths 20 | paths_to_probs = ["APPS/train", "APPS/test"] 21 | names = ["train", "test"] 22 | 23 | all_paths = [] 24 | for index in range(len(paths_to_probs)): 25 | all_paths.extend(create_split(split=paths_to_probs[index], name=names[index])) 26 | 27 | with open("train_and_test.json", "w") as f: 28 | print(f"Writing all paths. Length = {len(all_paths)}") 29 | json.dump(all_paths, f) 30 | -------------------------------------------------------------------------------- /train/dataset_apps/APPSBaseDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset to be used for APPS Training 3 | """ 4 | 5 | import torch 6 | import glob 7 | import logging 8 | import random 9 | import fnmatch 10 | 11 | from multiprocessing import Manager 12 | # from multiprocessing.shared_memory import ShareableList 13 | 14 | import dataset_lm.util as dsutil 15 | import numpy as np 16 | import gc 17 | import os 18 | import io 19 | 20 | import transformers 21 | 22 | from dataset_lm.reindent import run as run_reindent 23 | from tqdm import tqdm 24 | 25 | import json 26 | 27 | class APPSBaseDataset(torch.utils.data.Dataset): 28 | def __init__(self, dataroot, problem_dirs, mode, max_tokens, sample_mode): 29 | self.dataroot = dataroot 30 | self.problem_dirs = problem_dirs # Loaded from train/test split json files 31 | 32 | self.mode = mode 33 | self.sample_mode = sample_mode # Either "uniform_sol" or "uniform_prob" 34 | self.max_tokens = max_tokens 35 | 36 | self.samples = [] # Should be set in initialize() 37 | self.initialize() 38 | 39 | if ('EleutherAI' in mode or '2700' in mode): 40 | self.tokenizer = transformers.GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") 41 | elif 'gpt' in self.mode: # Should handle GPT-2 and GPT-Neo 42 | self.tokenizer = transformers.GPT2Tokenizer.from_pretrained(mode) 43 | elif self.mode in {'codebert'}: 44 | self.tokenizer = transformers.RobertaTokenizer.from_pretrained("microsoft/codebert-base") 45 | else: 46 | raise NotImplementedError() 47 | 48 | 49 | def initialize(self): 50 | """ 51 | Assume self.dataroot is set to folderName/data 52 | """ 53 | 54 | all_samples = [] 55 | skipped_problems = [] 56 | 57 | all_samples_dict = {} # Mapping from question_fname to list of samples 58 | 59 | print(f"Loading {len(self.problem_dirs)} problems from {self.dataroot}.") 60 | for problem_name in tqdm(self.problem_dirs): 61 | question_fname = os.path.join(self.dataroot, problem_name, "question.txt") 62 | sols_fname = os.path.join(self.dataroot, problem_name, "solutions.json") 63 | starter_code = os.path.join(self.dataroot, problem_name, "starter_code.py") 64 | 65 | # print(question_fname) 66 | 67 | if os.path.exists(starter_code): 68 | answer_type = "\nUse Call-Based format\n" 69 | else: 70 | answer_type = "\nUse Standard Input format\n" 71 | 72 | if (not os.path.isfile(question_fname)) or (not os.path.isfile(sols_fname)): 73 | skipped_problems.append(problem_name) 74 | continue 75 | 76 | if (os.path.isfile(starter_code)): 77 | with open(starter_code, 'r') as f: 78 | starter_code = f.read() 79 | else: 80 | starter_code = "" 81 | 82 | # Read the question description 83 | with open(question_fname, 'r') as f: 84 | question_str = f.read() 85 | 86 | # Read all the solutions 87 | with open(sols_fname, 'r') as f: 88 | sols_str_list = json.load(f) 89 | for sol_str in sols_str_list: 90 | sol_str = reindent_code(sol_str) 91 | sample = (question_str, starter_code, sol_str, answer_type) 92 | 93 | all_samples.append(sample) 94 | if question_str in all_samples_dict: 95 | all_samples_dict[question_str].append(sample) 96 | else: 97 | all_samples_dict[question_str] = [sample] 98 | 99 | print(f"Loaded {len(all_samples)} saamples from {self.dataroot}.") 100 | print(f"Skipped {len(skipped_problems)} problems from {self.dataroot}.") 101 | self.samples = all_samples 102 | self.samples_dict = all_samples_dict 103 | 104 | 105 | def __len__(self): 106 | return len(self.samples) 107 | 108 | 109 | def pack_samples(self, idx): 110 | """ 111 | Repeatedly pick question, answer pairs from self.dataroot until we hit max_tokens. 112 | This will not include the tokens for the QUESTION and ANSWER prompt, as well as the 113 | self.question_prefix. These will be added later and the total input will be 114 | truncated if necessary. 115 | 116 | Always include the sample at idx at the beginning. 117 | """ 118 | curr_num_tokens = 0 119 | curr_samples = [] 120 | 121 | if self.sample_mode == 'uniform_sol': 122 | curr_q, curr_s, curr_a, curr_q_prefix = self.samples[idx] 123 | elif self.sample_mode == 'uniform_prob': 124 | curr_q = random.choice(list(self.samples_dict.keys())) 125 | curr_q, curr_s, curr_a, curr_q_prefix = random.choice(self.samples_dict[curr_q]) 126 | else: 127 | raise NotImplementedError() 128 | 129 | while curr_num_tokens < self.max_tokens: 130 | 131 | # Never remove. Fixes stalling bug. 132 | curr_q = curr_q[:150000] 133 | curr_s = curr_s[:150000] 134 | curr_a = curr_a[:150000] 135 | 136 | if self.mode in {'codebert'}: 137 | curr_q = curr_q.replace('\t', '\0') 138 | curr_s = curr_s.replace('\t', '\0') 139 | curr_a = curr_a.replace('\t', '\0') 140 | 141 | curr_num_tokens += len(self.tokenizer.tokenize(curr_q)) 142 | curr_num_tokens += len(self.tokenizer.tokenize(curr_s)) 143 | curr_num_tokens += len(self.tokenizer.tokenize(curr_a)) 144 | 145 | curr_samples.append((curr_q, curr_s, curr_a, curr_q_prefix)) 146 | 147 | if self.sample_mode == 'uniform_sol': 148 | curr_q, curr_s, curr_a, curr_q_prefix = random.choice(self.samples) 149 | elif self.sample_mode == 'uniform_prob': 150 | curr_q = random.choice(list(self.samples_dict.keys())) 151 | curr_q, curr_s, curr_a, curr_q_prefix = random.choice(self.samples_dict[curr_q]) 152 | else: 153 | raise NotImplementedError() 154 | 155 | return curr_samples 156 | 157 | def __getitem__(self, idx): 158 | 159 | raw_samples = self.pack_samples(idx) 160 | 161 | if 'gpt' in self.mode: 162 | retval = sample_gpt_task( 163 | raw_samples, 164 | max_tokens=self.max_tokens, 165 | tokenizer=self.tokenizer, 166 | ) 167 | elif self.mode in {'codebert'}: 168 | retval = sample_gpt_task( 169 | raw_samples, 170 | max_tokens=self.max_tokens, 171 | tokenizer=self.tokenizer, 172 | ) 173 | else: 174 | raise NotImplementedError() 175 | 176 | gc.collect() 177 | return retval 178 | 179 | def sample_gpt_task(raw_samples, max_tokens, tokenizer): 180 | """ 181 | Create the true sample used for the GPT task 182 | """ 183 | 184 | input_ids = [] 185 | label_ids = [] 186 | 187 | for q_str, s_str, a_str, answer_type in raw_samples: 188 | 189 | # Loss is not calculated on this 190 | q_str = "\nQUESTION:\n" + q_str + "\n" + s_str + "\n" + answer_type + "\nANSWER:\n" 191 | 192 | question_token_ids = tokenizer.encode(q_str, verbose=False) 193 | answer_token_ids = tokenizer.encode(a_str, verbose=False) 194 | answer_token_ids.append(tokenizer.eos_token_id) 195 | 196 | input_ids.extend(question_token_ids) 197 | input_ids.extend(answer_token_ids) 198 | 199 | label_ids.extend([-100] * len(question_token_ids)) 200 | label_ids.extend(answer_token_ids) 201 | 202 | # Sanity check 203 | assert len(input_ids) == len(label_ids) 204 | 205 | if len(input_ids) < max_tokens: 206 | print(len(input_ids)) 207 | import pdb; pdb.set_trace() 208 | 209 | # Cut off the excess 210 | input_ids = input_ids[:max_tokens] 211 | label_ids = label_ids[:max_tokens] 212 | 213 | return { 214 | "input_ids" : torch.LongTensor(input_ids), 215 | "labels" : torch.LongTensor(label_ids) 216 | } 217 | 218 | 219 | def reindent_code(codestr): 220 | """ 221 | Given code string, reindent it in the same way that the 222 | Github dataset was indented 223 | """ 224 | codestr = io.StringIO(codestr) 225 | ret = io.StringIO() 226 | 227 | run_reindent( 228 | codestr, 229 | ret, 230 | config = { 231 | "dry-run": False, 232 | "help": False, 233 | "to": 4, 234 | "from": -1, 235 | "tabs": True, 236 | "encoding": "utf-8", 237 | "is-tabs": False, 238 | "tabsize": 4, 239 | "all-tabs": False 240 | } 241 | ) 242 | 243 | return ret.getvalue() 244 | 245 | 246 | if __name__ == '__main__': 247 | import json 248 | 249 | # Do sanity checking 250 | with open("~/apps/data_split/train.json") as f: 251 | fnames = json.load(f) 252 | 253 | tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') 254 | dataset = APPSBaseDataset( 255 | dataroot='~/apps/', 256 | problem_dirs=fnames, 257 | mode='gpt2', 258 | max_tokens=1024 259 | ) 260 | 261 | e = dataset[0] 262 | print(e) 263 | print("------- input_ids ------------------------------------------------------------------------------------") 264 | print(tokenizer.decode(e['input_ids'])) 265 | print("------- labels ------------------------------------------------------------------------------------") 266 | labels = e['labels'] 267 | labels[labels == -100] = tokenizer.eos_token_id 268 | labels_str = tokenizer.decode(labels) 269 | print(labels_str) 270 | 271 | import pdb; pdb.set_trace() 272 | -------------------------------------------------------------------------------- /train/dataset_lm/base_lm_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset to be used for Language Modelling 3 | """ 4 | 5 | import torch 6 | import glob 7 | import logging 8 | import random 9 | import fnmatch 10 | 11 | from multiprocessing import Manager 12 | # from multiprocessing.shared_memory import ShareableList 13 | 14 | import dataset_lm.util as dsutil 15 | import numpy as np 16 | import gc 17 | import os 18 | import time 19 | 20 | import transformers 21 | 22 | class BaseLMDataset(torch.utils.data.Dataset): 23 | """Configurable LMDataset. 24 | """ 25 | 26 | def __init__(self, dataroots, mode, max_tokens, mask_probability=None, english_data=None): 27 | """Initializes the dataset with given configuration. 28 | Args: 29 | dataroot: str 30 | Glob format data. 31 | """ 32 | self.dataroots = dataroots 33 | self.mode = mode 34 | self.max_tokens = max_tokens 35 | self.mask_probability = mask_probability 36 | 37 | self.start_iteration = 0 # Set elsewhere right before training 38 | 39 | if self.mode == 'dummy': 40 | self.num_examples = 1000000 41 | else: 42 | 43 | if self.mode in {'gpt2', 'gpt2-medium'}: 44 | self.tokenizer = transformers.GPT2Tokenizer.from_pretrained(mode) 45 | elif self.mode in {'facebook/bart-large'}: 46 | self.tokenizer = transformers.BartTokenizer.from_pretrained(mode) 47 | else: 48 | raise NotImplementedError() 49 | 50 | # Fixes some memory leak issues 51 | # https://gist.github.com/mprostock/2850f3cd465155689052f0fa3a177a50 52 | # https://gist.github.com/vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 53 | manager = Manager() 54 | 55 | # Ensure ordering since we want to be able to resume in the middle of training 56 | # - glob.glob() does not guarantee the same ordering across machines or arcoss runs 57 | # - sorting() guarantees ordering but might give us entire batches from the same git repo. 58 | # - Setting random seed before shuffling should be reproducible across machines. 59 | l = [] 60 | for dataroot_info in self.dataroots: 61 | globstr = dataroot_info['globstr'] 62 | print(f"Loading globstr {globstr}") 63 | l.extend(glob.glob(globstr)) 64 | 65 | l = sorted(l) 66 | random.seed(1234) 67 | random.shuffle(l) 68 | 69 | self.all_files = manager.list(l) 70 | del l 71 | 72 | self.num_examples = len(self.all_files) 73 | print(f"Found {self.num_examples} training examples") 74 | 75 | self.english_data = english_data 76 | print(f"English data has {len(self.english_data)} samples") 77 | 78 | def _get_english_fraction(self, filename): 79 | """ 80 | Given a filename, return the english fraction for that filename 81 | """ 82 | for dataroot_info in self.dataroots: 83 | globstr = dataroot_info['globstr'] 84 | english_frac = dataroot_info['english_frac'] 85 | if fnmatch.fnmatch(filename, globstr): 86 | return english_frac 87 | raise RuntimeError(f"{filename} does not match any globstr.") 88 | 89 | def _get_english_sample(self): 90 | sample_str = "" 91 | curr_num_tokens = 0 92 | while curr_num_tokens < self.max_tokens: 93 | rand_index = random.randint(0, len(self.english_data) - 10000) 94 | sample_str += self.english_data[rand_index]['text'] 95 | # print(f"{os.getpid()}: _get_english_sample 1") 96 | curr_num_tokens += len(self.tokenizer.tokenize(sample_str)) 97 | # print(f"{os.getpid()}: _get_english_sample 2") 98 | rand_index += 1 99 | return sample_str 100 | 101 | def __len__(self): 102 | return self.num_examples - self.start_iteration 103 | 104 | def __getitem__(self, index): 105 | # Each worker needs a different seed.... 106 | random.seed(os.getpid() + time.time()) 107 | 108 | index = index + self.start_iteration 109 | 110 | if self.mode == 'dummy': 111 | return dsutil.dummy_gpt_task( 112 | max_tokens=self.max_tokens 113 | ) 114 | 115 | # Get a file from self.all_files 116 | fname = self.all_files[index] 117 | english_frac = self._get_english_fraction(fname) 118 | if random.random() < english_frac: 119 | # Use English data 120 | sample_str = self._get_english_sample() 121 | else: 122 | with open(fname, 'r') as f: 123 | sample_str = f.read() 124 | 125 | # Never remove. Fixes stalling bug. 126 | sample_str = sample_str[:150000] 127 | 128 | if self.mode in {'gpt2', 'gpt2-medium'}: 129 | retval = dsutil.batch_gpt_task( 130 | sample_str, 131 | max_tokens=self.max_tokens, 132 | tokenizer=self.tokenizer, 133 | ) 134 | elif self.mode in {'facebook/bart-large'}: 135 | retval = dsutil.batch_bart_task( 136 | sample_str, 137 | max_tokens=self.max_tokens, 138 | tokenizer=self.tokenizer, 139 | mask_probability=self.mask_probability 140 | ) 141 | else: 142 | raise NotImplementedError() 143 | 144 | gc.collect() 145 | return retval 146 | 147 | 148 | if __name__ == '__main__': 149 | 150 | from datasets import load_dataset 151 | 152 | print("Loading english data") 153 | english_data = load_dataset( 154 | 'wikipedia', 155 | '20200501.en', 156 | beam_runner='DirectRunner', 157 | cache_dir='/data/hendrycks/english_datasets', 158 | split='train' 159 | ) 160 | english_data.set_format(type=None, columns=['text']) 161 | print("Loaded english data") 162 | 163 | dataroots = [] 164 | dataroots.append({ 165 | "globstr" : '/data/sauravkadavath/code_datasets/stackoverflow/cleaned_noTagFilter/*.txt', 166 | "english_frac" : 0.0 167 | }) 168 | dataroots.append({ 169 | "globstr" : '/data/sauravkadavath/code_datasets/github_scraped_noempty_fixspacing_GPT_MaxLen1024_Packed_Cleaned_12.22.2020/worker_0/*.txt', 170 | "english_frac" : 0.15 171 | }) 172 | 173 | tokenizer = transformers.BartTokenizer.from_pretrained('facebook/bart-large') 174 | train_data = BaseLMDataset( 175 | dataroots=dataroots, 176 | mode='facebook/bart-large', 177 | max_tokens=1024, 178 | mask_probability=0.15, # GPT does not need masking 179 | english_data=english_data 180 | ) 181 | 182 | import pdb; pdb.set_trace() 183 | -------------------------------------------------------------------------------- /train/dataset_lm/reindent.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Reindent files. 4 | """ 5 | 6 | from __future__ import print_function 7 | import sys 8 | import getopt 9 | import codecs 10 | import tempfile 11 | import shutil 12 | import os 13 | 14 | 15 | def _find_indentation(line, config): 16 | if len(line) and line[0] in (" ", "\t") and not line.isspace(): 17 | if line[0] == "\t": 18 | config['is-tabs'] = True 19 | # Find indentation 20 | i = 0 21 | for char in list(line): 22 | if char not in (" ", "\t"): 23 | break 24 | i += 1 25 | config["from"] = i 26 | 27 | 28 | def find_indentation(line, config): 29 | # Find indentation level used in file 30 | if config['from'] < 0: 31 | _find_indentation(line, config) 32 | 33 | if config['from'] >= 0: 34 | # Set old indent 35 | indent = " " if not config['is-tabs'] else "\t" 36 | indent = indent * config['from'] 37 | 38 | # Set new indent 39 | newindent = " " if not config['tabs'] else "\t" 40 | if not config['tabs']: 41 | newindent = newindent * config['to'] 42 | 43 | return indent, newindent 44 | 45 | # Continue to the next line, indentation not found 46 | return False 47 | 48 | 49 | def replace_inline_tabs(content, config): 50 | newcontent = "" 51 | imagined_i = 0 52 | for i in range(0, len(content)): 53 | char = content[i] 54 | if char == '\t': 55 | spaces = config['tabsize']-(imagined_i % config['tabsize']) 56 | newcontent += " " * spaces 57 | imagined_i += spaces 58 | else: 59 | newcontent += char 60 | imagined_i += 1 61 | return newcontent 62 | 63 | 64 | def run(fd_in, fd_out, config): 65 | while True: 66 | line = fd_in.readline() 67 | if not line: 68 | break 69 | line = line.rstrip('\r\n') 70 | 71 | # Find indentation style used in file if not set 72 | if config['from'] < 0: 73 | indent = find_indentation(line, config) 74 | if not indent: 75 | print(line, file=fd_out) 76 | continue 77 | indent, newindent = indent 78 | 79 | # Find current indentation level 80 | level = 0 81 | while True: 82 | whitespace = line[:len(indent) * (level + 1)] 83 | if whitespace == indent * (level + 1): 84 | level += 1 85 | else: 86 | break 87 | 88 | content = line[len(indent) * level:] 89 | if config['all-tabs']: 90 | content = replace_inline_tabs(content, config) 91 | 92 | line = (newindent * level) + content 93 | print(line, file=fd_out) 94 | 95 | 96 | def run_files(filenames, config): 97 | for filename in filenames: 98 | with codecs.open(filename, encoding=config['encoding']) as fd_in: 99 | if config['dry-run']: 100 | print("Filename: %s" % filename) 101 | fd_out = sys.stdout 102 | else: 103 | fd_out = tempfile.NamedTemporaryFile(mode='wb', delete=False) 104 | fd_out.close() 105 | fd_out = codecs.open(fd_out.name, "wb", encoding=config['encoding']) 106 | 107 | run(fd_in, fd_out, config) 108 | 109 | if not config["dry-run"]: 110 | fd_out.close() 111 | shutil.copy(fd_out.name, filename) 112 | os.remove(fd_out.name) 113 | 114 | 115 | def main(args): 116 | config = { 117 | "dry-run": False, 118 | "help": False, 119 | "to": 4, 120 | "from": -1, 121 | "tabs": False, 122 | "encoding": "utf-8", 123 | "is-tabs": False, 124 | "tabsize": 4, 125 | "all-tabs": False 126 | } 127 | possible_args = { 128 | "d": "dry-run", 129 | "h": "help", 130 | "t:": "to=", 131 | "f:": "from=", 132 | "n": "tabs", 133 | "e:": "encoding=", 134 | "s:": "tabsize=", 135 | "a": "all-tabs", 136 | } 137 | optlist, filenames = getopt.getopt( 138 | args[1:], 139 | "".join(possible_args.keys()), 140 | possible_args.values() 141 | ) 142 | 143 | shortargs, longargs = [], [] 144 | for shortarg in possible_args: 145 | shortargs.append(shortarg.rstrip(":")) 146 | longargs.append(possible_args[shortarg].rstrip("=")) 147 | 148 | for opt, val in optlist: 149 | opt = opt.lstrip("-") 150 | if opt in shortargs: 151 | opt = longargs[shortargs.index(opt)] 152 | if isinstance(config[opt], bool): 153 | config[opt] = True 154 | elif isinstance(config[opt], int): 155 | config[opt] = int(val) 156 | else: 157 | config[opt] = val 158 | 159 | if config['help']: 160 | help = """ 161 | Usage: %s [options] filename(s) 162 | Options: 163 | -h, --help Show this message 164 | -d, --dry-run Don't save anything, just print 165 | the result 166 | -t , --to Convert to this number of spaces 167 | (default: 4) 168 | -f , --from Convert from this number of spaces 169 | (default: auto-detect, will also 170 | detect tabs) 171 | -n, --tabs Don't convert indentation to spaces, 172 | convert to tabs instead. -t and 173 | --to will have no effect. 174 | -a, --all-tabs Also convert tabs used for alignment 175 | in the code (Warning: will replace 176 | all tabs in the file, even if inside 177 | a string) 178 | -s , --tabsize Set how many spaces one tab is 179 | (only has an effect on -a, default: 4) 180 | -e , --encoding Open files with specified encoding 181 | (default: utf-8) 182 | """ % args[0] 183 | 184 | # Also removes 8 leading spaces to remove our indentation 185 | print("\n".join([x[8:] for x in help[1:].split("\n")])) 186 | sys.exit(0) 187 | 188 | if filenames: 189 | run_files(filenames, config) 190 | else: 191 | run(sys.stdin, sys.stdout, config) 192 | 193 | if __name__ == "__main__": 194 | main(sys.argv) 195 | -------------------------------------------------------------------------------- /train/dataset_lm/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import os 10 | 11 | ################################################################# 12 | ### GPT-style denoising LM task 13 | ################################################################# 14 | 15 | def batch_gpt_task(sample, max_tokens, tokenizer): 16 | """ 17 | Take sample, which is a raw string, and then 18 | """ 19 | 20 | # print(f"{os.getpid()}: batch_gpt_task 1: {sample[:200]} END") 21 | sample_input_ids = torch.LongTensor(tokenizer.encode(sample, max_length=max_tokens, truncation=True)) 22 | # print(f"{os.getpid()}: batch_gpt_task 2") 23 | 24 | assert len(sample_input_ids) <= max_tokens 25 | 26 | N_pad_inputs = max_tokens - len(sample_input_ids) 27 | if N_pad_inputs > 0: 28 | sample_input_ids = F.pad(sample_input_ids, [0, N_pad_inputs], mode='constant', value=tokenizer.eos_token_id) 29 | 30 | target_ids = sample_input_ids.detach().clone() # Will be shifted right inside the model. 31 | target_ids[target_ids == tokenizer.eos_token_id] = -100 32 | 33 | # import pdb; pdb.set_trace() 34 | 35 | return { 36 | "input_ids" : sample_input_ids, 37 | "labels" : target_ids 38 | } 39 | 40 | def batch_bart_task(sample, max_tokens, tokenizer, mask_probability): 41 | 42 | sample_input_ids = torch.LongTensor(tokenizer.encode(sample, max_length=max_tokens, truncation=True)) 43 | 44 | N_pad_inputs = max_tokens - len(sample_input_ids) 45 | if N_pad_inputs > 0: 46 | sample_input_ids = F.pad(sample_input_ids, [0, N_pad_inputs], mode='constant', value=tokenizer.pad_token_id) 47 | 48 | mask = torch.bernoulli(torch.ones_like(sample_input_ids) * mask_probability) # 1's in mask_probability% of the places 49 | 50 | target_ids = sample_input_ids.detach().clone() 51 | target_ids[target_ids == tokenizer.pad_token_id] = -100 52 | target_ids[mask == 0] = -100 53 | 54 | sample_input_ids = (sample_input_ids * (1 - mask)) + (torch.ones_like(sample_input_ids) * tokenizer.mask_token_id * mask) # Mask input 55 | sample_input_ids = sample_input_ids.long() 56 | 57 | return { 58 | "input_ids" : sample_input_ids, 59 | "labels" : target_ids 60 | } 61 | 62 | def dummy_gpt_task(max_tokens): 63 | seq = torch.zeros((max_tokens)).long() 64 | return { 65 | "input_ids" : seq, 66 | "labels" : seq, 67 | "attention_mask" : torch.ones_like(seq) 68 | } 69 | 70 | ################################################################# 71 | ### T5-style denoising LM task 72 | ################################################################# 73 | 74 | def _T5_mask(sample, mask_probability, tokenizer): 75 | 76 | assert len(sample.shape) == 1 77 | 78 | mask = torch.bernoulli(torch.ones_like(sample) * mask_probability).bool() # 15 % are 1s 79 | 80 | new_sample = _T5_apply_mask(sample, mask, tokenizer) 81 | target = _T5_apply_mask(sample, torch.logical_not(mask), tokenizer) 82 | 83 | return new_sample, target 84 | 85 | def _T5_apply_mask(sample, mask, tokenizer, hide_sentinels=False): 86 | """ 87 | Applies T5's masking scheme to batch. From the paper: 88 | 89 | Inspired by BERT’s “masked language modeling” objective and the “word dropout” regularization technique 90 | (Bowman et al., 2015), we design an objective that randomly samples and then drops out 15% of tokens in the input 91 | sequence. All consecutive spans of dropped-out tokens are replaced by a single sentinel token. Each sentinel token 92 | is assigned a token ID that is unique to the sequence. The sentinel IDs are special tokens which are added to our 93 | vocabulary and do not correspond to any wordpiece. The target then corresponds to all of the dropped-out spans of 94 | tokens, delimited by the same sentinel tokens used in the input sequence plus a final sentinel token to mark the end of 95 | the target sequence. Our choices to mask consecutive spans of tokens and only predict dropped-out tokens were 96 | made to reduce the computational cost of pre-training.  97 | """ 98 | 99 | assert len(sample.shape) == 1 100 | 101 | sample_not_padding_tokens = torch.logical_not(torch.eq(sample, tokenizer.pad_token_id)) 102 | 103 | # Do masking. See below link for more info: 104 | # TODO: Right now, this is being done twice per mask. Move it out so it is only done once per mask? 105 | # https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117 106 | # Shift to the right 107 | prev_token_is_masked = F.pad(mask[:-1], (1, 0), mode='constant', value=0) 108 | first_mask_tokens = torch.logical_and(mask, torch.logical_not(prev_token_is_masked)) 109 | subsequent_mask_tokens = torch.logical_and(mask, prev_token_is_masked) 110 | # Magic formula. See https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_t5.py#L241 111 | # Note we do NOT need to subtract the number of tokens added with T5_new_tokens since 112 | # tokenizer.vocab_size does NOT include those. 113 | sentinel_idxs = tokenizer.vocab_size - torch.cumsum(first_mask_tokens, dim=0) 114 | 115 | sample = torch.where( 116 | torch.logical_and(first_mask_tokens, sample_not_padding_tokens), 117 | sentinel_idxs, 118 | sample 119 | ) 120 | sample = torch.masked_select(sample, torch.logical_not(subsequent_mask_tokens)) 121 | 122 | return sample 123 | 124 | def apply_mask_denoising(sample, max_tokens, tokenizer, mask_probability): 125 | """ 126 | Arguments: 127 | sample: string, already with bad characters replaced with T5_replace_chars() 128 | Returns: 129 | dict: With the input (raw string), input_ids (Tensor), labels (Tensor), attention_mask (Tensor) 130 | """ 131 | sample_input_ids = torch.LongTensor(tokenizer.encode(sample, padding='max_length', max_length=max_tokens, truncation=True)) 132 | 133 | masked_sample_input_ids, masked_sample_labels = _T5_mask(sample_input_ids, mask_probability, tokenizer) 134 | 135 | assert len(masked_sample_input_ids) <= max_tokens 136 | assert len(masked_sample_labels) <= max_tokens 137 | 138 | N_pad_inputs = max_tokens - len(masked_sample_input_ids) 139 | if N_pad_inputs > 0: 140 | masked_sample_input_ids = F.pad(masked_sample_input_ids, [0, N_pad_inputs], mode='constant', value=tokenizer.pad_token_id) 141 | 142 | N_pad_labels = max_tokens - len(masked_sample_labels) 143 | if N_pad_labels > 0: 144 | masked_sample_labels = F.pad(masked_sample_labels, [0, N_pad_labels], mode='constant', value=tokenizer.pad_token_id) 145 | 146 | attention_mask = ~ torch.eq(masked_sample_input_ids, tokenizer.pad_token_id) 147 | 148 | return { 149 | "raw_strings" : sample, 150 | "input_ids" : masked_sample_input_ids, 151 | "labels" : masked_sample_labels, 152 | "attention_mask" : attention_mask 153 | } 154 | 155 | 156 | ################################################################# 157 | ### Clasic BERT-style masked LM task 158 | ################################################################# 159 | 160 | def _BERT_mlm_mask(sample, mask_probability, tokenizer): 161 | mask = torch.bernoulli(torch.ones_like(sample) * mask_probability).bool() # 15 % are 1s 162 | sentinel_idxs = tokenizer.vocab_size - torch.ones_like(sample) 163 | 164 | new_sample = torch.where( 165 | mask, 166 | sentinel_idxs, 167 | sample 168 | ) 169 | 170 | target = torch.where( 171 | mask, 172 | sample, 173 | torch.ones_like(sample) * -100, 174 | ) 175 | 176 | return new_sample, target 177 | 178 | 179 | def apply_mask_bert_mlm(sample, max_tokens, tokenizer, mask_probability): 180 | """ 181 | Apply BERT-MLM-style masking to the given sample 182 | """ 183 | sample_input_ids = torch.LongTensor(tokenizer.encode(sample, padding='max_length', max_length=max_tokens, truncation=True)) 184 | 185 | masked_sample_input_ids, masked_sample_labels = _BERT_mlm_mask(sample_input_ids, mask_probability, tokenizer) 186 | 187 | assert len(masked_sample_input_ids) <= max_tokens 188 | assert len(masked_sample_labels) <= max_tokens 189 | 190 | N_pad_inputs = max_tokens - len(masked_sample_input_ids) 191 | if N_pad_inputs > 0: 192 | masked_sample_input_ids = F.pad(masked_sample_input_ids, [0, N_pad_inputs], mode='constant', value=tokenizer.pad_token_id) 193 | 194 | N_pad_labels = max_tokens - len(masked_sample_labels) 195 | if N_pad_labels > 0: 196 | masked_sample_labels = F.pad(masked_sample_labels, [0, N_pad_labels], mode='constant', value=tokenizer.pad_token_id) 197 | 198 | attention_mask = ~ torch.eq(masked_sample_input_ids, tokenizer.pad_token_id) 199 | 200 | return { 201 | "raw_strings" : sample, 202 | "input_ids" : masked_sample_input_ids, 203 | "labels" : masked_sample_labels, 204 | "attention_mask" : attention_mask 205 | } 206 | 207 | 208 | -------------------------------------------------------------------------------- /train/deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": true, 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "hysteresis": 2, 7 | "min_loss_scale": 1 8 | }, 9 | 10 | "zero_optimization": { 11 | "stage": 2, 12 | "allgather_partitions": true, 13 | "allgather_bucket_size": 1e8, 14 | "overlap_comm": true, 15 | "reduce_scatter": true, 16 | "reduce_bucket_size": 1e8, 17 | "contiguous_gradients": true, 18 | "cpu_offload": true 19 | }, 20 | 21 | "zero_allow_untested_optimizer": true, 22 | 23 | "steps_per_print": 2000, 24 | "wall_clock_breakdown": false, 25 | "dump_state": false, 26 | "train_batch_size": 8, 27 | 28 | "optimizer": { 29 | "type": "AdamW", 30 | "params": { 31 | "lr": 1e-4, 32 | "betas": [ 0.9, 0.999 ], 33 | "eps": 1e-8, 34 | "weight_decay": 0.05 35 | } 36 | }, 37 | 38 | "scheduler": { 39 | "type": "WarmupLR", 40 | "params": { 41 | "warmup_min_lr": 0, 42 | "warmup_max_lr": 1e-4, 43 | "warmup_num_steps": 500 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /train/tune_apps_gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tune LM on Code 3 | """ 4 | 5 | import io 6 | import logging 7 | import math 8 | import os 9 | import pprint 10 | import sys 11 | import time 12 | import json 13 | 14 | import transformers 15 | 16 | from tqdm import tqdm 17 | from datasets import load_dataset 18 | from datetime import datetime 19 | 20 | import torch 21 | import torch.distributed as dist 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.optim as optim 25 | import torch.multiprocessing as mp 26 | 27 | from dataset_lm.base_lm_dataset import BaseLMDataset 28 | from dataset_apps.APPSBaseDataset import APPSBaseDataset 29 | from CustomTensorboardCallback import CustomTensorBoardCallback 30 | 31 | # torch.set_num_threads(2) 32 | 33 | # https://github.com/pytorch/pytorch/issues/11201 34 | import torch.multiprocessing 35 | torch.multiprocessing.set_sharing_strategy('file_system') 36 | 37 | 38 | def run_training(args, train_data): 39 | 40 | ## Checkpoint Loading ######################################################## 41 | if args.load: 42 | if '2700' in args.load: 43 | model = transformers.GPTNeoForCausalLM.from_pretrained(args.load) 44 | else: 45 | model = transformers.GPT2LMHeadModel.from_pretrained(args.load) 46 | print(f"Loaded model from {args.load}") 47 | else: 48 | if "EleutherAI" in args.arch: 49 | model = transformers.GPTNeoForCausalLM.from_pretrained(args.arch) 50 | else: 51 | model = transformers.GPT2LMHeadModel.from_pretrained(args.arch) 52 | 53 | if args.resume: 54 | raise NotImplementedError 55 | model = transformers.GPT2LMHeadModel.from_pretrained(args.resume) 56 | print(f"Loaded model from {args.resume}") 57 | start_epoch = 0 58 | start_iteration = int(args.resume.split("-")[-1]) 59 | print("start_iteration = ", start_iteration) 60 | else: 61 | start_iteration = 0 62 | 63 | ## Dataloading ######################################################## 64 | train_data.start_iteration = start_iteration 65 | 66 | ## Start Loop ######################################################## 67 | print(f"Starting main loop") 68 | 69 | training_args = transformers.TrainingArguments( 70 | output_dir=args.save_dir, 71 | overwrite_output_dir=False, 72 | 73 | do_train=True, 74 | do_eval=False, 75 | do_predict=True, 76 | evaluation_strategy='no', 77 | eval_steps=0, 78 | 79 | num_train_epochs=args.epochs, 80 | per_device_train_batch_size=args.batch_size_per_replica, 81 | gradient_accumulation_steps=args.grad_acc_steps, 82 | 83 | learning_rate=args.lr, 84 | weight_decay=0.05, 85 | # warmup_steps=args.lr_warmup_steps, 86 | # max_grad_norm=100000.0, 87 | 88 | logging_dir=args.save_dir, 89 | logging_first_step=True, 90 | logging_steps=args.log_freq, 91 | save_steps=args.save_freq, 92 | save_total_limit=2, 93 | 94 | dataloader_drop_last=True, 95 | dataloader_num_workers=3, 96 | 97 | local_rank=args.local_rank, 98 | 99 | deepspeed=args.deepspeed, 100 | fp16=args.fp16, 101 | ) 102 | 103 | trainer = transformers.Trainer( 104 | model=model, 105 | args=training_args, 106 | train_dataset=train_data, 107 | ) 108 | trainer.remove_callback(transformers.integrations.TensorBoardCallback) 109 | trainer.add_callback(CustomTensorBoardCallback()) 110 | 111 | trainer.train() 112 | 113 | if args.local_rank == 0: 114 | model.save_pretrained(os.path.join(args.save_dir, "final_checkpoint")) 115 | 116 | 117 | def get_dataset(args): 118 | 119 | fnames = os.listdir(args.apps_train_files) 120 | 121 | train_data = APPSBaseDataset( 122 | dataroot=args.apps_dataroot, 123 | problem_dirs=fnames, 124 | mode=args.arch, 125 | max_tokens=2048 if ('EleutherAI' in args.arch or '2700' in args.load) else 1024, 126 | sample_mode=args.apps_sample_mode 127 | ) 128 | 129 | return train_data 130 | 131 | 132 | def main(args): 133 | 134 | argsdict = vars(args) 135 | print(pprint.pformat(argsdict)) 136 | 137 | os.makedirs(args.save_dir, exist_ok=True) 138 | 139 | train_data = get_dataset(args) 140 | 141 | # Save command to file 142 | with open(os.path.join(args.save_dir, "command.txt"), 'w') as f: 143 | f.write(pprint.pformat(argsdict)) 144 | 145 | run_training(args, train_data) 146 | 147 | 148 | if __name__ == "__main__": 149 | import argparse 150 | 151 | parser = argparse.ArgumentParser(description="Language Modelling on Code") 152 | parser.add_argument('--arch', default='gpt2', choices=transformers.GPT2_PRETRAINED_MODEL_ARCHIVE_LIST + ["EleutherAI/gpt-neo-2.7B"]) 153 | parser.add_argument('--dummy-model', action='store_true') 154 | parser.add_argument('--load', default=None, type=str) 155 | parser.add_argument('--resume', default=None, type=str) 156 | 157 | # Dataloading 158 | parser.add_argument('--apps-dataroot', default='../apps/', type=str) 159 | parser.add_argument('--apps-train-files', default='../apps/data_split/train.json', type=str) 160 | parser.add_argument('--apps-sample-mode', default='uniform_sol') 161 | 162 | # Training 163 | parser.add_argument('--epochs', default=10, type=int) 164 | parser.add_argument('--lr', default=5e-5, type=float) 165 | # parser.add_argument('--lr-warmup-steps', default=500, type=int) 166 | parser.add_argument('--batch-size-per-replica', default=8, type=int) 167 | parser.add_argument('--grad-acc-steps', default=4, type=int) 168 | parser.add_argument('--local_rank', default=-1, type=int) 169 | parser.add_argument('--deepspeed', default=None, type=str) 170 | parser.add_argument('--fp16', default=False, action='store_true') 171 | 172 | # Logging and stuff 173 | parser.add_argument('--save-dir', default="checkpoints/TEMP", type=str) 174 | parser.add_argument('--log-freq', default=5, type=int) 175 | parser.add_argument('--save-freq', default=200, type=int) 176 | 177 | args = parser.parse_args() 178 | 179 | args.save_dir = os.path.join(args.save_dir, datetime.now().strftime("%m-%d-%Y__%H:%M:%S")) 180 | 181 | main(args) 182 | --------------------------------------------------------------------------------