├── .gitignore ├── CITATION.cff ├── Convert2HF ├── README.md ├── convert.sh ├── convert_neox_pt_to_huggingface_neox.py ├── generate.py └── polycoder │ └── configs │ ├── config_0-4B.json │ ├── config_160M.json │ └── config_2-7B.json ├── Data ├── README.md ├── clone_repo.sh ├── code-merges.txt ├── code-vocab.json ├── collect_data.sh ├── deduplicate.py ├── extract_code.py ├── gh_crawler.py ├── requirements.txt └── yield_from_code_files.py ├── Evaluation └── eval_codex_all.py ├── LICENSE.md ├── README.md └── images └── fig6.png /.gitignore: -------------------------------------------------------------------------------- 1 | TopLists/ 2 | Code/ 3 | Repos/ 4 | Preprocessed/ -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | @article{xu2022systematic, 2 | title={A Systematic Evaluation of Large Language Models of Code}, 3 | author={Xu, Frank F and Alon, Uri and Neubig, Graham and Hellendoorn, Vincent J}, 4 | journal={arXiv preprint arXiv:2202.13169}, 5 | year={2022} 6 | } -------------------------------------------------------------------------------- /Convert2HF/README.md: -------------------------------------------------------------------------------- 1 | # Convert to HuggingFace 2 | This directory contains a script `convert_neox_pt_to_huggingface_neox.py` to convert PolyCoder checkpoints trained by [gpt-neox](https://github.com/EleutherAI/gpt-neox) into HuggingFace format, and a script `generate.py` to load the converted model and generate code from a given prompt. 3 | Shoutout to @NinedayWang for implementing this! 4 | 5 | ## Environment 6 | transformers 4.23.1 7 | 8 | ## Convert 9 | You can use the `convert.sh` script to convert specified model to the HuggingFace format, using `./convert.sh 0-4B` (or pass a different model size). This script in turn invokes `convert_neox_pt_to_huggingface_neox.py`, which you can also call directly as follows: 10 | ``` 11 | python convert_neox_pt_to_huggingface_neox.py \ 12 | --checkpoint_dir ../checkpoints/checkpoints-0-4B/global_step150000 \ 13 | --vocab_file ../Data/code-vocab.json \ 14 | --merge_file ../Data/code-merges.txt \ 15 | --hf_config_path ./polycoder/configs/config_0-4B.json \ 16 | --hf_save_dir ./polycoder/0-4B 17 | ``` 18 | HuggingFace configuration files for different size models are provided in `polycoder/configs/`, including `config_0-4B.json`, `config_2-7B.json` and `config_160M.json`. 19 | 20 | After running, you can get a complete HuggingFace model in the directory specified by `hf_save_dir`. If the directory does not exist, it can be built automatically. 21 | 22 | ## Generate 23 | The following is an example to load the converted 0.4B HuggingFace model and generate code from a given prompt: 24 | ``` 25 | python generate.py \ 26 | --model_name_or_path ./polycoder/0-4B \ 27 | --temperature 0.2 \ 28 | --top_p 0.95 \ 29 | --max_length 128 30 | ``` 31 | You can evaluate models of other sizes by specifying `model_name_or_path`. 32 | -------------------------------------------------------------------------------- /Convert2HF/convert.sh: -------------------------------------------------------------------------------- 1 | size=${1:-0-4B} 2 | 3 | python convert_neox_pt_to_huggingface_neox.py \ 4 | --checkpoint_dir ../checkpoints/checkpoints-${size}/global_step150000 \ 5 | --vocab_file ../Data/code-vocab.json \ 6 | --merge_file ../Data/code-merges.txt \ 7 | --hf_config_path ./polycoder/configs/config_${size}.json \ 8 | --hf_save_dir ./polycoder/${size} 9 | -------------------------------------------------------------------------------- /Convert2HF/convert_neox_pt_to_huggingface_neox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from collections import OrderedDict 5 | from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, GPT2Tokenizer 6 | 7 | 8 | def get_hf_state_dict_from_pt_files(checkpoint_dir): 9 | layer_files = [] 10 | for root, dirs, files in os.walk(checkpoint_dir): 11 | for file in files: 12 | if file.startswith("layer_"): 13 | # print(file) 14 | layer_files.append(os.path.join(root, file)) 15 | layer_files = sorted(layer_files) 16 | 17 | layer_id = -1 18 | state_dict = OrderedDict() 19 | for file in layer_files: 20 | print(f"Loading: {file}") 21 | new_layer = True 22 | 23 | module = torch.load(file, map_location=torch.device('cpu')) 24 | for key, value in module.items(): 25 | if "word_embeddings" in key: 26 | new_key = key.replace("word_embeddings", "gpt_neox.embed_in") 27 | state_dict[new_key] = value 28 | elif "_layernorm" in key or "attention" in key or "mlp" in key: 29 | if new_layer: 30 | layer_id += 1 31 | new_layer = False 32 | new_key = "gpt_neox.layers." + str(layer_id) + "." + key 33 | state_dict[new_key] = value 34 | elif key.startswith("norm."): 35 | new_key = "gpt_neox.final_layer_norm." + key.split(".")[-1] 36 | state_dict[new_key] = value 37 | elif "final_linear" in key: 38 | new_key = "embed_out." + key.split(".")[-1] 39 | state_dict[new_key] = value 40 | print(f"Convert \"{key}\" to \"{new_key}\"") 41 | 42 | return state_dict 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--checkpoint_dir", 48 | type=str, 49 | required=True, 50 | help="Directory that contains .pt files.") 51 | parser.add_argument("--vocab_file", 52 | type=str, 53 | required=True, 54 | help="Path to the vocab file.") 55 | parser.add_argument('--merge_file', 56 | type=str, 57 | required=True, 58 | help='Path to the BPE merge file.') 59 | parser.add_argument("--hf_config_path", 60 | type=str, 61 | required=True, 62 | help="Path to HuggingFace configuration file.") 63 | parser.add_argument("--hf_save_dir", 64 | type=str, 65 | required=True, 66 | help="Directory to save HuggingFace model.") 67 | args = parser.parse_args() 68 | 69 | config = GPTNeoXConfig.from_json_file(args.hf_config_path) 70 | 71 | model = GPTNeoXForCausalLM(config) 72 | state_dict = get_hf_state_dict_from_pt_files(args.checkpoint_dir) 73 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 74 | print(f"missing keys: {missing_keys}") 75 | print(f"unexpected keys: {unexpected_keys}") 76 | 77 | tokenizer = GPT2Tokenizer(args.vocab_file, args.merge_file) 78 | 79 | if not os.path.exists(args.hf_save_dir): 80 | os.makedirs(args.hf_save_dir) 81 | print(f"Save HuggingFace model to {args.hf_save_dir} ...") 82 | model.save_pretrained(args.hf_save_dir) 83 | tokenizer.save_pretrained(args.hf_save_dir) 84 | print(f"Finished.") 85 | -------------------------------------------------------------------------------- /Convert2HF/generate.py: -------------------------------------------------------------------------------- 1 | from transformers import GPTNeoXForCausalLM, GPT2Tokenizer 2 | import torch 3 | import argparse 4 | 5 | 6 | def load_model_and_generate(model_name_or_path, prompt, gen_kwargs): 7 | tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) 8 | model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path) 9 | 10 | encoded_input = tokenizer(prompt, return_tensors="pt") 11 | input_ids, attention_mask = encoded_input['input_ids'], encoded_input['attention_mask'] 12 | if torch.cuda.is_available(): 13 | model = model.cuda() 14 | input_ids = input_ids.cuda() 15 | attention_mask = attention_mask.cuda() 16 | 17 | prediction_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)[0] 18 | prediction_tokens = tokenizer.decode(prediction_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[len(prompt):] 19 | print(prompt + prediction_tokens) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument("--model_name_or_path", type=str, default="./polycoder/0-4B") 26 | parser.add_argument('--temperature', type=float, default=0.2) 27 | parser.add_argument('--top_p', type=float, default=0.95) 28 | parser.add_argument('--max_length', type=int, default=128) 29 | 30 | args = parser.parse_args() 31 | 32 | gen_kwargs = { 33 | "do_sample": True, 34 | "temperature": args.temperature, 35 | "max_length": args.max_length, 36 | "top_p": args.top_p, 37 | } 38 | 39 | prompt = "\ndef add(x: int, y: int):\n \"\"\"Add two numbers x and y\n >>> add(2, 3)\n 5\n >>> add(5, 7)\n 12\n \"\"\"\n" 40 | 41 | load_model_and_generate(args.model_name_or_path, prompt, gen_kwargs) 42 | -------------------------------------------------------------------------------- /Convert2HF/polycoder/configs/config_0-4B.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_act": "gelu", 3 | "architectures": [ 4 | "GPTNeoXForCausalLM" 5 | ], 6 | "bos_token_id": 0, 7 | "eos_token_id": 0, 8 | "initializer_range": 0.02, 9 | "layer_norm_eps": 1e-05, 10 | "model_type": "gpt_neox", 11 | "hidden_size": 1024, 12 | "intermediate_size": 4096, 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "max_position_embeddings": 2048, 16 | "rotary_pct": 1.0, 17 | "rotary_emb_base": 10000, 18 | "torch_dtype": "float16", 19 | "use_cache": true, 20 | "use_parallel_residual": false, 21 | "vocab_size": 50304, 22 | "transformers_version": "4.23.1", 23 | "tokenizer_class": "GPT2Tokenizer" 24 | } 25 | -------------------------------------------------------------------------------- /Convert2HF/polycoder/configs/config_160M.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_act": "gelu", 3 | "architectures": [ 4 | "GPTNeoXForCausalLM" 5 | ], 6 | "bos_token_id": 0, 7 | "eos_token_id": 0, 8 | "initializer_range": 0.02, 9 | "layer_norm_eps": 1e-05, 10 | "model_type": "gpt_neox", 11 | "hidden_size": 768, 12 | "intermediate_size": 3072, 13 | "num_attention_heads": 12, 14 | "num_hidden_layers": 12, 15 | "max_position_embeddings": 2048, 16 | "rotary_pct": 1.0, 17 | "rotary_emb_base": 10000, 18 | "torch_dtype": "float16", 19 | "use_cache": true, 20 | "use_parallel_residual": false, 21 | "vocab_size": 50304, 22 | "transformers_version": "4.23.1", 23 | "tokenizer_class": "GPT2Tokenizer" 24 | } 25 | -------------------------------------------------------------------------------- /Convert2HF/polycoder/configs/config_2-7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_act": "gelu", 3 | "architectures": [ 4 | "GPTNeoXForCausalLM" 5 | ], 6 | "bos_token_id": 0, 7 | "eos_token_id": 0, 8 | "initializer_range": 0.02, 9 | "layer_norm_eps": 1e-05, 10 | "model_type": "gpt_neox", 11 | "hidden_size": 2560, 12 | "intermediate_size": 10240, 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "max_position_embeddings": 2048, 16 | "rotary_pct": 1.0, 17 | "rotary_emb_base": 10000, 18 | "torch_dtype": "float16", 19 | "use_cache": true, 20 | "use_parallel_residual": false, 21 | "vocab_size": 50304, 22 | "transformers_version": "4.23.1", 23 | "tokenizer_class": "GPT2Tokenizer" 24 | } 25 | -------------------------------------------------------------------------------- /Data/README.md: -------------------------------------------------------------------------------- 1 | # Code Data Collection 2 | This directory contains scripts to [mine](#mining) a dataset of code similar to the one used to train [PolyCoder](https://arxiv.org/pdf/2202.13169.pdf), as well as details of that dataset. Note that because of the nature of the GH API, the exact results of each query will be different, so this will not precisely replicate the training data 3 | 4 | ## Mining 5 | Update `gh_crawler.py` by adding your GH API token (line 6). Then, run `collect_data.sh`, which invokes the GitHub API crawler (`gh_crawler.py`), followed by a repo cloning script (`clone_repo.sh`, in parallel), which uses `extract_code.py` to extract all source code files in the corresponding language (and filter very long/short files), and finally `deduplicate.py` to remove duplicate files. 6 | 7 | Once this is completed, you can use [gpt-neox](https://github.com/EleutherAI/gpt-neox)'s `preprocess_data.py` (currently in `tools/`) to tokenize this dataset for the model, using a either the pretrained code vocabularies by providing the `code-vocab.json` and `code-merges.txt` files, or producing a new one. 8 | 9 | At the time of this writing*, the following command processes the entire `Code/` directory to a new directory named `Preprocessed/` using the pretrained vocabularies across 16 parallel workers (assuming that `gpt-neox` is checked out in the current directory): 10 | ``` 11 | mkdir Preprocessed 12 | sudo python3 gpt-neox/tools/preprocess_data.py --input Code --tokenizer-type GPT2BPETokenizer --vocab-file code-vocab.json --merge-file code-merges.txt --output-prefix Preprocessed/code --workers 16 13 | ``` 14 | And that's it! Just modify the `local_setup.yml` config in the gpt-neox toolkit to point it to the new vocab & merges file and data directory and it should be able to train. 15 | 16 | *I did have to modify the `yield_from_files` function to recursively yield all (shuffled) files from a directory; the default version uses `lm_dataformat`, which balks at code file extensions. The updated function can be found in `yield_from_code_files.py`. 17 | 18 | ## PolyCoder Data 19 | The approach above was used to collect 249GB of multi-lingual training data to train [PolyCoder](https://arxiv.org/pdf/2202.13169.pdf) -- see the paper and top-level directory for details. Because of the ever-changing nature of repos on GitHub, running the above won't get you back the exact data, which is quite fine for most purposes (our training run didn't even use all of it), but it's naturally useful to know what data we used. We therefore release a list of all files used for training and their SHA-256 hash in [this file](https://zenodo.org/record/6341643/files/index.zip) (warning: zipped, still large), formatted as `{language}__{organization}__{project}__{full__file__path}\tSHA` (using double underscores instead of slashes in the file path). 20 | 21 | To check whether a file was used during training, I strongly encourage considering not just its path but also its hashed contents. Files are often duplicated verbatim across and within projects. The following Python code was used to create the hash values in the file above, which allows fast deduplication of a new file against the set of all hashes used in our training data: 22 | 23 | ```python 24 | import hashlib 25 | with open(file_path, 'rb') as f: 26 | bytes = f.read() 27 | hash = hashlib.sha256(bytes).hexdigest(); 28 | ``` 29 | -------------------------------------------------------------------------------- /Data/clone_repo.sh: -------------------------------------------------------------------------------- 1 | # Clone a given repository, extract any files belonging to the given language, and delete the repository afterwards to save space. 2 | in=$1 3 | language=$2 4 | 5 | # Extract the org and name from lines formatted as stars\thttps://github.com/org/name 6 | repo=$(echo $in | cut -d$'\t' -f2); 7 | name_part=$(echo $repo | cut -d"/" -f4-6); 8 | name=$(echo $name_part | cut -d"/" -f2); 9 | org=$(echo $name_part | cut -d"/" -f1); 10 | echo "Cloning $org/$name" 11 | DIR=Repos/$language/$org; \ 12 | OUT=Code/$language/$org; \ 13 | # Skip repositories for which we already have extracted code files. 14 | if [ -d $OUT/$name ]; then echo "deja vu"; exit; fi; 15 | mkdir -p $DIR; \ 16 | mkdir -p $OUT; \ 17 | 18 | # Clone with depth=1 to only get most recent files, rather than entire history. 19 | if [ ! -d $DIR/$name ]; then 20 | git clone -q --depth 1 https://github.com/$org/$name $DIR/$name; 21 | fi; 22 | 23 | # Extract all language-specific code files from the repository and delete it afterwards. 24 | python3 extract_code.py $language $DIR/$name $OUT/$name; 25 | rm -rf $DIR/$name -------------------------------------------------------------------------------- /Data/collect_data.sh: -------------------------------------------------------------------------------- 1 | # Hand-picked set of languages. 2 | langs=("C" "C#" "C++" "Go" "Java" "JavaScript" "PHP" "Python" "Ruby" "Rust" "Scala" "TypeScript") 3 | 4 | if [ ! -d TopLists ]; then 5 | mkdir TopLists; 6 | fi 7 | 8 | # Install required Python packages. 9 | pip install -r requirements.txt 10 | 11 | # Collect 25K repos with at least 50 stars. 12 | # NOTE: the GH API neither guarantees nor (remotely) achieves completeness or consistency, so the resulting set of repositories will be different on each run. 13 | # NOTE: make sure to insert your GH API key into the gh_crawler.py file. 14 | for lang in ${langs[@]}; do 15 | python3 gh_crawler.py $lang; 16 | done 17 | 18 | # Clone repositories in parallel and extract all language-specific files. 19 | for lang in ${langs[@]}; do 20 | cat 'TopLists/'$lang'-top-repos.txt' | xargs -P16 -n1 -I% bash clone_repo.sh % $lang 21 | done 22 | 23 | # Deduplicate code files. 24 | python3 deduplicate.py 25 | -------------------------------------------------------------------------------- /Data/deduplicate.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | ROOT = 'Code' # NOTE: hard-coded. 5 | seen = set() 6 | count = 0 7 | dups = 0 8 | 9 | for root_dir, _, files in os.walk(ROOT): 10 | for file in files: 11 | count += 1 12 | file_path = os.path.join(root_dir, file) 13 | # Hash the entire file's content. 14 | with open(file_path, 'rb') as f: 15 | bytes = f.read() 16 | hash = hashlib.sha256(bytes).hexdigest() 17 | 18 | # Delete identical files. 19 | if hash in seen: 20 | os.remove(file_path) 21 | dups += 1 22 | else: 23 | seen.add(hash) 24 | 25 | # Periodically print progress and the running duplication ratio. 26 | if count % 10000 == 0: 27 | print(f'Processed {count:,} files, duplicates so far: {dups:,} ({dups/count:.1%})') 28 | -------------------------------------------------------------------------------- /Data/extract_code.py: -------------------------------------------------------------------------------- 1 | """Copies all files belonging to a given language to a new directory.""" 2 | import os 3 | import sys 4 | from shutil import copyfile 5 | 6 | import pygments 7 | from pygments.lexers import get_lexer_by_name 8 | from pygments.token import Token 9 | 10 | # Basic config options. 11 | MAX_FILE_SIZE = 1024 ** 2 # 1 MB 12 | MIN_FILE_TOKENS = 100 13 | 14 | def main(): 15 | if len(sys.argv) <= 3: 16 | raise ValueError('Provide a language, source directory and target directory.') 17 | 18 | language = sys.argv[1] 19 | proj_dir = sys.argv[2] 20 | out_dir = sys.argv[3] 21 | 22 | # Use Pygments to get language extensions. 23 | lexer = get_lexer_by_name(language) 24 | language_extensions = set(ext.lower()[1:] for ext in lexer.filenames) 25 | 26 | print(f'Processing: {proj_dir}') 27 | if not os.path.exists(out_dir): 28 | os.makedirs(out_dir) 29 | 30 | files_found = 0 31 | for root, _, files in os.walk(proj_dir): 32 | for file in files: 33 | if any(file.endswith(ext) for ext in language_extensions): 34 | in_path = os.path.join(root, file) 35 | if not os.path.exists(in_path): # Can happen due to broken symlinks. 36 | continue 37 | if os.path.getsize(in_path) > MAX_FILE_SIZE: # Drop excessively long files. 38 | continue 39 | with open(in_path, errors='ignore') as f_in: 40 | text = f_in.read() 41 | if sum(1 for _ in pygments.lex(text, lexer)) < MIN_FILE_TOKENS: # Drop files with too few tokens. 42 | continue 43 | 44 | # Copy all other files to the target directory using a simplified path. 45 | rel_path = root[len(proj_dir)+1:].replace('/', '__') 46 | out_path = os.path.join(out_dir, rel_path + ('__' if rel_path else '') + file) 47 | if not os.path.exists(out_path): 48 | try: 49 | copyfile(in_path, out_path) 50 | except Exception as e: 51 | print(f'Skipping problematic file {in_path} due to: {e}') 52 | files_found += 1 53 | print(f'Done processing; copied {files_found} files.') 54 | 55 | 56 | if __name__ == '__main__': 57 | main() -------------------------------------------------------------------------------- /Data/gh_crawler.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import sys 3 | import time 4 | 5 | # Insert GitHub API token here, in place of *TOKEN*. 6 | headers = {"Authorization": "token *TOKEN*"} 7 | 8 | # Constants & language argument. 9 | NUM_REPOS = 25_000 10 | MIN_STARS = 50 11 | LAST_ACTIVE = '2020-01-01' 12 | LANGUAGE = "java" if len(sys.argv) <= 1 else sys.argv[1] # Default to Java, if none passed. 13 | 14 | def main(): 15 | repositories = set() # Keep track of a set of repositories seen to avoid duplicate entries across pages. 16 | next_max_stars = 1_000_000_000 # Initialize to a very high value. 17 | with open(f'TopLists/{LANGUAGE}-top-repos.txt', 'w') as f: 18 | while len(repositories) < NUM_REPOS: 19 | results = run_query(next_max_stars) # Get the next set of pages. 20 | if not results: 21 | break 22 | new_repositories = [repository for repository, _ in results] 23 | next_max_stars = min([stars for _, stars in results]) 24 | 25 | # If a query returns no new repositories, drop it. 26 | if len(repositories | set(new_repositories)) == len(repositories): 27 | break 28 | for repository, stars in sorted(results, key=lambda e: e[1], reverse=True): 29 | if repository not in repositories: 30 | repositories.add(repository) 31 | f.write(f'{stars}\t{repository}\n') 32 | f.flush() 33 | print(f'Collected {len(repositories):,} repositories so far; lowest number of stars: {next_max_stars:,}') 34 | 35 | 36 | def run_query(max_stars): 37 | end_cursor = None # Used to track pagination. 38 | repositories = set() 39 | 40 | while end_cursor != "": 41 | # Extracts non-fork, recently active repositories in the provided language, in groups of 100. 42 | # Leaves placeholders for maximum stars and page cursor. The former allows us to retrieve more than 1,000 repositories 43 | # by repeatedly lowering the bar. 44 | query = f""" 45 | {{ 46 | search(query: "language:{LANGUAGE} fork:false pushed:>{LAST_ACTIVE} sort:stars stars:<{max_stars}", type: REPOSITORY, first: 100 {', after: "' + end_cursor + '"' if end_cursor else ''}) {{ 47 | edges {{ 48 | node {{ 49 | ... on Repository {{ 50 | url 51 | isPrivate 52 | isDisabled 53 | isLocked 54 | stargazers {{ 55 | totalCount 56 | }} 57 | }} 58 | }} 59 | }} 60 | pageInfo {{ 61 | hasNextPage 62 | endCursor 63 | }} 64 | }} 65 | }} 66 | """ 67 | print(f' Retrieving next page; {len(repositories)} repositories in this batch so far.') 68 | # Attempt a query up to three times, pausing when a query limit is hit. 69 | attempts = 0 70 | success = False 71 | while not success and attempts < 3: 72 | request = requests.post('https://api.github.com/graphql', json={'query': query}, headers=headers) 73 | content = request.json() 74 | if 'data' not in content or 'search' not in content['data']: 75 | # If this is simply a signal to pause querying, wait two minutes. 76 | if 'message' in content and 'wait' in content['message']: 77 | attempts += 1 78 | time.sleep(120) 79 | # Otherwise, assume we've hit the end of the stream. 80 | else: 81 | break 82 | else: 83 | success = True 84 | if not success: 85 | break 86 | end_cursor = get_end_cursor(content) 87 | new_repositories, is_done = get_repositories(content) 88 | repositories.update(new_repositories) 89 | if len(repositories) > NUM_REPOS or is_done: 90 | break 91 | return repositories 92 | 93 | 94 | def get_end_cursor(content): 95 | page_info = content['data']['search']['pageInfo'] 96 | has_next_page = page_info['hasNextPage'] 97 | if has_next_page: 98 | return page_info['endCursor'] 99 | return "" 100 | 101 | 102 | def get_repositories(content): 103 | edges = content['data']['search']['edges'] 104 | repositories_with_stars = [] 105 | for edge in edges: 106 | if edge['node']['isPrivate'] is False and edge['node']['isDisabled'] is False and edge['node']['isLocked'] is False: 107 | repository = edge['node']['url'] 108 | star_count = edge['node']['stargazers']['totalCount'] 109 | if star_count < MIN_STARS: 110 | return repositories_with_stars, True 111 | repositories_with_stars.append((repository, star_count)) 112 | return repositories_with_stars, False 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /Data/requirements.txt: -------------------------------------------------------------------------------- 1 | # Just need Pygments for lexing. 2 | pygments 3 | -------------------------------------------------------------------------------- /Data/yield_from_code_files.py: -------------------------------------------------------------------------------- 1 | """A drop-in replacement for `yield_from_files` in the gpt-neox tools/preprocess_data.py version which does not rely on lm_dataformat.""" 2 | 3 | import random 4 | def yield_from_files(dir, semaphore): 5 | """ 6 | Iterator over input documents, treated as plaintext. 7 | 8 | :param dir: directory to recursively extract files from. 9 | """ 10 | fnames = [] 11 | for root, _, files in os.walk(dir): 12 | for file in files: 13 | fnames.append(os.path.join(root, file)) 14 | random.shuffle(fnames) 15 | 16 | def read(fname): 17 | with open(fname) as inp: 18 | doc = inp.read() 19 | return doc 20 | 21 | def yielder(fname, semaphore): 22 | f = read(fname) 23 | if f: 24 | semaphore.acquire() 25 | yield f 26 | 27 | for fname in fnames: 28 | yield from yielder(fname, semaphore) -------------------------------------------------------------------------------- /Evaluation/eval_codex_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import time 6 | import math 7 | import openai 8 | import shutil 9 | import pathlib 10 | 11 | languages_to_run = {'C', 'C#', 'C++', 'Go', 'Java', 'JavaScript', 12 | 'PHP', 'Python', 'Ruby', 'Rust', 'Scala', 'TypeScript'} 13 | 14 | # The private OpenAI API key needs to be an environment variable 15 | openai.api_key = os.getenv('OPENAI_API_KEY') 16 | # As instructed here: https://community.openai.com/t/token-logprobs-when-echo-is-true/9626/2 17 | # "Transformer models don’t predict the probability of the first token. If you want to get the probability 18 | # for your first token you can try to use <|endoftext|> as the first token as a workaround." 19 | endoftext_token = '<|endoftext|>' 20 | 21 | def ppl(avg_logprob): 22 | return 2 ** (-avg_logprob / math.log(2)) 23 | 24 | def call_codex(code_str, save_probs): 25 | eos_code_str = endoftext_token + code_str 26 | # engine: 'davinci-codex' is currently the best codex model 27 | # max_tokens=0 means that we don't want the model to generate additional tokens 28 | # logprobs=0 means that we don't want the logprobs of the alternative tokens, only the actual tokens 29 | # echo=True means that we want the model to echo our prompt, in addition to our (not existing) completion 30 | completion = openai.Completion.create(engine="davinci-codex", prompt=eos_code_str, 31 | max_tokens=0, 32 | temperature=0.0, 33 | logprobs=0, 34 | n=1, 35 | echo=True) 36 | 37 | c = completion.choices[0] 38 | # skipping the <|endoftext|> token 39 | sum_logprobs = sum(c.logprobs.token_logprobs[1:]) 40 | num_tokens = len(c.logprobs.token_logprobs[1:]) 41 | if save_probs: 42 | saved_probs = { 43 | 'text': code_str, 44 | 'tokens': c.logprobs.tokens[1:], 45 | 'logprobs': c.logprobs.token_logprobs[1:], 46 | 'sum_logprobs': sum_logprobs 47 | } 48 | else: 49 | saved_probs = None 50 | 51 | return sum_logprobs, num_tokens, saved_probs 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('--dirs', type=str, help='path to a directory that contains a subdirectory for each evaluated language', required=False) 56 | parser.add_argument('--save-probs', type=str, required=False, default=None) 57 | parser.add_argument('--output', type=str, required=False, default=os.devnull) 58 | args = parser.parse_args() 59 | 60 | results = {} 61 | dirs = glob.glob(os.path.join(args.dirs, '*'), recursive=False) 62 | excluded_dirs = args.dirs + '-excluded' 63 | pathlib.Path(excluded_dirs).mkdir(parents=True, exist_ok=True) 64 | for language in dirs: 65 | if language.split('/')[-1] not in languages_to_run: 66 | continue 67 | print('Language:', language) 68 | files = glob.glob(os.path.join(language, '**/*'), recursive=True) 69 | files = [f for f in files if os.path.isfile(f)] 70 | 71 | log_probs_sum = 0 72 | tokens_count = 0 73 | ignored_files = [] 74 | all_per_token_probs = [] 75 | with open(args.output, 'w') as out_file: 76 | for file in files: 77 | try: 78 | with open(file, 'r') as f: 79 | code_str = f.read() 80 | logprobs_sum, logprobs_count, per_token_probs = call_codex(code_str, args.save_probs is not None) 81 | except Exception as e: 82 | print(f'EXCEPTION in file {file}: {e}') 83 | print(e) 84 | ignored_files.append(file) 85 | # OpenAI limits the request rate to 20/min 86 | time.sleep(10) 87 | continue 88 | out_str = f'{logprobs_sum}\t{logprobs_count}\t{file}' 89 | if args.output != os.devnull: 90 | out_file.writelines([f'Evaluating file: {file}', out_str, '\n']) 91 | 92 | log_probs_sum += logprobs_sum 93 | tokens_count += logprobs_count 94 | # OpenAI limits the request rate to 20/min 95 | time.sleep(10) 96 | 97 | print(f'\n\n\nlogprobs sum: {log_probs_sum}') 98 | print(f'total tokens: {tokens_count}') 99 | print(f'Average loss: {-log_probs_sum / tokens_count}') 100 | print(f'Perplexity: {ppl(log_probs_sum / tokens_count)}') 101 | print(f'Ignored files:') 102 | for f in ignored_files: 103 | print(f'\t{f}') 104 | new_location = os.path.join(excluded_dirs, os.path.dirname(f)) 105 | pathlib.Path(new_location).mkdir(parents=True, exist_ok=True) 106 | shutil.move(f, new_location) 107 | results[language] = { 108 | 'log_probs_sum': log_probs_sum, 109 | 'tokens_count': tokens_count, 110 | 'average_loss': -log_probs_sum / tokens_count, 111 | 'perplexity': ppl(log_probs_sum / tokens_count), 112 | } 113 | 114 | print('Language, sum_logprobs, average_loss, perplexity, num_tokens') 115 | for language in results: 116 | print(f'{language.split("/")[-1]}, {results[language]["log_probs_sum"]}, {results[language]["average_loss"]}, {results[language]["perplexity"]}, {results[language]["tokens_count"]}') 117 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vincent Hellendoorn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Models of Source Code 2 | I occasionally train and publicly release large neural language models on programs, including [PolyCoder](https://arxiv.org/pdf/2202.13169.pdf). Here, I describe how to use these. 3 | 4 | ## October 2022 - PolyCoder is available on Huggingface! 5 | Thanks to [@NinedayWang](https://github.com/NinedayWang), PolyCoder is available on the Huggingface Hub! 6 | 7 | The available models are: 8 | * `NinedayWang/PolyCoder-160M` 9 | * `NinedayWang/PolyCoder-0.4B` 10 | * `NinedayWang/PolyCoder-2.7B` 11 | 12 | To use in Huggingface, simply run (requires the newest version of `transformers`: `pip install transformers==4.23.0` ): 13 | ```python 14 | import transformers 15 | from transformers import AutoTokenizer, AutoModelForCausalLM 16 | 17 | from packaging import version 18 | assert version.parse(transformers.__version__) >= version.parse("4.23.0") 19 | 20 | tokenizer = AutoTokenizer.from_pretrained("NinedayWang/PolyCoder-2.7B") 21 | model = AutoModelForCausalLM.from_pretrained("NinedayWang/PolyCoder-2.7B") 22 | ``` 23 | 24 | The model can be used, for example, by: 25 | ```python 26 | prompt = '''def binarySearch(arr, left, right, x): 27 | mid = (left +''' 28 | input_ids = tokenizer.encode(prompt, return_tensors='pt') 29 | result = model.generate(input_ids, max_length=50, num_beams=4, num_return_sequences=4) 30 | for res in result: 31 | print(tokenizer.decode(res)) 32 | ``` 33 | 34 | ## Table of Contents 35 | 1. [Setup](#getting-started) 36 | 2. [Models (incl. PolyCoder)](#models) 37 | 3. [Datasets](#datasets) 38 | 4. [Evaluation](#evaluation) 39 | 5. [How to cite](#citation) 40 | 41 | 42 | ## Getting Started 43 | All current models were trained using the [GPT NeoX toolkit](https://github.com/EleutherAI/gpt-neox). First, download a pretrained checkpoint as described below and then use this either with [a Docker image](#via-docker) or through our fork of this toolkit [from source](#from-source) to [generate code](#code-generation) or [replicate our evaluation](#evaluation). 44 | 45 | ### Retrieving Checkpoints 46 | Checkpoint files for training PolyCoder are hosted on this [public Zenodo repository](https://zenodo.org/record/6363556). See [this section](#models) for details on currently available models. Model checkpoints range up to 6GB, which is also the amount of GPU memory they require to run (running on CPU is neither tested nor recommended). Download and untar a checkpoint file (in this case for a 2.7B parameter model trained for 150K steps) to a directory called `checkpoints/`, using: 47 | 48 | ``` 49 | mkdir checkpoints 50 | cd checkpoints 51 | wget https://zenodo.org/record/6363556/files/2-7B-150K.tar 52 | tar -xvf 2-7B-150K.tar 53 | ``` 54 | 55 | ### From Source 56 | We maintain a public fork of the NeoX repository [here](https://github.com/frankxu2004/gpt-neox), which includes the (minor) changes we made to the codebase to allow for tabs & newlines in the tokenization, and also includes instructions for running the perplexity and HumanEval tasks. Note that this repository uses [a forked version](https://github.com/frankxu2004/lm-evaluation-harness) of the LM Evaluation Harness with the code benchmark from [our work](#citation). 57 | 58 | Building this repository should match the process for GPT-NeoX almost exactly. You may also use the Docker image mentioned next, but mounting a checkout of the latest version of this fork over the `/gpt-neox` directory inside the container. Once set up `generate.py` entrypoint (described [below](#code-generation)) for free-form code generation, or use one of the commands [here](https://github.com/frankxu2004/gpt-neox#a-modified-version-for-polycoder-code-pretraining) to calculate perplexity and HumanEval results as in [the paper](https://arxiv.org/pdf/2202.13169). 59 | 60 | ### Via Docker 61 | A *base* Docker image containing a slightly modified version of the [gpt-neox repository](https://github.com/EleutherAI/gpt-neox) is [available via DockerHub](https://hub.docker.com/repository/docker/vhellendoorn/code-lms-neox): 62 | ``` 63 | docker pull vhellendoorn/code-lms-neox:base 64 | ``` 65 | 66 | This image can be used together with a checkpoint file hosted on this [public Zenodo repository](https://zenodo.org/record/6363556). The base Docker image size is 5.4GB. Once a checkpoint has been retrieved, start the container with the following commands (substituting another GPU device index if needed): 67 | ``` 68 | nvidia-docker run --rm -it -e NVIDIA_VISIBLE_DEVICES=0 --shm-size=1g --ulimit memlock=-1 --mount type=bind,src=$PWD/checkpoints,dst=/gpt-neox/checkpoints vhellendoorn/code-lms-neox:base 69 | ``` 70 | 71 | ### Code Generation 72 | The following command can be used to generate code from a prompt: 73 | ``` 74 | sudo ./deepy.py generate.py configs/text_generation.yml checkpoints/configs/local_setup.yml checkpoints/configs/2-7B.yml 75 | ``` 76 | **Note:** if not using the 2.7B parameter model, replace the final config file with the appropriate model size (e.g., `small` = 160M parameters, `medium` = 405M). 77 | 78 | Once the checkpoint has been loaded, you can feed it an example such as `def return1():\n """Returns 1."""\n ` (note the whitespace tokens) and watch it predict `return 1` (and then probably a bunch of other `returnX` methods, depending on the sample). 79 | 80 | The modifications to gpt-neox mentioned above center around the need to allow tabs and newlines in the prompt input. For the _interactive_ mode, these can be added using their escaped versions (`\t`, `\n`); when using file-based input, the project will read the entire file instead of treating each line as a prompt. By default, the command below will create an interactive prompt and return relatively short outputs (256 tokens) with a sampling temperature of 0.5; this behavior can be changed in `/gpt-neox/checkpoints/configs/text_generation.yml`. 81 | 82 | A lower temperature (e.g., 0.2) will produce more consistent and plausible (to the model) predictions; a higher temperature such as the default may be useful for generating and evaluating many candidates (see [our paper](https://arxiv.org/pdf/2202.13169) for recommendations). For the latter setting, consider switching to the `input-file` mode and providing an entire snippet (without escaping whitespace) in the corresponding file 83 | 84 | ## Multi-lingual Models 85 | Several models have been trained on a [large corpus](#data-characteristics) of code spanning 12 programming languages. This includes a 2.7B parameter model (nick-named **PolyCoder**, trained for 100K and 150K steps), a 405M parameter model (100K & 150K steps) and a 160M parameter model (150K steps). 86 | 87 | ### Available Models 88 | All models are available [at a public Zenodo repository](https://zenodo.org/record/6363556), in the form of `.tar` files with fairly self-explanatory names (e.g., 2-7B-100K => a 2.7B parameter model trained for 100K steps). Currently available models include: 89 | 90 | * **[GPT2 - 2.7B](https://zenodo.org/record/6363556/files/2-7B-150K.tar):** A 32 layer, 2,560 dimensional Transformer model, trained with a batch size of 128 sequences (256K tokens). Models available both at 100K and at 150K steps steps. 91 | * Note that GPT-Neox' [default config](https://github.com/EleutherAI/gpt-neox/blob/main/configs/2-7B.yml) for this model was modified to reduce the number of training steps (and learning rate decay steps accordingly) to 160K, down from 320K, to better match the available training resources. Hence, this model may not have reached its peak performance. 92 | * **[GPT2 - 0.4B](https://zenodo.org/record/6363556/files/0-4B-150K.tar):** A 24 layer, 1,024 dimensional Transformer model based on the [`medium` config](https://github.com/EleutherAI/gpt-neox/blob/main/configs/medium.yml), trained with 256K tokens per batch. 93 | * **[GPT2 - 160M](https://zenodo.org/record/6363556/files/160M-150K.tar):** A 12 layer, 768 dimensional Transformer model based on the [`small` config](https://github.com/EleutherAI/gpt-neox/blob/main/configs/small.yml), trained with 256K tokens per batch. 94 | 95 | ### Training Process 96 | Training was done on 4 to 8 NVIDIA RTX 8000 GPUs, largely following the standard config values, except also enabling "scaled-upper-triang-masked-softmax-fusion" and "bias-gelu-fusion" for performance and slightly changing the batch size (see [model details](#available-models)), data split (changed to 98.9%, 0.1%, 1%), initial loss scale (2^16), and print/eval intervals. 97 | 98 | The below image shows the loss curve of the various models' training process in terms of validation loss. 99 | ![image](https://user-images.githubusercontent.com/1426353/153651075-a0ceb8ef-6207-4853-b801-40dd6172d5a6.png) 100 | 101 | ### Caveats 102 | The trained models come with a few minor known limitations: 103 | - This model was not trained to solve programming problems and may not perform well on a benchmark such as [HumanEval](https://github.com/openai/human-eval). Models like Codex (powering Copilot) are pretrained on natural language, which may boost their ability to interpret NL prompts; this model only learned language from comments in code. 104 | - The model appears to start generating a random new file once it reaches the (predicted) end of the current one. It is possible that the end-of-document token was not properly added to the training data. 105 | - Whitespace is **very important** to the model, since no preprocessing was done on the input files. For instance, the following snippet will yield poor predictions, because in Java we would never expect an instance-method at the top-level, as is indicated by the single level of (`\t`) indentation of the two lines within this method: 106 | ``` 107 | public int getTotalWeight(List weights) {\n\t// Sum weights in parallel.\n\treturn 108 | ``` 109 | Adjusting the indentation makes it predict more reasonable continuations: 110 | ``` 111 | public int getTotalWeight(List weights) {\n\t\t// Sum weights in parallel.\n\t\treturn 112 | ``` 113 | The Codex model discusses controlling for this to increase usability; this may be worth doing in a future version of the model. 114 | 115 | 116 | ## Datasets 117 | 118 | ### 249GB Multi-Lingual Corpus 119 | This is the corpus used to train PolyCoder. 120 | 121 | The datasets were cloned overnight on October 9-10, 2021. To mine a similar training set, see [Data](https://github.com/VHellendoorn/Code-LMs/tree/main/Data). 122 | 123 | The list of file paths can be downloaded from: [https://zenodo.org/record/6363556/files/index.zip](https://zenodo.org/record/6363556/files/index.zip). 124 | Each row in the file is the file path along with its SHA-256 hash, to ease deduplication. That is, the hashes allow checking if files from any future test set were already contained in the training set. 125 | 126 | The data collection and filtering process is described in detail in [the paper](https://arxiv.org/pdf/2202.13169.pdf) and below. The final, filtered dataset statistics are: 127 | 128 | |Language|Repositories|Size(GB)|Files| 129 | |------|-----|-----|-------| 130 | |C | 10,749 | 55G | 3,037,112 | 131 | |C# | 9,511 | 21G | 2,514,494 | 132 | |C++ | 13,726 | 52G | 4,289,506 | 133 | |Go | 12,371 | 15G | 1,416,789 | 134 | |Java | 15,044 | 41G | 5,120,129 | 135 | |JavaScript | 25,144 | 22G | 1,774,174 | 136 | |PHP | 9,960 | 13G | 1,714,058 | 137 | |Python | 25,446 | 16G | 1,550,208 | 138 | |Ruby | 5,826 | 4.1G | 674,343 | 139 | |Rust | 4,991 | 3.5G | 304,842 | 140 | |Scala | 1,497 | 1.8G | 245,100 | 141 | |TypeScript | 12,830 | 9.2G | 1,441,926 | 142 | 143 | ### Data Collection & Filtering 144 | I cloned the most popular repositories for 12 popular programming languages with at least 50 stars (stopping at ~25K per language) from GitHub in October 2021. For each project, each file belonging to the majority-language of that project was extracted, yielding the training set below (after cleaning). This initial, unfiltered dataset spanned 631GB and 38.9M files. 145 | 146 | Next, similar to Codex and CodeParrot, very large (>1MB) and very short (<100 tokens) files were filtered out, reducing the dataset to 424GB. Files were then deduplicated based on a hash of their content, which reduced the number of files by another 30% or so, leaving 249GB of data and 24.1M files. No tokenization filters were applied; the model processes entire files including all comments. A code-specific vocabulary was constructed on a random 5% subset of the files above. 147 | 148 | ## Evaluation 149 | Please find detailed instructions for replicating our perplexity and HumanEval results on [our public fork](https://github.com/frankxu2004/gpt-neox#a-modified-version-for-polycoder-code-pretraining) of the NeoX repository. This in turn leverages [our extension](https://github.com/frankxu2004/lm-evaluation-harness) of the LM Evaluation Harness. 150 | 151 | ### Evaluating Codex 152 | To download the test sets that we used in the paper (12 programming languages), use: 153 | ``` 154 | wget https://zenodo.org/record/6363556/files/unseen_test_sets.tar.gz 155 | tar -xvzf unseen_test_sets.tar.gz 156 | ``` 157 | 158 | To get perplexity results on these samples using Codex' API, use: 159 | ``` 160 | export OPENAI_API_KEY= 161 | python3 -u Evaluation/eval_codex_all.py --dirs Code-sampled100 162 | ``` 163 | Where `` is a private string that can be obtained by signing up for [OpenAI's beta](https://beta.openai.com/account/api-keys). 164 | 165 | As of **March 2022**, getting an API Key is free for 3 months, and afterwards a credit card needs to be entered. However, even after entering a credit card, using our evaluation script does not lead to any costs. 166 | 167 | ### Results - HumanEval 168 | These are PolyCoder's results on the [HumanEval benchmark](https://github.com/openai/human-eval): 169 | 170 | |Model|Pass@1|Pass@10|Pass@100| 171 | |------|-----|-----|-------| 172 | |PolyCoder (160M) | 2.13% | 3.35% | 4.88% | 173 | |PolyCoder (400M) | 2.96% | 5.29% | 11.59% | 174 | |PolyCoder (2.7B) | 5.59% | 9.87% | 17.68% | 175 | | CodeParrot (110M) | 3.80% | 6.57% | 12.78% | 176 | | CodeParrot (1.5B) | 3.58% | 8.03% | 14.96% | 177 | | GPT-Neo (125M) | 0.75% | 1.88% | 2.97% | 178 | | GPT-Neo (1.3B) | 4.79% | 7.47% | 16.30% | 179 | | GPT-Neo (2.7B) | 6.41% | 11.27% | 21.37% | 180 | | GPT-J (6B) | 11.62% | 15.74% | 27.74% | 181 | | Codex (300M) | 13.17% | 20.37% | 36.27% | 182 | | Codex (2.5B) | 21.36% | 35.42% | 59.50% | 183 | | Codex (12B) | 28.81% | 46.81% | 72.31% | 184 | 185 | 186 | ### Results - Multilingual Language Modeling 187 | These are the perplexity results of PolyCoder on the [multilingual test sets](https://zenodo.org/record/6363556/files/unseen_test_sets.tar.gz): 188 | 189 | |Language| Perplexity | 190 | |------|-----| 191 | |C | 2.3464 | 192 | |C# | 2.5832 | 193 | |C++ | 2.9189 | 194 | |Go | 2.567 | 195 | |Java | 2.9194 | 196 | |JavaScript | 3.0611 | 197 | |PHP | 3.6954 | 198 | |Python | 3.1767 | 199 | |Ruby | 3.9742 | 200 | |Rust | 3.2449 | 201 | |Scala | 3.8735 | 202 | |TypeScript | 3.6143 | 203 | 204 | A comparison with the other models is available in Figure 6 in the paper: 205 | ![image](images/fig6.png) 206 | 207 | ## Citation 208 | 209 | [A Systematic Evaluation of Large Language Models of Code](https://arxiv.org/pdf/2202.13169) 210 | 211 | ``` 212 | @article{xu2022systematic, 213 | title={A Systematic Evaluation of Large Language Models of Code}, 214 | author={Xu, Frank F and Alon, Uri and Neubig, Graham and Hellendoorn, Vincent J}, 215 | journal={arXiv preprint arXiv:2202.13169}, 216 | year={2022} 217 | } 218 | ``` 219 | -------------------------------------------------------------------------------- /images/fig6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VHellendoorn/Code-LMs/570feba44e72a165cbd82d089fd8a9a14414d1d4/images/fig6.png --------------------------------------------------------------------------------