├── models ├── generator │ └── .placeholder └── retriever │ └── .placeholder ├── media └── overview.png ├── generator └── fid │ ├── requirements.txt │ ├── check_reload_model.py │ ├── setup.py │ ├── utils │ ├── save_codet5.py │ └── convert_data.py │ ├── fid_to_reload.py │ ├── src │ ├── index.py │ ├── preprocess.py │ ├── evaluation.py │ ├── slurm.py │ ├── options.py │ ├── data.py │ └── util.py │ ├── README.md │ ├── test_reader_simple.py │ └── train_reader.py ├── .gitignore ├── CITATION.cff ├── utils ├── constants.py └── util.py ├── prompts ├── tldr_baseline.txt ├── conala_baseline.txt ├── tldr_docprompting_oracle_docs.txt ├── tldr_docprompting_retrieved_docs.txt ├── conala_docprompting_oracle_docs.txt └── conala_docpropmting_retrieved_docs.txt ├── scripts └── tldr_gpt_neo.py ├── dataset_helper └── conala │ ├── execution_eval.py │ └── gen_metric.py ├── retriever ├── bm25 │ ├── indexer.py │ └── main.py ├── simcse │ ├── data_utils.py │ ├── run_inference.py │ ├── run_train.py │ └── arguments.py └── eval.py ├── requirements.txt ├── LICENSE └── README.md /models/generator/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/retriever/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuyanzhou/docprompting/HEAD/media/overview.png -------------------------------------------------------------------------------- /generator/fid/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | faiss-cpu 4 | transformers==3.0.2 5 | tensorboard 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/generator/* 2 | models/retriever/* 3 | !models/generator/.placeholder 4 | !models/retriever/.placeholder 5 | **/.DS_Store 6 | data/.DS_Store 7 | data/* 8 | .idea -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | @article{zhou2022doccoder, 2 | title={DocCoder: Generating Code by Retrieving and Reading Docs}, 3 | author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and JIang, Zhengbao and Neubig, Graham}, 4 | journal={arXiv preprint arXiv:2207.05987}, 5 | year={2022} 6 | } 7 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | TQDM_DISABLE = True if 'TQDM_DISABLE' in os.environ and str(os.environ['TQDM_DISABLE']) == '1' else False 4 | WANDB_DISABLE = True if 'WANDB_DISABLE' in os.environ and str(os.environ['WANDB_DISABLE']) == '1' else False 5 | 6 | VAR_STR = "[[VAR]]" 7 | NONE_STR = "[[NONE]]" 8 | -------------------------------------------------------------------------------- /prompts/tldr_baseline.txt: -------------------------------------------------------------------------------- 1 | Potential document 0: manual: manual 2 | # get the label of a fat32 partition 3 | fatlabel {{/dev/sda1}} 4 | 5 | #END 6 | 7 | Potential document 0: manual: manual 8 | # display information without including the login, jcpu and pcpu columns 9 | w --short 10 | 11 | #END 12 | 13 | Potential document 0: manual: manual 14 | # sort a csv file by column 9 15 | csvsort -c {{9}} {{data.csv}} 16 | 17 | #END 18 | 19 | -------------------------------------------------------------------------------- /prompts/conala_baseline.txt: -------------------------------------------------------------------------------- 1 | # convert string '2011221' into a DateTime object using format '%Y%W%w' 2 | datetime.strptime('2011221', '%Y%W%w') 3 | 4 | #END 5 | 6 | # Sort a list of strings 'words' such that items starting with 's' come first. 7 | sorted(words, key=lambda x: 'a' + x if x.startswith('s') else 'b' + x) 8 | 9 | #END 10 | 11 | # replace all the nan values with 0 in a pandas dataframe `df` 12 | df.fillna(0) 13 | 14 | #END 15 | 16 | -------------------------------------------------------------------------------- /generator/fid/check_reload_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from fid_to_reload import change_name 3 | import os 4 | # if it is a reload checkpoint, check its existence, if not exist, convert from its original saved results 5 | def main(): 6 | args = sys.argv[1] 7 | print(args) 8 | model_name = None 9 | for idx, tok in enumerate(args.split()): 10 | if tok == '--model_name': 11 | model_name = args.split()[idx+1] 12 | break 13 | assert model_name 14 | 15 | if 'checkpoint' in model_name: 16 | assert '.reload' in model_name 17 | path = model_name.replace(".reload", "") 18 | if not os.path.exists(f"{model_name}/pytorch_model.bin"): 19 | change_name(path, model_name) 20 | else: 21 | print("good to go") 22 | 23 | if __name__ == "__main__": 24 | main() -------------------------------------------------------------------------------- /generator/fid/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import setuptools 9 | 10 | with open("README.md", "r") as fh: 11 | long_description = fh.read() 12 | 13 | with open("requirements.txt", "r") as f: 14 | install_requires = list(f.read().splitlines()) 15 | 16 | 17 | setuptools.setup( 18 | name="FiD", 19 | version="0.1.0", 20 | description="Fusion-in-Decoder", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | packages=setuptools.find_packages(), 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: MIT License", 27 | "Operating System :: OS Independent", 28 | ], 29 | python_requires=">=3.7", 30 | install_requires=install_requires 31 | ) 32 | -------------------------------------------------------------------------------- /generator/fid/utils/save_codet5.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import transformers 4 | assert transformers.__version__ == '4.11.3', (transformers.__version__) 5 | 6 | tokenizer = transformers.RobertaTokenizer.from_pretrained('Salesforce/codet5-base') 7 | tokenizer.save_pretrained('data/fid/codet5-base') 8 | 9 | with open('data/fid/codet5-base/tokenizer_config.json', "w+") as f: 10 | json.dump({"model_max_length": 512}, f) 11 | 12 | with open('data/fid/codet5-base/special_tokens_map.json', 'r') as f: 13 | d = json.load(f) 14 | d['additional_special_tokens'] = [x['content'] for x in d['additional_special_tokens']] 15 | # add_tokens = d.pop('additional_special_tokens') 16 | # for item in add_tokens: 17 | # d[item['content']] = item 18 | 19 | shutil.move('data/fid/codet5-base/special_tokens_map.json', 'data/fid/codet5-base/special_tokens_map.json.bck') 20 | with open('data/fid/codet5-base/special_tokens_map.json', 'w+') as f: 21 | json.dump(d, f, indent=2) 22 | 23 | 24 | print('save tokenizer') 25 | 26 | t5 = transformers.T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base') 27 | t5.save_pretrained('data/fid/codet5-base') 28 | 29 | print('save model') -------------------------------------------------------------------------------- /scripts/tldr_gpt_neo.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import json 3 | 4 | device = 'cuda:0' 5 | base_model_name = 'neulab/docprompting-tldr-gpt-neo-1.3B' 6 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 7 | model = AutoModelForCausalLM.from_pretrained(base_model_name) 8 | model = model.to(device) 9 | 10 | with open('data/tldr/fid.cmd_dev.codet5.t10.json', 'r') as f: 11 | examples = json.load(f) 12 | 13 | for example in examples: 14 | manual_list = [doc['text'] for doc in example['ctxs']] 15 | manual_list = "\n".join(manual_list).strip() 16 | nl = example['question'] 17 | prompt = f'{tokenizer.bos_token} {manual_list} ' 18 | prompt += f'{tokenizer.sep_token} {nl} {tokenizer.sep_token}' 19 | print(prompt) 20 | 21 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 22 | input_ids = input_ids.to(device) 23 | gen_tokens = model.generate( 24 | input_ids, 25 | num_beams=5, 26 | max_new_tokens=150, 27 | num_return_sequences=2, 28 | pad_token_id=tokenizer.eos_token_id 29 | ) 30 | gen_tokens = gen_tokens.reshape(1, -1, gen_tokens.shape[-1])[0][0] 31 | gen_text = tokenizer.decode(gen_tokens) 32 | # parse 33 | gen_text = gen_text.split(tokenizer.sep_token)[2].strip().split(tokenizer.eos_token)[0].strip() 34 | print(gen_text) 35 | -------------------------------------------------------------------------------- /prompts/tldr_docprompting_oracle_docs.txt: -------------------------------------------------------------------------------- 1 | Potential document 0: fatlabel_3: fatlabel will display or change the volume label or volume ID on the MS- DOS filesystem located on DEVICE. By default it works in label mode. It can be switched to volume ID mode with the option -i or --volume-id. 2 | # get the label of a fat32 partition 3 | fatlabel {{/dev/sda1}} 4 | 5 | #END 6 | 7 | Potential document 0: w_3: w displays information about the users currently on the machine, and their processes. The header shows, in this order, the current time, how long the system has been running, how many users are currently logged on, and the system load averages for the past 1, 5, and 15 minutes. 8 | Potential document 1: w_9: -s, --short Use the short format. Don't print the login time, JCPU or PCPU times. 9 | # display information without including the login, jcpu and pcpu columns 10 | w --short 11 | 12 | #END 13 | 14 | Potential document 0: csvsort_2: Sort CSV files. Like the Unix “sort” command, but for tabular data: 15 | Potential document 1: csvsort_3: usage: csvsort [-h] [-d DELIMITER] [-t] [-q QUOTECHAR] [-u {0,1,2,3}] [-b] [-p ESCAPECHAR] [-z FIELD_SIZE_LIMIT] [-e ENCODING] [-L LOCALE] [-S] [--blanks] [--date-format DATE_FORMAT] [--datetime-format DATETIME_FORMAT] [-H] [-K SKIP_LINES] [-v] [-l] [--zero] [-V] [-n] [-c COLUMNS] [-r] [-y SNIFF_LIMIT] [-I 16 | Potential document 2: csvsort_6: optional arguments: -h, --help show this help message and exit -n, --names Display column names and indices from the input CSV and exit. -c COLUMNS, --columns COLUMNS A comma separated list of column indices, names or ranges to sort by, e.g. "1,id,3-5". Defaults to all columns. -r, --reverse Sort in descending order. -y SNIFF_LIMIT, --snifflimit SNIFF_LIMIT Limit CSV dialect sniffing to the specified number of bytes. Specify " 17 | Potential document 3: csvsort_10: csvsort -c 9 examples/realdata/FY09_EDU_Recipients_by_State.csv 18 | Potential document 4: csvsort_12: csvcut -c 1,9 examples/realdata/FY09_EDU_Recipients_by_State.csv | csvsort -r -c 2 | head -n 5 19 | # sort a csv file by column 9 20 | csvsort -c {{9}} {{data.csv}} 21 | 22 | #END 23 | 24 | -------------------------------------------------------------------------------- /prompts/tldr_docprompting_retrieved_docs.txt: -------------------------------------------------------------------------------- 1 | Potential document 0: fatlabel_3: fatlabel will display or change the volume label or volume ID on the MS- DOS filesystem located on DEVICE. By default it works in label mode. It can be switched to volume ID mode with the option -i or --volume-id. 2 | # get the label of a fat32 partition 3 | fatlabel {{/dev/sda1}} 4 | 5 | #END 6 | 7 | Potential document 0: w_3: w displays information about the users currently on the machine, and their processes. The header shows, in this order, the current time, how long the system has been running, how many users are currently logged on, and the system load averages for the past 1, 5, and 15 minutes. 8 | Potential document 1: w_9: -s, --short Use the short format. Don't print the login time, JCPU or PCPU times. 9 | # display information without including the login, jcpu and pcpu columns 10 | w --short 11 | 12 | #END 13 | 14 | Potential document 0: csvsort_2: Sort CSV files. Like the Unix “sort” command, but for tabular data: 15 | Potential document 1: csvsort_3: usage: csvsort [-h] [-d DELIMITER] [-t] [-q QUOTECHAR] [-u {0,1,2,3}] [-b] [-p ESCAPECHAR] [-z FIELD_SIZE_LIMIT] [-e ENCODING] [-L LOCALE] [-S] [--blanks] [--date-format DATE_FORMAT] [--datetime-format DATETIME_FORMAT] [-H] [-K SKIP_LINES] [-v] [-l] [--zero] [-V] [-n] [-c COLUMNS] [-r] [-y SNIFF_LIMIT] [-I 16 | Potential document 2: csvsort_6: optional arguments: -h, --help show this help message and exit -n, --names Display column names and indices from the input CSV and exit. -c COLUMNS, --columns COLUMNS A comma separated list of column indices, names or ranges to sort by, e.g. "1,id,3-5". Defaults to all columns. -r, --reverse Sort in descending order. -y SNIFF_LIMIT, --snifflimit SNIFF_LIMIT Limit CSV dialect sniffing to the specified number of bytes. Specify " 17 | Potential document 3: csvsort_10: csvsort -c 9 examples/realdata/FY09_EDU_Recipients_by_State.csv 18 | Potential document 4: csvsort_12: csvcut -c 1,9 examples/realdata/FY09_EDU_Recipients_by_State.csv | csvsort -r -c 2 | head -n 5 19 | # sort a csv file by column 9 20 | csvsort -c {{9}} {{data.csv}} 21 | 22 | #END 23 | 24 | -------------------------------------------------------------------------------- /dataset_helper/conala/execution_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | execution-based evaluation 3 | """ 4 | import argparse 5 | import json 6 | import sys 7 | import evaluate 8 | from datasets import load_metric 9 | import numpy as np 10 | from collections import defaultdict 11 | import os 12 | 13 | from py import test 14 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 15 | 16 | def pass_at_k(result_file, unittest_file): 17 | with open(unittest_file, 'r') as f: 18 | unittests = json.load(f) 19 | 20 | # select the examples which have unit test 21 | selected_predictions = [] 22 | with open(result_file, 'r') as f: 23 | for line in f: 24 | pred = json.loads(line) 25 | if pred["question_id"] in unittests: 26 | selected_predictions.append(pred) 27 | print(f"selected {len(selected_predictions)} examples with unit test") 28 | 29 | # run the test 30 | # load the metric from huggingface 31 | code_eval_metric = evaluate.load("code_eval") 32 | preds = [] 33 | tests = [] 34 | for prediction in selected_predictions: 35 | suffix = "" 36 | question_id = prediction["question_id"] 37 | unittest = unittests[question_id] 38 | entry_point = unittest["entry_point"] 39 | test_func = f"\n{unittest['test']}\ncheck({entry_point})" 40 | 41 | # wrap the generated code to a runnable function 42 | if isinstance(prediction['clean_code'], list): 43 | runnable_func = [f"{unittest['prompt']}{x}{suffix}" for x in prediction['clean_code']] 44 | else: 45 | runnable_func = [f"{unittest['prompt']}{prediction['clean_code']}{suffix}"] 46 | 47 | preds.append(runnable_func) 48 | tests.append(test_func) 49 | 50 | r = code_eval_metric.compute( 51 | predictions=preds, 52 | references=tests, 53 | k=[1, 5, 10, 50, 100, 150, 200], 54 | num_workers=8, 55 | ) 56 | print(r[0]) 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--result_file", type=str, default="") 61 | args = parser.parse_args() 62 | result_file = args.result_file 63 | unittest_file = "data/conala/unittest_docprompting_conala.json" 64 | assert result_file 65 | pass_at_k(result_file, unittest_file) 66 | -------------------------------------------------------------------------------- /generator/fid/fid_to_reload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert SimCSE's checkpoints to Huggingface style. 3 | """ 4 | 5 | import argparse 6 | import shutil 7 | 8 | import torch 9 | import os 10 | import json 11 | 12 | def change_name(path, new_path=None): 13 | if new_path is None: 14 | new_path = path 15 | state_dict = torch.load(os.path.join(path, "pytorch_model.bin"), map_location=torch.device("cpu")) 16 | new_state_dict = {} 17 | keep = [] 18 | change = [] 19 | for key, param in state_dict.items(): 20 | if key.startswith("encoder.encoder"): 21 | key = key.replace("encoder.encoder", "encoder") 22 | key = key.replace("module.layer", "layer") 23 | change.append(key) 24 | else: 25 | keep.append(key) 26 | new_state_dict[key] = param 27 | 28 | if not os.path.exists(new_path): 29 | os.makedirs(new_path) 30 | 31 | torch.save(new_state_dict, os.path.join(new_path, "pytorch_model.bin")) 32 | print(f"kept keys: {keep}") 33 | print(f"changed keys: {change}") 34 | for file in os.listdir(path): 35 | if file != 'pytorch_model.bin': 36 | shutil.copyfile(os.path.join(path, file), os.path.join(new_path, file)) 37 | 38 | for name in ['config.json', 'special_tokens_map.json', 'vocab.json', 'merges.txt', 'tokenizer_config.json']: 39 | shutil.copyfile(os.path.join('/home/shuyanzh/workshop/op_agent/data/fid/codet5-base', name), os.path.join(new_path, name)) 40 | print("Copy tokenization files from codet5-base folder to the target folder") 41 | 42 | # Change architectures in config. json 43 | # config = json.load(open(os.path.join(path, "config.json"))) 44 | # for i in range(len(config["architectures"])): 45 | # config["architectures"][i] = config["architectures"][i].replace("ForCL", "Model") 46 | # json.dump(config, open(os.path.join(new_path, "config.json"), "w"), indent=2) 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--path", type=str, help="Path of SimCSE checkpoint folder") 51 | parser.add_argument('--new_path', type=str, help='New place to save checkpoints', default=None) 52 | args = parser.parse_args() 53 | args.path = '/home/shuyanzh/workshop/op_agent/data/fid/code_t5_nothing/checkpoint/best_dev' 54 | args.new_path = '/home/shuyanzh/workshop/op_agent/data/fid/code_t5_nothing/checkpoint/best_dev.reload' 55 | print("FiD checkpoint -> Fid Reload checkpoint for {}".format(args.path)) 56 | change_name(args.path, args.new_path) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /retriever/bm25/indexer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from elasticsearch import Elasticsearch 3 | from elasticsearch.helpers import bulk 4 | from tqdm import tqdm 5 | import logging 6 | 7 | logging.getLogger('elasticsearch').setLevel(logging.ERROR) 8 | 9 | class ESSearch: 10 | def __init__(self, index: str, source: str, 11 | host_address: str='localhost', 12 | re_index=False, 13 | manual_path=None, 14 | func_descpt_path=None): 15 | self.es = Elasticsearch(timeout=60, host=host_address) 16 | self.source = source 17 | self.index = f"{index}.{source}" 18 | self.manual_path = manual_path 19 | self.func_descpt_path = func_descpt_path 20 | 21 | if re_index: 22 | self.es.indices.delete(index=self.index, ignore=[400, 404]) 23 | print(f"delete {self.index}") 24 | self.es.indices.create(index=self.index) 25 | # print(self.es.indices.get_alias().keys()) 26 | self.create_index() 27 | 28 | print(f"done init the index {self.index}") 29 | self.es.indices.refresh(self.index) 30 | print(self.es.cat.count(self.index, params={"format": "json"})) 31 | 32 | def gendata(self): 33 | descpt_d = None 34 | if self.func_descpt_path: 35 | with open(self.func_descpt_path, "r") as f: 36 | descpt_d = json.load(f) 37 | 38 | with open(self.manual_path, "r") as f: 39 | man_d = json.load(f) 40 | 41 | for lib_key, lib_man in tqdm(man_d.items()): 42 | cmd_name = '_'.join(lib_key.split("_")[:-1]) if lib_key[-1].isdigit() else lib_key 43 | descpt = descpt_d[lib_key] if descpt_d is not None else "" 44 | result = { 45 | '_index': self.index, 46 | '_type': "_doc", 47 | 'manual': lib_man, 48 | 'func_description': descpt, 49 | 'library_key': lib_key, 50 | 'cmd_name': cmd_name 51 | } 52 | yield result 53 | 54 | def create_index(self): 55 | all_docs = list(self.gendata()) 56 | print(bulk(self.es, all_docs, index=self.index)) 57 | 58 | 59 | def get_topk(self, search_field, query, topk): 60 | real_query = {'query': {'match': {search_field: query}}, 61 | 'size': topk + 10} 62 | r_mans = self.es.search(index=self.index, body=real_query)['hits']['hits'][:topk] 63 | _r_mans = [] 64 | for r in r_mans: 65 | i = {'library_key': r['_source']['library_key'], 'score': r['_score']} 66 | _r_mans.append(i) 67 | r_mans = _r_mans 68 | return r_mans 69 | 70 | 71 | -------------------------------------------------------------------------------- /prompts/conala_docprompting_oracle_docs.txt: -------------------------------------------------------------------------------- 1 | Potential document 0: python datetime datetime strptime: classmethod datetime.strptime(date_string, format) Return a datetime corresponding to date_string, parsed according to format. This is equivalent to: datetime(*(time.strptime(date_string, format)[0:6])) ValueError is raised if the date_string and format can’t be parsed by time.strptime() or if it returns a value which isn’t a time tuple. For a complete list of formatting directives, see strftime() and strptime() Behavior. 2 | # convert string '2011221' into a DateTime object using format '%Y%W%w' 3 | datetime.strptime('2011221', '%Y%W%w') 4 | 5 | #END 6 | 7 | Potential document 0: python sorted: sorted(iterable, *, key=None, reverse=False) Return a new sorted list from the items in iterable. Has two optional arguments which must be specified as keyword arguments. key specifies a function of one argument that is used to extract a comparison key from each element in iterable (for example, key=str.lower). The default value is None (compare the elements directly). reverse is a boolean value. If set to True, then the list elements are sorted as if each comparison were reversed. Use functools.cmp_to_key() to convert an old-style cmp function to a key function. The built-in sorted() function is guaranteed to be stable. A sort 8 | Potential document 1: python str startswith: str.startswith(prefix[, start[, end]]) Return True if string starts with the prefix, otherwise return False. prefix can also be a tuple of prefixes to look for. With optional start, test string beginning at that position. With optional end, stop comparing string at that position. 9 | # Sort a list of strings 'words' such that items starting with 's' come first. 10 | sorted(words, key=lambda x: 'a' + x if x.startswith('s') else 'b' + x) 11 | 12 | #END 13 | 14 | Potential document 0: pandas dataframe fillna: pandas.DataFrame.fillna DataFrame.fillna(value=None, method=None, axis=None, inplace=False, limit=None, downcast=None)[source] Fill NA/NaN values using the specified method. Parameters value:scalar, dict, Series, or DataFrame Value to use to fill holes (e.g. 0), alternately a dict/Series/DataFrame of values specifying which value to use for each index (for a Series) or column (for a DataFrame). Values not in the dict/Series/DataFrame will not be filled. This value cannot be a list. 15 | Potential document 1: pandas dataframe loc: pandas.DataFrame.loc propertyDataFrame.loc Access a group of rows and columns by label(s) or a boolean array..loc[] is primarily label based, but may also be used with a boolean array. Allowed inputs are: A single label, e.g. 5 or 'a', (note that 5 is interpreted as a label of the index, and never as an integer position along the index). A list or array of labels, e.g. ['a', 'b', 'c']. A slice object with labels, e.g. 'a':'f'. Warning Note that contrary to usual python slices, both the start and the stop 16 | # replace all the nan values with 0 in a pandas dataframe `df` 17 | df.fillna(0) 18 | 19 | #END 20 | 21 | -------------------------------------------------------------------------------- /generator/fid/src/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import logging 9 | import pickle 10 | from typing import List, Tuple 11 | 12 | import faiss 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | logger = logging.getLogger() 17 | 18 | class Indexer(object): 19 | 20 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8): 21 | if n_subquantizers > 0: 22 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) 23 | else: 24 | self.index = faiss.IndexFlatIP(vector_sz) 25 | self.index_id_to_db_id = np.empty((0), dtype=np.int64) 26 | 27 | def index_data(self, ids, embeddings): 28 | self._update_id_mapping(ids) 29 | embeddings = embeddings.astype('float32') 30 | if not self.index.is_trained: 31 | self.index.train(embeddings) 32 | self.index.add(embeddings) 33 | 34 | logger.info(f'Total data indexed {len(self.index_id_to_db_id)}') 35 | 36 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size=1024) -> List[Tuple[List[object], List[float]]]: 37 | query_vectors = query_vectors.astype('float32') 38 | result = [] 39 | nbatch = (len(query_vectors)-1) // index_batch_size + 1 40 | for k in tqdm(range(nbatch)): 41 | start_idx = k*index_batch_size 42 | end_idx = min((k+1)*index_batch_size, len(query_vectors)) 43 | q = query_vectors[start_idx: end_idx] 44 | scores, indexes = self.index.search(q, top_docs) 45 | # convert to external ids 46 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] 47 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 48 | return result 49 | 50 | def serialize(self, dir_path): 51 | index_file = dir_path / 'index.faiss' 52 | meta_file = dir_path / 'index_meta.dpr' 53 | logger.info(f'Serializing index to {index_file}, meta data to {meta_file}') 54 | 55 | faiss.write_index(self.index, index_file) 56 | with open(meta_file, mode='wb') as f: 57 | pickle.dump(self.index_id_to_db_id, f) 58 | 59 | def deserialize_from(self, dir_path): 60 | index_file = dir_path / 'index.faiss' 61 | meta_file = dir_path / 'index_meta.dpr' 62 | logger.info(f'Loading index from {index_file}, meta data from {meta_file}') 63 | 64 | self.index = faiss.read_index(index_file) 65 | logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 66 | 67 | with open(meta_file, "rb") as reader: 68 | self.index_id_to_db_id = pickle.load(reader) 69 | assert len( 70 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 71 | 72 | def _update_id_mapping(self, db_ids: List): 73 | new_ids = np.array(db_ids, dtype=np.int64) 74 | self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0) -------------------------------------------------------------------------------- /retriever/simcse/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict 3 | from dataclasses import dataclass, field 4 | 5 | @dataclass 6 | class OurDataCollatorWithPadding: 7 | def __init__(self, pad_token_id, idf_dict): 8 | self.pad_token_id = pad_token_id 9 | self.idf_dict = idf_dict 10 | 11 | def padding_func(self, arr, pad_token, dtype=torch.long): 12 | lens = torch.LongTensor([len(a) for a in arr]) 13 | max_len = lens.max().item() 14 | padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token 15 | mask = torch.zeros(len(arr), max_len, dtype=torch.long) 16 | for i, a in enumerate(arr): 17 | padded[i, : lens[i]] = torch.tensor(a, dtype=dtype) 18 | mask[i, : lens[i]] = 1 19 | return padded, lens, mask 20 | 21 | def negative_sample_mask(self, target_sent, dtype=torch.bool): 22 | mask = torch.ones((len(target_sent), len(target_sent)), dtype=dtype) 23 | for i in range(len(target_sent)): 24 | s1 = target_sent[i] 25 | for j in range(len(target_sent)): 26 | s2 = target_sent[j] 27 | if i != j and s1 == s2: 28 | mask[i, j] = 0 29 | return mask 30 | 31 | def __call__(self, batch) -> Dict[str, torch.Tensor]: 32 | bs = len(batch) 33 | assert bs 34 | num_sent = len(batch[0]['input_ids']) 35 | pad_token = self.pad_token_id 36 | 37 | flat_input_ids = [] 38 | for sample in batch: 39 | for i in range(num_sent): 40 | flat_input_ids.append(sample['input_ids'][i]) 41 | 42 | flat_idf_weights = [] 43 | for input_ids in flat_input_ids: 44 | cur_idf_weights = [self.idf_dict[id] for id in input_ids] 45 | flat_idf_weights.append(cur_idf_weights) 46 | 47 | padded, lens, mask = self.padding_func(flat_input_ids, pad_token, dtype=torch.long) 48 | padded_idf, _, _ = self.padding_func(flat_idf_weights, 0, dtype=torch.float) 49 | assert padded.shape == padded_idf.shape 50 | 51 | target_sent = [] 52 | for sample in batch: 53 | target_sent.append(sample['plain_text'][1]) 54 | negative_sample_mask = self.negative_sample_mask(target_sent) 55 | 56 | # padded = padded.to(device=device) 57 | # mask = mask.to(device=device) 58 | # lens = lens.to(device=device) 59 | # return padded, padded_idf, lens, mask 60 | return {'input_ids': padded, 'attention_mask': mask, 'negative_sample_mask': negative_sample_mask, 61 | 'lengths': lens, 'input_idf': padded_idf, 'num_sent': num_sent} 62 | 63 | def tok_sentences(tokenizer, sentences, has_hard_neg, total, max_length=None): 64 | sent_features = tokenizer( 65 | sentences, 66 | add_special_tokens=True, 67 | # add_prefix_space=True, 68 | max_length=tokenizer.model_max_length if max_length is None else max_length, 69 | truncation=True 70 | ) 71 | 72 | features = {} 73 | if has_hard_neg: 74 | for key in sent_features: 75 | features[key] = [[sent_features[key][i], sent_features[key][i + total], sent_features[key][i + total * 2]] 76 | for i in range(total)] 77 | 78 | else: 79 | for key in sent_features: 80 | features[key] = [[sent_features[key][i], sent_features[key][i + total]] for i in range(total)] 81 | # get the plain text 82 | features['plain_text'] = [] 83 | for i in range(total): 84 | features['plain_text'].append([sentences[i], sentences[i + total]]) 85 | 86 | return features 87 | 88 | -------------------------------------------------------------------------------- /generator/fid/README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for: 2 | - Fusion-in-Decoder models 3 | - Distilling Knowledge from Reader to Retriever 4 | 5 | ## Dependencies 6 | 7 | - Python 3 8 | - [PyTorch](http://pytorch.org/) (currently tested on version 1.6.0) 9 | - [Transformers](http://huggingface.co/transformers/) (**version 3.0.2**, unlikely to work with a different version) 10 | - [NumPy](http://www.numpy.org/) 11 | 12 | 13 | # Data 14 | 15 | ### Data format 16 | 17 | The expected data format is a list of entry examples, where each entry example is a dictionary containing 18 | - `id`: example id, optional 19 | - `question`: question text 20 | - `target`: answer used for model training, if not given, the target is randomly sampled from the 'answers' list 21 | - `answers`: list of answer text for evaluation, also used for training if target is not given 22 | - `ctxs`: a list of passages where each item is a dictionary containing 23 | - `title`: article title 24 | - `text`: passage text 25 | 26 | Entry example: 27 | ``` 28 | { 29 | 'id': '0', 30 | 'question': 'What element did Marie Curie name after her native land?', 31 | 'target': 'Polonium', 32 | 'answers': ['Polonium', 'Po (chemical element)', 'Po'], 33 | 'ctxs': [ 34 | { 35 | "title": "Marie Curie", 36 | "text": "them on visits to Poland. She named the first chemical element that she discovered in 1898 \"polonium\", after her native country. Marie Curie died in 1934, aged 66, at a sanatorium in Sancellemoz (Haute-Savoie), France, of aplastic anemia from exposure to radiation in the course of her scientific research and in the course of her radiological work at field hospitals during World War I. Maria Sk\u0142odowska was born in Warsaw, in Congress Poland in the Russian Empire, on 7 November 1867, the fifth and youngest child of well-known teachers Bronis\u0142awa, \"n\u00e9e\" Boguska, and W\u0142adys\u0142aw Sk\u0142odowski. The elder siblings of Maria" 37 | }, 38 | { 39 | "title": "Marie Curie", 40 | "text": "was present in such minute quantities that they would eventually have to process tons of the ore. In July 1898, Curie and her husband published a joint paper announcing the existence of an element which they named \"polonium\", in honour of her native Poland, which would for another twenty years remain partitioned among three empires (Russian, Austrian, and Prussian). On 26 December 1898, the Curies announced the existence of a second element, which they named \"radium\", from the Latin word for \"ray\". In the course of their research, they also coined the word \"radioactivity\". To prove their discoveries beyond any" 41 | } 42 | ] 43 | } 44 | ``` 45 | 46 | ## References 47 | 48 | [1] G. Izacard, E. Grave [*Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering*](https://arxiv.org/abs/2007.01282) 49 | 50 | ```bibtex 51 | @misc{izacard2020leveraging, 52 | title={Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering}, 53 | author={Gautier Izacard and Edouard Grave}, 54 | year={2020}, 55 | eprint={2007.01282}, 56 | archivePrefix={arXiv}, 57 | primaryClass={cs.CL} 58 | } 59 | ``` 60 | 61 | [2] G. Izacard, E. Grave [*Distilling Knowledge from Reader to Retriever for Question Answering*](https://arxiv.org/abs/2012.04584) 62 | 63 | ```bibtex 64 | @misc{izacard2020distilling, 65 | title={Distilling Knowledge from Reader to Retriever for Question Answering}, 66 | author={Gautier Izacard and Edouard Grave}, 67 | year={2020}, 68 | eprint={2012.04584}, 69 | archivePrefix={arXiv}, 70 | primaryClass={cs.CL} 71 | } 72 | ``` 73 | 74 | ## License 75 | 76 | See the [LICENSE](LICENSE) file for more details. 77 | -------------------------------------------------------------------------------- /generator/fid/src/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import json 9 | import parser 10 | from pathlib import Path 11 | import numpy as np 12 | import util 13 | 14 | def select_examples_TQA(data, index, passages, passages_index): 15 | selected_data = [] 16 | for i, k in enumerate(index): 17 | ex = data[k] 18 | q = ex['Question'] 19 | answers = ex['Answer']['Aliases'] 20 | target = ex['Answer']['Value'] 21 | 22 | ctxs = [ 23 | { 24 | 'id': idx, 25 | 'title': passages[idx][1], 26 | 'text': passages[idx][0], 27 | } 28 | for idx in passages_index[ex['QuestionId']] 29 | ] 30 | 31 | if target.isupper(): 32 | target = target.title() 33 | selected_data.append( 34 | { 35 | 'question': q, 36 | 'answers': answers, 37 | 'target': target, 38 | 'ctxs': ctxs, 39 | } 40 | ) 41 | return selected_data 42 | 43 | def select_examples_NQ(data, index, passages, passages_index): 44 | selected_data = [] 45 | for i, k in enumerate(index): 46 | ctxs = [ 47 | { 48 | 'id': idx, 49 | 'title': passages[idx][1], 50 | 'text': passages[idx][0], 51 | } 52 | for idx in passages_index[str(i)] 53 | ] 54 | dico = { 55 | 'question': data[k]['question'], 56 | 'answers': data[k]['answer'], 57 | 'ctxs': ctxs, 58 | } 59 | selected_data.append(dico) 60 | 61 | return selected_data 62 | 63 | if __name__ == "__main__": 64 | dir_path = Path(sys.argv[1]) 65 | save_dir = Path(sys.argv[2]) 66 | 67 | passages = util.load_passages(save_dir/'psgs_w100.tsv') 68 | passages = {p[0]: (p[1], p[2]) for p in passages} 69 | 70 | #load NQ question idx 71 | NQ_idx = {} 72 | NQ_passages = {} 73 | for split in ['train', 'dev', 'test']: 74 | with open(dir_path/('NQ.' + split + '.idx.json'), 'r') as fin: 75 | NQ_idx[split] = json.load(fin) 76 | with open(dir_path/'nq_passages' / (split + '.json'), 'r') as fin: 77 | NQ_passages[split] = json.load(fin) 78 | 79 | 80 | originaltrain, originaldev = [], [] 81 | with open(dir_path/'NQ-open.dev.jsonl') as fin: 82 | for k, example in enumerate(fin): 83 | example = json.loads(example) 84 | originaldev.append(example) 85 | 86 | with open(dir_path/'NQ-open.train.jsonl') as fin: 87 | for k, example in enumerate(fin): 88 | example = json.loads(example) 89 | originaltrain.append(example) 90 | 91 | NQ_train = select_examples_NQ(originaltrain, NQ_idx['train'], passages, NQ_passages['train']) 92 | NQ_dev = select_examples_NQ(originaltrain, NQ_idx['dev'], passages, NQ_passages['dev']) 93 | NQ_test = select_examples_NQ(originaldev, NQ_idx['test'], passages, NQ_passages['test']) 94 | 95 | NQ_save_path = save_dir / 'NQ' 96 | NQ_save_path.mkdir(parents=True, exist_ok=True) 97 | 98 | with open(NQ_save_path/'train.json', 'w') as fout: 99 | json.dump(NQ_train, fout, indent=4) 100 | with open(NQ_save_path/'dev.json', 'w') as fout: 101 | json.dump(NQ_dev, fout, indent=4) 102 | with open(NQ_save_path/'test.json', 'w') as fout: 103 | json.dump(NQ_test, fout, indent=4) 104 | 105 | #load Trivia question idx 106 | TQA_idx, TQA_passages = {}, {} 107 | for split in ['train', 'dev', 'test']: 108 | with open(dir_path/('TQA.' + split + '.idx.json'), 'r') as fin: 109 | TQA_idx[split] = json.load(fin) 110 | with open(dir_path/'tqa_passages' / (split + '.json'), 'r') as fin: 111 | TQA_passages[split] = json.load(fin) 112 | 113 | 114 | originaltrain, originaldev = [], [] 115 | with open(dir_path/'triviaqa-unfiltered'/'unfiltered-web-train.json') as fin: 116 | originaltrain = json.load(fin)['Data'] 117 | 118 | with open(dir_path/'triviaqa-unfiltered'/'unfiltered-web-dev.json') as fin: 119 | originaldev = json.load(fin)['Data'] 120 | 121 | TQA_train = select_examples_TQA(originaltrain, TQA_idx['train'], passages, TQA_passages['train']) 122 | TQA_dev = select_examples_TQA(originaltrain, TQA_idx['dev'], passages, TQA_passages['dev']) 123 | TQA_test = select_examples_TQA(originaldev, TQA_idx['test'], passages, TQA_passages['test']) 124 | 125 | TQA_save_path = save_dir / 'TQA' 126 | TQA_save_path.mkdir(parents=True, exist_ok=True) 127 | 128 | with open(TQA_save_path/'train.json', 'w') as fout: 129 | json.dump(TQA_train, fout, indent=4) 130 | with open(TQA_save_path/'dev.json', 'w') as fout: 131 | json.dump(TQA_dev, fout, indent=4) 132 | with open(TQA_save_path/'test.json', 'w') as fout: 133 | json.dump(TQA_test, fout, indent=4) 134 | -------------------------------------------------------------------------------- /generator/fid/src/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import logging 10 | import regex 11 | import string 12 | import unicodedata 13 | from functools import partial 14 | from multiprocessing import Pool as ProcessPool 15 | from typing import Tuple, List, Dict 16 | import numpy as np 17 | 18 | """ 19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 20 | """ 21 | 22 | class SimpleTokenizer(object): 23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 24 | NON_WS = r'[^\p{Z}\p{C}]' 25 | 26 | def __init__(self): 27 | """ 28 | Args: 29 | annotators: None or empty set (only tokenizes). 30 | """ 31 | self._regexp = regex.compile( 32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 34 | ) 35 | 36 | def tokenize(self, text, uncased=False): 37 | matches = [m for m in self._regexp.finditer(text)] 38 | if uncased: 39 | tokens = [m.group().lower() for m in matches] 40 | else: 41 | tokens = [m.group() for m in matches] 42 | return tokens 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 47 | 48 | def calculate_matches(data: List, workers_num: int): 49 | """ 50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 53 | :param answers: list of answers's list. One list per question 54 | :param closest_docs: document ids of the top results along with their scores 55 | :param workers_num: amount of parallel threads to process data 56 | :param match_type: type of answer matching. Refer to has_answer code for available options 57 | :return: matching information tuple. 58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 59 | valid matches across an entire dataset. 60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 61 | """ 62 | 63 | logger.info('Matching answers in top docs...') 64 | 65 | tokenizer = SimpleTokenizer() 66 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 67 | 68 | processes = ProcessPool(processes=workers_num) 69 | scores = processes.map(get_score_partial, data) 70 | 71 | logger.info('Per question validation results len=%d', len(scores)) 72 | 73 | n_docs = len(data[0]['ctxs']) 74 | top_k_hits = [0] * n_docs 75 | for question_hits in scores: 76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 77 | if best_hit is not None: 78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 79 | 80 | return QAMatchStats(top_k_hits, scores) 81 | 82 | def check_answer(example, tokenizer) -> List[bool]: 83 | """Search through all the top docs to see if they have any of the answers.""" 84 | answers = example['answers'] 85 | ctxs = example['ctxs'] 86 | 87 | hits = [] 88 | 89 | for i, doc in enumerate(ctxs): 90 | text = doc['text'] 91 | 92 | if text is None: # cannot find the document for some reason 93 | logger.warning("no doc in db") 94 | hits.append(False) 95 | continue 96 | 97 | hits.append(has_answer(answers, text, tokenizer)) 98 | 99 | return hits 100 | 101 | def has_answer(answers, text, tokenizer) -> bool: 102 | """Check if a document contains an answer string.""" 103 | text = _normalize(text) 104 | text = tokenizer.tokenize(text, uncased=True) 105 | 106 | for answer in answers: 107 | answer = _normalize(answer) 108 | answer = tokenizer.tokenize(answer, uncased=True) 109 | for i in range(0, len(text) - len(answer) + 1): 110 | if answer == text[i: i + len(answer)]: 111 | return True 112 | return False 113 | 114 | ################################################# 115 | ######## READER EVALUATION ######## 116 | ################################################# 117 | 118 | def _normalize(text): 119 | return unicodedata.normalize('NFD', text) 120 | 121 | #Normalization from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 122 | def normalize_answer(s): 123 | def remove_articles(text): 124 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 125 | 126 | def white_space_fix(text): 127 | return ' '.join(text.split()) 128 | 129 | def remove_punc(text): 130 | exclude = set(string.punctuation) 131 | return ''.join(ch for ch in text if ch not in exclude) 132 | 133 | def lower(text): 134 | return text.lower() 135 | 136 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 137 | 138 | def exact_match_score(prediction, ground_truth): 139 | return normalize_answer(prediction) == normalize_answer(ground_truth) 140 | 141 | def ems(prediction, ground_truths): 142 | return max([exact_match_score(prediction, gt) for gt in ground_truths]) 143 | 144 | #################################################### 145 | ######## RETRIEVER EVALUATION ######## 146 | #################################################### 147 | 148 | def eval_batch(scores, inversions, avg_topk, idx_topk): 149 | for k, s in enumerate(scores): 150 | s = s.cpu().numpy() 151 | sorted_idx = np.argsort(-s) 152 | score(sorted_idx, inversions, avg_topk, idx_topk) 153 | 154 | def count_inversions(arr): 155 | inv_count = 0 156 | lenarr = len(arr) 157 | for i in range(lenarr): 158 | for j in range(i + 1, lenarr): 159 | if (arr[i] > arr[j]): 160 | inv_count += 1 161 | return inv_count 162 | 163 | def score(x, inversions, avg_topk, idx_topk): 164 | x = np.array(x) 165 | inversions.append(count_inversions(x)) 166 | for k in avg_topk: 167 | # ratio of passages in the predicted top-k that are 168 | # also in the topk given by gold score 169 | avg_pred_topk = (x[:k] 1 132 | params.multi_gpu = params.world_size > 1 133 | 134 | # summary 135 | PREFIX = "%i - " % params.global_rank 136 | 137 | # set GPU device 138 | if params.is_distributed: 139 | torch.cuda.set_device(params.local_rank) 140 | device = torch.device("cuda", params.local_rank) 141 | else: 142 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 143 | params.device = device 144 | 145 | # summary 146 | PREFIX = "%i - " % params.global_rank 147 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 148 | print(PREFIX + "Node ID : %i" % params.node_id) 149 | print(PREFIX + "Local rank : %i" % params.local_rank) 150 | print(PREFIX + "Global rank : %i" % params.global_rank) 151 | print(PREFIX + "World size : %i" % params.world_size) 152 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 153 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 154 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 155 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 156 | 157 | # initialize multi-GPU 158 | if params.is_distributed: 159 | 160 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 161 | # 'env://' will read these environment variables: 162 | # MASTER_PORT - required; has to be a free port on machine with rank 0 163 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 164 | # WORLD_SIZE - required; can be set either here, or in a call to init function 165 | # RANK - required; can be set either here, or in a call to init function 166 | 167 | #print("Initializing PyTorch distributed ...") 168 | torch.distributed.init_process_group( 169 | init_method='env://', 170 | backend='nccl', 171 | ) -------------------------------------------------------------------------------- /generator/fid/test_reader_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import os 8 | 9 | import torch 10 | import transformers 11 | import numpy as np 12 | from pathlib import Path 13 | import torch.distributed as dist 14 | from torch.utils.data import DataLoader, SequentialSampler 15 | from dataset_helper.conala.gen_metric import _bleu as conala_bleu 16 | from dataset_helper.tldr.gen_metric import tldr_metrics 17 | from tqdm import tqdm 18 | 19 | import src.slurm 20 | import src.util 21 | from src.options import Options 22 | import src.data 23 | import src.model 24 | 25 | TQDM_DISABLED = os.environ['TQDM_DISABLED'] if 'TQDM_DISABLED' in os.environ else False 26 | 27 | def evaluate(model, dataset, dataloader, tokenizer, opt): 28 | loss, curr_loss = 0.0, 0.0 29 | model.eval() 30 | if hasattr(model, "module"): 31 | model = model.module 32 | if opt.write_crossattention_scores: 33 | model.overwrite_forward_crossattention() 34 | model.reset_score_storage() 35 | total = 0 36 | 37 | with torch.no_grad(): 38 | result_d = [] 39 | 40 | with open(f"{opt.checkpoint_path}/gold.gold", "w+") as fg, \ 41 | open(f'{opt.checkpoint_path}/pred.pred', 'w+') as fp, \ 42 | open(opt.result_file, 'w+') as fr: 43 | for i, batch in enumerate(tqdm(dataloader, disable=TQDM_DISABLED)): 44 | (idx, _, _, context_ids, context_mask) = batch 45 | 46 | if opt.write_crossattention_scores: 47 | model.reset_score_storage() 48 | 49 | outputs = model.generate( 50 | input_ids=context_ids.cuda(), 51 | attention_mask=context_mask.cuda(), 52 | max_length=150, 53 | lenpen=opt.lenpen, 54 | num_beams=opt.num_beams, 55 | temperature=opt.temperature, 56 | top_p=opt.top_p, 57 | num_return_sequences=opt.num_return_sequences, 58 | ) 59 | if opt.num_return_sequences == 1: 60 | for k, o in enumerate(outputs): 61 | ans = tokenizer.decode(o, skip_special_tokens=False) 62 | gold = dataset.get_example(idx[k])['target'] 63 | ans = ans.replace("{{", " {{").replace("\n", ' ').replace("\r", "").replace("", "").replace("", "").replace("", "").strip() 64 | ans = " ".join(ans.split()) 65 | gold = gold.replace("\n", ' ') 66 | fg.write(f"{gold}\n") 67 | fp.write(f"{ans}\n") 68 | cur_result = {'question_id': dataset.get_example(idx[k])['id'], 'gold': gold, 'clean_code': ans} 69 | result_d.append(cur_result) 70 | total += 1 71 | fr.write(json.dumps(cur_result) + "\n") 72 | else: 73 | outputs = outputs.view(-1, opt.num_return_sequences, outputs.size(-1)) 74 | for k, o in enumerate(outputs): 75 | ans_list = [] 76 | gold = dataset.get_example(idx[k])['target'] 77 | gold = gold.replace("\n", ' ') 78 | for j, oj in enumerate(o): 79 | ans = tokenizer.decode(oj, skip_special_tokens=False) 80 | ans = ans.replace("{{", " {{").replace("\n", ' ').replace("\r", "").replace("", "").replace("", "").replace("", "").strip() 81 | ans = " ".join(ans.split()) 82 | ans_list.append(ans) 83 | cur_result = {'question_id': dataset.get_example(idx[k])['id'], 'gold': gold, 'clean_code': ans_list} 84 | result_d.append(cur_result) 85 | total += 1 86 | fr.write(json.dumps(cur_result) + "\n") 87 | 88 | 89 | if opt.num_return_sequences == 1: 90 | if opt.eval_metric == 'bleu': 91 | score = conala_bleu( 92 | f"{opt.checkpoint_path}/gold.gold", 93 | f"{opt.checkpoint_path}/pred.pred", 94 | smooth=False, code_tokenize=True) 95 | score = {'bleu': score} 96 | 97 | elif opt.eval_metric == 'token_f1': 98 | score = tldr_metrics( 99 | f"{opt.checkpoint_path}/gold.gold", 100 | f"{opt.checkpoint_path}/pred.pred", 101 | ) 102 | else: 103 | raise NotImplementedError 104 | else: 105 | score = 0 106 | 107 | return score, total 108 | 109 | if __name__ == "__main__": 110 | options = Options() 111 | options.add_reader_options() 112 | options.add_eval_options() 113 | opt = options.parse() 114 | src.slurm.init_distributed_mode(opt) 115 | src.slurm.init_signal_handler() 116 | opt.train_batch_size = opt.per_gpu_batch_size * max(1, opt.world_size) 117 | opt.checkpoint_path = Path(opt.checkpoint_dir) / opt.name 118 | opt.result_file = Path(opt.checkpoint_dir) / opt.name / f'test_results_{opt.result_tag}.json' 119 | 120 | dir_path = Path(opt.checkpoint_dir) / opt.name 121 | directory_exists = dir_path.exists() 122 | 123 | if opt.is_distributed: 124 | torch.distributed.barrier() 125 | dir_path.mkdir(parents=True, exist_ok=True) 126 | 127 | logger = src.util.init_logger(opt.is_main, opt.is_distributed, Path(opt.checkpoint_dir) / opt.name / 'run.log') 128 | 129 | if not directory_exists and opt.is_main: 130 | options.print_options(opt) 131 | 132 | if 'codet5' in opt.tokenizer_name: 133 | logger.info(f'load the tokenizer from codet5') 134 | tokenizer = transformers.RobertaTokenizer.from_pretrained(opt.tokenizer_name) 135 | else: 136 | logger.info(f'load the tokenizer from t5') 137 | tokenizer = transformers.T5Tokenizer.from_pretrained(opt.tokenizer_name) 138 | 139 | if opt.dataset == 'tldr': 140 | special_tokens_dict = {'additional_special_tokens': ['{{', '}}']} 141 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 142 | 143 | collator_function = src.data.Collator(opt.text_maxlength, tokenizer) 144 | eval_examples = src.data.load_data( 145 | opt.eval_data, 146 | global_rank=opt.global_rank, 147 | # use the global rank and world size attibutes to split the eval set on multiple gpus 148 | world_size=opt.world_size 149 | ) 150 | eval_dataset = src.data.Dataset( 151 | eval_examples, 152 | opt.n_context, 153 | ) 154 | 155 | eval_sampler = SequentialSampler(eval_dataset) 156 | eval_dataloader = DataLoader( 157 | eval_dataset, 158 | sampler=eval_sampler, 159 | batch_size=opt.per_gpu_batch_size, 160 | num_workers=20, 161 | collate_fn=collator_function 162 | ) 163 | 164 | model_class = src.model.FiDT5 165 | model = model_class.from_pretrained(opt.model_path) 166 | model = model.to(opt.device) 167 | 168 | logger.info("Start eval") 169 | score, total = evaluate(model, eval_dataset, eval_dataloader, tokenizer, opt) 170 | 171 | logger.info(f'Total number of example {total}') 172 | logger.info(json.dumps(score, indent=2)) 173 | 174 | -------------------------------------------------------------------------------- /prompts/conala_docpropmting_retrieved_docs.txt: -------------------------------------------------------------------------------- 1 | Potential document 0: python datetime datetime strptime: classmethod datetime.strptime(date_string, format) Return a datetime corresponding to date_string, parsed according to format. This is equivalent to: datetime(*(time.strptime(date_string, format)[0:6])) ValueError is raised if the date_string and format can’t be parsed by time.strptime() or if it returns a value which isn’t a time tuple. For a complete list of formatting directives, see strftime() and strptime() Behavior. 2 | Potential document 1: python time strftime: time.strftime(format[, t]) Convert a tuple or struct_time representing a time as returned by gmtime() or localtime() to a string as specified by the format argument. If t is not provided, the current time as returned by localtime() is used. format must be a string. ValueError is raised if any field in t is outside of the allowed range. 0 is a legal argument for any position in the time tuple; if it is normally illegal the value is forced to a correct one. The following directives can be embedded in the format string. They are shown without the optional field width and precision specification, and are replaced by the indicated characters in the strftime() result: 3 | Potential document 2: python time strptime: time.strptime(string[, format]) Parse a string representing a time according to a format. The return value is a struct_time as returned by gmtime() or localtime(). The format parameter uses the same directives as those used by strftime(); it defaults to "%a %b %d %H:%M:%S %Y" which matches the formatting returned by ctime(). If string cannot be parsed according to format, or if it has excess data after parsing, ValueError is raised. The default values used to fill in any missing data when more accurate values cannot be inferred are (1900, 1, 1, 0, 0, 0, 0, 1, -1). Both 4 | Potential document 3: python datetime datetime strftime: datetime.strftime(format) Return a string representing the date and time, controlled by an explicit format string. For a complete list of formatting directives, see strftime() and strptime() Behavior. 5 | Potential document 4: python datetime date strftime: date.strftime(format) Return a string representing the date, controlled by an explicit format string. Format codes referring to hours, minutes or seconds will see 0 values. For a complete list of formatting directives, see strftime() and strptime() Behavior. 6 | # convert string '2011221' into a DateTime object using format '%Y%W%w' 7 | datetime.strptime('2011221', '%Y%W%w') 8 | 9 | #END 10 | 11 | Potential document 0: python str rsplit: str.rsplit(sep=None, maxsplit=-1) Return a list of the words in the string, using sep as the delimiter string. If maxsplit is given, at most maxsplit splits are done, the rightmost ones. If sep is not specified or None, any whitespace string is a separator. Except for splitting from the right, rsplit() behaves like split() which is described in detail below. 12 | Potential document 1: python sorted: sorted(iterable, *, key=None, reverse=False) Return a new sorted list from the items in iterable. Has two optional arguments which must be specified as keyword arguments. key specifies a function of one argument that is used to extract a comparison key from each element in iterable (for example, key=str.lower). The default value is None (compare the elements directly). reverse is a boolean value. If set to True, then the list elements are sorted as if each comparison were reversed. Use functools.cmp_to_key() to convert an old-style cmp function to a key function. The built-in sorted() function is guaranteed to be stable. A sort 13 | Potential document 2: python list sort: sort(*, key=None, reverse=False) This method sorts the list in place, using only < comparisons between items. Exceptions are not suppressed - if any comparison operations fail, the entire sort operation will fail (and the list will likely be left in a partially modified state). sort() accepts two arguments that can only be passed by keyword (keyword-only arguments): key specifies a function of one argument that is used to extract a comparison key from each list element (for example, key=str.lower). The key corresponding to each item in the list is calculated once and then used for the entire sorting process. The default value of None means that list items are sorted directly without calculating a separate key value. The 14 | Potential document 3: python operator itemgetter: operator.itemgetter(item) operator.itemgetter(*items) Return a callable object that fetches item from its operand using the operand’s __getitem__() method. If multiple items are specified, returns a tuple of lookup values. For example: After f = itemgetter(2), the call f(r) returns r[2]. After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3]). Equivalent to: def itemgetter(*items): if len(items) == 1: 15 | Potential document 4: python str rfind: str.rfind(sub[, start[, end]]) Return the highest index in the string where substring sub is found, such that sub is contained within s[start:end]. Optional arguments start and end are interpreted as in slice notation. Return -1 on failure. 16 | # Sort a list of strings 'words' such that items starting with 's' come first. 17 | sorted(words, key=lambda x: 'a' + x if x.startswith('s') else 'b' + x) 18 | 19 | #END 20 | 21 | Potential document 0: pandas series dropna: pandas.Series.dropna Series.dropna(axis=0, inplace=False, how=None)[source] Return a new Series with missing values removed. See the User Guide for more on which values are considered missing, and how to work with missing data. Parameters axis:{0 or ‘index’}, default 0 There is only one axis to drop values from. inplace:bool, default False If True, do operation inplace and return None. how:str, optional Not in use. Kept for compatibility. Returns Series or None Series with NA entries 22 | Potential document 1: pandas dataframe fillna: pandas.DataFrame.fillna DataFrame.fillna(value=None, method=None, axis=None, inplace=False, limit=None, downcast=None)[source] Fill NA/NaN values using the specified method. Parameters value:scalar, dict, Series, or DataFrame Value to use to fill holes (e.g. 0), alternately a dict/Series/DataFrame of values specifying which value to use for each index (for a Series) or column (for a DataFrame). Values not in the dict/Series/DataFrame will not be filled. This value cannot be a list. 23 | Potential document 2: pandas dataframe isnull: pandas.DataFrame.isnull DataFrame.isnull()[source] DataFrame.isnull is an alias for DataFrame.isna. Detect missing values. Return a boolean same-sized object indicating if the values are NA. NA values, such as None or numpy.NaN, gets mapped to True values. Everything else gets mapped to False values. Characters such as empty strings '' or numpy.inf are not considered NA values (unless you set pandas.options.mode.use_inf_as_na = True). Returns DataFrame Mask of bool values for each element in DataFrame that indicates whether an element is an NA value. 24 | Potential document 3: pandas index dropna: pandas.Index.dropna Index.dropna(how='any')[source] Return Index without NA/NaN values. Parameters how:{‘any’, ‘all’}, default ‘any’ If the Index is a MultiIndex, drop the value when any or all levels are NaN. Returns Index 25 | Potential document 4: pandas dataframe notnull: pandas.DataFrame.notnull DataFrame.notnull()[source] DataFrame.notnull is an alias for DataFrame.notna. Detect existing (non-missing) values. Return a boolean same-sized object indicating if the values are not NA. Non-missing values get mapped to True. Characters such as empty strings '' or numpy.inf are not considered NA values (unless you set pandas.options.mode.use_inf_as_na = True). NA values, such as None or numpy.NaN, get mapped to False values. Returns DataFrame Mask of bool values for each element in DataFrame that indicates whether an element 26 | # replace all the nan values with 0 in a pandas dataframe `df` 27 | df.fillna(0) 28 | 29 | #END 30 | 31 | -------------------------------------------------------------------------------- /generator/fid/src/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | from pathlib import Path 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class Options(): 15 | def __init__(self): 16 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | self.initialize_parser() 18 | 19 | def add_optim_options(self): 20 | self.parser.add_argument('--warmup_steps', type=int, default=1000) 21 | self.parser.add_argument('--total_steps', type=int, default=10000) 22 | self.parser.add_argument('--scheduler_steps', type=int, default=None, 23 | help='total number of step for the scheduler, if None then scheduler_total_step = total_step') 24 | self.parser.add_argument('--accumulation_steps', type=int, default=1) 25 | self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 26 | self.parser.add_argument('--lr', type=float, default= 0.00005, help='learning rate') 27 | self.parser.add_argument('--clip', type=float, default=1., help='gradient clipping') 28 | self.parser.add_argument('--optim', type=str, default='adamw') 29 | self.parser.add_argument('--scheduler', type=str, default='linear') 30 | self.parser.add_argument('--weight_decay', type=float, default=0.01) 31 | self.parser.add_argument('--fixed_lr', action='store_true') 32 | 33 | def add_eval_options(self): 34 | self.parser.add_argument('--write_results', action='store_true', help='save results') 35 | self.parser.add_argument('--write_crossattention_scores', action='store_true', 36 | help='save dataset with cross-attention scores') 37 | self.parser.add_argument('--use_softmax', help='use softmax instead of logits as attention score', action='store_true') 38 | self.parser.add_argument('--tokenizer_name', choices=('models/generator/codet5-base', 't5-base', 't5-large'), default='data/fid/codet5-base') 39 | self.parser.add_argument('--result_tag', type=str, help='the evaluation setting') 40 | self.parser.add_argument('--lenpen', type=float, default=1.0, help='length penalty') 41 | self.parser.add_argument('--num_beams', type=int, default=10) 42 | self.parser.add_argument('--num_return_sequences', type=int, default=1, help='number of return sequences') 43 | self.parser.add_argument('--temperature', type=float, default=0.8, help='temperature for sampling') 44 | self.parser.add_argument('--top_p', type=float, default=0.9, help='top_p for nucleus sampling') 45 | 46 | def add_reader_options(self): 47 | self.parser.add_argument('--train_data', type=str, default='none', help='path of train data') 48 | self.parser.add_argument('--eval_data', type=str, default='none', help='path of eval data') 49 | self.parser.add_argument('--model_size', type=str, default='base') 50 | self.parser.add_argument('--model_name', type=str, default='t5-base') 51 | self.parser.add_argument('--use_checkpoint', action='store_true', help='use checkpoint in the encoder') 52 | self.parser.add_argument('--text_maxlength', type=int, default=200, 53 | help='maximum number of tokens in text segments (question+passage)') 54 | self.parser.add_argument('--answer_maxlength', type=int, default=-1, 55 | help='maximum number of tokens used to train the model, no truncation if -1') 56 | self.parser.add_argument('--no_title', action='store_true', 57 | help='article titles not included in passages') 58 | self.parser.add_argument('--n_context', type=int, default=1) 59 | self.parser.add_argument('--encoder_weights', type=str, default=None, help='path of encoder weight') 60 | self.parser.add_argument('--dataset', choices=('conala', 'tldr'), default='conala') 61 | 62 | def add_retriever_options(self): 63 | self.parser.add_argument('--train_data', type=str, default='none', help='path of train data') 64 | self.parser.add_argument('--eval_data', type=str, default='none', help='path of eval data') 65 | self.parser.add_argument('--indexing_dimension', type=int, default=768) 66 | self.parser.add_argument('--no_projection', action='store_true', 67 | help='No addition Linear layer and layernorm, only works if indexing size equals 768') 68 | self.parser.add_argument('--question_maxlength', type=int, default=40, 69 | help='maximum number of tokens in questions') 70 | self.parser.add_argument('--passage_maxlength', type=int, default=200, 71 | help='maximum number of tokens in passages') 72 | self.parser.add_argument('--no_question_mask', action='store_true') 73 | self.parser.add_argument('--no_passage_mask', action='store_true') 74 | self.parser.add_argument('--extract_cls', action='store_true') 75 | self.parser.add_argument('--no_title', action='store_true', 76 | help='article titles not included in passages') 77 | self.parser.add_argument('--n_context', type=int, default=1) 78 | 79 | 80 | def initialize_parser(self): 81 | # basic parameters 82 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment') 83 | self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint/', help='models are saved here') 84 | self.parser.add_argument('--model_path', type=str, default='none', help='path for retraining') 85 | self.parser.add_argument('--continue_from_checkpoint', action='store_true') 86 | 87 | # dataset parameters 88 | self.parser.add_argument("--per_gpu_batch_size", default=1, type=int, 89 | help="Batch size per GPU/CPU for training.") 90 | self.parser.add_argument('--maxload', type=int, default=-1) 91 | 92 | self.parser.add_argument("--local_rank", type=int, default=-1, 93 | help="For distributed training: local_rank") 94 | self.parser.add_argument("--main_port", type=int, default=-1, 95 | help="Main port (for multi-node SLURM jobs)") 96 | self.parser.add_argument('--seed', type=int, default=0, help="random seed for initialization") 97 | # training parameters 98 | self.parser.add_argument('--eval_freq', type=int, default=500, 99 | help='evaluate model every steps during training') 100 | self.parser.add_argument('--save_freq', type=int, default=5000, 101 | help='save model every steps during training') 102 | self.parser.add_argument('--eval_print_freq', type=int, default=1000, 103 | help='print intermdiate results of evaluation every steps') 104 | 105 | self.parser.add_argument('--eval_metric', choices=('exact_match', 'bleu', 'token_f1'), default='bleu') 106 | 107 | 108 | def print_options(self, opt): 109 | message = '\n' 110 | for k, v in sorted(vars(opt).items()): 111 | comment = '' 112 | default_value = self.parser.get_default(k) 113 | if v != default_value: 114 | comment = f'\t(default: {default_value})' 115 | message += f'{str(k):>30}: {str(v):<40}{comment}\n' 116 | 117 | expr_dir = Path(opt.checkpoint_dir)/ opt.name 118 | model_dir = expr_dir / 'models' 119 | model_dir.mkdir(parents=True, exist_ok=True) 120 | with open(expr_dir/'opt.log', 'wt') as opt_file: 121 | opt_file.write(message) 122 | opt_file.write('\n') 123 | 124 | logger.info(message) 125 | 126 | def parse(self): 127 | opt = self.parser.parse_args() 128 | return opt 129 | 130 | 131 | def get_options(use_reader=False, 132 | use_retriever=False, 133 | use_optim=False, 134 | use_eval=False): 135 | options = Options() 136 | if use_reader: 137 | options.add_reader_options() 138 | if use_retriever: 139 | options.add_retriever_options() 140 | if use_optim: 141 | options.add_optim_options() 142 | if use_eval: 143 | options.add_eval_options() 144 | return options.parse() 145 | -------------------------------------------------------------------------------- /generator/fid/src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import random 9 | import json 10 | import numpy as np 11 | 12 | class Dataset(torch.utils.data.Dataset): 13 | def __init__(self, 14 | data, 15 | n_context=None, 16 | question_prefix='question:', 17 | title_prefix='title:', 18 | passage_prefix='context:'): 19 | self.data = data 20 | self.n_context = n_context 21 | self.question_prefix = question_prefix 22 | self.title_prefix = title_prefix 23 | self.passage_prefix = passage_prefix 24 | self.sort_data() 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | def get_target(self, example): 30 | if 'target' in example: 31 | target = example['target'] 32 | return target + ' ' 33 | elif 'answers' in example: 34 | return random.choice(example['answers']) + ' ' 35 | else: 36 | return None 37 | 38 | def __getitem__(self, index): 39 | example = self.data[index] 40 | question = self.question_prefix + " " + example['question'] 41 | target = self.get_target(example) 42 | 43 | if 'ctxs' in example and self.n_context is not None: 44 | f = self.title_prefix + " {} " + self.passage_prefix + " {}" 45 | contexts = example['ctxs'][:self.n_context] 46 | passages = [f.format(c['title'], c['text']) for c in contexts] 47 | scores = [float(c['score']) for c in contexts] 48 | scores = torch.tensor(scores) 49 | # TODO(egrave): do we want to keep this? 50 | if len(contexts) == 0: 51 | contexts = [question] 52 | else: 53 | passages, scores = None, None 54 | 55 | 56 | return { 57 | 'index' : index, 58 | 'question' : question, 59 | 'target' : target, 60 | 'passages' : passages, 61 | 'scores' : scores 62 | } 63 | 64 | def sort_data(self): 65 | if self.n_context is None or not 'score' in self.data[0]['ctxs'][0]: 66 | return 67 | for ex in self.data: 68 | ex['ctxs'].sort(key=lambda x: float(x['score']), reverse=True) 69 | 70 | def get_example(self, index): 71 | return self.data[index] 72 | 73 | def encode_passages(batch_text_passages, tokenizer, max_length): 74 | passage_ids, passage_masks = [], [] 75 | for k, text_passages in enumerate(batch_text_passages): 76 | p = tokenizer.batch_encode_plus( 77 | text_passages, 78 | max_length=max_length, 79 | pad_to_max_length=True, 80 | return_tensors='pt', 81 | truncation=True 82 | ) 83 | passage_ids.append(p['input_ids'][None]) 84 | passage_masks.append(p['attention_mask'][None]) 85 | 86 | passage_ids = torch.cat(passage_ids, dim=0) 87 | passage_masks = torch.cat(passage_masks, dim=0) 88 | return passage_ids, passage_masks.bool() 89 | 90 | class Collator(object): 91 | def __init__(self, text_maxlength, tokenizer, answer_maxlength=20): 92 | self.tokenizer = tokenizer 93 | self.text_maxlength = text_maxlength 94 | self.answer_maxlength = answer_maxlength 95 | 96 | def __call__(self, batch): 97 | assert(batch[0]['target'] != None) 98 | index = torch.tensor([ex['index'] for ex in batch]) 99 | target = [ex['target'] for ex in batch] 100 | target = self.tokenizer.batch_encode_plus( 101 | target, 102 | max_length=self.answer_maxlength if self.answer_maxlength > 0 else None, 103 | pad_to_max_length=True, 104 | return_tensors='pt', 105 | truncation=True if self.answer_maxlength > 0 else False, 106 | ) 107 | target_ids = target["input_ids"] 108 | target_mask = target["attention_mask"].bool() 109 | target_ids = target_ids.masked_fill(~target_mask, -100) 110 | 111 | def append_question(example): 112 | if example['passages'] is None: 113 | return [example['question']] 114 | return [example['question'] + " " + t for t in example['passages']] 115 | text_passages = [append_question(example) for example in batch] 116 | passage_ids, passage_masks = encode_passages(text_passages, 117 | self.tokenizer, 118 | self.text_maxlength) 119 | 120 | return (index, target_ids, target_mask, passage_ids, passage_masks) 121 | 122 | def load_data(data_path=None, global_rank=-1, world_size=-1): 123 | assert data_path 124 | if data_path.endswith('.jsonl'): 125 | data = open(data_path, 'r') 126 | elif data_path.endswith('.json'): 127 | with open(data_path, 'r') as fin: 128 | data = json.load(fin) 129 | examples = [] 130 | for k, example in enumerate(data): 131 | if global_rank > -1 and not k%world_size==global_rank: 132 | continue 133 | if data_path is not None and data_path.endswith('.jsonl'): 134 | example = json.loads(example) 135 | if not 'id' in example: 136 | example['id'] = k 137 | for c in example['ctxs']: 138 | if not 'score' in c: 139 | c['score'] = 1.0 / (k + 1) 140 | examples.append(example) 141 | ## egrave: is this needed? 142 | if data_path is not None and data_path.endswith('.jsonl'): 143 | data.close() 144 | 145 | return examples 146 | 147 | class RetrieverCollator(object): 148 | def __init__(self, tokenizer, passage_maxlength=200, question_maxlength=40): 149 | self.tokenizer = tokenizer 150 | self.passage_maxlength = passage_maxlength 151 | self.question_maxlength = question_maxlength 152 | 153 | def __call__(self, batch): 154 | index = torch.tensor([ex['index'] for ex in batch]) 155 | 156 | question = [ex['question'] for ex in batch] 157 | question = self.tokenizer.batch_encode_plus( 158 | question, 159 | pad_to_max_length=True, 160 | return_tensors="pt", 161 | max_length=self.question_maxlength, 162 | truncation=True 163 | ) 164 | question_ids = question['input_ids'] 165 | question_mask = question['attention_mask'].bool() 166 | 167 | if batch[0]['scores'] is None or batch[0]['passages'] is None: 168 | return index, question_ids, question_mask, None, None, None 169 | 170 | scores = [ex['scores'] for ex in batch] 171 | scores = torch.stack(scores, dim=0) 172 | 173 | passages = [ex['passages'] for ex in batch] 174 | passage_ids, passage_masks = encode_passages( 175 | passages, 176 | self.tokenizer, 177 | self.passage_maxlength 178 | ) 179 | 180 | return (index, question_ids, question_mask, passage_ids, passage_masks, scores) 181 | 182 | class TextDataset(torch.utils.data.Dataset): 183 | def __init__(self, 184 | data, 185 | title_prefix='title:', 186 | passage_prefix='context:'): 187 | self.data = data 188 | self.title_prefix = title_prefix 189 | self.passage_prefix = passage_prefix 190 | 191 | def __len__(self): 192 | return len(self.data) 193 | 194 | def __getitem__(self, index): 195 | example = self.data[index] 196 | text = self.title_prefix + " " + example[2] + " " + \ 197 | self.passage_prefix + " " + example[1] 198 | return example[0], text 199 | 200 | class TextCollator(object): 201 | def __init__(self, tokenizer, maxlength=200): 202 | self.tokenizer = tokenizer 203 | self.maxlength = maxlength 204 | 205 | def __call__(self, batch): 206 | index = [x[0] for x in batch] 207 | encoded_batch = self.tokenizer.batch_encode_plus( 208 | [x[1] for x in batch], 209 | pad_to_max_length=True, 210 | return_tensors="pt", 211 | max_length=self.maxlength, 212 | truncation=True 213 | ) 214 | text_ids = encoded_batch['input_ids'] 215 | text_mask = encoded_batch['attention_mask'].bool() 216 | 217 | return index, text_ids, text_mask 218 | -------------------------------------------------------------------------------- /retriever/bm25/main.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import os.path 4 | 5 | from tqdm import tqdm 6 | from collections import defaultdict, OrderedDict 7 | from indexer import ESSearch 8 | from retriever.eval import calc_recall, calc_hit, eval_retrieval_from_file 9 | import argparse 10 | 11 | def retrieve_manual(test_file, source, index, host, search_conf, saved_file, top_k=10): 12 | indexer = ESSearch(index, 13 | source, 14 | host_address=host, 15 | re_index=False) 16 | 17 | query_type, search_field = search_conf['query'], search_conf['field'] 18 | tag = f"{index}.{query_type}.{search_field}" 19 | with open(test_file, "r") as f: 20 | d = json.load(f) 21 | results = [] 22 | for item in tqdm(d): 23 | query = item[query_type] 24 | try: 25 | r_mans = indexer.get_topk(search_field, query, topk=top_k) 26 | 27 | results.append({**item, 28 | f'{tag}.retrieved': [x['library_key'] for x in r_mans], 29 | f'{tag}.score': [x.score if not isinstance(x, dict) else x['score'] for x in r_mans]}) 30 | except Exception as e: 31 | print(repr(e)) 32 | results.append({**item, 33 | f'{tag}.retrieved': [], 34 | f'{tag}.score': []}) 35 | 36 | # metrics = calc_recall(results) 37 | saved_file = saved_file.replace(".json", f".{tag}.json") 38 | print(f"save to {saved_file}") 39 | with open(saved_file, "w+") as f: 40 | json.dump(results, f, indent=2) 41 | 42 | return saved_file 43 | 44 | 45 | def doc_base_retrieval(index, source, host, r1_result_file, 46 | retrieval_entry, saved_file, top_k_doc, 47 | top_k_result, oracle_only=False): 48 | indexer = ESSearch(index, source, host, re_index=False) 49 | 50 | print(f"index: {index}") 51 | print(f"r1 file: {r1_result_file}") 52 | print(f"r2 file: {saved_file}") 53 | 54 | with open(r1_result_file, "r") as f: 55 | r1_result = json.load(f) 56 | 57 | split_flag = False 58 | r0 = r1_result[0][retrieval_entry][0] 59 | if r0.split("_")[-1].isdigit(): 60 | split_flag = True 61 | 62 | r2_result = [] 63 | 64 | for item in tqdm(r1_result): 65 | if oracle_only: 66 | if split_flag: 67 | pred_cmd = ["_".join(item['cmd_name'].split("_")[:-1])] 68 | else: 69 | pred_cmd = [item['cmd_name']] 70 | else: 71 | pred_cmd = [] 72 | for pred in item[retrieval_entry]: 73 | cmd_name = pred 74 | if split_flag: 75 | cmd_name = "_".join(pred.split("_")[:-1]) 76 | pred_cmd.append(cmd_name) 77 | pred_cmd = list(OrderedDict.fromkeys(pred_cmd)) 78 | 79 | for cmd_idx, cmd in enumerate(pred_cmd[:top_k_doc]): 80 | item_r2 = item.copy() 81 | item_r2['parent_cmd'] = cmd 82 | item_r2['question_id'] = f"{item['question_id']}-{cmd_idx}" 83 | item_r2.pop(retrieval_entry) 84 | item_r2.pop(retrieval_entry.replace(".retrieved", ".score")) 85 | try: 86 | real_query = { 87 | "query": { 88 | "bool": { 89 | "must": [ 90 | {"term": {'cmd_name.keyword': cmd}}, 91 | {"match": {'manual': item['nl']}} 92 | ] 93 | } 94 | }, 95 | "size": top_k_result 96 | } 97 | r_mans = indexer.es.search(index=indexer.index, body=real_query)['hits']['hits'][:top_k_result] 98 | _r_mans = [] 99 | for r in r_mans: 100 | i = {'library_key': r['_source']['library_key'], 'score': r['_score']} 101 | _r_mans.append(i) 102 | r_mans = _r_mans 103 | 104 | 105 | r2_result.append({**item_r2, 106 | retrieval_entry: [x['library_key'] for x in r_mans], 107 | retrieval_entry.replace(".retrieved", ".score"): [x.score if not isinstance(x, dict) else x['score'] for x in r_mans]}) 108 | 109 | except Exception as e: 110 | print(repr(e)) 111 | r2_result.append({**item_r2, 112 | retrieval_entry: [], 113 | retrieval_entry.replace(".retrieved", ".score"): []}) 114 | 115 | 116 | print(f"size of the results: {len(r2_result)}") 117 | with open(saved_file, "w+") as f: 118 | json.dump(r2_result, f, indent=2) 119 | 120 | 121 | def config(): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--retrieval_stage', type=int, choices=(0, 1, 2), 124 | help='which retrieval stage to run for tldr' 125 | 'stage 0: build retrieval index' 126 | 'stage 1: stage 1 retrieval that retrieves the bash command' 127 | 'stage 2: stage 2 retrieval that retrieves the paragraphs') 128 | parser.add_argument('--split', type=str, 129 | choices=('cmd_train', 'cmd_dev', 'cmd_test'), 130 | default='cmd_dev', 131 | help='which data split to run') 132 | 133 | parser.add_argument('--host', type=str, default='localhost') 134 | 135 | args = parser.parse_args() 136 | return args 137 | 138 | if __name__ == "__main__": 139 | args = config() 140 | stage = args.retrieval_stage 141 | split = args.split 142 | host = args.host 143 | if stage == 0: # build the index 144 | index = "bash_man_whole" 145 | source = "chunk" 146 | _ = ESSearch(index, source, host_address=host, 147 | re_index=True, manual_path='data/tldr/manual_all_raw.json', 148 | func_descpt_path=None) 149 | 150 | index = "bash_man_para" 151 | source = "chunk" 152 | indexer = ESSearch(index, source, host_address=host, 153 | re_index=True, manual_path='data/tldr/manual_section.json', 154 | func_descpt_path=None) 155 | 156 | if stage == 1: 157 | index = 'bash_man_whole' # in the first stage, use the whole bash manual to retrieve the bash commands 158 | source = "chunk" 159 | search_config_1 = {'query': 'nl', 'field': 'manual', 'filter_result': False} 160 | 161 | print(split, index) 162 | data_file = f"./data/tldr/{split}.seed.json" 163 | save_file = data_file.replace(".seed.json", f".full.json") 164 | query_type, search_field = search_config_1['query'], search_config_1['field'] 165 | tag = f"{index}.{query_type}.{search_field}" 166 | real_save_file = save_file.replace(".json", f".{tag}.json") 167 | 168 | if not os.path.exists(real_save_file): 169 | _ = retrieve_manual(data_file, source, index, host, search_config_1, save_file, top_k=35) 170 | 171 | with open(real_save_file, 'r') as f: 172 | d = json.load(f) 173 | 174 | src = [] 175 | pred = [] 176 | for item in d: 177 | src.append(item['cmd_name']) 178 | pred.append(item[f'{tag}.retrieved']) 179 | 180 | calc_hit(src, pred, top_k=[1, 3, 5, 10, 15, 20, 30]) 181 | 182 | if stage == 2: 183 | source = "chunk" 184 | r1_index = 'bash_man_whole' 185 | r2_index = 'bash_man_para' # in the second stage, use paragraphs to retrieve descriptions of relevant arguments etc 186 | search_config_1 = {'query': 'nl', 'field': 'manual', 'filter_result': False} 187 | 188 | data_file = f"./data/tldr/{split}.seed.json" 189 | query_type, search_field = search_config_1['query'], search_config_1['field'] 190 | r1_tag = f"{r1_index}.{query_type}.{search_field}" 191 | r1_save_file = data_file.replace(".seed.json", f".full.{r1_tag}.json") 192 | 193 | r2_save_file = r1_save_file.replace(".json", f".r2-{r2_index}.json") 194 | 195 | if not os.path.exists(r2_save_file): 196 | _ = doc_base_retrieval(r2_index, 197 | source, 198 | host, 199 | r1_save_file, 200 | f'{r1_tag}.retrieved', 201 | r2_save_file, 202 | top_k_doc=5, 203 | top_k_result=30, 204 | oracle_only=False) 205 | 206 | -------------------------------------------------------------------------------- /generator/fid/src/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import errno 9 | import torch 10 | import sys 11 | import logging 12 | import json 13 | from pathlib import Path 14 | import torch.distributed as dist 15 | import csv 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | def init_logger(is_main=True, is_distributed=False, filename=None): 20 | if is_distributed: 21 | torch.distributed.barrier() 22 | handlers = [logging.StreamHandler(sys.stdout)] 23 | if filename is not None: 24 | handlers.append(logging.FileHandler(filename=filename)) 25 | logging.basicConfig( 26 | datefmt="%m/%d/%Y %H:%M:%S", 27 | level=logging.INFO if is_main else logging.WARN, 28 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 29 | handlers=handlers, 30 | ) 31 | logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR) 32 | logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR) 33 | return logger 34 | 35 | def get_checkpoint_path(opt): 36 | checkpoint_path = Path(opt.checkpoint_dir) / opt.name 37 | checkpoint_exists = checkpoint_path.exists() 38 | if opt.is_distributed: 39 | torch.distributed.barrier() 40 | checkpoint_path.mkdir(parents=True, exist_ok=True) 41 | return checkpoint_path, checkpoint_exists 42 | 43 | def symlink_force(target, link_name): 44 | try: 45 | os.symlink(target, link_name) 46 | except OSError as e: 47 | if e.errno == errno.EEXIST: 48 | os.remove(link_name) 49 | os.symlink(target, link_name) 50 | else: 51 | raise e 52 | 53 | def save(model, optimizer, scheduler, step, best_eval_metric, opt, dir_path, name): 54 | model_to_save = model.module if hasattr(model, "module") else model 55 | path = os.path.join(dir_path, "checkpoint") 56 | epoch_path = os.path.join(path, name) #"step-%s" % step) 57 | os.makedirs(epoch_path, exist_ok=True) 58 | model_to_save.save_pretrained(epoch_path) 59 | cp = os.path.join(path, "latest") 60 | fp = os.path.join(epoch_path, "optimizer.pth.tar") 61 | checkpoint = { 62 | "step": step, 63 | "optimizer": optimizer.state_dict(), 64 | "scheduler": scheduler.state_dict(), 65 | "opt": opt, 66 | "best_eval_metric": best_eval_metric, 67 | } 68 | torch.save(checkpoint, fp) 69 | symlink_force(epoch_path, cp) 70 | 71 | 72 | def load(model_class, dir_path, opt, reset_params=False): 73 | epoch_path = os.path.realpath(dir_path) 74 | optimizer_path = os.path.join(epoch_path, "optimizer.pth.tar") 75 | logger.info("Loading %s" % epoch_path) 76 | model = model_class.from_pretrained(epoch_path) 77 | model = model.to(opt.device) 78 | logger.info("loading checkpoint %s" %optimizer_path) 79 | checkpoint = torch.load(optimizer_path, map_location=opt.device) 80 | opt_checkpoint = checkpoint["opt"] 81 | step = checkpoint["step"] 82 | if "best_eval_metric" in checkpoint: 83 | best_eval_metric = checkpoint["best_eval_metric"] 84 | else: 85 | best_eval_metric = checkpoint["best_dev_em"] 86 | if not reset_params: 87 | optimizer, scheduler = set_optim(opt_checkpoint, model) 88 | scheduler.load_state_dict(checkpoint["scheduler"]) 89 | optimizer.load_state_dict(checkpoint["optimizer"]) 90 | else: 91 | optimizer, scheduler = set_optim(opt, model) 92 | 93 | return model, optimizer, scheduler, opt_checkpoint, step, best_eval_metric 94 | 95 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 96 | def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, fixed_lr, last_epoch=-1): 97 | self.warmup_steps = warmup_steps 98 | self.scheduler_steps = scheduler_steps 99 | self.min_ratio = min_ratio 100 | self.fixed_lr = fixed_lr 101 | super(WarmupLinearScheduler, self).__init__( 102 | optimizer, self.lr_lambda, last_epoch=last_epoch 103 | ) 104 | 105 | def lr_lambda(self, step): 106 | if step < self.warmup_steps: 107 | return (1 - self.min_ratio)*step/float(max(1, self.warmup_steps)) + self.min_ratio 108 | 109 | if self.fixed_lr: 110 | return 1.0 111 | 112 | return max(0.0, 113 | 1.0 + (self.min_ratio - 1) * (step - self.warmup_steps)/float(max(1.0, self.scheduler_steps - self.warmup_steps)), 114 | ) 115 | 116 | 117 | class FixedScheduler(torch.optim.lr_scheduler.LambdaLR): 118 | def __init__(self, optimizer, last_epoch=-1): 119 | super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 120 | def lr_lambda(self, step): 121 | return 1.0 122 | 123 | 124 | def set_dropout(model, dropout_rate): 125 | for mod in model.modules(): 126 | if isinstance(mod, torch.nn.Dropout): 127 | mod.p = dropout_rate 128 | 129 | 130 | def set_optim(opt, model): 131 | if opt.optim == 'adam': 132 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) 133 | elif opt.optim == 'adamw': 134 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) 135 | if opt.scheduler == 'fixed': 136 | scheduler = FixedScheduler(optimizer) 137 | elif opt.scheduler == 'linear': 138 | if opt.scheduler_steps is None: 139 | scheduler_steps = opt.total_steps 140 | else: 141 | scheduler_steps = opt.scheduler_steps 142 | scheduler = WarmupLinearScheduler(optimizer, warmup_steps=opt.warmup_steps, scheduler_steps=scheduler_steps, min_ratio=0., fixed_lr=opt.fixed_lr) 143 | return optimizer, scheduler 144 | 145 | 146 | def average_main(x, opt): 147 | if not opt.is_distributed: 148 | return x 149 | if opt.world_size > 1: 150 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 151 | if opt.is_main: 152 | x = x / opt.world_size 153 | return x 154 | 155 | 156 | def sum_main(x, opt): 157 | if not opt.is_distributed: 158 | return x 159 | if opt.world_size > 1: 160 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 161 | return x 162 | 163 | 164 | def weighted_average(x, count, opt): 165 | if not opt.is_distributed: 166 | return x, count 167 | t_loss = torch.tensor([x * count], device=opt.device) 168 | t_total = torch.tensor([count], device=opt.device) 169 | t_loss = sum_main(t_loss, opt) 170 | t_total = sum_main(t_total, opt) 171 | return (t_loss / t_total).item(), t_total.item() 172 | 173 | 174 | def write_output(glob_path, output_path): 175 | files = list(glob_path.glob('*.txt')) 176 | files.sort() 177 | with open(output_path, 'w') as outfile: 178 | for path in files: 179 | with open(path, 'r') as f: 180 | lines = f.readlines() 181 | for line in lines: 182 | outfile.write(line) 183 | path.unlink() 184 | glob_path.rmdir() 185 | 186 | 187 | def save_distributed_dataset(data, opt): 188 | dir_path = Path(opt.checkpoint_dir) / opt.name 189 | write_path = dir_path / 'tmp_dir' 190 | write_path.mkdir(exist_ok=True) 191 | tmp_path = write_path / f'{opt.global_rank}.json' 192 | with open(tmp_path, 'w') as fw: 193 | json.dump(data, fw) 194 | if opt.is_distributed: 195 | torch.distributed.barrier() 196 | if opt.is_main: 197 | final_path = dir_path / f'attention_score.{opt.result_tag}.json' 198 | logger.info(f'Writing dataset with scores at {final_path}') 199 | glob_path = write_path / '*' 200 | results_path = write_path.glob('*.json') 201 | alldata = [] 202 | for path in results_path: 203 | with open(path, 'r') as f: 204 | data = json.load(f) 205 | alldata.extend(data) 206 | path.unlink() 207 | with open(final_path, 'w') as fout: 208 | json.dump(alldata, fout, indent=4) 209 | write_path.rmdir() 210 | 211 | def load_passages(path): 212 | if not os.path.exists(path): 213 | logger.info(f'{path} does not exist') 214 | return 215 | logger.info(f'Loading passages from: {path}') 216 | passages = [] 217 | with open(path) as fin: 218 | reader = csv.reader(fin, delimiter='\t') 219 | for k, row in enumerate(reader): 220 | if not row[0] == 'id': 221 | try: 222 | passages.append((row[0], row[1], row[2])) 223 | except: 224 | logger.warning(f'The following input line has not been correctly loaded: {row}') 225 | return passages -------------------------------------------------------------------------------- /retriever/simcse/run_inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | import pickle 4 | 5 | import argparse 6 | import shlex 7 | 8 | import faiss 9 | import numpy as np 10 | import torch 11 | from tqdm import tqdm 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig 14 | from retriever.eval import eval_retrieval_from_file 15 | from model import RetrievalModel 16 | TQDM_DISABLED = os.environ['TQDM_DISABLED'] if 'TQDM_DISABLED' in os.environ else False 17 | 18 | class Dummy: 19 | pass 20 | 21 | class CodeT5Retriever: 22 | def __init__(self, args): 23 | self.args = args 24 | 25 | def prepare_model(self, model=None, tokenizer=None, config=None): 26 | if self.args.log_level == 'verbose': 27 | transformers.logging.set_verbosity_info() 28 | self.model_name = self.args.model_name 29 | 30 | if model is None: 31 | self.tokenizer = transformers.RobertaTokenizer.from_pretrained(self.model_name) 32 | model_arg = Dummy() 33 | setattr(model_arg, 'sim_func', args.sim_func) 34 | config = AutoConfig.from_pretrained(self.model_name) 35 | self.model = RetrievalModel( 36 | config=config, 37 | model_type=self.model_name, 38 | num_layers=args.num_layers, 39 | tokenizer=tokenizer, 40 | training_args=None, 41 | model_args=model_arg) 42 | self.device = torch.device('cuda') if not self.args.cpu else torch.device('cpu') 43 | self.model.eval() 44 | self.model = self.model.to(self.device) 45 | else: # this is only for evaluation durning training time 46 | self.model = model 47 | self.tokenizer = tokenizer 48 | self.device = self.model.device 49 | 50 | def encode_file(self, text_file, save_file, **kwargs): 51 | normalize_embed = kwargs.get('normalize_embed', False) 52 | with open(text_file, "r") as f: 53 | dataset = [] 54 | for line in f: 55 | dataset.append(line.strip()) 56 | # print(line) 57 | print(f"number of sentences in {text_file}: {len(dataset)}") 58 | 59 | def pad_batch(examples): 60 | sentences = examples 61 | sent_features = self.tokenizer( 62 | sentences, 63 | add_special_tokens=True, 64 | max_length=self.tokenizer.model_max_length, 65 | truncation=True 66 | ) 67 | arr = sent_features['input_ids'] 68 | lens = torch.LongTensor([len(a) for a in arr]) 69 | max_len = lens.max().item() 70 | padded = torch.ones(len(arr), max_len, dtype=torch.long) * self.tokenizer.pad_token_id 71 | mask = torch.zeros(len(arr), max_len, dtype=torch.long) 72 | for i, a in enumerate(arr): 73 | padded[i, : lens[i]] = torch.tensor(a, dtype=torch.long) 74 | mask[i, : lens[i]] = 1 75 | return {'input_ids': padded, 'attention_mask': mask, 'lengths': lens} 76 | 77 | bs = 128 78 | with torch.no_grad(): 79 | all_embeddings = [] 80 | for i in tqdm(range(0, len(dataset), bs), disable=TQDM_DISABLED): 81 | batch = dataset[i: i + bs] 82 | padded_batch = pad_batch(batch) 83 | for k in padded_batch: 84 | if isinstance(padded_batch[k], torch.Tensor): 85 | padded_batch[k] = padded_batch[k].to(self.device) 86 | output = self.model.get_pooling_embedding(**padded_batch, normalize=normalize_embed).detach().cpu().numpy() 87 | all_embeddings.append(output) 88 | 89 | all_embeddings = np.concatenate(all_embeddings, axis=0) 90 | print(f"done embedding: {all_embeddings.shape}") 91 | 92 | if not os.path.exists(os.path.dirname(save_file)): 93 | os.makedirs(os.path.dirname(save_file)) 94 | 95 | np.save(save_file, all_embeddings) 96 | 97 | @staticmethod 98 | def retrieve(source_embed_file, target_embed_file, source_id_file, target_id_file, top_k, save_file): 99 | print(f'source: {source_embed_file}, target: {target_embed_file}') 100 | with open(source_id_file, "r") as f: 101 | source_id_map = {} 102 | for idx, line in enumerate(f): 103 | source_id_map[idx] = line.strip() 104 | with open(target_id_file, "r") as f: 105 | target_id_map = {} 106 | for idx, line in enumerate(f): 107 | target_id_map[idx] = line.strip() 108 | 109 | source_embed = np.load(source_embed_file + ".npy") 110 | target_embed = np.load(target_embed_file + ".npy") 111 | assert len(source_id_map) == source_embed.shape[0] 112 | assert len(target_id_map) == target_embed.shape[0] 113 | indexer = faiss.IndexFlatIP(target_embed.shape[1]) 114 | indexer.add(target_embed) 115 | print(source_embed.shape, target_embed.shape) 116 | D, I = indexer.search(source_embed, top_k) 117 | 118 | results = {} 119 | for source_idx, (dist, retrieved_index) in enumerate(zip(D, I)): 120 | source_id = source_id_map[source_idx] 121 | results[source_id] = {} 122 | retrieved_target_id = [target_id_map[x] for x in retrieved_index] 123 | results[source_id]['retrieved'] = retrieved_target_id 124 | results[source_id]['score'] = dist.tolist() 125 | 126 | with open(save_file, "w+") as f: 127 | json.dump(results, f, indent=2) 128 | 129 | return results 130 | 131 | def config(in_program_call=None): 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--model_name', type=str) 134 | parser.add_argument('--batch_size', type=int, default=48) 135 | parser.add_argument('--source_file', default='data/conala/conala_nl.txt') 136 | parser.add_argument('--target_file', default='data/conala/python_manual_firstpara.txt') 137 | parser.add_argument('--source_embed_save_file', default='data/conala/.tmp/src_embedding') 138 | parser.add_argument('--target_embed_save_file', default='data/conala/.tmp/tgt_embedding') 139 | parser.add_argument('--save_file', default='[REPLACE]data/conala/simcse.[MODEL].[SOURCE].[TARGET].[POOLER].t[TOPK].json') 140 | parser.add_argument('--top_k', type=int, default=200) 141 | parser.add_argument('--cpu', action='store_true') 142 | parser.add_argument('--pooler', choices=('cls', 'cls_before_pooler'), default='cls') 143 | parser.add_argument('--log_level', default='verbose') 144 | parser.add_argument('--nl_cm_folder', default='data/conala/nl.cm') 145 | parser.add_argument('--sim_func', default='cls_distance.cosine', choices=('cls_distance.cosine', 'cls_distance.l2', 'bertscore')) 146 | parser.add_argument('--num_layers', type=int, default=12) 147 | parser.add_argument('--origin_mode', action='store_true') 148 | parser.add_argument('--oracle_eval_file', default='data/conala/cmd_dev.oracle_man.full.json') 149 | parser.add_argument('--eval_hit', action='store_true') 150 | parser.add_argument('--normalize_embed', action='store_true') 151 | 152 | 153 | 154 | args = parser.parse_args() if in_program_call is None else parser.parse_args(shlex.split(in_program_call)) 155 | 156 | args.source_idx_file = args.source_file.replace(".txt", ".id") 157 | args.target_idx_file = args.target_file.replace(".txt", ".id") 158 | 159 | if in_program_call is None and args.save_file.startswith("[REPLACE]"): 160 | args.save_file = args.save_file.replace("[REPLACE]", "") 161 | args.save_file = args.save_file.replace("[MODEL]", os.path.basename(args.model_name)) 162 | args.save_file = args.save_file.replace("[SOURCE]", os.path.basename(args.source_file).split(".")[0]) 163 | args.save_file = args.save_file.replace("[TARGET]", os.path.basename(args.target_file).split(".")[0]) 164 | args.save_file = args.save_file.replace("[POOLER]", args.pooler) 165 | args.save_file = args.save_file.replace("[TOPK]", str(args.top_k)) 166 | print(json.dumps(vars(args), indent=2)) 167 | return args 168 | 169 | if __name__ == "__main__": 170 | args = config() 171 | 172 | searcher = CodeT5Retriever(args) 173 | searcher.prepare_model() 174 | searcher.encode_file(args.source_file, args.source_embed_save_file, normalize_embed=args.normalize_embed) 175 | searcher.encode_file(args.target_file, args.target_embed_save_file, normalize_embed=args.normalize_embed) 176 | searcher.retrieve(args.source_embed_save_file, 177 | args.target_embed_save_file, args.source_idx_file, 178 | args.target_idx_file, args.top_k, args.save_file) 179 | 180 | flag = 'recall' 181 | top_n = 10 182 | m1 = eval_retrieval_from_file(args.oracle_eval_file, args.save_file) 183 | 184 | 185 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import os.path 10 | import socket 11 | from collections import defaultdict 12 | from typing import Dict, List 13 | import re 14 | import pickle 15 | from utils.constants import VAR_STR 16 | 17 | def init_logger(args): 18 | # setup logger 19 | logger = logging.getLogger() 20 | handler = logging.StreamHandler() 21 | formatter = logging.Formatter(f"%(asctime)s %(module)s - %(funcName)s: [ {socket.gethostname()} | Node {args.node_id} | Rank {args.global_rank} ] %(message)s", 22 | datefmt='%Y-%m-%d %H:%M:%S') 23 | handler.setFormatter(formatter) 24 | logger.handlers.clear() 25 | logger.addHandler(handler) 26 | logger.setLevel(logging.INFO) 27 | 28 | return logger 29 | 30 | def dedup_results(saved): 31 | exc_key = ['nl', 'gold', 'cmd_name'] 32 | for item in saved: 33 | if 'prediction' in item: 34 | item['code'] = item['prediction'] 35 | item.pop('prediction') 36 | if 'score' in item: 37 | item['sequence_ll'] = item['score'] 38 | item.pop('score') 39 | 40 | _saved = [saved[0]] 41 | for item in saved[1:]: 42 | if item['nl'] == _saved[-1]['nl']: 43 | for k in item.keys(): 44 | if k not in exc_key: 45 | _saved[-1][k] += item[k] 46 | else: 47 | _saved.append(item) 48 | _saved = {x['nl']: x for x in _saved} 49 | return _saved 50 | 51 | def clean_command(s): 52 | s = s.replace("sudo", "").strip() 53 | s = s.replace("`", "").replace('"', "").replace("'", "") 54 | # '>', '|', '+' 55 | s = s.replace("|", " ").replace(">", " ").replace("<", " ") 56 | s = " ".join(s.split()) 57 | return s 58 | 59 | def anonymize_command(s): 60 | s = s.replace("={", " {") 61 | var_to_pc_holder = defaultdict(lambda: len(var_to_pc_holder)) 62 | for var in re.findall("{{(.*?)}}", s): 63 | _ = var_to_pc_holder[var] 64 | for var, id in var_to_pc_holder.items(): 65 | var_str = "{{%s}}" % var 66 | s = s.replace(var_str, f"{VAR_STR}_{id}") 67 | # s = re.sub("{{.*?}}", VAR_STR, s) 68 | return s 69 | 70 | def get_bag_of_keywords(cmd): 71 | cmd = clean_anonymize_command(cmd) 72 | # try: 73 | # tokens = list(bashlex.split(cmd)) 74 | # except NotImplementedError: 75 | # tokens = cmd.strip().split() 76 | tokens = cmd.strip().split() 77 | tokens = [x for x in tokens if VAR_STR not in x] 78 | return tokens 79 | 80 | def get_bag_of_words(cmd): 81 | cmd = clean_anonymize_command(cmd) 82 | # try: 83 | # tokens = list(bashlex.split(cmd)) 84 | # except NotImplementedError: 85 | # tokens = cmd.strip().split() 86 | tokens = cmd.strip().split() 87 | return tokens 88 | 89 | def clean_manual(man_string): 90 | cur_man_line = [x.strip() for x in man_string.strip().split("\n") if len(x.strip().split()) >= 1] 91 | cur_man_line = [" ".join(x.split()) for x in cur_man_line] 92 | cur_man_line = " ".join(cur_man_line) 93 | return cur_man_line 94 | 95 | def clean_anonymize_command(s): 96 | return anonymize_command(clean_command(s)) 97 | 98 | # used for constraint command_name decoding 99 | class Trie(object): 100 | def __init__(self, sequences: List[List[int]] = []): 101 | self.trie_dict = {} 102 | self.len = 0 103 | if sequences: 104 | for sequence in sequences: 105 | Trie._add_to_trie(sequence, self.trie_dict) 106 | self.len += 1 107 | 108 | self.append_trie = None 109 | self.bos_token_id = None 110 | 111 | def append(self, trie, bos_token_id): 112 | self.append_trie = trie 113 | self.bos_token_id = bos_token_id 114 | 115 | def add(self, sequence: List[int]): 116 | Trie._add_to_trie(sequence, self.trie_dict) 117 | self.len += 1 118 | 119 | def get(self, prefix_sequence: List[int]): 120 | return Trie._get_from_trie( 121 | prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id 122 | ) 123 | 124 | @staticmethod 125 | def load_from_dict(trie_dict): 126 | trie = Trie() 127 | trie.trie_dict = trie_dict 128 | trie.len = sum(1 for _ in trie) 129 | return trie 130 | 131 | @staticmethod 132 | def _add_to_trie(sequence: List[int], trie_dict: Dict): 133 | if sequence: 134 | if sequence[0] not in trie_dict: 135 | trie_dict[sequence[0]] = {} 136 | Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]]) 137 | 138 | @staticmethod 139 | def _get_from_trie( 140 | prefix_sequence: List[int], 141 | trie_dict: Dict, 142 | append_trie=None, 143 | bos_token_id: int = None, 144 | ): 145 | if len(prefix_sequence) == 0: 146 | output = list(trie_dict.keys()) 147 | if append_trie and bos_token_id in output: 148 | output.remove(bos_token_id) 149 | output += list(append_trie.trie_dict.keys()) 150 | return output 151 | elif prefix_sequence[0] in trie_dict: 152 | return Trie._get_from_trie( 153 | prefix_sequence[1:], 154 | trie_dict[prefix_sequence[0]], 155 | append_trie, 156 | bos_token_id, 157 | ) 158 | else: 159 | if append_trie: 160 | return append_trie.get(prefix_sequence) 161 | else: 162 | return [] 163 | 164 | def __iter__(self): 165 | def _traverse(prefix_sequence, trie_dict): 166 | if trie_dict: 167 | for next_token in trie_dict: 168 | yield from _traverse( 169 | prefix_sequence + [next_token], trie_dict[next_token] 170 | ) 171 | else: 172 | yield prefix_sequence 173 | 174 | return _traverse([], self.trie_dict) 175 | 176 | def __len__(self): 177 | return self.len 178 | 179 | def __getitem__(self, value): 180 | return self.get(value) 181 | 182 | def build_trie(): 183 | from glob import glob 184 | from transformers import AutoTokenizer 185 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") 186 | ss = [] 187 | for cmd in glob("./data/tldr/manual_trimmed/*.txt"): 188 | cmd = os.path.basename(cmd).replace(".txt", "") 189 | tok_cmd = tokenizer(f" {cmd}")['input_ids'] 190 | ss.append([-1] + tok_cmd + [-2]) 191 | print(f"number of commands: {len(ss)}") 192 | trie = Trie(ss) 193 | 194 | with open("./data/tldr/nl.cm/cmd_trie.pkl", "wb") as f: 195 | pickle.dump(trie.trie_dict, f) 196 | 197 | 198 | def constrain_cmd_name_fn(cmd_trie, tokenizer, batch_idx, prefix_beam): 199 | sep_token_idx = tokenizer.sep_token_id 200 | if prefix_beam[-1] == sep_token_idx: # the first token 201 | next_tok = cmd_trie.get([-1]) 202 | else: 203 | # get the prefix 204 | prefix_idx = prefix_beam.index(sep_token_idx) 205 | prefix = [-1] + prefix_beam[prefix_idx+1:] 206 | next_tok = cmd_trie.get(prefix) 207 | # EOS or not a command anymore 208 | if [-2] in next_tok or next_tok == []: 209 | next_tok = [x for x in range(len(tokenizer))] 210 | 211 | return next_tok 212 | 213 | if __name__ == "__main__": 214 | cmd = "firejail --net={{eth0}} --ip={{192.168.1.244}} {{/etc/init.d/apache2}} {{start}}" 215 | print(clean_command(cmd)) 216 | print(anonymize_command(cmd)) 217 | print(anonymize_command(clean_command(cmd))) 218 | 219 | cmd = "toilet {{input_text}} -f {{font_filename}} {{font_filename}}" 220 | print(clean_command(cmd)) 221 | print(anonymize_command(cmd)) 222 | print(anonymize_command(clean_command(cmd))) 223 | # print(get_bag_of_keywords(cmd)) 224 | # build_trie() 225 | # with open("./data/tldr/nl.cm/cmd_trie.pkl", "rb") as f: 226 | # d = pickle.load(f) 227 | # trie = Trie.load_from_dict(d) 228 | # print(len(trie.get([-1]))) 229 | # print(trie.get([-1, 300])) 230 | 231 | # from transformers import AutoTokenizer 232 | # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") 233 | # tokenizer.add_special_tokens({"sep_token": "|nl2code|"}) 234 | # nl = "disable ldap authentication |nl2code|" 235 | # tok_nl = tokenizer(nl)['input_ids'] 236 | # tt = [335, 499, 2364] 237 | # while tt: 238 | # next_tok = constrain_cmd_name_fn(trie, tokenizer, None, tok_nl) 239 | # assert tt[0] in next_tok, (tokenizer.convert_ids_to_tokens(tok_nl), tt[0], tokenizer.convert_ids_to_tokens(next_tok)) 240 | # tok_nl.append(tt.pop(0)) 241 | -------------------------------------------------------------------------------- /dataset_helper/conala/gen_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | import json 27 | import os.path 28 | import re 29 | 30 | 31 | def _get_ngrams(segment, max_order): 32 | """Extracts all n-grams upto a given maximum order from an input segment. 33 | 34 | Args: 35 | segment: text segment from which n-grams will be extracted. 36 | max_order: maximum length in tokens of the n-grams returned by this 37 | methods. 38 | 39 | Returns: 40 | The Counter containing all n-grams upto max_order in segment 41 | with a count of how many times each n-gram occurred. 42 | """ 43 | ngram_counts = collections.Counter() 44 | for order in range(1, max_order + 1): 45 | for i in range(0, len(segment) - order + 1): 46 | ngram = tuple(segment[i:i + order]) 47 | ngram_counts[ngram] += 1 48 | return ngram_counts 49 | 50 | 51 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 52 | smooth=False): 53 | """Computes BLEU score of translated segments against one or more references. 54 | 55 | Args: 56 | reference_corpus: list of lists of references for each translation. Each 57 | reference should be tokenized into a list of tokens. 58 | translation_corpus: list of translations to score. Each translation 59 | should be tokenized into a list of tokens. 60 | max_order: Maximum n-gram order to use when computing BLEU score. 61 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 62 | 63 | Returns: 64 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 65 | precisions and brevity penalty. 66 | """ 67 | matches_by_order = [0] * max_order 68 | possible_matches_by_order = [0] * max_order 69 | reference_length = 0 70 | translation_length = 0 71 | for (references, translation) in zip(reference_corpus, 72 | translation_corpus): 73 | reference_length += min(len(r) for r in references) 74 | translation_length += len(translation) 75 | 76 | merged_ref_ngram_counts = collections.Counter() 77 | for reference in references: 78 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 79 | translation_ngram_counts = _get_ngrams(translation, max_order) 80 | overlap = translation_ngram_counts & merged_ref_ngram_counts 81 | for ngram in overlap: 82 | matches_by_order[len(ngram) - 1] += overlap[ngram] 83 | for order in range(1, max_order + 1): 84 | possible_matches = len(translation) - order + 1 85 | if possible_matches > 0: 86 | possible_matches_by_order[order - 1] += possible_matches 87 | 88 | precisions = [0] * max_order 89 | for i in range(0, max_order): 90 | if smooth: 91 | precisions[i] = ((matches_by_order[i] + 1.) / 92 | (possible_matches_by_order[i] + 1.)) 93 | else: 94 | if possible_matches_by_order[i] > 0: 95 | precisions[i] = (float(matches_by_order[i]) / 96 | possible_matches_by_order[i]) 97 | # print(i, f"{precisions[i]:.03f}={float(matches_by_order[i]):.03f}/{possible_matches_by_order[i]}") 98 | else: 99 | precisions[i] = 0.0 100 | # print("========") 101 | if min(precisions) > 0: 102 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 103 | geo_mean = math.exp(p_log_sum) 104 | else: 105 | geo_mean = 0 106 | 107 | ratio = float(translation_length) / reference_length 108 | 109 | if ratio > 1.0: 110 | bp = 1. 111 | else: 112 | bp = math.exp(1 - 1. / ratio) 113 | 114 | bleu = geo_mean * bp 115 | 116 | # print(bleu, precisions, bp, ratio, translation_length, reference_length) 117 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 118 | 119 | 120 | """ The tokenizer that we use for code submissions, from Wang Ling et al., Latent Predictor Networks for Code Generation (2016) 121 | @param code: string containing a code snippet 122 | @return: list of code tokens 123 | """ 124 | 125 | 126 | def tokenize_for_bleu_eval(code): 127 | code = re.sub(r'([^A-Za-z0-9_])', r' \1 ', code) 128 | code = re.sub(r'([a-z])([A-Z])', r'\1 \2', code) 129 | code = re.sub(r'\s+', ' ', code) 130 | code = code.replace('"', '`') 131 | code = code.replace('\'', '`') 132 | tokens = [t for t in code.split(' ') if t] 133 | return tokens 134 | 135 | 136 | def _bleu(ref_file, trans_file, subword_option=None, smooth=True, code_tokenize=False): 137 | assert code_tokenize 138 | assert not smooth 139 | max_order = 4 140 | ref_files = [ref_file] 141 | reference_text = [] 142 | for reference_filename in ref_files: 143 | with open(reference_filename) as fh: 144 | reference_text.append(fh.readlines()) 145 | per_segment_references = [] 146 | for references in zip(*reference_text): 147 | reference_list = [] 148 | for reference in references: 149 | if code_tokenize: 150 | reference_list.append(tokenize_for_bleu_eval(reference.strip())) 151 | else: 152 | reference_list.append(reference.strip().split()) 153 | per_segment_references.append(reference_list) 154 | translations = [] 155 | with open(trans_file) as fh: 156 | for line in fh: 157 | if code_tokenize: 158 | translations.append(tokenize_for_bleu_eval(line.strip())) 159 | else: 160 | translations.append(line.strip().split()) 161 | print(f'src length: {len(per_segment_references)}, tgt length: {len(translations)}') 162 | bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) 163 | return round(100 * bleu_score, 2) 164 | 165 | 166 | def enum_pred_file(split, re_tag, mode): 167 | if mode == 0: 168 | for model_name in ['model_13b_TY_MS0N1_DT', 'model_13b_TY_MS0N3_DT', 'model_13b_TY_MS0N5_DT', 169 | 'model_13b_TY_MS1N1_DT', 'model_13b_TY_MS1N3_DT', 'model_13b_TY_MS1N5_DT', 170 | 'model_13b_TY_MO_DT', 'model_13b_TY_MN_DT', 'model_13b_TY_MN_DCT', 171 | 'model_13b_TY_MO_DCT', 172 | 'model_13b_TY_MS1N3_DCT', 173 | 'model_13b_TY_MS1N5_DCT', 174 | 'model_13b_TY_MS1N3_DCT/CUT160', 175 | 'model_13b_TY_MS1T10_S160_DCT', 176 | 'model_13b_TY_MS1T10_S0_DCT', 177 | 'model_13b_TY_MS1T10_S0_RR_DCT', 178 | 'model_13b_TY_MN_DCT', 179 | 'model_13b_TY_MS1T10_DCT', 180 | 'model_13b_TY_MN_DCT_v2' 181 | ][:]: 182 | 183 | pred_file = f"./data/conala/models/{model_name}/decode.epochbest.t5.b5.l150.cn0.trie0.{split}{re_tag}.json" 184 | if os.path.exists(pred_file): 185 | yield pred_file 186 | 187 | # for i in list(range(0, 20)): 188 | # pred_file = f"./data/conala/models/{model_name}/decode.epoch{i:02d}.t5.b5.l150.cn0.trie0.{split}{re_tag}.json" 189 | # if os.path.exists(pred_file): 190 | # yield pred_file 191 | 192 | elif mode == 1: 193 | pred_d = f"./data/conala/models/model_13b_TY_MS1_DT/decode.epoch06.t5.b5.l150.cn0.trie0.random_test.oracle_man.json" 194 | yield pred_d 195 | 196 | elif mode == 2: 197 | for i in range(10): 198 | pred_file = f"./data/conala/models/model_13b_TY_MN_DM/decode.epoch{i:02d}.t5.b5.l150.cn0.trie0.{split}.json" 199 | if os.path.exists(pred_file): 200 | yield pred_file 201 | for i in range(10): 202 | pred_file = f"./data/conala/models/model_13b_TY_MN_DM/decode.epoch{i:02d}.t5.b5.l150.cn0.trie0.{split}.json.ft" 203 | if os.path.exists(pred_file): 204 | yield pred_file 205 | 206 | def clean_code(code): 207 | return code.replace("<|endoftext|>", "").strip() 208 | 209 | -------------------------------------------------------------------------------- /retriever/simcse/run_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from dataclasses import dataclass, field 5 | from collections import defaultdict 6 | from datasets import load_dataset 7 | import wandb 8 | import transformers 9 | from transformers import ( 10 | CONFIG_MAPPING, 11 | MODEL_FOR_MASKED_LM_MAPPING, 12 | AutoConfig, 13 | AutoTokenizer, 14 | HfArgumentParser, 15 | set_seed, 16 | ) 17 | from transformers.trainer_utils import is_main_process 18 | from model import RetrievalModel 19 | from trainers import CLTrainer 20 | from data_utils import OurDataCollatorWithPadding, tok_sentences 21 | from arguments import ModelArguments, DataTrainingArguments, OurTrainingArguments, RetrieverArguments 22 | logger = logging.getLogger(__name__) 23 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 24 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 25 | 26 | 27 | def main(): 28 | # See all possible arguments in src/transformers/training_args.py 29 | # or by passing the --help flag to this script. 30 | # We now keep distinct sets of args, for a cleaner separation of concerns. 31 | 32 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments, RetrieverArguments)) 33 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 34 | # If we pass only one argument to the script and it's the path to a json file, 35 | # let's parse it to get our arguments. 36 | model_args, data_args, training_args, bertscore_args = parser.parse_json_file( 37 | json_file=os.path.abspath(sys.argv[1])) 38 | else: 39 | model_args, data_args, training_args, bertscore_args = parser.parse_args_into_dataclasses() 40 | 41 | if ( 42 | os.path.exists(training_args.output_dir) 43 | and os.listdir(training_args.output_dir) 44 | and training_args.do_train 45 | and not training_args.overwrite_output_dir 46 | ): 47 | raise ValueError( 48 | f"Output directory ({training_args.output_dir}) already exists and is not empty." 49 | "Use --overwrite_output_dir to overcome." 50 | ) 51 | 52 | # Setup logging 53 | logging.basicConfig( 54 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 55 | datefmt="%m/%d/%Y %H:%M:%S", 56 | level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, 57 | ) 58 | 59 | # Log on each process the small summary: 60 | logger.warning( 61 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 62 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 63 | ) 64 | # Set the verbosity to info of the Transformers logger (on main process only): 65 | if is_main_process(training_args.local_rank): 66 | transformers.utils.logging.set_verbosity_info() 67 | transformers.utils.logging.enable_default_handler() 68 | transformers.utils.logging.enable_explicit_format() 69 | logger.info("Training/evaluation parameters %s", training_args) 70 | 71 | # Set seed before initializing model. 72 | set_seed(training_args.seed) 73 | training_args.eval_file = data_args.eval_file 74 | 75 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 76 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 77 | # (the dataset will be downloaded automatically from the datasets Hub 78 | # 79 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this 80 | # behavior (see below) 81 | # 82 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 83 | # download the dataset. 84 | 85 | assert 'json' in data_args.train_file 86 | data_files = {'train': data_args.train_file} 87 | datasets = load_dataset('json', data_files=data_files) 88 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 89 | # https://huggingface.co/docs/datasets/loading_datasets.html. 90 | 91 | # Load pretrained model and tokenizer 92 | # 93 | # Distributed training: 94 | # The .from_pretrained methods guarantee that only one local process can concurrently 95 | # download model & vocab. 96 | config_kwargs = { 97 | "cache_dir": model_args.cache_dir, 98 | "revision": model_args.model_revision, 99 | "use_auth_token": True if model_args.use_auth_token else None, 100 | } 101 | if model_args.config_name: 102 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 103 | elif model_args.model_name_or_path: 104 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 105 | else: 106 | config = CONFIG_MAPPING[model_args.model_type]() 107 | logger.warning("You are instantiating a new config instance from scratch.") 108 | 109 | # tokenizer_kwargs = { 110 | # "cache_dir": model_args.cache_dir, 111 | # "use_fast": model_args.use_fast_tokenizer, 112 | # "revision": model_args.model_revision, 113 | # "use_auth_token": True if model_args.use_auth_token else None, 114 | # } 115 | assert model_args.model_name_or_path 116 | if 'codet5' in model_args.model_name_or_path: 117 | tokenizer = transformers.RobertaTokenizerFast.from_pretrained(model_args.model_name_or_path, add_prefix_space=True) 118 | else: 119 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 120 | 121 | assert model_args.model_name_or_path 122 | # assert training_args.sim_func == 'bertscore' 123 | model = RetrievalModel( 124 | config=config, 125 | model_type=model_args.model_name_or_path, 126 | num_layers=bertscore_args.num_layers, 127 | all_layers=bertscore_args.all_layers, 128 | idf=bertscore_args.idf, 129 | rescale_with_baseline=bertscore_args.rescale_with_baseline, 130 | baseline_path=bertscore_args.baseline_path, 131 | tokenizer=tokenizer, 132 | training_args = training_args, 133 | model_args=model_args) 134 | 135 | 136 | # load idf dict 137 | if bertscore_args.idf: 138 | raise NotImplementedError 139 | # assert _idf_dict, "IDF weights are not computed" 140 | # idf_dict = _idf_dict 141 | else: 142 | idf_dict = defaultdict(lambda: 1.0) 143 | idf_dict[tokenizer.sep_token_id] = 0 144 | idf_dict[tokenizer.cls_token_id] = 0 145 | 146 | def prepare_features(examples): 147 | total = len(examples['text1']) 148 | for idx in range(total): 149 | if examples['text1'][idx] == '': 150 | examples['text1'][idx] = " " 151 | if examples['text2'][idx] == '': 152 | examples['text2'][idx] = " " 153 | 154 | sentences = examples['text1'] + examples['text2'] 155 | features = tok_sentences(tokenizer, sentences, has_hard_neg=False, total=total, max_length=data_args.max_seq_length) 156 | return features 157 | 158 | 159 | if training_args.do_train: 160 | train_dataset = datasets['train'].map( 161 | prepare_features, 162 | batched=True, 163 | num_proc=data_args.preprocessing_num_workers, 164 | load_from_cache_file=not data_args.overwrite_cache, 165 | ) 166 | 167 | # Data collator 168 | data_collator = OurDataCollatorWithPadding(tokenizer.pad_token_id, idf_dict) 169 | 170 | training_args.remove_unused_columns = False 171 | trainer = CLTrainer( 172 | model=model, 173 | args=training_args, 174 | train_dataset=train_dataset if training_args.do_train else None, 175 | tokenizer=tokenizer, 176 | data_collator=data_collator, 177 | ) 178 | trainer.model_args = model_args 179 | trainer.epoch_metric = {} 180 | trainer.metric_for_best_model = training_args.metric_for_best_model 181 | training_args.do_eval = False 182 | 183 | # Training 184 | if training_args.do_train: 185 | model_path = ( 186 | model_args.model_name_or_path 187 | if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) 188 | else None 189 | ) 190 | 191 | trainer.train(model_path=model_path) 192 | 193 | # Evaluation 194 | results = {} 195 | if training_args.do_eval: 196 | logger.info("*** Evaluate ***") 197 | results = trainer.evaluate(eval_senteval_transfer=True) 198 | 199 | output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") 200 | if trainer.is_world_process_zero(): 201 | with open(output_eval_file, "w") as writer: 202 | logger.info("***** Eval results *****") 203 | for key, value in sorted(results.items()): 204 | logger.info(f" {key} = {value}") 205 | writer.write(f"{key} = {value}\n") 206 | 207 | return results 208 | 209 | 210 | def _mp_fn(index): 211 | # For xla_spawn (TPUs) 212 | main() 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /retriever/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import argparse 7 | 8 | TOP_K = [1, 3, 5, 8, 10, 12, 15, 20, 30, 50, 100, 200] 9 | 10 | def align_src_pred(src_file, pred_file): 11 | with open(src_file, "r") as fsrc, open(pred_file, "r") as fpred: 12 | src = json.load(fsrc) 13 | pred = json.load(fpred)['results'] 14 | # assert len(src) == len(pred), (len(src), len(pred)) 15 | 16 | # re-order src 17 | src_nl = [x['nl'] for x in src] 18 | _src = [] 19 | _pred = [] 20 | for p in pred: 21 | if p['nl'] in src_nl: 22 | _src.append(src[src_nl.index(p['nl'])]) 23 | _pred.append(p) 24 | 25 | src = _src 26 | pred = _pred 27 | 28 | for s, p in zip(src, pred): 29 | assert s['nl'] == p['nl'], (s['nl'], p['nl']) 30 | 31 | print(f"unique nl: {len(set(src_nl))}") 32 | print(f"number of samples (src/pred): {len(src)}/{len(pred)}") 33 | print("pass nl matching check") 34 | 35 | return src, pred 36 | 37 | def calc_metrics(src_file, pred_file): 38 | src, pred = align_src_pred(src_file, pred_file) 39 | 40 | _src = [] 41 | _pred = [] 42 | for s, p in zip(src, pred): 43 | cmd_name = s['cmd_name'] 44 | oracle_man = get_oracle(s, cmd_name) 45 | pred_man = p['pred'] 46 | _src.append(oracle_man) 47 | _pred.append(pred_man) 48 | calc_recall(_src, _pred) 49 | 50 | # description only 51 | _src = [] 52 | for s in src: 53 | _src.append(s['matching_info']['|main|']) 54 | calc_recall(_src, _pred) 55 | 56 | _src = [] 57 | _pred = [] 58 | for s, p in zip(src, pred): 59 | cmd_name = s['cmd_name'] 60 | pred_man = p['pred'] 61 | _src.append(cmd_name) 62 | _pred.append(pred_man) 63 | calc_hit(_src, _pred) 64 | # calc_mean_rank(src, pred) 65 | 66 | 67 | def calc_mean_rank(src, pred): 68 | rank = [] 69 | for s, p in zip(src, pred): 70 | cur_rank = [] 71 | cmd_name = s['cmd_name'] 72 | pred_man = p['pred'] 73 | oracle_man = get_oracle(s, cmd_name) 74 | for o in oracle_man: 75 | if o in pred_man: 76 | cur_rank.append(oracle_man.index(o)) 77 | else: 78 | cur_rank.append(101) 79 | if cur_rank: 80 | rank.append(np.mean(cur_rank)) 81 | 82 | print(np.mean(rank)) 83 | 84 | 85 | def calc_hit(src, pred, top_k=None): 86 | top_k = TOP_K if top_k is None else top_k 87 | hit_n = {x: 0 for x in top_k} 88 | assert len(src) == len(pred), (len(src), len(pred)) 89 | 90 | for s, p in zip(src, pred): 91 | cmd_name = s 92 | pred_man = p 93 | 94 | for tk in hit_n.keys(): 95 | cur_result_vids = pred_man[:tk] 96 | cur_hit = any([cmd_name in x for x in cur_result_vids]) 97 | hit_n[tk] += cur_hit 98 | 99 | hit_n = {k: v / len(pred) for k, v in hit_n.items()} 100 | for k in sorted(hit_n.keys()): 101 | print(f"{hit_n[k] :.3f}", end="\t") 102 | print() 103 | return hit_n 104 | 105 | def get_oracle(item, cmd_name): 106 | # oracle = [f"{cmd_name}_{x}" for x in itertools.chain(*item['matching_info'].values())] 107 | oracle = [f"{cmd_name}_{x}" for x in item['oracle_man']] 108 | return oracle 109 | 110 | def calc_recall(src, pred, print_result=True, top_k=None): 111 | top_k = TOP_K if top_k is None else top_k 112 | recall_n = {x: 0 for x in top_k} 113 | precision_n = {x: 0 for x in top_k} 114 | 115 | for s, p in zip(src, pred): 116 | # cmd_name = s['cmd_name'] 117 | oracle_man = s 118 | pred_man = p 119 | 120 | for tk in recall_n.keys(): 121 | cur_result_vids = pred_man[:tk] 122 | cur_hit = sum([x in cur_result_vids for x in oracle_man]) 123 | # recall_n[tk] += cur_hit / (len(oracle_man) + 1e-10) 124 | recall_n[tk] += cur_hit / (len(oracle_man)) if len(oracle_man) else 1 125 | precision_n[tk] += cur_hit / tk 126 | recall_n = {k: v / len(pred) for k, v in recall_n.items()} 127 | precision_n = {k: v / len(pred) for k, v in precision_n.items()} 128 | 129 | if print_result: 130 | for k in sorted(recall_n.keys()): 131 | print(f"{recall_n[k] :.3f}", end="\t") 132 | print() 133 | for k in sorted(precision_n.keys()): 134 | print(f"{precision_n[k] :.3f}", end="\t") 135 | print() 136 | for k in sorted(recall_n.keys()): 137 | print(f"{2 * precision_n[k] * recall_n[k] / (precision_n[k] + recall_n[k] + 1e-10) :.3f}", end="\t") 138 | print() 139 | 140 | return {'recall': recall_n, 'precision': precision_n} 141 | 142 | def clean_dpr_results(result_file): 143 | results = {'results': [], 'metrics': {}} 144 | with open(result_file, "r") as f: 145 | d = json.load(f) 146 | for _item in d: 147 | item = {} 148 | item['nl'] = _item['question'] 149 | item['pred'] = [x['id'] for x in _item['ctxs']] 150 | results['results'].append(item) 151 | 152 | with open(result_file + ".clean", "w+") as f: 153 | json.dump(results, f, indent=2) 154 | 155 | def recall_per_manual(src_file, result_file, chunk_length_file, topk): 156 | 157 | def find_sum_in_list(len_list, max_num): 158 | idx = len(len_list) 159 | for i in range(len(len_list) + 1): 160 | if sum(len_list[:i]) >= max_num: 161 | idx = i - 1 162 | break 163 | assert sum(len_list[:idx]) <= max_num 164 | return idx 165 | 166 | with open(chunk_length_file, "r") as f: 167 | d = json.load(f) 168 | man_chunk_length = {k: len(v) for k, v in d.items()} 169 | 170 | src, pred = align_src_pred(src_file, result_file) 171 | hit_man = 0 172 | recall = 0 173 | tot = len(src) 174 | for s, p in zip(src, pred): 175 | cmd_name = s['cmd_name'] 176 | oracle_man = get_oracle(s, cmd_name) 177 | pred_man = p['pred'] 178 | top_k_cmd = p['top_pred_cmd'][:topk] 179 | if cmd_name in top_k_cmd: 180 | hit_man += 1 181 | pred_chunks = pred_man[cmd_name] 182 | len_list = [man_chunk_length[x] for x in pred_chunks] 183 | idx = find_sum_in_list(len_list, 1536) 184 | pred_chunks = pred_chunks[:idx] 185 | cur_hit = sum([x in pred_chunks for x in oracle_man]) 186 | recall += cur_hit / len(oracle_man) 187 | 188 | print(f"hit rate: {hit_man}/{tot}={hit_man/tot}") 189 | print(f"recall: {recall}/{hit_man}={recall/hit_man}") 190 | 191 | 192 | def eval_hit_from_file(data_file, retrieval_file, 193 | oracle_entry='oracle_man', retrieval_entry='retrieved'): 194 | assert 'tldr' in data_file 195 | with open(data_file, "r") as f: 196 | d = json.load(f) 197 | gold = ['_'.join(item[oracle_entry][0].split("_")[:-1]) for item in d] 198 | 199 | with open(retrieval_file, "r") as f: 200 | r_d = json.load(f) 201 | # check whether we need to process the retrieved ids 202 | split_flag = False 203 | k0 = list(r_d.keys())[0] 204 | r0 = r_d[k0][retrieval_entry][0] 205 | if r0.split("_")[-1].isdigit(): 206 | split_flag = True 207 | 208 | for k, item in r_d.items(): 209 | if split_flag: 210 | r = ['_'.join(x.split("_")[:-1]) for x in item[retrieval_entry]] 211 | else: 212 | r = item[retrieval_entry] 213 | r = list(OrderedDict.fromkeys(r)) 214 | item[retrieval_entry] = r 215 | 216 | pred = [r_d[x['question_id']][retrieval_entry] for x in d] 217 | print(gold[:3]) 218 | print(pred[0][:3]) 219 | metrics = calc_hit(gold, pred) 220 | return {'hit': metrics} 221 | 222 | def eval_retrieval_from_file(data_file, retrieval_file, 223 | oracle_entry='oracle_man', retrieval_entry='retrieved', top_k=None): 224 | 225 | assert 'oracle_man.full' in data_file or 'conala' not in data_file, (data_file) 226 | # for conala 227 | with open(data_file, "r") as f: 228 | d = json.load(f) 229 | gold = [item[oracle_entry] for item in d] 230 | 231 | with open(retrieval_file, "r") as f: 232 | r_d = json.load(f) 233 | pred = [r_d[x['question_id']][retrieval_entry] for x in d] 234 | metrics = calc_recall(gold, pred, top_k=top_k) 235 | return metrics 236 | 237 | def eval_retrieval_from_loaded(data_file, r_d): 238 | # for conala 239 | with open(data_file, "r") as f: 240 | d = json.load(f) 241 | gold = [item['oracle_man'] for item in d] 242 | pred = [r_d[x['question_id']]['retrieved'] for x in d] 243 | metrics = calc_recall(gold, pred, print_result=False) 244 | return metrics 245 | 246 | if __name__ == "__main__": 247 | parser = argparse.ArgumentParser() 248 | parser.add_argument('--result-file', default=None) 249 | parser.add_argument('--src-file', default=None) 250 | parser.add_argument('--chunk-length-file', default="data/tldr/nl.cm/manual_section.tok.json") 251 | parser.add_argument('--function', nargs='+', type=int, default=[1]) 252 | args = parser.parse_args() 253 | 254 | for cur_func in args.function: 255 | if cur_func == 0: 256 | calc_metrics(args.src_file, args.result_file) 257 | elif cur_func == 1: 258 | # convert data 259 | clean_dpr_results(args.result_file) 260 | elif cur_func == 2: 261 | clean_dpr_results(args.result_file) 262 | args.result_file += ".clean" 263 | calc_metrics(args.src_file, args.result_file) 264 | elif cur_func == 3: 265 | # measure recall for per-doc retrieval 266 | for k in [1, 10, 30, 50]: 267 | print(f"top {k}") 268 | recall_per_manual(args.src_file, args.result_file, args.chunk_length_file, topk=k) 269 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /retriever/simcse/arguments.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | from typing import Optional, Union, List, Dict, Tuple 4 | import torch 5 | from transformers import ( 6 | MODEL_FOR_MASKED_LM_MAPPING, 7 | TrainingArguments, 8 | ) 9 | from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available 10 | 11 | logger = logging.getLogger(__name__) 12 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 13 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 14 | 15 | @dataclass 16 | class ModelArguments: 17 | """ 18 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 19 | """ 20 | 21 | # Huggingface's original arguments 22 | model_name_or_path: Optional[str] = field( 23 | default=None, 24 | metadata={ 25 | "help": "The model checkpoint for weights initialization." 26 | "Don't set if you want to train a model from scratch." 27 | }, 28 | ) 29 | 30 | mlp_weight_path: Optional[str] = field( 31 | default=None, 32 | metadata={ 33 | "help": "mlp weight path" 34 | }, 35 | ) 36 | 37 | model_type: Optional[str] = field( 38 | default=None, 39 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 40 | ) 41 | config_name: Optional[str] = field( 42 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 43 | ) 44 | tokenizer_name: Optional[str] = field( 45 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 46 | ) 47 | cache_dir: Optional[str] = field( 48 | default=None, 49 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 50 | ) 51 | use_fast_tokenizer: bool = field( 52 | default=True, 53 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 54 | ) 55 | model_revision: str = field( 56 | default="main", 57 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 58 | ) 59 | use_auth_token: bool = field( 60 | default=False, 61 | metadata={ 62 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 63 | "with private models)." 64 | }, 65 | ) 66 | 67 | pooler_type: str = field( 68 | default="cls", 69 | metadata={ 70 | "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)." 71 | } 72 | ) 73 | 74 | mlp_only_train: bool = field( 75 | default=False, 76 | metadata={ 77 | "help": "Use MLP only during training" 78 | } 79 | ) 80 | 81 | sim_func: str = field( 82 | default='cls_distance', 83 | metadata={"help": "the similarity function", 84 | "choices": ['cls_distance.cosine', 'cls_distance.l2', 'bertscore']} 85 | ) 86 | 87 | 88 | bert_score_loss: str = field( 89 | default='softmax', 90 | metadata={'help': 'loss function for bertscore sim function', 91 | 'choices': ['softmax', 'hinge']} 92 | ) 93 | 94 | hinge_margin: float = field( 95 | default=1.0 96 | ) 97 | 98 | hard_negative_weight: float = field( 99 | default=0, 100 | metadata={ 101 | "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)." 102 | } 103 | ) 104 | 105 | # SimCSE's arguments 106 | temp: float = field( 107 | default=0.05, 108 | metadata={ 109 | "help": "Temperature for softmax." 110 | } 111 | ) 112 | 113 | do_mlm: bool = field( 114 | default=False, 115 | metadata={ 116 | "help": "Whether to use MLM auxiliary objective." 117 | } 118 | ) 119 | 120 | mlm_weight: float = field( 121 | default=0.1, 122 | metadata={ 123 | "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)." 124 | } 125 | ) 126 | 127 | def __post_init__(self): 128 | if self.sim_func == 'cls_distance.l2': 129 | self.temp = 1 130 | 131 | 132 | 133 | @dataclass 134 | class DataTrainingArguments: 135 | """ 136 | Arguments pertaining to what data we are going to input our model for training and eval. 137 | """ 138 | 139 | # Huggingface's original arguments. 140 | dataset_name: Optional[str] = field( 141 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 142 | ) 143 | dataset_config_name: Optional[str] = field( 144 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 145 | ) 146 | overwrite_cache: bool = field( 147 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 148 | ) 149 | validation_split_percentage: Optional[int] = field( 150 | default=5, 151 | metadata={ 152 | "help": "The percentage of the train set used as validation set in case there's no validation split" 153 | }, 154 | ) 155 | preprocessing_num_workers: Optional[int] = field( 156 | default=None, 157 | metadata={"help": "The number of processes to use for the preprocessing."}, 158 | ) 159 | 160 | # SimCSE's arguments 161 | train_file: Optional[str] = field( 162 | default=None, 163 | metadata={"help": "The training data file (.txt or .csv)."} 164 | ) 165 | 166 | eval_file: Optional[str] = field( 167 | default=None, 168 | metadata={"help": "The eval data file (.txt or .csv)."} 169 | ) 170 | 171 | 172 | max_seq_length: Optional[int] = field( 173 | default=32, 174 | metadata={ 175 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 176 | "than this will be truncated." 177 | }, 178 | ) 179 | pad_to_max_length: bool = field( 180 | default=False, 181 | metadata={ 182 | "help": "Whether to pad all samples to `max_seq_length`. " 183 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 184 | }, 185 | ) 186 | mlm_probability: float = field( 187 | default=0.15, 188 | metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"} 189 | ) 190 | 191 | def __post_init__(self): 192 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 193 | raise ValueError("Need either a dataset name or a training/validation file.") 194 | else: 195 | if self.train_file is not None: 196 | extension = self.train_file.split(".")[-1] 197 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 198 | 199 | x = [1, 2, 3] 200 | from matplotlib.pyplot import plot 201 | plot(x, "go", label="temperature") 202 | 203 | @dataclass 204 | class OurTrainingArguments(TrainingArguments): 205 | # Evaluation 206 | ## By default, we evaluate STS (dev) during training (for selecting best checkpoints) and evaluate 207 | ## both STS and transfer tasks (dev) at the end of training. Using --eval_transfer will allow evaluating 208 | ## both STS and transfer tasks (dev) during training. 209 | eval_transfer: bool = field( 210 | default=False, 211 | metadata={"help": "Evaluate transfer task dev sets (in validation)."} 212 | ) 213 | 214 | customized_eval: bool = field( 215 | default=True, 216 | metadata={"help": "Evaluate on the original set, if True, evaluate on user's own data"} 217 | ) 218 | 219 | customized_eval_used_split: Optional[str] = field( 220 | default='dev' 221 | ) 222 | 223 | tmp_tag: Optional[str] = field( 224 | default='tmp', 225 | metadata={'help': 'tag to save tmp models in case of overwriting'} 226 | ) 227 | 228 | report_to: Optional[str] = field( 229 | default='wandb' 230 | ) 231 | 232 | logging_steps: int = field( 233 | default=1 234 | ) 235 | 236 | logging_dir: Optional[str] = field( 237 | default='logs' 238 | ) 239 | 240 | disable_tqdm: bool = field( 241 | default=True 242 | ) 243 | 244 | eval_form: str = field( 245 | default='reranking', 246 | metadata={'choices': ['reranking', 'retrieval']} 247 | ) 248 | 249 | eval_retriever: str = field( 250 | default='t5', 251 | metadata={'choices': ['mlm', 't5']}, 252 | ) 253 | 254 | eval_src_file: str = field( 255 | default='conala_nl.txt' 256 | ) 257 | 258 | eval_tgt_file: str = field( 259 | default='python_manual_firstpara.tok.txt' 260 | ) 261 | 262 | eval_root_folder: str = field( 263 | default='data/conala', 264 | metadata={'help': 'root folder of validation dataset'} 265 | ) 266 | 267 | 268 | eval_oracle_file: str = field( 269 | default='cmd_dev.oracle_man.full.json' 270 | ) 271 | 272 | # eval_max_length: int = field( 273 | # default=None, 274 | # metadata={'help': 'the length for dev set, None will call the max length of the tokenizer'} 275 | # ) 276 | 277 | 278 | @cached_property 279 | @torch_required 280 | def _setup_devices(self) -> "torch.device": 281 | logger.info("PyTorch: setting up devices") 282 | if self.no_cuda: 283 | device = torch.device("cpu") 284 | self._n_gpu = 0 285 | elif is_torch_tpu_available(): 286 | device = xm.xla_device() 287 | self._n_gpu = 0 288 | elif self.local_rank == -1: 289 | # if n_gpu is > 1 we'll use nn.DataParallel. 290 | # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` 291 | # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will 292 | # trigger an error that a device index is missing. Index 0 takes into account the 293 | # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` 294 | # will use the first GPU in that env, i.e. GPU#1 295 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 296 | # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at 297 | # the default value. 298 | self._n_gpu = torch.cuda.device_count() 299 | else: 300 | # Here, we'll use torch.distributed. 301 | # Initializes the distributed backend which will take care of synchronizing nodes/GPUs 302 | # 303 | # deepspeed performs its own DDP internally, and requires the program to be started with: 304 | # deepspeed ./program.py 305 | # rather than: 306 | # python -m torch.distributed.launch --nproc_per_node=2 ./program.py 307 | if self.deepspeed: 308 | from .integrations import is_deepspeed_available 309 | 310 | if not is_deepspeed_available(): 311 | raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") 312 | import deepspeed 313 | 314 | deepspeed.init_distributed() 315 | else: 316 | torch.distributed.init_process_group(backend="nccl") 317 | device = torch.device("cuda", self.local_rank) 318 | self._n_gpu = 1 319 | 320 | if device.type == "cuda": 321 | torch.cuda.set_device(device) 322 | 323 | return device 324 | 325 | @dataclass 326 | class RetrieverArguments: 327 | """ 328 | model_type=model_args.model_name_or_path, 329 | num_layers=bertscore_args.bertscore_layer_num, 330 | all_layers=bertscore_args.all_layers, 331 | idf = bertscore_args.idf, 332 | idf_sents= bertscore_args.idf_sents, 333 | rescale_with_baseline=bertscore_args.rescale_with_baseline, 334 | baseline_path=bertscore_args.baseline_path 335 | """ 336 | num_layers: int = field( 337 | default=11 338 | ) 339 | all_layers: bool = field( 340 | default=False 341 | ) 342 | idf: bool = field( 343 | default=False 344 | ) 345 | rescale_with_baseline: bool = field( 346 | default=False 347 | ) 348 | baseline_path: str = field( 349 | default=None 350 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocPrompting: Generating Code by Retrieving the Docs 2 | This is the official implementation of 3 | 4 | Shuyan Zhou, Uri Alon, Frank F. Xu, Zhiruo Wang, Zhengbao Jiang, Graham Neubig, ["DocPrompting: Generating Code by Retrieving the Docs"](https://arxiv.org/pdf/2207.05987.pdf), 5 | ICLR'2023 (**Spotlight**) 6 | 7 | _**January 2023**_ - The paper was accepted to ICLR'2023 as a **Spotlight**! 8 | 9 | --- 10 | Publicly available source-code libraries are continuously growing and changing. 11 | This makes it impossible for models of code to keep current with all available APIs by simply training these models 12 | on existing code repositories. 13 | We introduce DocPrompting: a natural-language-to-code generation approach that explicitly leverages documentation by 14 | 1. retrieving the relevant documentation pieces given an NL intent, 15 | and 16 | 2. generating code based on the NL intent and the retrieved documentation. 17 | 18 | In this repository we provide the *best* model in each setting described in the paper. 19 | 20 | ![overview](media/overview.png) 21 | 22 | ## Table of content 23 | - [Quick Dataset&Eval Access through 🤗](#huggingface--dataset--evaluation) 24 | - [Quick Models Loading 🤗](#huggingface--models) 25 | - [Preparation](#preparation) 26 | - [Retrieval](#retrieval) 27 | * [Dense retrieval](#dense-retrieval) 28 | * [Sparse retrieval](#sparse-retrieval) 29 | - [Generation](#generation) 30 | * [FID generation](#fid-generation) 31 | - [Data](#data) 32 | - [Resources](#resources) 33 | - [Citation](#citation) 34 | 35 | --- 36 | 37 | ## Huggingface 🤗 Dataset & Evaluation 38 | In this work, we introduce a new natural language to bash generation benchmark `tldr` 39 | and re-split `CoNaLa` to have *unseen* functions on the dev and test set. 40 | The datasets and the corresponding evaluations are available on huggingface 41 | * [tldr](https://huggingface.co/datasets/neulab/tldr) and [eval](https://huggingface.co/spaces/neulab/tldr_eval) 42 | * [CoNaLa](https://huggingface.co/datasets/neulab/docprompting-conala) and [eval](https://huggingface.co/spaces/neulab/python_bleu) 43 | ```python 44 | import datasets 45 | import evaluate 46 | tldr = datasets.load_dataset('neulab/tldr') 47 | tldr_metric = evaluate.load('neulab/tldr_eval') 48 | 49 | conala = datasets.load_dataset('neulab/docprompting-conala') 50 | conala_metric = evaluate.load('neulab/python_bleu') 51 | ``` 52 | 53 | ## Huggingface 🤗 Models 54 | We make the following models available on Huggingface: 55 | 56 | * neulab/docprompting-tldr-gpt-neo-125M 57 | * neulab/docprompting-tldr-gpt-neo-1.3B 58 | 59 | ### Example usage 60 | ```python 61 | from transformers import AutoTokenizer, AutoModelForCausalLM 62 | tokenizer = AutoTokenizer.from_pretrained("neulab/docprompting-tldr-gpt-neo-1.3B") 63 | model = AutoModelForCausalLM.from_pretrained("neulab/docprompting-tldr-gpt-neo-1.3B") 64 | 65 | # prompt template 66 | prompt = f"""{tokenizer.bos_token} Potential manual 0: makepkg - package build utility 67 | Potential manual 1: -c, --clean Clean up leftover work files and directories after a successful build. 68 | Potential manual 2: -r, --rmdeps Upon successful build, remove any dependencies installed by makepkg during dependency auto-resolution and installation when using -s 69 | Potential manual 3: CONTENT_OF_THE_MANUAL_3 70 | ... 71 | Potential manual 10: CONTENT_OF_THE_MANUAL_10""" 72 | prompt += f"{tokenizer.sep_token} clean up work directories after a successful build {tokenizer.sep_token}" 73 | 74 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 75 | gen_tokens = model.generate( 76 | input_ids, 77 | num_beams=5, 78 | max_new_tokens=150, 79 | num_return_sequences=2, 80 | pad_token_id=tokenizer.eos_token_id 81 | ) 82 | gen_tokens = gen_tokens.reshape(1, -1, gen_tokens.shape[-1])[0][0] 83 | # to text and clean 84 | gen_code = tokenizer.decode(gen_tokens) 85 | gen_code = gen_code.split(tokenizer.sep_token)[2].strip().split(tokenizer.eos_token)[0].strip() 86 | print(gen_code) 87 | 88 | >>> makepkg --clean {{path/to/directory}} 89 | ``` 90 | 91 | ### Example script 92 | An example script on tldr by using the retrieved docs is [here](./scripts/tldr_gpt_neo.py) 93 | 94 | ### Other models 95 | Other models require the customized implementations in our repo, please read through the corresponding sections to use them. These models are: 96 | 1. sparse retriever based on BM25 for `tldr` 97 | 2. dense retriever based on CodeT5 for `CoNaLa` 98 | 3. FiD T5 generator for `tldr` 99 | 4. FiD CodeT5 generator for `CoNaLa` 100 | 101 | --- 102 | >The following instructions are for reproducing the results in the paper. 103 | 104 | ## Preparation 105 | Download data for `CoNaLa` and `tldr` from [link](https://drive.google.com/file/d/1CzNlo8-e4XqrgAME5zHEWEKIQMPga0xl/view?usp=sharing) 106 | ```bash 107 | # unzip 108 | unzip docprompting_data.zip 109 | # move to the data folder 110 | mv docprompting_data/* data 111 | ``` 112 | 113 | Download trained generator weights from [link](https://drive.google.com/file/d/1NmPMxY1EOWkjM7S8VSKa13DKJmEZ3TqV/view?usp=sharing) 114 | ```bash 115 | unzip docprompting_generator_models.zip 116 | # move to the model folder 117 | mv docprompting_generator_models/* models/generator 118 | 119 | ``` 120 | ## Retrieval 121 | ### Dense retrieval 122 | (`CoNaLa` as an example) 123 | 124 | The code is based on [SimCSE](https://github.com/princeton-nlp/SimCSE) 125 | 126 | 1. Run inference with our trained model on CoNaLa (Python) 127 | ```bash 128 | python retriever/simcse/run_inference.py \ 129 | --model_name "neulab/docprompting-codet5-python-doc-retriever" \ 130 | --source_file data/conala/conala_nl.txt \ 131 | --target_file data/conala/python_manual_firstpara.tok.txt \ 132 | --source_embed_save_file data/conala/.tmp/src_embedding \ 133 | --target_embed_save_file data/conala/.tmp/tgt_embedding \ 134 | --sim_func cls_distance.cosine \ 135 | --num_layers 12 \ 136 | --save_file data/conala/retrieval_results.json 137 | ``` 138 | We observed that model whether or not to normalize the embeddings can affect the retrieval results. 139 | We therefore selected this hyper-parameter (`--normalize_embed`) on the validation set. 140 | 141 | The results will be saved to `data/conala/retrieval_results.json`. 142 | 143 | 2. Train your own retriever 144 | ```bash 145 | python retriever/simcse/run_train.py \ 146 | --num_layers 12 \ 147 | --model_name_or_path Salesforce/codet5-base \ 148 | --sim_func cls_distance.cosine \ 149 | --temp 0.05 \ 150 | --train_file data/conala/train_retriever_sup_unsup.json \ 151 | --eval_file data/conala/dev_retriever.json \ 152 | --output_dir models/retriever/docprompting_codet5_python_doc_retriever \ 153 | --eval_src_file data/conala/conala_nl.txt \ 154 | --eval_tgt_file data/conala/python_manual_firstpara.tok.txt \ 155 | --eval_root_folder data/conala \ 156 | --eval_oracle_file data/conala/cmd_dev.oracle_man.full.json \ 157 | --run_name docprompting_codet5_python_doc_retriever \ 158 | --num_train_epochs 10 \ 159 | --per_device_train_batch_size 512 \ 160 | --learning_rate 1e-5 \ 161 | --max_seq_length 32 \ 162 | --evaluation_strategy steps \ 163 | --metric_for_best_model recall@10 \ 164 | --load_best_model_at_end \ 165 | --eval_steps 125 \ 166 | --overwrite_output_dir \ 167 | --do_train \ 168 | --eval_form retrieval 169 | "$@" 170 | ``` 171 | * `train_retriever_sup_unsup.json` contains the supervised (`CoNaLa` training and mined) and unsupervised data (duplication of sentences in a doc) for training the retriever. 172 | * Be accurate on the saved model name. If using codet5, make sure `codet5` is in the name. 173 | 174 | ### Sparse retrieval 175 | (`tldr` as an example) 176 | 177 | *There are two stages in the retr*ieval procedure in `tldr`. 178 | The first stage retrieves the bash command and the second stage retrieves the potentially relevant paragraphs that describe the usage of the arguments 179 | 1. build index with Elasticsearch 180 | ```bash 181 | python retriever/bm25/main.py \ 182 | --retrieval_stage 0 183 | ``` 184 | 2. first stage retrieval 185 | ```bash 186 | python retriever/bm25/main.py \ 187 | --retrieval_stage 1 \ 188 | --split {cmd_train, cmd_dev, cmd_test} 189 | ``` 190 | 3. second stage retrieval 191 | ```bash 192 | python retriever/bm25/main.py \ 193 | --retrieval_stage 2 \ 194 | --split {cmd_train, cmd_dev, cmd_test} 195 | ``` 196 | 197 | --- 198 | ## Generation 199 | ### FID generation 200 | The code is based on [FiD](https://github.com/facebookresearch/FiD) 201 | A training or evaluation file should be converted to the format compatible with FiD. 202 | An example is [here](./data/conala/example_fid_data.json) 203 | > **Important note**: FiD has a strong dependency on the version of `transformers` (3.0.2). 204 | > Unable to match the version might result in inreproducible results. 205 | 1. Run generation. Here is an example with our [trained model](./models/generator/) on Python CoNaLa 206 | ```bash 207 | ds='conala' 208 | python generator/fid/test_reader_simple.py \ 209 | --model_path models/generator/${ds}.fid.codet5.top10/checkpoint/best_dev \ 210 | --tokenizer_name models/generator/codet5-base \ 211 | --eval_data data/${ds}/fid.cmd_test.codet5.t10.json \ 212 | --per_gpu_batch_size 8 \ 213 | --n_context 10 \ 214 | --name ${ds}.fid.codet5.top10 \ 215 | --checkpoint_dir models/generator \ 216 | --result_tag test_same \ 217 | --main_port 81692 218 | ``` 219 | The results will be saved to `models/generator/{name}/test_results_test_same.json` 220 | 221 | To evaluate `pass@k`, we need more generations, we use nucleus sampling (instead of beam search) for the generation. 222 | ```bash 223 | ds='conala' 224 | t=1.0 # set this from 0.2, 0.4, 0.6, .. 1.0. Use the dev set to find the best temperature 225 | python generator/fid/test_reader_simple.py \ 226 | --model_path models/generator/${ds}.fid.codet5.top10/checkpoint/best_dev \ 227 | --tokenizer_name models/generator/codet5-base \ 228 | --eval_data data/${ds}/fid.cmd_test.codet5.t10.ns200.json \ 229 | --per_gpu_batch_size 8 \ 230 | --n_context 10 \ 231 | --name ${ds}.fid.codet5.top10.ns200 \ 232 | --checkpoint_dir models/generator \ 233 | --result_tag test_same \ 234 | --num_beams 1 \ 235 | --temperature $t \ 236 | --top_p 0.95 \ 237 | --num_return_sequences 200 \ 238 | --main_port 81692 239 | ``` 240 | Then run this [script](./dataset_helper/conala/execution_eval.py) 241 | ```bash 242 | python dataset_helper/conala/execution_eval.py --result_file data/${ds}/fid.cmd_test.codet5.t10.ns200.json 243 | ``` 244 | 245 | 2. Train your own generator 246 | ```bash 247 | ds='conala' 248 | python generator/fid/train_reader.py \ 249 | --seed 1996 \ 250 | --train_data data/${ds}/fid.cmd_train.codet5.t10.json \ 251 | --eval_data data/${ds}/fid.cmd_dev.codet5.t10.json \ 252 | --model_name models/generator/codet5-base \ # initialize with the codet5-base model \ 253 | --per_gpu_batch_size 4 \ 254 | --n_context 10 \ 255 | --name ${ds}.fid.codet5.top10 \ 256 | --checkpoint_dir models/generator/ \ 257 | --eval_freq 500 \ 258 | --accumulation_steps 2 \ 259 | --main_port 30843 \ 260 | --total_steps 20000 \ 261 | --warmup_steps 2000 262 | 263 | ds='tldr' 264 | python generator/fid/train_reader.py \ 265 | --dataset tldr \ 266 | --train_data data/${ds}/fid.cmd_train.codet5.t10.json \ 267 | --eval_data data/${ds}/fid.cmd_model_select.codet5.t10.json \ 268 | --model_name models/generator/codet5-base \ 269 | --per_gpu_batch_size 4 \ 270 | --n_context 10 \ 271 | --eval_metric token_f1 \ 272 | --name ${ds}.fid.codet5.top10 \ 273 | --checkpoint_dir models/generator/ \ 274 | --eval_freq 1000 \ 275 | --accumulation_steps 2 \ 276 | --main_port 32420 \ 277 | --total_steps 20000 \ 278 | --warmup_steps 2000 279 | ``` 280 | * Examples in `fid.cmd_model_select.codet5.t10.json` are the same as `fid.cmd_dev.codet5.t10.json`. 281 | The difference is that it use the oracle first stage retrieval results (oracle bash name). 282 | --- 283 | ## Data 284 | The `data` folder contains the two benchmarks we curated or re-splitted. 285 | * tldr 286 | * CoNaLa 287 | 288 | On each dataset, we provide 289 | 1. Natural language intent (entry `nl`) 290 | 2. Oracle code (entry `cmd`) 291 | * Bash for tldr 292 | * Python for CoNaLa 293 | 3. Oracle docs (entry `oracle_man`) 294 | * In the data files, we only provide the manual ids, their contents could be found in the `{dataset}/{dataset}_docs.json`. 295 | 4. Other data with different format for different modules 296 | 297 | ## Resources 298 | * [tldr](https://github.com/tldr-pages/tldr) Github repo. Thanks for all the contributors! 299 | * [CoNaLa](https://conala-corpus.github.io) 300 | 301 | ## Citation 302 | ``` 303 | @inproceedings{zhou23docprompting, 304 | title = {DocPrompting: Generating Code by Retrieving the Docs}, 305 | author = {Shuyan Zhou and Uri Alon and Frank F. Xu and Zhiruo Wang and Zhengbao Jiang and Graham Neubig}, 306 | booktitle = {International Conference on Learning Representations (ICLR)}, 307 | address = {Kigali, Rwanda}, 308 | month = {May}, 309 | url = {https://arxiv.org/abs/2207.05987}, 310 | year = {2023} 311 | } 312 | ``` 313 | -------------------------------------------------------------------------------- /generator/fid/utils/convert_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import sys 6 | from pathlib import Path 7 | 8 | 9 | def convert_data(src_file, manual_file, info_file, fid_file, retrieved_manual_list, topk=100, sort_ctx=False): 10 | with open(manual_file, 'r') as f: 11 | manual_d = json.load(f) 12 | for k in list(manual_d.keys()): 13 | manual_d[f'|{"_".join(k.split("_")[:-1])}|'] = "_".join(k.split("_")[:-1]) 14 | manual_d['|placeholder|'] = 'manual' 15 | 16 | if info_file: 17 | with open(info_file, 'r') as f: 18 | manual_info = json.load(f) 19 | manual_info['|placeholder|'] = {'lib_signature': 'manual'} 20 | else: 21 | manual_info = None 22 | 23 | 24 | with open(src_file, 'r') as f: 25 | src_d = json.load(f) 26 | 27 | tgt_d = [] 28 | tot = 0 29 | for src_item in src_d: 30 | tgt_item = {} 31 | tgt_item['id'] = src_item['question_id'] 32 | tgt_item['question'] = src_item['nl'] 33 | tgt_item['target'] = src_item['cmd'] 34 | tgt_item['answers'] = [src_item['cmd']] 35 | ctxs = [] 36 | _ctxs = set() 37 | for cur_retrieved in retrieved_manual_list: 38 | for man in cur_retrieved[src_item['question_id']]['retrieved']: 39 | if manual_info: 40 | title = manual_info[man]['lib_signature'].replace(".", " ") 41 | else: 42 | title = man if man != '|placeholder|' else 'manual' 43 | if man not in manual_d: 44 | text = "" 45 | print(f"[WARNING] {man} cannot be found") 46 | else: 47 | text = manual_d[man] 48 | cur_ctx = {'title': title, 'text': text, 'man_id': man} 49 | if man not in _ctxs: 50 | ctxs.append(cur_ctx) 51 | _ctxs.add(man) 52 | 53 | tgt_item['ctxs'] = ctxs[:topk] 54 | if len(tgt_item['ctxs']) < topk: 55 | for _ in range(topk - len(tgt_item['ctxs'])): 56 | tgt_item['ctxs'].append({'title': '', 'text': '', 'man_id': 'fake'}) 57 | if sort_ctx: 58 | tgt_item['ctxs'] = sorted(tgt_item['ctxs'], key=lambda x: int(x['man_id'].split("_")[-1]) if x['man_id'][-1].isdigit() else 10000) 59 | tot += len(tgt_item['ctxs']) 60 | tgt_d.append(tgt_item) 61 | 62 | 63 | with open(fid_file, 'w+') as f: 64 | json.dump(tgt_d, f, indent=2) 65 | 66 | # with open(str(fid_file).replace(".json", ".10.json"), 'w+') as f: 67 | # json.dump(tgt_d[:10], f, indent=2) 68 | print(f"save {len(tgt_d)} data to {os.path.basename(fid_file)}") 69 | 70 | 71 | def process_manual_list(manual_list): 72 | _manual_list = [] 73 | for l in manual_list: 74 | keys = list(l[0].keys()) 75 | r_key = [x for x in keys if 'retrieved' in x][0] 76 | # s_key = [x for x in keys if 'score' in x][0] 77 | for item in l: 78 | item['retrieved'] = item.pop(r_key) 79 | l = {x['question_id']: x for x in l} 80 | _manual_list.append(l) 81 | return _manual_list 82 | 83 | def run_conala(args): 84 | root = Path('data/conala/nl.cm') 85 | all_splits = [] 86 | if args.have_train: 87 | all_splits.append('cmd_train') 88 | if args.have_dev: 89 | all_splits.append('cmd_dev') 90 | if args.have_test: 91 | all_splits.append('cmd_test') 92 | 93 | # # fake 94 | if args.gen_fake: 95 | for s in all_splits: 96 | src_file = root / f'{s}.seed.json' 97 | manual_file = root / 'manual_all_raw.json' 98 | info_file = root / 'manual.info.json' 99 | retrieved_manual_list = [] 100 | with open(root / f'{s}.oracle_manual.es0.code.library.full.json', 'r') as f: 101 | d = json.load(f) 102 | d = {x['question_id']: {'retrieved': ['|placeholder|']} for x in d} 103 | retrieved_manual_list.append(d) 104 | fid_file = root / f'fid.{s}.nothing.json' 105 | convert_data(src_file, manual_file, info_file, fid_file, retrieved_manual_list, topk=1) 106 | 107 | if args.gen_mine_fake: 108 | s = 'cmd_mined' 109 | src_file = root / f'{s}.seed.json' 110 | manual_file = root / 'manual_all_raw.json' 111 | info_file = root / 'manual.info.json' 112 | retrieved_manual_list = [] 113 | with open(root / f'{s}.oracle_manual.es0.code.library.full.json', 'r') as f: 114 | d = json.load(f) 115 | d = {x['question_id']: {'retrieved': ['|placeholder|']} for x in d} 116 | retrieved_manual_list.append(d) 117 | fid_file = root / f'fid.{s}.nothing.json' 118 | convert_data(src_file, manual_file, info_file, fid_file, retrieved_manual_list, topk=1) 119 | 120 | with open(fid_file, 'r') as f: 121 | d = json.load(f) 122 | _d = [] 123 | for item in d: 124 | if '_' in item['id']: 125 | if len(item['question'].split()) >= 100 or \ 126 | len(item['target'].split()) >= 100 or \ 127 | len(item['question']) >= 500 or \ 128 | len(item['target']) >= 500: 129 | continue 130 | 131 | _d.append(item) 132 | 133 | print(len(d), len(_d)) 134 | with open(fid_file, 'w+') as f: 135 | json.dump(_d, f, indent=2) 136 | 137 | if args.gen_oracle: 138 | for s in all_splits: 139 | src_file = root / f'{s}.seed.json' 140 | manual_file = root / 'manual_all_raw.json' 141 | info_file = root / 'manual.info.json' 142 | retrieved_manual_list = [] 143 | with open(root / f'{s}.oracle_manual.es0.code.library.full.json', 'r') as f: 144 | retrieved_manual_list.append(json.load(f)) 145 | fid_file = root / f'fid.{s}.oracle.json' 146 | retrieved_manual_list = process_manual_list(retrieved_manual_list) 147 | convert_data(src_file, manual_file, info_file, fid_file, retrieved_manual_list, topk=3) 148 | 149 | if args.gen_retrieval: 150 | for s in all_splits: 151 | for topk in [15, 20, 25, 30][:]: 152 | src_file = root / f'{s}.seed.json' 153 | manual_file = root / 'manual_all_raw.json' 154 | info_file = root / 'manual.info.json' 155 | retrieved_manual_list = [] 156 | with open(root / f'{s}.{args.retrieval_file_tag}.json', 'r') as f: 157 | retrieved_manual_list.append(json.load(f)) 158 | fid_file = root / f'fid.{s}.{args.retrieval_file_tag}.t{topk}.json' 159 | retrieved_manual_list = process_manual_list(retrieved_manual_list) 160 | convert_data(src_file, manual_file, info_file, fid_file, retrieved_manual_list, topk=topk) 161 | 162 | if args.gen_oracle_retrieval: 163 | for s in all_splits: 164 | for topk in [1, 3, 5, 10][:]: 165 | src_file = root / f'{s}.seed.json' 166 | manual_file = root / 'manual_all_raw.json' 167 | info_file = root / 'manual.info.json' 168 | retrieved_manual_list = [] 169 | with open(root / f'{s}.oracle_manual.es0.code.library.full.json', 'r') as f: 170 | retrieved_manual_list.append(json.load(f)) 171 | with open(root / f'{s}.{args.retrieval_file_tag}.json', 'r') as f: 172 | retrieved_manual_list.append(json.load(f)) 173 | fid_file = root / f'fid.{s}.oracle.{args.retrieval_file_tag}.t{topk}.json' 174 | retrieved_manual_list = process_manual_list(retrieved_manual_list) 175 | convert_data(src_file, manual_file, info_file, fid_file, retrieved_manual_list, topk=topk) 176 | 177 | def run_tldr(args): 178 | root = Path('data/tldr/nl.cm') 179 | all_splits = [] 180 | if args.have_train: 181 | all_splits.append('cmd_train') 182 | if args.have_dev: 183 | all_splits.append('cmd_dev') 184 | if args.have_test: 185 | all_splits.append('cmd_test') 186 | 187 | manual_file = root / 'manual_section.json' 188 | # # fake 189 | if args.gen_fake: 190 | for s in all_splits: 191 | src_file = root / f'{s}.seed.json' 192 | retrieved_manual_list = [] 193 | with open(root / f'{s}.oracle_manual.es1.full.oracle.json', 'r') as f: 194 | d = json.load(f) 195 | d = {x['question_id']: {'retrieved': ['|placeholder|']} for x in d} 196 | retrieved_manual_list.append(d) 197 | fid_file = root / f'fid.{s}.nothing.json' 198 | convert_data(src_file, manual_file, None, fid_file, retrieved_manual_list, topk=1, sort_ctx=True) 199 | 200 | if args.gen_oracle: 201 | for s in all_splits: 202 | src_file = root / f'{s}.seed.json' 203 | retrieved_manual_list = [] 204 | with open(root / f'{s}.oracle_manual.es1.full.oracle.json', 'r') as f: 205 | retrieved_manual_list.append(json.load(f)) 206 | fid_file = root / f'fid.{s}.oracle.json' 207 | retrieved_manual_list = process_manual_list(retrieved_manual_list) 208 | convert_data(src_file, manual_file, None, fid_file, retrieved_manual_list, topk=10, sort_ctx=True) 209 | 210 | if args.gen_oracle_cmd: 211 | assert 'tldr' in str(root) 212 | for s in all_splits: 213 | src_file = root / f'{s}.seed.json' 214 | retrieved_manual_list = [] 215 | with open(root / f'{s}.oracle_manual.es1.full.oracle.json', 'r') as f: 216 | d = json.load(f) 217 | d = {x['question_id']: {'retrieved': [f'|{x["cmd_name"]}|']} for x in d} 218 | retrieved_manual_list.append(d) 219 | fid_file = root / f'fid.{s}.oracle_cmd.json' 220 | convert_data(src_file, manual_file, None, fid_file, retrieved_manual_list, topk=1, sort_ctx=True) 221 | 222 | if args.gen_retrieval: 223 | for s in all_splits: 224 | for topk in [5][:]: 225 | for x in range(30): 226 | if not (root / f'{s}.{args.retrieval_file_tag}.{x}.json').exists(): 227 | break 228 | src_file = root / f'{s}.seed.json' 229 | retrieved_manual_list = [] 230 | with open(root / f'{s}.{args.retrieval_file_tag}.{x}.json', 'r') as f: 231 | retrieved_manual_list.append(json.load(f)) 232 | fid_file = root / f'fid.{s}.{args.retrieval_file_tag}.t{topk}.{x}.json' 233 | retrieved_manual_list = process_manual_list(retrieved_manual_list) 234 | convert_data(src_file, manual_file, None, fid_file, retrieved_manual_list, topk=topk, sort_ctx=True) 235 | 236 | # merge data 237 | all_data = [] 238 | for x in range(30): 239 | fid_file = root / f'fid.{s}.{args.retrieval_file_tag}.t{topk}.{x}.json' 240 | if not fid_file.exists(): 241 | break 242 | with open(fid_file, 'r') as f: 243 | curr_data = json.load(f) 244 | all_data.extend(curr_data) 245 | # os.remove(fid_file) 246 | 247 | print(f"merged: {len(all_data)}") 248 | with open(root / f'fid.{s}.{args.retrieval_file_tag}.t{topk}.json', 'w') as f: 249 | json.dump(all_data, f, indent=2) 250 | 251 | 252 | if args.gen_oracle_retrieval: 253 | for s in all_splits: 254 | for topk in [10, 15][:]: 255 | src_file = root / f'{s}.seed.json' 256 | retrieved_manual_list = [] 257 | with open(root / f'{s}.oracle_manual.es1.full.oracle.json', 'r') as f: 258 | retrieved_manual_list.append(json.load(f)) 259 | with open(root / f'{s}.{args.retrieval_file_tag}.0.json', 'r') as f: 260 | retrieved_manual_list.append(json.load(f)) 261 | fid_file = root / f'fid.{s}.oracle.{args.retrieval_file_tag}.t{topk}.json' 262 | retrieved_manual_list = process_manual_list(retrieved_manual_list) 263 | convert_data(src_file, manual_file, None, fid_file, retrieved_manual_list, topk=topk, sort_ctx=True) 264 | 265 | 266 | def config(): 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument('--gen_retrieval', action='store_true') 269 | parser.add_argument('--gen_oracle', action='store_true') 270 | parser.add_argument('--gen_oracle_cmd', action='store_true') 271 | parser.add_argument('--gen_fake', action='store_true') 272 | parser.add_argument('--gen_mine_fake', action='store_true') 273 | parser.add_argument('--gen_oracle_retrieval', action='store_true') 274 | parser.add_argument('--retrieval_file_tag') 275 | parser.add_argument('--have_train', action='store_true') 276 | parser.add_argument('--have_dev', action='store_true') 277 | parser.add_argument('--have_test', action='store_true') 278 | parser.add_argument('--conala', action='store_true') 279 | parser.add_argument('--tldr', action='store_true') 280 | 281 | args = parser.parse_args() 282 | return args 283 | 284 | if __name__ == "__main__": 285 | args = config() 286 | if args.conala: 287 | run_conala(args) 288 | data = 'conala' 289 | elif args.tldr: 290 | run_tldr(args) 291 | data = 'tldr' 292 | else: 293 | raise NotImplementedError 294 | -------------------------------------------------------------------------------- /generator/fid/train_reader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import os 8 | import time 9 | import sys 10 | import torch 11 | import transformers 12 | import numpy as np 13 | from pathlib import Path 14 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 15 | from tqdm import tqdm 16 | 17 | from dataset_helper.conala.gen_metric import _bleu as conala_bleu 18 | from dataset_helper.tldr.gen_metric import tldr_metrics 19 | from src.options import Options 20 | 21 | import src.slurm 22 | import src.util 23 | import src.evaluation 24 | import src.data 25 | import src.model 26 | import wandb 27 | 28 | WANDB_DISABLED= os.environ['WANDB_DISABLED'] if 'WANDB_DISABLED' in os.environ else False 29 | TQDM_DISABLED = os.environ['TQDM_DISABLED'] if 'TQDM_DISABLED' in os.environ else False 30 | 31 | def train(model, optimizer, scheduler, step, train_dataset, eval_dataset, opt, collator, best_dev_em, checkpoint_path): 32 | 33 | # if opt.is_main: 34 | # try: 35 | # tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opt.checkpoint_dir)/opt.name) 36 | # except: 37 | # tb_logger = None 38 | # logger.warning('Tensorboard is not available.') 39 | 40 | torch.manual_seed(opt.global_rank + opt.seed) #different seed for different sampling depending on global_rank 41 | train_sampler = RandomSampler(train_dataset) 42 | train_dataloader = DataLoader( 43 | train_dataset, 44 | sampler=train_sampler, 45 | batch_size=opt.per_gpu_batch_size, 46 | drop_last=True, 47 | num_workers=10, 48 | collate_fn=collator 49 | ) 50 | 51 | loss, curr_loss = 0.0, 0.0 52 | epoch = 1 53 | model.train() 54 | while step < opt.total_steps: 55 | epoch += 1 56 | opt._epoch = epoch 57 | for i, batch in enumerate(tqdm(train_dataloader, disable=TQDM_DISABLED)): 58 | step += 1 59 | opt._train_step = step 60 | (idx, labels, _, context_ids, context_mask) = batch 61 | 62 | train_loss = model( 63 | input_ids=context_ids.cuda(), 64 | attention_mask=context_mask.cuda(), 65 | labels=labels.cuda() 66 | )[0] 67 | 68 | train_loss.backward() 69 | 70 | if step % opt.accumulation_steps == 0: 71 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip) 72 | optimizer.step() 73 | scheduler.step() 74 | model.zero_grad() 75 | 76 | train_loss = src.util.average_main(train_loss, opt) 77 | curr_loss += train_loss.item() 78 | wandb.log({'train_loss': train_loss.item(), 'lr': scheduler.get_last_lr()[0]}) 79 | 80 | if step % opt.eval_freq == 0: 81 | dev_em = evaluate(model, eval_dataset, tokenizer, collator, opt) 82 | wandb.log({f'eval_{x}': y for x, y in dev_em.items()}) 83 | dev_em = dev_em[opt.eval_metric] 84 | model.train() 85 | if opt.is_main: 86 | if dev_em > best_dev_em: 87 | best_dev_em = dev_em 88 | src.util.save(model, optimizer, scheduler, step, best_dev_em, 89 | opt, checkpoint_path, 'best_dev') 90 | log = f"{step} / {opt.total_steps} |" 91 | log += f"train: {curr_loss/opt.eval_freq:.3f} |" 92 | log += f"evaluation {opt.eval_metric}: {dev_em:.02f}|" 93 | log += f"lr: {scheduler.get_last_lr()[0]:.5f}" 94 | logger.info(log) 95 | # if tb_logger is not None: 96 | # tb_logger.add_scalar("Evaluation", dev_em, step) 97 | # tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), step) 98 | curr_loss = 0. 99 | 100 | 101 | if opt.is_main and step % opt.save_freq == 0: 102 | src.util.save(model, optimizer, scheduler, step, best_dev_em, 103 | opt, checkpoint_path, f"step-{step}") 104 | 105 | if step > opt.total_steps: 106 | break 107 | 108 | def evaluate_em(model, dataset, tokenizer, collator, opt): 109 | sampler = SequentialSampler(dataset) 110 | dataloader = DataLoader(dataset, 111 | sampler=sampler, 112 | batch_size=opt.per_gpu_batch_size, 113 | drop_last=False, 114 | num_workers=10, 115 | collate_fn=collator 116 | ) 117 | model.eval() 118 | total = 0 119 | exactmatch = [] 120 | model = model.module if hasattr(model, "module") else model 121 | with torch.no_grad(): 122 | for i, batch in enumerate(dataloader): 123 | (idx, _, _, context_ids, context_mask) = batch 124 | 125 | outputs = model.generate( 126 | input_ids=context_ids.cuda(), 127 | attention_mask=context_mask.cuda(), 128 | max_length=50 129 | ) 130 | 131 | for k, o in enumerate(outputs): 132 | ans = tokenizer.decode(o, skip_special_tokens=True) 133 | gold = dataset.get_example(idx[k])['answers'] 134 | score = src.evaluation.ems(ans, gold) 135 | total += 1 136 | exactmatch.append(score) 137 | 138 | exactmatch, total = src.util.weighted_average(np.mean(exactmatch), total, opt) 139 | return exactmatch 140 | 141 | def evaluate_customized(model, dataset, tokenizer, collator, opt, result_file=None, is_bleu=False, is_token_f1=False): 142 | assert is_bleu != is_token_f1 143 | sampler = SequentialSampler(dataset) 144 | dataloader = DataLoader(dataset, 145 | sampler=sampler, 146 | batch_size=opt.per_gpu_batch_size, 147 | drop_last=False, 148 | num_workers=10, 149 | collate_fn=collator 150 | ) 151 | model.eval() 152 | model = model.module if hasattr(model, "module") else model 153 | 154 | with torch.no_grad(): 155 | 156 | result_file = f"{opt.checkpoint_path}/dev_result_{opt._train_step}.json" if result_file is None else result_file 157 | result_d = [] 158 | 159 | with open(f"{opt.checkpoint_path}/gold.gold", "w+") as fg, open(f'{opt.checkpoint_path}/pred.pred', 'w+') as fp, open(result_file, 'w+') as fr: 160 | 161 | for i, batch in enumerate(tqdm(dataloader, disable=TQDM_DISABLED)): 162 | (idx, _, _, context_ids, context_mask) = batch 163 | 164 | outputs = model.generate( 165 | input_ids=context_ids.cuda(), 166 | attention_mask=context_mask.cuda(), 167 | max_length=150, 168 | ) 169 | 170 | for k, o in enumerate(outputs): 171 | ans = tokenizer.decode(o, skip_special_tokens=False, clean_up_tokenization_spaces=False) 172 | gold = dataset.get_example(idx[k])['target'] 173 | ans = ans.replace("{{", " {{").replace("\n", ' ').replace("\r", "").replace("", "").replace("", "").replace("", "").strip() 174 | ans = " ".join(ans.split()) 175 | gold = gold.replace("\n", ' ') 176 | fg.write(f"{gold}\n") 177 | fp.write(f"{ans}\n") 178 | cur_result = {'question_id': dataset.get_example(idx[k])['id'], 'gold': gold, 'clean_code': ans} 179 | result_d.append(cur_result) 180 | 181 | json.dump(result_d, fr, indent=2) 182 | 183 | 184 | 185 | if is_bleu: 186 | score = conala_bleu( 187 | f"{opt.checkpoint_path}/gold.gold", 188 | f"{opt.checkpoint_path}/pred.pred", 189 | smooth=False, code_tokenize=True) 190 | score = {'bleu': score} 191 | elif is_token_f1: 192 | score = tldr_metrics( 193 | f"{opt.checkpoint_path}/gold.gold", 194 | f"{opt.checkpoint_path}/pred.pred") 195 | 196 | else: 197 | raise NotImplementedError 198 | 199 | return score 200 | 201 | def evaluate(model, dataset, tokenizer, collator, opt): 202 | if opt.eval_metric == 'exact_match': 203 | x = evaluate_em(model, dataset, tokenizer, collator, opt) 204 | metric = {'exact_match': x} 205 | elif opt.eval_metric == 'bleu': 206 | metric = evaluate_customized(model, dataset, tokenizer, collator, opt, is_bleu=True) 207 | elif opt.eval_metric == 'token_f1': 208 | metric = evaluate_customized(model, dataset, tokenizer, collator, opt, is_token_f1=True) 209 | metric['token_f1'] = metric.pop('f1') 210 | else: 211 | raise NotImplementedError(f'{opt.eval_metric} has not been implemented yet') 212 | print(json.dumps(metric, indent=2)) 213 | return metric 214 | 215 | if __name__ == "__main__": 216 | options = Options() 217 | options.add_reader_options() 218 | options.add_optim_options() 219 | opt = options.parse() 220 | 221 | #opt = options.get_options(use_reader=True, use_optim=True) 222 | 223 | torch.manual_seed(opt.seed) 224 | src.slurm.init_distributed_mode(opt) 225 | src.slurm.init_signal_handler() 226 | 227 | checkpoint_path = Path(opt.checkpoint_dir)/opt.name 228 | checkpoint_exists = checkpoint_path.exists() 229 | if opt.is_distributed: 230 | torch.distributed.barrier() 231 | checkpoint_path.mkdir(parents=True, exist_ok=True) 232 | opt.checkpoint_path = checkpoint_path 233 | #if not checkpoint_exists and opt.is_main: 234 | # options.print_options(opt) 235 | #checkpoint_path, checkpoint_exists = util.get_checkpoint_path(opt) 236 | 237 | logger = src.util.init_logger( 238 | opt.is_main, 239 | opt.is_distributed, 240 | checkpoint_path / 'run.log' 241 | ) 242 | 243 | logger.info(f"device type: {torch.cuda.get_device_name(0)}, memory: {torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024}G") 244 | 245 | # logger.info(json.dumps(vars(opt), indent=2)) 246 | 247 | if WANDB_DISABLED: 248 | wandb.init(mode='disabled') 249 | else: 250 | if opt.is_main: 251 | # is the master 252 | wandb.init(project='fid') 253 | wandb.config.update(opt) 254 | else: 255 | wandb.init(mode='disabled') 256 | 257 | # model_name = 't5-' + opt.model_size 258 | model_name = opt.model_name 259 | model_class = src.model.FiDT5 260 | 261 | #load data 262 | if 'codet5' in model_name or 'code_t5' in model_name: 263 | logger.info(f'load the tokenizer from codet5') 264 | tokenizer = transformers.RobertaTokenizer.from_pretrained(model_name) 265 | else: 266 | logger.info(f'load the tokenizer from t5') 267 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name) 268 | 269 | if opt.dataset == 'tldr': 270 | special_tokens_dict = {'additional_special_tokens': ['{{', '}}']} 271 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 272 | 273 | collator = src.data.Collator(opt.text_maxlength, tokenizer, 274 | answer_maxlength=opt.answer_maxlength) 275 | 276 | # use golbal rank and world size to split the eval set on multiple gpus 277 | train_examples = src.data.load_data( 278 | opt.train_data, 279 | global_rank=opt.global_rank, 280 | world_size=opt.world_size, 281 | ) 282 | train_dataset = src.data.Dataset(train_examples, opt.n_context) 283 | # use golbal rank and world size to split the eval set on multiple gpus 284 | eval_examples = src.data.load_data( 285 | opt.eval_data, 286 | global_rank=opt.global_rank, 287 | world_size=opt.world_size, 288 | ) 289 | eval_dataset = src.data.Dataset(eval_examples, opt.n_context) 290 | 291 | if not opt.continue_from_checkpoint: 292 | logger.info("init a model from T5") 293 | t5 = transformers.T5ForConditionalGeneration.from_pretrained(model_name) 294 | t5.resize_token_embeddings(len(tokenizer)) 295 | 296 | if opt.encoder_weights is not None: 297 | state_dict = torch.load(f'{opt.encoder_weights}/pytorch_model.bin') 298 | # rename 299 | new_state_dict = {} 300 | for k, v in state_dict.items(): 301 | if k.startswith('_model.encoder.'): 302 | k = k.replace('_model.encoder.', '') 303 | new_state_dict[k] = v 304 | load_model_keys = list(new_state_dict.keys()) 305 | model_keys = list(t5.encoder.state_dict().keys()) 306 | ignored = [] 307 | missed = [] 308 | for k in load_model_keys: 309 | if k not in model_keys: 310 | ignored.append(k) 311 | for k in model_keys: 312 | if k not in load_model_keys: 313 | missed.append(k) 314 | 315 | logger.info(f'Some weights in the checkpoint are not used when initializing the encoder : {ignored}') 316 | logger.info(f'Some weights in the encoder were not initialized from the checkpoint : {missed}') 317 | t5.encoder.load_state_dict(new_state_dict, strict=False) 318 | logger.info(f'Loaded encoder weights from {opt.encoder_weights}') 319 | 320 | model = src.model.FiDT5(t5.config) 321 | model.load_t5(t5.state_dict()) 322 | model = model.to(opt.local_rank) 323 | optimizer, scheduler = src.util.set_optim(opt, model) 324 | step, best_dev_em = 0, 0.0 325 | elif opt.model_path == "none" and opt.cont_from_checkpoint: 326 | load_path = checkpoint_path / 'checkpoint' / 'latest' 327 | model, optimizer, scheduler, opt_checkpoint, step, best_dev_em = \ 328 | src.util.load(model_class, load_path, opt, reset_params=False) 329 | logger.info(f"Model loaded from checkpoint {load_path}") 330 | else: # load from model path 331 | model, optimizer, scheduler, opt_checkpoint, step, best_dev_em = \ 332 | src.util.load(model_class, opt.model_path, opt, reset_params=True) 333 | logger.info(f"Model loaded from a model {opt.model_path}") 334 | 335 | model.set_checkpoint(opt.use_checkpoint) 336 | 337 | if opt.is_distributed: 338 | model = torch.nn.parallel.DistributedDataParallel( 339 | model, 340 | device_ids=[opt.local_rank], 341 | output_device=opt.local_rank, 342 | find_unused_parameters=False, 343 | ) 344 | 345 | logger.info("Start training") 346 | train( 347 | model, 348 | optimizer, 349 | scheduler, 350 | step, 351 | train_dataset, 352 | eval_dataset, 353 | opt, 354 | collator, 355 | best_dev_em, 356 | checkpoint_path 357 | ) 358 | --------------------------------------------------------------------------------