├── .gitignore ├── README.md ├── codeon.gif ├── data └── README.md ├── demo1.gif ├── demo2.gif ├── gen ├── MainDataset.py └── run.py ├── jsdemo.gif ├── search ├── code │ ├── add_search.py │ ├── make_index.py │ ├── model.py │ └── run.py └── scripts │ └── train.sh ├── search_demo_js.gif └── search_demo_py.gif /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Codeon 2 | 3 | Code generation and code search for Python and Javascript. 4 | 5 | Similar to [GitHub Copilot](https://copilot.github.com/) with one major difference: Code search is leveraged to make up for smaller models and less data. The generation model makes use of search results along with code context to generate contextual code. Moreover, code search results from all over GitHub are displayed in a seperate file along with the exact url of the source code. 6 | 7 | 8 | ### VSCode extention: 9 | https://marketplace.visualstudio.com/items?itemName=samipdahal.codeon 10 | 11 | 12 | ### Code Generation: 13 | 14 | Currently, [GPT-Neo-125M model](https://huggingface.co/EleutherAI/gpt-neo-125M) is used for generation. 15 | 16 | Training code is under 'gen' directory. 17 | 18 | ![Demo Video Dark](demo2.gif) 19 | 20 | 21 | ![Demo Video Light](demo1.gif) 22 | 23 | 24 | ### Code Search: 25 | 26 | [Codebert-base model](https://huggingface.co/microsoft/codebert-base) is used for code search, along with nearest neighbor approximation algorithm [ScaNN](https://github.com/google-research/google-research/tree/master/scann). [CodeSearchNet](https://github.com/github/CodeSearchNet) dataset is used as codebase to search over. 27 | 28 | Code to finetune the model and setup ScaNN is under 'search' directory. 29 | 30 | ![Demo Video Js](search_demo_js.gif) 31 | 32 | 33 | ![Demo Video Py](search_demo_py.gif) 34 | 35 | ## Usage: 36 | ### Python: 37 | 38 | #YOUR_QUERY. 39 | 40 | ### Javascript: 41 | 42 | //YOUR_QUERY. 43 | 44 | (Note the dot ‘.’ at the end.) 45 | 46 | ## Example: 47 | ### Python: 48 | 49 | ```cpp 50 | # concat two dicts. 51 | ``` 52 | 53 | ### Javascript: 54 | 55 | ```cpp 56 | // merge two arrays. 57 | ``` 58 | 59 | ## Notes: 60 | 61 | 1. The extension only supports python and javascript as of now and won't run on files not ending with '.py' or '.js' 62 | 63 | ## Requirements 64 | VSCode 1.59.0 or up. 65 | 66 | ## Feedback/Contact: 67 | 68 | If you spot any mistakes or any possible improvements, please feel free to let me know and contributions are welcome! 69 | 70 | [Form](https://forms.gle/urfKTGLcLrSnEdLG9) or sdpmas@live.unc.edu 71 | 72 | ### Some of the code are adapted from following repositories: 73 | 1. [CodeSearchNet](https://github.com/github/CodeSearchNet) 74 | 1. [CodeXGLUE](https://github.com/microsoft/CodeXGLUE) 75 | 1. [APPS](https://github.com/hendrycks/apps) 76 | 77 | 78 | ----------------------------------------------------------------------------------------------------------- 79 | 80 | -------------------------------------------------------------------------------- /codeon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpmas/Codeon/917beea6ef25e9dd213b43620c28859892e927e8/codeon.gif -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | Code to scrape data from GitHub. 3 | 4 | (TODO: will be uploaded soon.) -------------------------------------------------------------------------------- /demo1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpmas/Codeon/917beea6ef25e9dd213b43620c28859892e927e8/demo1.gif -------------------------------------------------------------------------------- /demo2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpmas/Codeon/917beea6ef25e9dd213b43620c28859892e927e8/demo2.gif -------------------------------------------------------------------------------- /gen/MainDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import inspect 5 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 6 | parentdir = os.path.dirname(currentdir) 7 | sys.path.insert(0, parentdir) 8 | import numpy as np 9 | import os 10 | import io, pickle 11 | from search.code.add_search import ExCode 12 | import transformers 13 | import io, tokenize 14 | from tqdm import tqdm 15 | from docstring_parser import parse 16 | import json 17 | 18 | #adapted from: https://stackoverflow.com/a/62074206 19 | def remove_comments_and_docstrings(source): 20 | io_obj = io.StringIO(source) 21 | out = "" 22 | prev_toktype = tokenize.INDENT 23 | last_lineno = -1 24 | last_col = 0 25 | for tok in tokenize.generate_tokens(io_obj.readline): 26 | token_type = tok[0] 27 | token_string = tok[1] 28 | start_line, start_col = tok[2] 29 | end_line, end_col = tok[3] 30 | ltext = tok[4] 31 | if start_line > last_lineno: 32 | last_col = 0 33 | if start_col > last_col: 34 | out += (" " * (start_col - last_col)) 35 | if token_type == tokenize.COMMENT: 36 | pass 37 | elif token_type == tokenize.STRING: 38 | if prev_toktype != tokenize.INDENT: 39 | if prev_toktype != tokenize.NEWLINE: 40 | if start_col > 0: 41 | out += token_string 42 | else: 43 | out += token_string 44 | prev_toktype = token_type 45 | last_col = end_col 46 | last_lineno = end_line 47 | out = '\n'.join(l for l in out.splitlines() if l.strip()) 48 | return out 49 | 50 | class Example(object): 51 | def __init__(self, 52 | idx, 53 | nl, 54 | code, 55 | source, 56 | context 57 | ): 58 | self.idx = idx 59 | self.nl = nl 60 | self.code = code 61 | self.source=source 62 | self.context=context 63 | 64 | 65 | 66 | class MainDataset(torch.utils.data.Dataset): 67 | def __init__(self, max_input_len,max_target_len,max_context_len,mode=None): 68 | self.max_input_len=max_input_len 69 | self.max_target_len=max_target_len 70 | self.max_context_len=max_context_len 71 | self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') 72 | self.mode=mode 73 | self.inp_labels=[] 74 | self.initialize() 75 | 76 | 77 | def read_examples(self,filename,source): 78 | examples=[] 79 | f=pickle.load(open(filename,'rb')) 80 | for i, line in enumerate(f): 81 | js=line 82 | if 'idx' not in js or 'source' not in js: 83 | print("doesn't have any idx") 84 | exit() 85 | nl=js['doc'].replace('\n','') 86 | code=js['body'] 87 | context=js['context'] 88 | examples.append( 89 | Example( 90 | idx = js['idx'], 91 | nl=nl, 92 | code = code, 93 | source=source,context=context 94 | ) 95 | ) 96 | return examples 97 | def initialize(self): 98 | context_examples=self.read_examples(f'dataset/context_js/{self.mode}.bin',source='context') 99 | examples=context_examples 100 | np.random.seed(69) 101 | np.random.shuffle(examples) 102 | if self.mode=='valid': 103 | self.choose_rand_ex=np.random.choice([0,1],len(examples),p=[0,1]) 104 | self.choose_rand_context=np.random.choice([0,1],len(examples),p=[0,1]) 105 | else: 106 | self.choose_rand_ex=np.random.choice([0,1],len(examples),p=[0.45,0.55]) 107 | self.choose_rand_context=np.random.choice([0,1],len(examples),p=[0.40,0.60]) 108 | 109 | context_excodes=pickle.load(open(f'data/{self.mode}_context_excodes.bin','rb')) 110 | inp_labels=[] 111 | 112 | for i,ex in enumerate(examples): 113 | query,code,idx,src,context=ex.nl, ex.code,ex.idx,ex.source,ex.context 114 | 115 | excode=context_excodes[idx] 116 | if not context: 117 | context_tokens='None\n' 118 | else: 119 | context_tokens=context 120 | context_tokens+='\n' 121 | 122 | query_tokens="Query:\n"+query+'\n' 123 | query_ids=self.tokenizer.encode(query_tokens) 124 | 125 | if self.choose_rand_context[i]: 126 | context_ids=self.tokenizer.encode('Context:\n')+self.tokenizer.encode(context_tokens)[-self.max_context_len:] 127 | else: 128 | context_ids=self.tokenizer.encode('Context:\n')+self.tokenizer.encode('None\n')[-self.max_context_len:] 129 | excode_tokens='Examples from search:\n' 130 | if self.choose_rand_ex[i]: 131 | for i_e_code,e_code in enumerate(excode.codes): 132 | 133 | if i_e_code>1:break 134 | excode_tokens+=e_code 135 | excode_tokens+='\n' 136 | else: 137 | excode_tokens+='None\n' 138 | 139 | excode_ids=self.tokenizer.encode(excode_tokens) 140 | total_query_ids=context_ids+query_ids+excode_ids 141 | add_newline=False 142 | if len(total_query_ids)>self.max_input_len:add_newline=True 143 | total_query_ids=total_query_ids[:self.max_input_len] 144 | if not add_newline: 145 | code_prompt='Generate Code:\n'+code 146 | else: 147 | code_prompt='\nGenerate Code:\n'+code 148 | code_ids=self.tokenizer.encode(code_prompt,verbose=False)[:self.max_target_len-1] 149 | code_ids.append(self.tokenizer.eos_token_id) 150 | input_ids=total_query_ids+code_ids 151 | labels=[-100]*len(total_query_ids)+code_ids 152 | 153 | #add paddings 154 | padding_length=self.max_input_len+self.max_target_len-len(input_ids) 155 | input_ids+=[self.tokenizer.eos_token_id]*padding_length 156 | labels+=[-100]*padding_length 157 | inp_labels.append({ 158 | 'input_ids':torch.LongTensor(input_ids), 159 | 'labels':torch.LongTensor(labels) 160 | }) 161 | 162 | self.inp_labels=inp_labels 163 | 164 | def __len__(self): 165 | return len(self.inp_labels) 166 | 167 | 168 | def __getitem__(self, idx): 169 | return self.inp_labels[idx] 170 | 171 | -------------------------------------------------------------------------------- /gen/run.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import math 3 | import os 4 | import pprint 5 | import sys 6 | import transformers 7 | 8 | from tqdm import tqdm 9 | from datetime import datetime 10 | 11 | import torch 12 | import torch.distributed as dist 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import torch.multiprocessing as mp 17 | 18 | from MainDataset import MainDataset 19 | from search.code.add_search import ExCode 20 | 21 | def run_training(args, train_data,eval_data): 22 | 23 | model=transformers.GPTNeoForCausalLM.from_pretrained("saved_models/checkpoint/") 24 | train_data.start_iteration = 0 25 | training_args = transformers.TrainingArguments( 26 | output_dir=args.save_dir, 27 | overwrite_output_dir=False, 28 | 29 | do_train=True, 30 | do_eval=False, 31 | do_predict=True, 32 | evaluation_strategy='steps', 33 | eval_steps=args.eval_steps, 34 | 35 | num_train_epochs=args.epochs, 36 | per_device_train_batch_size=args.batch_size_per_replica, 37 | gradient_accumulation_steps=args.grad_acc_steps, 38 | 39 | learning_rate=args.lr, 40 | 41 | logging_dir=args.save_dir, 42 | logging_first_step=True, 43 | logging_steps=args.log_freq, 44 | save_steps=args.save_freq, 45 | save_total_limit=3, 46 | 47 | dataloader_drop_last=True, 48 | dataloader_num_workers=1, 49 | 50 | local_rank=args.local_rank, 51 | 52 | deepspeed=args.deepspeed, 53 | fp16=args.fp16, 54 | ) 55 | trainer = transformers.Trainer( 56 | model=model, 57 | args=training_args, 58 | train_dataset=train_data, 59 | eval_dataset=eval_data 60 | ) 61 | trainer.train() 62 | model.save_pretrained(os.path.join(args.save_dir, "final_checkpoint")) 63 | 64 | 65 | def get_dataset(args): 66 | #remove this 67 | #TODO: max tokens? play with it. 68 | train_data = MainDataset( 69 | max_input_len=550, 70 | max_target_len=300, 71 | max_context_len=250, 72 | mode='train' 73 | ) 74 | eval_data = MainDataset( 75 | max_input_len=550, 76 | max_target_len=300, 77 | max_context_len=250, 78 | mode='valid' 79 | ) 80 | torch.save(train_data,'data/train.pt') 81 | torch.save(eval_data,'data/valid.pt') 82 | 83 | print('saved tensors') 84 | pickle.dump(train_data,open('data/train.bin','wb')) 85 | pickle.dump(eval_data,open('data/valid.bin','wb')) 86 | # train_data=None 87 | return train_data,eval_data 88 | 89 | if __name__ == "__main__": 90 | import argparse 91 | 92 | parser = argparse.ArgumentParser(description="Language Modelling on Code") 93 | parser.add_argument('--arch', default='gpt2') 94 | parser.add_argument('--dummy-model', action='store_true') 95 | parser.add_argument('--load', default=None, type=str) 96 | parser.add_argument('--load_train_dataset', default='data/train.bin', type=str) 97 | parser.add_argument('--load_eval_dataset', default='data/valid.bin', type=str) 98 | parser.add_argument('--resume', default=None, type=str) 99 | # Dataloading 100 | parser.add_argument('--context-dataroot', default='dataset/context_js/', type=str) 101 | # Training 102 | parser.add_argument('--epochs', default=15, type=int) 103 | parser.add_argument('--lr', default=5e-5, type=float) 104 | # parser.add_argument('--lr-warmup-steps', default=500, type=int) 105 | parser.add_argument('--batch-size-per-replica', default=3, type=int) 106 | parser.add_argument('--grad-acc-steps', default=2, type=int) 107 | parser.add_argument('--local_rank', default=-1, type=int) 108 | parser.add_argument('--deepspeed', default=None, type=str) 109 | parser.add_argument('--fp16', default=False, action='store_true') 110 | # Logging and stuff 111 | parser.add_argument('--save-dir', default="saved_gen/", type=str) 112 | parser.add_argument('--log_freq', default=1000, type=int) 113 | parser.add_argument('--save-freq', default=10000, type=int) 114 | parser.add_argument('--eval_steps', default=5000, type=int) 115 | 116 | args = parser.parse_args() 117 | 118 | argsdict = vars(args) 119 | print(pprint.pformat(argsdict)) 120 | 121 | os.makedirs(args.save_dir, exist_ok=True) 122 | if os.path.exists(args.load_train_dataset) and os.path.exists(args.load_eval_dataset): 123 | eval_data=pickle.load(open(args.load_eval_dataset,'rb')) 124 | train_data=pickle.load(open(args.load_train_dataset,'rb')) 125 | print('original train len: ',len(train_data)) 126 | else: 127 | train_data,eval_data = get_dataset(args) 128 | 129 | run_training(args, train_data,eval_data=eval_data) 130 | 131 | -------------------------------------------------------------------------------- /jsdemo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpmas/Codeon/917beea6ef25e9dd213b43620c28859892e927e8/jsdemo.gif -------------------------------------------------------------------------------- /search/code/add_search.py: -------------------------------------------------------------------------------- 1 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 2 | RobertaConfig, RobertaModel, RobertaTokenizer) 3 | import sys 4 | from search.code.model import Model 5 | import torch 6 | import scann 7 | import pickle,json 8 | import time 9 | from torch.utils.data import DataLoader, Dataset 10 | from tqdm import tqdm 11 | import os 12 | 13 | class InputFeatures(object): 14 | """A single training/test features for a example.""" 15 | def __init__(self, 16 | docstring_ids, 17 | idx, 18 | docstring,source,code,orig_doc 19 | 20 | ): 21 | self.docstring_ids=docstring_ids 22 | self.idx=idx 23 | self.docstring=docstring 24 | self.source=source 25 | self.code=code 26 | self.orig_doc=orig_doc 27 | 28 | def convert_examples_to_features(js,tokenizer,block_size): 29 | #choose the language 30 | docstring='Language: Javascript'+' NL: '+js['doc'] 31 | orig_doc=js['doc'] 32 | code=js['body'] 33 | 34 | docstring_tokens=[tokenizer.cls_token]+tokenizer.tokenize(docstring)[:block_size-2]+[tokenizer.sep_token] 35 | 36 | docstring_ids = tokenizer.convert_tokens_to_ids(docstring_tokens) 37 | padding_length =block_size - len(docstring_ids) 38 | docstring_ids+=[tokenizer.pad_token_id]*padding_length 39 | 40 | return InputFeatures(docstring_ids=docstring_ids,idx=js['idx'],docstring=docstring,source=js['source'],code=code,orig_doc=orig_doc) 41 | class TextDataset(Dataset): 42 | def __init__(self, tokenizer, block_size=100, file_path=None): 43 | self.examples = [] 44 | data=[] 45 | f=pickle.load(open(file_path,'rb')) 46 | for i,line in enumerate(f): 47 | js=line 48 | data.append(js) 49 | for js in data: 50 | converted_ex=convert_examples_to_features(js,tokenizer,block_size=block_size) 51 | if converted_ex: 52 | self.examples.append(converted_ex) 53 | 54 | def __len__(self): 55 | return len(self.examples) 56 | 57 | def __getitem__(self, i): 58 | return (torch.LongTensor(self.examples[i].docstring_ids),self.examples[i].idx,self.examples[i].docstring,self.examples[i].source,self.examples[i].code,self.examples[i].orig_doc) 59 | class ExCode(object): 60 | """example of retrived code with corresponding nl and idx""" 61 | def __init__(self,codes,idx,nl,source): 62 | self.codes=codes 63 | self.idx=idx 64 | self.nl=nl 65 | self.source=source 66 | def save_excodes(loader,model,device,searcher,codebase,mode,source_data): 67 | excodes={} 68 | 69 | for step, batch in tqdm(enumerate(loader),total=len(loader)): 70 | docstring_ids=batch[0] 71 | idxs=batch[1] 72 | docstrings=batch[2] 73 | sources=batch[3] 74 | src_codes=batch[4] 75 | orig_docs=batch[5] 76 | 77 | with torch.no_grad(): 78 | embeds=model.get_representation_batch(qc_ids=docstring_ids,device=device) 79 | embeds=embeds.detach().cpu().numpy() 80 | 81 | assert len(idxs)==len(embeds)==len(docstrings)==len(sources)==len(src_codes) 82 | for idx, embed,docstring,source,src_code,orig_doc in zip(idxs,embeds,docstrings,sources,src_codes,orig_docs): 83 | assert source==source_data 84 | idx=idx.item() 85 | code_idx,_=searcher.search(embed) 86 | filtered_code_idx=[] 87 | for c_id in code_idx: 88 | 89 | if codebase[c_id]['docstring'].strip() == orig_doc.strip() or codebase[c_id]['language']!='javascript': 90 | continue 91 | else: 92 | set_c_id_code=set(codebase[c_id]['code'].split(' ')) 93 | set_src_code=set(src_code.split(' ')) 94 | common=set_c_id_code & set_src_code 95 | max_per=max(len(common)/len(set_c_id_code),len(common)/len(set_src_code)) 96 | if max_per>=0.95: 97 | continue 98 | else: 99 | filtered_code_idx.append(c_id) 100 | 101 | codes=[codebase[c_id]['code'] for c_id in filtered_code_idx if codebase[c_id]['language']=='javascript'] 102 | 103 | excodes[idx]=ExCode(idx=idx,source=source,codes=codes,nl=docstring) 104 | pickle.dump(excodes,open(f'data/{mode}_{source_data}_excodes.bin','wb')) 105 | print('done') 106 | 107 | def main(model,tokenizer,codebase): 108 | searcher = scann.scann_ops_pybind.load_searcher('data/scann_searcher') 109 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 110 | print('the device is : ',device) 111 | model.to(device) 112 | model.eval() 113 | 114 | context_train_dataset=TextDataset(tokenizer,file_path='dataset/context_js/train.bin') 115 | context_valid_dataset=TextDataset(tokenizer,file_path='dataset/context_js/valid.bin') 116 | print('lengths: ',len(context_train_dataset),len(context_valid_dataset)) 117 | context_train_dataloader = DataLoader(context_train_dataset,batch_size=30) 118 | context_valid_dataloader = DataLoader(context_valid_dataset,batch_size=30) 119 | save_excodes(context_train_dataloader,model,device,searcher,codebase,'train',source_data='context') 120 | save_excodes(context_valid_dataloader,model,device,searcher,codebase,'valid',source_data='context') 121 | return 122 | 123 | if __name__=='__main__': 124 | model_name='saved_search/checkpoint-best-mrr/model.bin' 125 | config = RobertaConfig.from_pretrained('codebert', 126 | cache_dir= None) 127 | tokenizer = RobertaTokenizer.from_pretrained('codebert', 128 | cache_dir=None) 129 | model = RobertaModel.from_pretrained('codebert', 130 | config=config, 131 | cache_dir=None) 132 | 133 | codebase= pickle.load(open('data/codebase.bin','rb')) 134 | 135 | model=Model(model,config,tokenizer,args=None) 136 | model.load_state_dict(torch.load(model_name)) 137 | main(model,tokenizer,codebase) 138 | 139 | 140 | -------------------------------------------------------------------------------- /search/code/make_index.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data 3 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 4 | RobertaConfig, RobertaModel, RobertaTokenizer) 5 | from model import Model 6 | import torch 7 | import pickle 8 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 9 | import torch.nn.functional as f 10 | from tqdm import tqdm 11 | import scann 12 | 13 | class InputFeatures(object): 14 | """A single training/test features for a example.""" 15 | def __init__(self, 16 | code_tokens, 17 | code_ids, 18 | original_code, 19 | docstring,url,language 20 | 21 | ): 22 | self.code_tokens = code_tokens 23 | self.original_code=original_code 24 | self.code_ids = code_ids 25 | self.docstring=docstring 26 | self.url=url 27 | self.language=language 28 | 29 | def convert_examples_to_features(js,tokenizer,block_size): 30 | #code 31 | docstring=' '.join(js['docstring_tokens']).strip() 32 | code='Code: '+' '.join(js['function_tokens']) 33 | code_tokens=tokenizer.tokenize(code)[:block_size-2] 34 | code_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] 35 | code_ids = tokenizer.convert_tokens_to_ids(code_tokens) 36 | padding_length =block_size - len(code_ids) 37 | code_ids+=[tokenizer.pad_token_id]*padding_length 38 | # print('js: ',js) 39 | 40 | original_code=js['function'] 41 | 42 | return InputFeatures(code_tokens,code_ids,original_code=original_code,docstring=docstring,url=js['url'],language=js['language']) 43 | class TextDataset(Dataset): 44 | def __init__(self, tokenizer, block_size=256, file_paths=None): 45 | self.examples = [] 46 | data=[] 47 | f_py=pickle.load(open(file_paths['py'],'rb')) 48 | f_js=pickle.load(open(file_paths['js'],'rb')) 49 | for i,line in enumerate(f_py): 50 | js=line 51 | data.append(js) 52 | print('len of py: ',len(data)) 53 | for i,line in enumerate(f_js): 54 | js=line 55 | data.append(js) 56 | print('len of js+py: ',len(data)) 57 | np.random.seed(69) 58 | np.random.shuffle(data) 59 | for js in data: 60 | self.examples.append(convert_examples_to_features(js,tokenizer,block_size=block_size)) 61 | 62 | def __len__(self): 63 | return len(self.examples) 64 | 65 | def __getitem__(self, i): 66 | return (torch.tensor(self.examples[i].code_ids),self.examples[i].original_code,self.examples[i].docstring,self.examples[i].url,self.examples[i].language) 67 | 68 | def main(model,tokenizer,js_dataset_file,py_dataset_file): 69 | # show_gpu('GPU memory usage initially:') 70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | print('the device is : ',device) 72 | model.to(device) 73 | model.eval() 74 | # model=model.cuda() 75 | dataset=TextDataset(tokenizer,file_paths={'js':js_dataset_file,'py':py_dataset_file}) 76 | sampler = SequentialSampler(dataset) 77 | 78 | dataloader = DataLoader(dataset, sampler=sampler, 79 | batch_size=15) 80 | 81 | codebase={} 82 | idx_count=0 83 | reprs=[] 84 | for step, batch in tqdm(enumerate(dataloader),total=len(dataloader)): 85 | codes=batch[0] 86 | orig=batch[1] 87 | docstrings=batch[2] 88 | urls=batch[3] 89 | languages=batch[4] 90 | 91 | with torch.no_grad(): 92 | embeds=model.get_representation_batch(qc_ids=codes,device=device) 93 | embeds=embeds.cpu() 94 | 95 | assert len(embeds)==len(orig)==len(docstrings) 96 | 97 | for embed,orig_code,docstring,url,language in zip(embeds,orig,docstrings,urls,languages): 98 | codebase[idx_count]={'code':orig_code,'docstring':docstring,'url':url,'language':language} 99 | reprs.append(embed) 100 | idx_count+=1 101 | scann_dataset=torch.stack(reprs) 102 | normalized_dataset=f.normalize(scann_dataset) 103 | searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product").tree( 104 | num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah( 105 | 2, anisotropic_quantization_threshold=0.2).reorder(100).build() 106 | searcher.serialize('data/scann_searcher') 107 | print('idx count: ',idx_count) 108 | print('searcher saved') 109 | pickle.dump(codebase,open('data/codebase.bin','wb')) 110 | print('codebase saved') 111 | 112 | 113 | if __name__=='__main__': 114 | model_name='saved_search/checkpoint-best-mrr/model.bin' 115 | config = RobertaConfig.from_pretrained('codebert', 116 | cache_dir= None) 117 | tokenizer = RobertaTokenizer.from_pretrained('codebert', 118 | cache_dir=None) 119 | model = RobertaModel.from_pretrained('codebert', 120 | config=config, 121 | cache_dir=None) 122 | 123 | py_dataset_file='dataset/python/py_superset.bin' 124 | js_dataset_file='dataset/js/js_superset.bin' 125 | 126 | model=Model(model,config,tokenizer,args=None) 127 | model.load_state_dict(torch.load(model_name)) 128 | main(model,tokenizer,js_dataset_file=js_dataset_file,py_dataset_file=py_dataset_file) -------------------------------------------------------------------------------- /search/code/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import copy 8 | import torch.nn.functional as F 9 | from torch.nn import CrossEntropyLoss, MSELoss 10 | 11 | 12 | 13 | class Model(nn.Module): 14 | def __init__(self, encoder,config,tokenizer,args=None): 15 | super(Model, self).__init__() 16 | self.encoder = encoder 17 | self.config=config 18 | self.tokenizer=tokenizer 19 | self.args=args 20 | def get_representation_batch(self,qc_ids=None,device=None): 21 | """get represenations in batch for either queries or codes""" 22 | return self.encoder(qc_ids.to(device),attention_mask=qc_ids.ne(1).to(device))[1] 23 | 24 | def get_representation_one(self,query,device=None): 25 | """get representation for a single query: less dataset stuffs.""" 26 | query_tokens=[self.tokenizer.cls_token]+self.tokenizer.tokenize(query)[:298]+[self.tokenizer.sep_token] 27 | query_ids=torch.tensor(self.tokenizer.convert_tokens_to_ids(query_tokens)).unsqueeze(dim=0).to(device) 28 | return self.encoder(query_ids,attention_mask=query_ids.ne(1))[1].squeeze(dim=0) 29 | 30 | def forward(self, code_inputs,nl_inputs,return_vec=False): 31 | bs=code_inputs.shape[0] 32 | inputs=torch.cat((code_inputs,nl_inputs),0) 33 | outputs=self.encoder(inputs,attention_mask=inputs.ne(1))[1] 34 | code_vec=outputs[:bs] 35 | nl_vec=outputs[bs:] 36 | 37 | if return_vec: 38 | return code_vec,nl_vec 39 | scores=(nl_vec[:,None,:]*code_vec[None,:,:]).sum(-1) 40 | loss_fct = CrossEntropyLoss() 41 | loss = loss_fct(scores, torch.arange(bs, device=scores.device)) 42 | return loss,code_vec,nl_vec 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /search/code/run.py: -------------------------------------------------------------------------------- 1 | 2 | #code adapted from https://github.com/microsoft/CodeXGLUE/tree/main/Text-Code/text-to-code 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import argparse 7 | import glob 8 | import logging 9 | import os 10 | import pickle 11 | import random 12 | import re 13 | import shutil 14 | import gzip 15 | from pathlib import Path 16 | import numpy as np 17 | import torch 18 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 19 | from torch.utils.data.distributed import DistributedSampler 20 | import json 21 | try: 22 | from torch.utils.tensorboard import SummaryWriter 23 | except: 24 | from tensorboardX import SummaryWriter 25 | 26 | from tqdm import tqdm, trange 27 | import multiprocessing 28 | from model import Model 29 | cpu_cont = multiprocessing.cpu_count() 30 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 31 | BertConfig, BertForMaskedLM, BertTokenizer, 32 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, 33 | OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, 34 | RobertaConfig, RobertaModel, RobertaTokenizer, 35 | DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | MODEL_CLASSES = { 40 | 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 41 | 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 42 | 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), 43 | 'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), 44 | 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) 45 | } 46 | 47 | 48 | class InputFeatures(object): 49 | """A single training/test features for a example.""" 50 | def __init__(self, 51 | code_tokens, 52 | code_ids, 53 | nl_tokens, 54 | nl_ids, 55 | url 56 | 57 | ): 58 | self.code_tokens = code_tokens 59 | self.code_ids = code_ids 60 | self.nl_tokens = nl_tokens 61 | self.nl_ids = nl_ids 62 | self.url=url 63 | 64 | 65 | 66 | def convert_examples_to_features(js,tokenizer,args): 67 | #code 68 | if js['language']!='python' and js['language']!='javascript': 69 | print('not python or javascript',js['language']) 70 | code='Code: '+' '.join(js['code_tokens']) 71 | 72 | code_tokens=tokenizer.tokenize(code)[:args.block_size-2] 73 | code_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] 74 | code_ids = tokenizer.convert_tokens_to_ids(code_tokens) 75 | padding_length = args.block_size - len(code_ids) 76 | code_ids+=[tokenizer.pad_token_id]*padding_length 77 | if js['language']=='python': 78 | nl='Language: Python'+' NL: '+' '.join(js['docstring_tokens']) 79 | else: 80 | assert js['language']=='javascript' 81 | nl='Language: Javascript'+' NL: '+' '.join(js['docstring_tokens']) 82 | 83 | nl_tokens=tokenizer.tokenize(nl)[:args.block_size-2] 84 | nl_tokens =[tokenizer.cls_token]+nl_tokens+[tokenizer.sep_token] 85 | nl_ids = tokenizer.convert_tokens_to_ids(nl_tokens) 86 | padding_length = args.block_size - len(nl_ids) 87 | nl_ids+=[tokenizer.pad_token_id]*padding_length 88 | 89 | return InputFeatures(code_tokens,code_ids,nl_tokens,nl_ids,js['url']) 90 | 91 | class TextDataset(Dataset): 92 | def __init__(self, tokenizer, args, data_path=None): 93 | self.examples = [] 94 | data=[] 95 | py_file_paths=sorted(Path(data_path['py']).glob('*.gz')) 96 | js_file_paths=sorted(Path(data_path['js']).glob('*.gz')) 97 | py_data=[] 98 | js_data=[] 99 | for py_file_path in py_file_paths: 100 | logger.info("Processing file: %s", py_file_path) 101 | with gzip.open(py_file_path,'r') as f: 102 | for i,line in enumerate(f): 103 | line=line.strip() 104 | js=json.loads(line) 105 | assert js['language']=='python' 106 | py_data.append(js) 107 | for js_file_path in js_file_paths: 108 | logger.info("Processing file: %s", js_file_path) 109 | with gzip.open(js_file_path,'r') as f: 110 | for i,line in enumerate(f): 111 | line=line.strip() 112 | js=json.loads(line) 113 | assert js['language']=='javascript' 114 | js_data.append(js) 115 | data=js_data+py_data 116 | np.random.shuffle(data) 117 | print('lengths: ',len(data),len(py_data),len(js_data)) 118 | for d in data: 119 | self.examples.append(convert_examples_to_features(d,tokenizer,args)) 120 | 121 | def __len__(self): 122 | return len(self.examples) 123 | 124 | def __getitem__(self, i): 125 | return (torch.tensor(self.examples[i].code_ids),torch.tensor(self.examples[i].nl_ids)) 126 | 127 | 128 | def set_seed(seed=42): 129 | random.seed(seed) 130 | os.environ['PYHTONHASHSEED'] = str(seed) 131 | np.random.seed(seed) 132 | torch.manual_seed(seed) 133 | torch.cuda.manual_seed(seed) 134 | torch.backends.cudnn.deterministic = True 135 | 136 | 137 | def train(args, train_dataset, model, tokenizer): 138 | """ Train the model """ 139 | print(f'len of train dataset: {len(train_dataset)}') 140 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 141 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 142 | 143 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 144 | batch_size=args.train_batch_size,num_workers=4,pin_memory=True) 145 | args.max_steps=args.epoch*len( train_dataloader) 146 | # args.save_steps=len( train_dataloader)//10 147 | args.warmup_steps=len( train_dataloader) 148 | args.logging_steps=len( train_dataloader) 149 | args.num_train_epochs=args.epoch 150 | model.to(args.device) 151 | # Prepare optimizer and schedule (linear warmup and decay) 152 | no_decay = ['bias', 'LayerNorm.weight'] 153 | optimizer_grouped_parameters = [ 154 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 155 | 'weight_decay': args.weight_decay}, 156 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 157 | ] 158 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 159 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.max_steps*0.1, 160 | num_training_steps=args.max_steps) 161 | if args.fp16: 162 | try: 163 | from apex import amp 164 | except ImportError: 165 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 166 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 167 | 168 | # multi-gpu training (should be after apex fp16 initialization) 169 | if args.n_gpu > 1: 170 | model = torch.nn.DataParallel(model) 171 | 172 | # Distributed training (should be after apex fp16 initialization) 173 | if args.local_rank != -1: 174 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 175 | output_device=args.local_rank, 176 | find_unused_parameters=True) 177 | 178 | checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last') 179 | scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt') 180 | optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt') 181 | if os.path.exists(scheduler_last): 182 | scheduler.load_state_dict(torch.load(scheduler_last)) 183 | if os.path.exists(optimizer_last): 184 | optimizer.load_state_dict(torch.load(optimizer_last)) 185 | # Train! 186 | logger.info("***** Running training *****") 187 | logger.info(" Num examples = %d", len(train_dataset)) 188 | logger.info(" Num Epochs = %d", args.num_train_epochs) 189 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 190 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 191 | args.train_batch_size * args.gradient_accumulation_steps * ( 192 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 193 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 194 | logger.info(" Total optimization steps = %d", args.max_steps) 195 | 196 | global_step = args.start_step 197 | tr_loss, logging_loss,avg_loss,tr_nb,tr_num,train_loss = 0.0, 0.0,0.0,0,0,0 198 | best_mrr=0.0 199 | best_acc=0.0 200 | # model.resize_token_embeddings(len(tokenizer)) 201 | model.zero_grad() 202 | 203 | print('starting training loop') 204 | for idx in range(args.start_epoch, int(args.num_train_epochs)): 205 | bar = train_dataloader 206 | tr_num=0 207 | train_loss=0 208 | for step, batch in enumerate(bar): 209 | code_inputs = batch[0].to(args.device) 210 | nl_inputs = batch[1].to(args.device) 211 | 212 | model.train() 213 | loss,code_vec,nl_vec = model(code_inputs,nl_inputs) 214 | assert len(code_vec)==len(code_inputs) 215 | assert len(nl_vec)==len(nl_inputs) 216 | if args.n_gpu > 1: 217 | loss = loss.mean() # mean() to average on multi-gpu parallel training 218 | if args.gradient_accumulation_steps > 1: 219 | loss = loss / args.gradient_accumulation_steps 220 | 221 | if args.fp16: 222 | with amp.scale_loss(loss, optimizer) as scaled_loss: 223 | scaled_loss.backward() 224 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 225 | else: 226 | loss.backward() 227 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 228 | 229 | tr_loss += loss.item() 230 | tr_num+=1 231 | train_loss+=loss.item() 232 | if avg_loss==0: 233 | avg_loss=tr_loss 234 | avg_loss=round(train_loss/tr_num,5) 235 | if (step+1)% 100==0: 236 | logger.info("epoch {} step {} loss {}".format(idx,step+1,avg_loss)) 237 | #bar.set_description("epoch {} loss {}".format(idx,avg_loss)) 238 | 239 | 240 | if (step + 1) % args.gradient_accumulation_steps == 0: 241 | optimizer.step() 242 | optimizer.zero_grad() 243 | scheduler.step() 244 | global_step += 1 245 | output_flag=True 246 | avg_loss=round(np.exp((tr_loss - logging_loss) /(global_step- tr_nb)),4) 247 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 248 | logging_loss = tr_loss 249 | tr_nb=global_step 250 | 251 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 252 | 253 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 254 | results = evaluate(args, model, tokenizer,eval_when_training=True) 255 | for key, value in results.items(): 256 | logger.info(" %s = %s", key, round(value,4)) 257 | # Save model checkpoint 258 | tr_num=0 259 | train_loss=0 260 | checkpoint_prefix = 'checkpoint-last' 261 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 262 | if not os.path.exists(output_dir): 263 | os.makedirs(output_dir) 264 | model_to_save = model.module if hasattr(model,'module') else model 265 | output_dir = os.path.join(output_dir, '{}'.format('model.bin')) 266 | torch.save(model_to_save.state_dict(), output_dir) 267 | logger.info("Saving last model checkpoint to %s", output_dir) 268 | if results['eval_mrr']>best_acc: 269 | 270 | best_acc=results['eval_mrr'] 271 | logger.info(' current epoch:%s',idx) 272 | logger.info(" "+"*"*20) 273 | logger.info(" Best mrr:%s",round(best_acc,4)) 274 | logger.info(" "+"*"*20) 275 | 276 | checkpoint_prefix = 'checkpoint-best-mrr' 277 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 278 | if not os.path.exists(output_dir): 279 | os.makedirs(output_dir) 280 | model_to_save = model.module if hasattr(model,'module') else model 281 | output_dir = os.path.join(output_dir, '{}'.format('model.bin')) 282 | torch.save(model_to_save.state_dict(), output_dir) 283 | logger.info("Saving model checkpoint to %s", output_dir) 284 | 285 | 286 | eval_dataset=None 287 | def evaluate(args, model, tokenizer,eval_when_training=False): 288 | # Loop to handle MNLI double evaluation (matched, mis-matched) 289 | eval_output_dir = args.output_dir 290 | global eval_dataset 291 | if eval_dataset is None: 292 | eval_dataset = TextDataset(tokenizer, args,{'py':args.py_eval_data_file,'js':args.js_eval_data_file}) 293 | 294 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 295 | os.makedirs(eval_output_dir) 296 | 297 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 298 | # Note that DistributedSampler samples randomly 299 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 300 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size,num_workers=4,pin_memory=True) 301 | 302 | # multi-gpu evaluate 303 | if args.n_gpu > 1 and eval_when_training is False: 304 | model = torch.nn.DataParallel(model) 305 | 306 | # Eval! 307 | logger.info("***** Running evaluation *****") 308 | logger.info(" Num examples = %d", len(eval_dataset)) 309 | logger.info(" Batch size = %d", args.eval_batch_size) 310 | eval_loss = 0.0 311 | nb_eval_steps = 0 312 | model.eval() 313 | code_vecs=[] 314 | nl_vecs=[] 315 | for batch in eval_dataloader: 316 | code_inputs = batch[0].to(args.device) 317 | nl_inputs = batch[1].to(args.device) 318 | with torch.no_grad(): 319 | lm_loss,code_vec,nl_vec = model(code_inputs,nl_inputs) 320 | eval_loss += lm_loss.mean().item() 321 | code_vecs.append(code_vec.cpu().numpy()) 322 | nl_vecs.append(nl_vec.cpu().numpy()) 323 | # try: 324 | assert len(code_vec)==len(code_inputs) 325 | assert len(nl_vec)==len(nl_inputs) 326 | nb_eval_steps += 1 327 | code_vecs=np.concatenate(code_vecs,0) 328 | nl_vecs=np.concatenate(nl_vecs,0) 329 | eval_loss = eval_loss / nb_eval_steps 330 | perplexity = torch.tensor(eval_loss) 331 | 332 | scores=np.matmul(nl_vecs,code_vecs.T) 333 | ranks=[] 334 | for i in range(len(scores)): 335 | score=scores[i,i] 336 | rank=1 337 | for j in range(len(scores)): 338 | if i!=j and scores[i,j]>=score: 339 | rank+=1 340 | ranks.append(1/rank) 341 | 342 | 343 | result = { 344 | "eval_loss": float(perplexity), 345 | "eval_mrr":float(np.mean(ranks)) 346 | } 347 | 348 | 349 | return result 350 | 351 | def main(): 352 | parser = argparse.ArgumentParser() 353 | 354 | ## Required parameters 355 | parser.add_argument("--py_train_data_file", default=None, type=str, required=True, 356 | help="The input training data file (a text file).") 357 | parser.add_argument("--js_train_data_file", default=None, type=str, required=True, 358 | help="The input training data file (a text file).") 359 | parser.add_argument("--output_dir", default=None, type=str, required=True, 360 | help="The output directory where the model predictions and checkpoints will be written.") 361 | 362 | ## Other parameters 363 | parser.add_argument("--py_eval_data_file", default=None, type=str, 364 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 365 | parser.add_argument("--js_eval_data_file", default=None, type=str, 366 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 367 | parser.add_argument("--py_test_data_file", default=None, type=str, 368 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 369 | parser.add_argument("--js_test_data_file", default=None, type=str, 370 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 371 | 372 | parser.add_argument("--model_type", default="bert", type=str, 373 | help="The model architecture to be fine-tuned.") 374 | parser.add_argument("--model_name_or_path", default=None, type=str, 375 | help="The model checkpoint for weights initialization.") 376 | 377 | parser.add_argument("--mlm", action='store_true', 378 | help="Train with masked-language modeling loss instead of language modeling.") 379 | parser.add_argument("--mlm_probability", type=float, default=0.15, 380 | help="Ratio of tokens to mask for masked language modeling loss") 381 | 382 | parser.add_argument("--config_name", default="", type=str, 383 | help="Optional pretrained config name or path if not the same as model_name_or_path") 384 | parser.add_argument("--tokenizer_name", default="", type=str, 385 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 386 | parser.add_argument("--cache_dir", default="", type=str, 387 | help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)") 388 | parser.add_argument("--block_size", default=-1, type=int, 389 | help="Optional input sequence length after tokenization." 390 | "The training dataset will be truncated in block of this size for training." 391 | "Default to the model max input length for single sentence inputs (take into account special tokens).") 392 | parser.add_argument("--do_train", action='store_true', 393 | help="Whether to run training.") 394 | parser.add_argument("--do_eval", action='store_true', 395 | help="Whether to run eval on the dev set.") 396 | parser.add_argument("--do_test", action='store_true', 397 | help="Whether to run eval on the dev set.") 398 | parser.add_argument("--evaluate_during_training", action='store_true', 399 | help="Run evaluation during training at each logging step.") 400 | parser.add_argument("--do_lower_case", action='store_true', 401 | help="Set this flag if you are using an uncased model.") 402 | 403 | parser.add_argument("--train_batch_size", default=4, type=int, 404 | help="Batch size per GPU/CPU for training.") 405 | parser.add_argument("--eval_batch_size", default=4, type=int, 406 | help="Batch size per GPU/CPU for evaluation.") 407 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 408 | help="Number of updates steps to accumulate before performing a backward/update pass.") 409 | parser.add_argument("--learning_rate", default=5e-5, type=float, 410 | help="The initial learning rate for Adam.") 411 | parser.add_argument("--weight_decay", default=0.0, type=float, 412 | help="Weight deay if we apply some.") 413 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 414 | help="Epsilon for Adam optimizer.") 415 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 416 | help="Max gradient norm.") 417 | parser.add_argument("--num_train_epochs", default=1.0, type=float, 418 | help="Total number of training epochs to perform.") 419 | parser.add_argument("--max_steps", default=-1, type=int, 420 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 421 | parser.add_argument("--warmup_steps", default=0, type=int, 422 | help="Linear warmup over warmup_steps.") 423 | 424 | parser.add_argument('--logging_steps', type=int, default=50, 425 | help="Log every X updates steps.") 426 | parser.add_argument('--save_steps', type=int, default=1000, 427 | help="Save checkpoint every X updates steps.") 428 | parser.add_argument('--save_total_limit', type=int, default=None, 429 | help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default') 430 | parser.add_argument("--eval_all_checkpoints", action='store_true', 431 | help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") 432 | parser.add_argument("--no_cuda", action='store_true', 433 | help="Avoid using CUDA when available") 434 | parser.add_argument('--overwrite_output_dir', action='store_true', 435 | help="Overwrite the content of the output directory") 436 | parser.add_argument('--overwrite_cache', action='store_true', 437 | help="Overwrite the cached training and evaluation sets") 438 | parser.add_argument('--seed', type=int, default=42, 439 | help="random seed for initialization") 440 | parser.add_argument('--epoch', type=int, default=42, 441 | help="random seed for initialization") 442 | parser.add_argument('--fp16', action='store_true', 443 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 444 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 445 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 446 | "See details at https://nvidia.github.io/apex/amp.html") 447 | parser.add_argument("--local_rank", type=int, default=-1, 448 | help="For distributed training: local_rank") 449 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 450 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 451 | 452 | 453 | 454 | args = parser.parse_args() 455 | 456 | # Setup distant debugging if needed 457 | if args.server_ip and args.server_port: 458 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 459 | import ptvsd 460 | print("Waiting for debugger attach") 461 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 462 | ptvsd.wait_for_attach() 463 | 464 | # Setup CUDA, GPU & distributed training 465 | if args.local_rank == -1 or args.no_cuda: 466 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 467 | args.n_gpu = torch.cuda.device_count() 468 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 469 | torch.cuda.set_device(args.local_rank) 470 | device = torch.device("cuda", args.local_rank) 471 | torch.distributed.init_process_group(backend='nccl') 472 | args.n_gpu = 1 473 | args.device = device 474 | args.per_gpu_train_batch_size=args.train_batch_size//args.n_gpu 475 | args.per_gpu_eval_batch_size=args.eval_batch_size//args.n_gpu 476 | # Setup logging 477 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 478 | datefmt='%m/%d/%Y %H:%M:%S', 479 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 480 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 481 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 482 | 483 | 484 | # Set seed 485 | set_seed(args.seed) 486 | 487 | # Load pretrained model and tokenizer 488 | if args.local_rank not in [-1, 0]: 489 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab 490 | 491 | args.start_epoch = 0 492 | args.start_step = 0 493 | checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last') 494 | if os.path.exists(checkpoint_last) and os.listdir(checkpoint_last): 495 | args.model_name_or_path = os.path.join(checkpoint_last, 'pytorch_model.bin') 496 | args.config_name = os.path.join(checkpoint_last, 'config.json') 497 | idx_file = os.path.join(checkpoint_last, 'idx_file.txt') 498 | with open(idx_file, encoding='utf-8') as idxf: 499 | args.start_epoch = int(idxf.readlines()[0].strip()) + 1 500 | 501 | step_file = os.path.join(checkpoint_last, 'step_file.txt') 502 | if os.path.exists(step_file): 503 | with open(step_file, encoding='utf-8') as stepf: 504 | args.start_step = int(stepf.readlines()[0].strip()) 505 | 506 | logger.info("reload model from {}, resume from {} epoch".format(checkpoint_last, args.start_epoch)) 507 | 508 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 509 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 510 | cache_dir=args.cache_dir if args.cache_dir else None) 511 | config.num_labels=1 512 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, 513 | do_lower_case=args.do_lower_case, 514 | cache_dir=args.cache_dir if args.cache_dir else None) 515 | if args.block_size <= 0: 516 | args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model 517 | args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) 518 | if args.model_name_or_path: 519 | model = model_class.from_pretrained(args.model_name_or_path, 520 | from_tf=bool('.ckpt' in args.model_name_or_path), 521 | config=config, 522 | cache_dir=args.cache_dir if args.cache_dir else None) 523 | else: 524 | model = model_class(config) 525 | 526 | model=Model(model,config,tokenizer,args) 527 | if args.local_rank == 0: 528 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab 529 | 530 | logger.info("Training/evaluation parameters haha jk back%s", args) 531 | if args.do_train: 532 | if args.local_rank not in [-1, 0]: 533 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache 534 | train_dataset = TextDataset(tokenizer, args,{'py':args.py_train_data_file,'js':args.js_train_data_file}) 535 | if args.local_rank == 0: 536 | torch.distributed.barrier() 537 | 538 | train(args, train_dataset, model, tokenizer) 539 | 540 | 541 | if __name__ == "__main__": 542 | main() 543 | 544 | 545 | -------------------------------------------------------------------------------- /search/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "begin training" 3 | python search/code/run.py \ 4 | --output_dir=./saved_search \ 5 | --model_type=roberta \ 6 | --config_name=microsoft/codebert-base \ 7 | --model_name_or_path=codebert \ 8 | --tokenizer_name=roberta-base \ 9 | --do_train \ 10 | --py_train_data_file=dataset/python/train \ 11 | --py_eval_data_file=dataset/python/valid \ 12 | --py_test_data_file=dataset/python/test \ 13 | --js_train_data_file=dataset/javascript/train \ 14 | --js_eval_data_file=dataset/javascript/valid \ 15 | --js_test_data_file=dataset/javascript/test \ 16 | --epoch 10 \ 17 | --block_size 300 \ 18 | --train_batch_size 70 \ 19 | --eval_batch_size 70\ 20 | --learning_rate 5e-5 \ 21 | --max_grad_norm 1.0 \ 22 | --evaluate_during_training \ 23 | --seed 123456 2>&1| tee train.log -------------------------------------------------------------------------------- /search_demo_js.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpmas/Codeon/917beea6ef25e9dd213b43620c28859892e927e8/search_demo_js.gif -------------------------------------------------------------------------------- /search_demo_py.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdpmas/Codeon/917beea6ef25e9dd213b43620c28859892e927e8/search_demo_py.gif --------------------------------------------------------------------------------