├── .gitignore ├── Annotation_Guidelines.md ├── Makefile ├── README.md ├── codet5_finetune.py ├── configs └── preproc_config.json ├── data ├── augs.py ├── code_block_parse.py ├── codex_util.py ├── extract_aligned_pairs.py ├── helper.py ├── pipeline.py └── spelling_en.txt ├── data_cleaning.py ├── datadump.py ├── instruct_augment_code_review ├── augment_code_review.py ├── compute_stats.py ├── filter_data.py ├── logger.py └── utils.py ├── models └── backtranslation.py ├── old └── scrape.py ├── requirements.txt ├── stats └── tags.json ├── t0_finetune.py ├── trace_data ├── inject_locals.py └── process_folder.py ├── utils ├── __init__.py └── parser │ ├── build_parser.py │ └── lang.json └── xml_to_json.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | *.pyc 3 | utils/parser/tmp 4 | utils/parser/build 5 | wandb/ 6 | checkpoints/ 7 | final/ 8 | preds/ 9 | -------------------------------------------------------------------------------- /Annotation_Guidelines.md: -------------------------------------------------------------------------------- 1 | # Annotation guidelines for the Code Review dataset 2 | 3 | 4 | Thank you for your help annotating this data. For context, we have this dataset of questions about code and critiques of the code, with refined code in these answers. We hope to generate a cleaner, more structured subset of this dataset. Given the full question and answer text, we ask you to extract the original question text, the original code, the textual critiques of the code, and the refined code. Such data would be used to train models that can take some code with a description and produce a critique and improved code. 5 | 6 | Here are some notes on the format of this dataset: 7 | 8 | The original text will include some description of the code and either a general request for review/critique or a specific review. Make sure to include all this text. Sometimes additional information like example outputs may be included. You can include it if it is not too long. While minimal code is fine in the "Question Text" field, we want to limit it. Another important aspect is sometimes code from multiple files are provided in the question. Make sure the filenames and description of the files are captured. 9 | 10 | Extract only the code that is crucial to the question and is being reviewed. If there are multiple code blocks (even if indicated to be from multiple files), please merge/concatenate them together. 11 | 12 | The critique will often come in many different forms. Here I will briefly overview how to address some of the common types. 13 | 14 | One common approach are general comments like this: 15 | 16 | > Python has what you want built into the standard library: see the multiprocessing module, and in particular the map method of the Pool class. 17 | 18 | So you can keep that in the critique text you extract. Sometimes the critiques make some reference to the text, maybe even a couple of lines of code. It is fine to include a few lines of code but once again the textual critique should have minimal amount of code. 19 | 20 | Some critiques go through various snippets of the code and write these block-by-block critiques relevant to it. In these cases, take each of the individual critiques and merge/concatenate them together. You may need to reword the critique a little bit so it's a little clear. For example if a line says "you should use a list comprehension here, so it looks like [SOME UPDATED CODE]", it can be rewritten as "you should use a list comprehension in function 'whatever_the_function_is_called'". Sometimes the code block is short enough and it is simply more convenient to refer to it in the text, then it is up to you to include it, but keep in mind that overall we want to avoid a lot of code in the critiques, whether they are large code blocks or even numerous code snippets one after another. 21 | 22 | Don't include any updated code, that's for the next section. If there is any text describing and talking about the updated code, you can ignore it. 23 | 24 | For the "refined code", sometime a huge code block of refined code is provided in the answer and you can directly extract it. Sometimes it is separated into smaller blocks. For example, the answer may go through one code block, critique it, provide a refined version of it, go through the next code block, etc. In that case, merge the refined code blocks to get the full program. Sometimes no improved code is possible. If there is enough information in the answer and it's not too difficult, you can write your own improved code. Sometimes, not enough is provided though, like if the critique asks to provide docstrings but not enough info about what the code does is provided. 25 | 26 | This is a very unstructured dataset (this is an attempt to structure it after all), so you may have to make some judgement calls about it. Just remember you want to just provide enough information, and not any more, for a future model to learn from. 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: Badges.xml 2 | 3 | # Badges.xml Comments.xml PostHistory.xml PostLinks.xml Posts.xml Tags.xml Users.xml Votes.xml 4 | 5 | Badges.xml: codereview.stackexchange.com.7z 6 | [ ! -f Badges.xml ] && 7za x codereview.stackexchange.com.7z || true 7 | 8 | codereview.stackexchange.com.7z: 9 | wget -nc https://archive.org/download/stackexchange/codereview.stackexchange.com.7z 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CodeReviewSE 2 | 3 | Project details: 4 | https://docs.google.com/document/d/1p81LMi_ievC7arVVXOVWxFRjiOWW7Cd2Vp7e1pWnavc 5 | -------------------------------------------------------------------------------- /codet5_finetune.py: -------------------------------------------------------------------------------- 1 | from transformers import SchedulerType, get_scheduler, set_seed, AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer 2 | from torch.optim import AdamW 3 | from datasets import Dataset, load_metric 4 | import torch 5 | import torch.nn.functional as F 6 | import pandas as pd 7 | from torch.utils.data import DataLoader 8 | from accelerate import Accelerator 9 | from functools import partial 10 | from data.helper import load_json_file 11 | from tqdm.auto import tqdm 12 | import os 13 | import argparse 14 | import math 15 | import numpy as np 16 | import nltk 17 | nltk.download('punkt') 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', type=str, default='dataset/CodeReviewSE_clean_QA.json') 22 | parser.add_argument('--batch_size', type=int, default=32) 23 | parser.add_argument('--eval_batch_size', type=int, default=64) 24 | parser.add_argument('--num_workers', type=int, default=16) 25 | parser.add_argument('--num_train_epochs', type=int, default=10) 26 | parser.add_argument('--learning_rate', type=float, default=3e-5) 27 | parser.add_argument('--model_name', type=str, default='Salesforce/codet5-base') 28 | parser.add_argument('--max_input_length', type=int, default=512) 29 | parser.add_argument('--max_target_length', type=int, default=256) 30 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1) 31 | parser.add_argument('--weight_decay', type=float, default=0.01) 32 | parser.add_argument('--num_warmup_steps', type=int, default=10) 33 | parser.add_argument('--lr_scheduler_type', type=SchedulerType, default="linear", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]) 34 | parser.add_argument('--validation_split', type=float, default=0.05) 35 | parser.add_argument('--seed', type=int, default=42) 36 | #parser.add_argument('--gpu_id', type=int, default=0) 37 | parser.add_argument('--wandb_project', type=str, default='finetune-codet5-for-codereview') 38 | parser.add_argument('--num_update_steps_per_epoch', type=int, default=100) 39 | parser.add_argument('--checkpointing_frequency', type=int, default=1) 40 | 41 | return parser.parse_args() 42 | 43 | 44 | 45 | 46 | def preprocess_examples(examples, tokenizer, max_input_length, max_target_length): 47 | # encode the question-answer pairs 48 | question = examples['question'] 49 | answer = examples['answer'] 50 | 51 | model_inputs = tokenizer(question, max_length=max_input_length, padding="max_length", truncation=True) 52 | labels = tokenizer(answer, max_length=max_target_length, padding="max_length", truncation=True).input_ids 53 | 54 | # important: we need to replace the index of the padding tokens by -100 55 | # such that they are not taken into account by the CrossEntropyLoss 56 | labels_with_ignore_index = [] 57 | for labels_example in labels: 58 | labels_example = [label if label != 0 else -100 for label in labels_example] 59 | labels_with_ignore_index.append(labels_example) 60 | 61 | model_inputs["labels"] = labels_with_ignore_index 62 | 63 | return model_inputs 64 | 65 | 66 | def calc_metric(metric, predictions, labels): 67 | decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in predictions] 68 | decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels] 69 | result = metric.compute(predictions=predictions, references=labels) 70 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 71 | return result 72 | 73 | if __name__ == "__main__": 74 | args = parse_args() 75 | set_seed(args.seed) 76 | 77 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) 78 | 79 | config = AutoConfig.from_pretrained(args.model_name) 80 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, config=config) 81 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 82 | 83 | dataset = load_json_file(args.dataset) 84 | dataset = Dataset.from_pandas(pd.DataFrame(data=dataset)) 85 | dataset = dataset.map(partial(preprocess_examples, tokenizer=tokenizer, max_input_length=args.max_input_length, max_target_length=args.max_target_length), batched=True, num_proc=16) 86 | dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels']) 87 | dataset = dataset.train_test_split(test_size=args.validation_split, seed=args.seed) 88 | train_dataloader = DataLoader(dataset['train'], shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) 89 | val_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=args.eval_batch_size, num_workers=args.num_workers) 90 | 91 | # Optimizer 92 | # Split weights in two groups, one with weight decay and the other not. 93 | no_decay = ["bias", "LayerNorm.weight"] 94 | optimizer_grouped_parameters = [ 95 | { 96 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 97 | "weight_decay": args.weight_decay, 98 | }, 99 | { 100 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 101 | "weight_decay": 0.0, 102 | }, 103 | ] 104 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 105 | 106 | # Scheduler 107 | lr_scheduler = get_scheduler( 108 | name=args.lr_scheduler_type, 109 | optimizer=optimizer, 110 | num_warmup_steps=args.num_warmup_steps, 111 | num_training_steps=args.num_train_epochs * len(train_dataloader) 112 | ) 113 | 114 | # metric 115 | metric = load_metric('rouge') # hard-coded to ROUGE 116 | 117 | # accelerator 118 | model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader, lr_scheduler) 119 | 120 | # wandb 121 | if bool(args.wandb_project) & accelerator.is_main_process: 122 | import wandb 123 | wandb.init(project=args.wandb_project, config=args) 124 | 125 | max_train_steps = len(train_dataloader)*args.num_train_epochs 126 | progress_bar = tqdm(range(max_train_steps)) 127 | global_steps = 0 128 | 129 | # Train the model 130 | for epoch in range(args.num_train_epochs): 131 | accelerator.print("Epoch: {}".format(epoch)) 132 | model.train() 133 | for step, batch in enumerate(train_dataloader): 134 | input_ids = batch["input_ids"] 135 | attention_mask = batch["attention_mask"] 136 | labels = batch["labels"] 137 | with accelerator.accumulate(model): 138 | outputs = model(input_ids, attention_mask=attention_mask, labels=labels) 139 | loss = outputs.loss 140 | accelerator.backward(loss) 141 | optimizer.step() 142 | lr_scheduler.step() 143 | optimizer.zero_grad() 144 | progress_bar.update(1) 145 | global_steps += 1 146 | loss = loss.item() 147 | if bool(args.wandb_project) & accelerator.is_main_process: 148 | wandb.log({"loss": loss}) 149 | 150 | if (step + 1) % args.num_update_steps_per_epoch == 0: 151 | accelerator.print("Step: {}/{}".format(step + 1, max_train_steps)) 152 | accelerator.print("Loss: {}".format(loss)) 153 | accelerator.print("LR: {}".format(optimizer.param_groups[0]["lr"])) 154 | accelerator.print("\n") 155 | 156 | 157 | # evaluate on validation set 158 | model.eval() 159 | all_input = [] 160 | all_preds = [] 161 | all_labels = [] 162 | for step, batch in tqdm(enumerate(val_dataloader),total=len(val_dataloader)): 163 | input_ids = batch["input_ids"] 164 | attention_mask = batch["attention_mask"] 165 | labels = batch["labels"] 166 | 167 | # valid loss 168 | with torch.no_grad(): 169 | outputs = model(input_ids, attention_mask=attention_mask, labels=labels) 170 | val_loss = outputs.loss.item() 171 | 172 | # text generation 173 | generated_ids = accelerator.unwrap_model(model).generate( 174 | input_ids = input_ids, 175 | attention_mask = attention_mask, 176 | max_length=150, 177 | num_beams=2, 178 | repetition_penalty=2.5, 179 | length_penalty=1.0, 180 | early_stopping=True 181 | ) 182 | 183 | 184 | generated_ids = accelerator.pad_across_processes( 185 | generated_ids, dim=1, pad_index=tokenizer.pad_token_id 186 | ) 187 | 188 | input_ids, generated_ids, labels = accelerator.gather((input_ids, generated_ids, labels)) 189 | input_ids = input_ids.cpu().numpy() 190 | generated_ids = generated_ids.cpu().numpy() 191 | labels = labels.cpu().numpy() 192 | 193 | decoded_input = tokenizer.batch_decode(input_ids, skip_special_tokens=True) 194 | decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 195 | # Replace -100 in the labels as we can't decode them. 196 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 197 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 198 | all_input += decoded_input 199 | all_preds += decoded_preds 200 | all_labels += decoded_labels 201 | 202 | accelerator.wait_for_everyone() 203 | 204 | # evaluate 205 | if accelerator.is_main_process: 206 | eval_metric = calc_metric(metric, all_preds, all_labels) 207 | accelerator.print('Metric: ', eval_metric) 208 | 209 | # checkpoints 210 | if epoch % args.checkpointing_frequency == 0: 211 | accelerator.save_state(f'checkpoints/epoch_{epoch}') 212 | 213 | # save predictions 214 | 215 | if accelerator.is_main_process: 216 | if not os.path.exists('preds'): os.makedirs('preds') 217 | preds_df = pd.DataFrame({"input": all_input, "preds": all_preds, "labels": all_labels}) 218 | preds_df.to_json(f"preds/epoch_{epoch}.json", orient="split") 219 | 220 | accelerator.wait_for_everyone() 221 | 222 | if bool(args.wandb_project) & accelerator.is_main_process: 223 | wandb.log({"val_loss": val_loss}) 224 | wandb.log({'validation predictions': wandb.Table(dataframe=preds_df.head(1000))}) 225 | wandb.log(eval_metric) 226 | wandb.save('preds/epoch_{}.json'.format(epoch)) 227 | wandb.save('checkpoints/epoch_{}/*'.format(epoch)) 228 | 229 | accelerator.print("Valid loss: {}".format(val_loss)) 230 | accelerator.print("\n") 231 | 232 | accelerator.wait_for_everyone() 233 | 234 | if bool(args.wandb_project) & accelerator.is_main_process: 235 | accelerator.unwrap_model(model).save_pretrained('final') 236 | wandb.save('final/*') 237 | wandb.finish() 238 | 239 | 240 | -------------------------------------------------------------------------------- /configs/preproc_config.json: -------------------------------------------------------------------------------- 1 | [ { 2 | "name": "SpellingAug", 3 | "params": { 4 | "spelling_dict":"data/spelling_en.txt", 5 | "include_reverse": true 6 | } 7 | } 8 | ] -------------------------------------------------------------------------------- /data/augs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Augmentations and factory function for augmentations. 3 | 4 | """ 5 | 6 | from helper import * 7 | import torch 8 | import numpy as np 9 | from typing import Dict 10 | import sys 11 | import json 12 | import multiprocessing 13 | from tqdm import tqdm 14 | from tqdm.contrib.concurrent import process_map 15 | from functools import partial 16 | 17 | 18 | # specifies a dictionary of augmentations 19 | _AUGS: Dict[str, any] = {} # registry 20 | 21 | 22 | def register_aug(name): 23 | """Decorator used register a an augmentation 24 | Args: 25 | name: Name of the augmentation 26 | """ 27 | 28 | def register_class(cls, name): 29 | _AUGS[name] = cls 30 | setattr(sys.modules[__name__], name, cls) 31 | return cls 32 | 33 | if isinstance(name, str): 34 | name = name.lower() 35 | return lambda c: register_class(c, name) 36 | 37 | cls = name 38 | name = cls.__name__ 39 | register_class(cls, name.lower()) 40 | 41 | return cls 42 | 43 | class RandomAug: 44 | """ 45 | Randomly apply one of the augmentations. 46 | """ 47 | 48 | def __init__(self, p=0.5): 49 | self.p = p 50 | 51 | def __call__(self, data): 52 | # creates a multiprocess pool to apply the augmentation 53 | return list(map(self.rand_apply, data)) 54 | 55 | def rand_apply(self, datum): 56 | """ 57 | Calls apply on the individual apply functions after first sampling torch.rand() 58 | """ 59 | if np.random.rand(1) < self.p: 60 | return self.apply(datum) 61 | else: 62 | return datum 63 | 64 | 65 | def apply(self, data): 66 | """ 67 | Apply augmentation to the data. 68 | """ 69 | raise NotImplementedError 70 | 71 | def get_aug(name): 72 | return _AUGS[name.lower()] 73 | 74 | 75 | def get_aug_names(): 76 | return _AUGS.keys() 77 | 78 | 79 | class Compose: 80 | """ 81 | Compose several augmentations together. Based on torchvision's `Compose` 82 | """ 83 | 84 | def __init__(self, composition_path): 85 | # composition_path refers to a json file that contains the augmentations we're using 86 | self.augs = [] 87 | composition = load_json_file(composition_path) 88 | for c in composition: 89 | self.augs.append(get_aug(c['name'])(**c['params'])) 90 | 91 | def __call__(self, data): 92 | for t in self.augs: 93 | data = t(data) 94 | return data 95 | 96 | def __repr__(self): 97 | format_string = self.__class__.__name__ + "(" 98 | for t in self.augs: 99 | format_string += "\n" 100 | format_string += " {0}".format(t) 101 | format_string += "\n)" 102 | return format_string 103 | 104 | class ApplyAugs: 105 | def __init__(self, augs): 106 | assert isinstance(augs, Compose), "augs must be composed together with `Compose`" 107 | self.augs = augs 108 | 109 | 110 | def _call_for_questions(self, data): 111 | self.keys = list(data.keys()) 112 | 113 | # print the number of questions 114 | print("There are: ") 115 | print(str(len(self.keys)) + " questions") 116 | self.keys = [k for k in self.keys if '_q' in k] 117 | print(self.keys) 118 | 119 | # pass augmentations to iter_body 120 | iter_body_local = partial(iter_body, self.augs) 121 | 122 | # get each question's body 123 | bodies = list(map(lambda x: data[x]['body'], self.keys)) 124 | 125 | # iterate over the bodies using multiprocessing 126 | with multiprocessing.Pool(multiprocessing.cpu_count()) as p: 127 | with tqdm(total=len(bodies)) as pbar: 128 | for i, body in enumerate(p.imap(iter_body_local, bodies)): 129 | bodies[i] = body 130 | # update the progress bar 131 | pbar.update() 132 | 133 | # return the source sequence 134 | return bodies 135 | 136 | def _call_for_answers(self, data): 137 | self.keys = list(data.keys()) 138 | self.keys = [k for k in self.keys if '_ans' in k] 139 | print(self.keys) 140 | 141 | iter_answers_local = partial(iter_body, self.augs) 142 | answer_bodies = list() 143 | 144 | # get each answer's body 145 | for question in tqdm(self.keys): 146 | for answer in tqdm(data[question]['answers']): 147 | answer_bodies.append(answer['body']) 148 | 149 | # iterate over the bodies using multiprocessing 150 | with multiprocessing.Pool(multiprocessing.cpu_count()) as p: 151 | with tqdm(total=len(answer_bodies)) as pbar: 152 | for i, body in enumerate(p.imap(iter_answers_local, answer_bodies)): 153 | answer_bodies[i] = body 154 | pbar.update() 155 | 156 | # return the target sequence 157 | return answer_bodies 158 | 159 | # copies back the output of __call__ to the original data 160 | def copy_back_question_bodies(self, outputs, data): 161 | """ 162 | outputs is a dict, where the keys are question ids and the values are the augmented bodies. 163 | Args: 164 | outputs: list of lists 165 | data: dict 166 | Returns: 167 | data: dict 168 | """ 169 | for idx, question in enumerate(self.keys): 170 | data[question]['body'] = outputs[idx] 171 | 172 | 173 | self.orig_keys = [k for k in data.keys() if '_q' not in k] 174 | 175 | for k in self.orig_keys: 176 | data[k]['body'] = ' '.join(parse_html_to_str(data[k]['body'])) 177 | 178 | return data 179 | # copies back the output of __call__ to the original data 180 | def copy_back_answer_bodies(self, outputs, data): 181 | """ 182 | outputs is a dict, where the keys are question ids. Each question has a sub-dict of answers. 183 | Args: 184 | outputs: list of lists 185 | data: dict 186 | Returns: 187 | data: dict 188 | """ 189 | 190 | idx = 0 191 | for question in self.keys: 192 | for answer in data[question]['answers']: 193 | answer['body'] = outputs[idx] 194 | idx += 1 195 | 196 | self.orig_keys = [k for k in data.keys() if '_ans' not in k] 197 | 198 | for k in self.orig_keys: 199 | for answer in data[k]['answers']: 200 | answer['body'] = ' '.join(parse_html_to_str(answer['body'])) 201 | 202 | return data 203 | 204 | def __call__(self, data, for_question=True): 205 | if for_question: 206 | return self.copy_back_question_bodies(self._call_for_questions(data), data) 207 | else: 208 | return self.copy_back_answer_bodies(self._call_for_answers(data), data) 209 | 210 | @register_aug 211 | class KeyboardAug(RandomAug): 212 | pass 213 | 214 | 215 | @register_aug 216 | class SpellingAug(RandomAug): 217 | def __init__(self, spelling_dict, include_reverse=True, p=0.5): 218 | super().__init__(p) 219 | self.spelling_dict = spelling_dict if type(spelling_dict) == dict else self.load_spelling_dict(spelling_dict, include_reverse) 220 | 221 | def load_spelling_dict(self, file_path, include_reverse=True): 222 | """ 223 | Loads the spelling dictionary from the file. 224 | """ 225 | spelling_dict = {} 226 | with open(file_path, 'r', encoding="utf-8") as f: 227 | for line in f.readlines(): 228 | tokens = line.split(' ') 229 | # Last token include newline separator 230 | tokens[-1] = tokens[-1].replace('\n', '') 231 | 232 | key = tokens[0] 233 | values = tokens[1:] 234 | 235 | if key not in spelling_dict: 236 | spelling_dict[key] = [] 237 | 238 | spelling_dict[key].extend(values) 239 | # Remove duplicate mapping 240 | spelling_dict[key] = list(set(spelling_dict[key])) 241 | # Build reverse mapping 242 | if include_reverse: 243 | for value in values: 244 | if value not in spelling_dict: 245 | spelling_dict[value] = [] 246 | if key not in spelling_dict[value]: 247 | spelling_dict[value].append(key) 248 | return spelling_dict 249 | 250 | def apply(self, i): 251 | """ 252 | Apply augmentation to the element. 253 | """ 254 | words = i.split() 255 | rands = np.random.randint(2, size=len(words)) 256 | for idx, word in enumerate(words): 257 | # Replace the word with the correct spelling 258 | if word not in self.spelling_dict: 259 | continue 260 | else: 261 | words[idx] = self.spelling_dict[word][min(len(self.spelling_dict[word]) - 1, rands[idx])] 262 | return ' '.join(words) 263 | 264 | @register_aug 265 | class BackTranslationAug(RandomAug): 266 | def __init__(self, src_model_name, tgt_model_name, device='cpu', batch_size=32, max_length=300, p=0.5): 267 | super().__init__(p) 268 | self.model = BackTranslationModel(src_model_name, tgt_model_name, device, batch_size, max_length) 269 | 270 | def apply(self, i): 271 | return self.model.translate(i) 272 | 273 | 274 | # This processes the first two questions with the augmentation pipeline for testing purposes 275 | if __name__ == "__main__": 276 | print(get_aug_names()) 277 | data = load_json_file("dataset/CodeReviewSE.json") 278 | augs = Compose("configs/preproc_config.json") 279 | augs = ApplyAugs(augs) 280 | 281 | # get a subset of data with the first 2 questions (too slow for whole dataset, need multiprocessing speedup) 282 | data = {k: data[k] for k in list(data.keys())[:1]} 283 | 284 | data = duplicate_data(data, is_ans=False, n=1) # duplicate for augmenting questions 285 | data = augs(data, for_question=True) # apply augmentations to questions 286 | data = duplicate_data(data, is_ans=True, n=1) # duplicate for augmenting answers 287 | data = augs(data, for_question=False) # apply augmentations to answers 288 | 289 | # print the augmented data 290 | print(data['1']['body']) 291 | print(data['1_q0']['body']) 292 | print(data['1_q0_ans0']['body']) 293 | print(data['1']['answers'][0]['body']) 294 | print(data['1_q0']['answers'][0]['body']) 295 | print(data['1_q0_ans0']['answers'][0]['body']) 296 | 297 | 298 | -------------------------------------------------------------------------------- /data/code_block_parse.py: -------------------------------------------------------------------------------- 1 | import tokenize 2 | import keyword 3 | from io import BytesIO 4 | 5 | keyword_list = keyword.kwlist 6 | 7 | SET_DELIM = " \n" # Delimiliter for code blocks 8 | 9 | def tokenize_code_snippet(code_snippet:str,exclude_keywords=True): 10 | """ 11 | Tokenize a code snippet 12 | """ 13 | try: 14 | tokens = tokenize.tokenize(BytesIO(code_snippet.encode("utf-8")).readline) 15 | if exclude_keywords: 16 | tokenized = [token.string for token in tokens if token.string not in keyword_list or len(token.string) != 0][1:-2] 17 | else: 18 | tokenized = [token.string for token in tokens if len(token.string) != 0][1:-2] 19 | except tokenize.TokenError: 20 | tokenized = None 21 | return tokenized 22 | 23 | def jaccard_similarity(code_snippet_1:str,code_snippet_2:str): 24 | """ 25 | Computes Jaccard Similarity between two list of tokenized strings. 26 | """ 27 | intersection = len(list(set(code_snippet_1).intersection(code_snippet_2))) 28 | union = (len(set(code_snippet_1)) + len(set(code_snippet_2))) - intersection 29 | return float(intersection) / union 30 | 31 | 32 | def merge_or_ignore(code_block_list:list[str],similarity_threshold:float): 33 | """ 34 | Merge code blocks if they are similar. 35 | args: 36 | code_block_list (list[str]): list of code blocks 37 | similarity_threshold (float): threshold for similarity 38 | returns: 39 | merged_frozen_set (list[str]): list of merged code blocks 40 | merge_list (list[list[ind]]) : coo matrix to map the merged code blocks to the original code blocks 41 | """ 42 | merged_code_blocks = [] 43 | skip_ind = [] 44 | merge_list = [] 45 | for code_block_ind_1 in range(len(code_block_list)): 46 | for code_block_ind_2 in range(code_block_ind_1,len(code_block_list)): 47 | tok_code_block_1 = tokenize_code_snippet(code_block_list[code_block_ind_1],exclude_keywords=False) 48 | tok_code_block_2 = tokenize_code_snippet(code_block_list[code_block_ind_2],exclude_keywords=False) 49 | if tok_code_block_1 != tok_code_block_2 and code_block_ind_1 not in skip_ind and code_block_ind_2 not in skip_ind: 50 | if tok_code_block_1 != None and tok_code_block_2 != None: 51 | sim = jaccard_similarity(tok_code_block_1,tok_code_block_2) 52 | if sim < similarity_threshold: 53 | skip_ind.append(code_block_ind_2) 54 | skip_ind.append(code_block_ind_1) 55 | merge_list.append([code_block_ind_1,code_block_ind_2]) 56 | merged_code_blocks.append(SET_DELIM.join(sorted([code_block_list[code_block_ind_1],code_block_list[code_block_ind_2]]))) 57 | else: 58 | merged_code_blocks.append(code_block_list[code_block_ind_1]) 59 | merged_code_blocks.append(code_block_list[code_block_ind_2]) 60 | else: 61 | #If one of the pair is unparsable, ignore the merge. 62 | merged_code_blocks.append(code_block_list[code_block_ind_1]) 63 | merged_code_blocks.append(code_block_list[code_block_ind_2]) 64 | #Check for missed ind blocks 65 | for ind in range(len(code_block_list)): 66 | if ind not in skip_ind: 67 | merged_code_blocks.append(code_block_list[ind]) 68 | return frozenset(merged_code_blocks),merge_list 69 | 70 | 71 | 72 | 73 | 74 | if __name__ == "__main__": 75 | code_blocks_hash = ["make_data = lambda x: x","dataset = data.append(_)","make_data = lambda x: x-1","dataset = None"] 76 | # print([tokenize_code_snippet(i) for i in code_blocks_hash]) 77 | # print(jaccard_similarity(tokenize_code_snippet(code_blocks_hash[0]),tokenize_code_snippet(code_blocks_hash[1]))) 78 | print(merge_or_ignore(code_blocks_hash,0.1)) -------------------------------------------------------------------------------- /data/codex_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import openai 4 | import ast 5 | import sys 6 | import subprocess 7 | from helper import parse_body_to_return,get_accepted_answer 8 | from pipeline import load_json_file 9 | import logging 10 | 11 | from bs4 import BeautifulSoup 12 | 13 | 14 | #Need the environment to have openai api key beforehand. 15 | openai.api_key = os.getenv("OPEN_AI_API_KEY") 16 | 17 | 18 | class SimpleClassVisitor(ast.NodeVisitor): 19 | def __init__(self): 20 | self.function_names = [] 21 | 22 | def visit_FunctionDef(self, node): 23 | self.function_names.append(node.name) 24 | ast.NodeVisitor.generic_visit(self, node) 25 | 26 | 27 | 28 | 29 | 30 | def parse_to_get_function_name_list(code_snippet:str): 31 | """ 32 | Return function names in a given python code snippet. 33 | """ 34 | visitor = SimpleClassVisitor() 35 | parsed = ast.parse(code_snippet) 36 | visitor.visit(parsed) 37 | return visitor.function_names 38 | 39 | 40 | def generated_codex_code_test_case(code_snippet:str,**kwargs)-> dict: 41 | """ 42 | Return the codex code generated from a given python code snippet. 43 | 44 | args: 45 | code_snippet: str 46 | The python code snippet to be augmented with test cases. 47 | 48 | returns: 49 | str: The codex test case code generated from the given python code snippet. 50 | 51 | """ 52 | try: 53 | function_name_list = parse_to_get_function_name_list(code_snippet) 54 | function_name_prompt = ", ".join(function_name_list) 55 | prompt_format = f"#Generate test cases for {function_name_prompt} function arguments.\n" 56 | response = openai.Edit.create( 57 | engine="code-davinci-edit-001", 58 | input=code_snippet, 59 | instruction=prompt_format, 60 | temperature = 0.7, 61 | top_p = 1, 62 | **kwargs 63 | ) 64 | response_dict = { 65 | "response" :response, 66 | "prompt" : prompt_format 67 | } 68 | return response_dict 69 | except: 70 | return { 71 | "response" : None, 72 | "prompt" : None 73 | } 74 | 75 | def generated_codex_code_type_inference(code_snippet:str,**kwargs)-> dict: 76 | """ 77 | Return the codex code generated from a given python code snippet. 78 | 79 | args: 80 | code_snippet: str 81 | The python code snippet to be augmented with test cases. 82 | 83 | returns: 84 | str: The codex test case code generated from the given python code snippet. 85 | 86 | """ 87 | try: 88 | function_name_list = parse_to_get_function_name_list(code_snippet) 89 | function_name_prompt = ", ".join(function_name_list) 90 | prompt_format = f"Anotate mypy Types for {function_name_prompt} function arguments.\n" 91 | response = openai.Edit.create( 92 | engine="code-davinci-edit-001", 93 | input=code_snippet, 94 | instruction=prompt_format, 95 | temperature = 0.7, 96 | top_p = 1, 97 | **kwargs 98 | ) 99 | response_dict = { 100 | "response" :response, 101 | "prompt" : prompt_format 102 | } 103 | return response_dict 104 | except: 105 | return { 106 | "response" : None, 107 | "prompt" : None 108 | } 109 | 110 | 111 | class TypeChecker: 112 | def __init__(self) -> None: 113 | self.tmp_dir = "tmp/" 114 | os.makedirs(self.tmp_dir, exist_ok=True) 115 | 116 | def check_runtime(self,code_snippet:str): 117 | raise NotImplementedError 118 | 119 | def eval_type_file(self,code_snippet:str)->str: 120 | """ 121 | Given a code snippet type checks and hashes the stdout 122 | """ 123 | file_dir = os.path.join(self.tmp_dir,"tmp.py") 124 | with open(file_dir, "w") as f: 125 | f.write(code_snippet) 126 | command = [f"mypy","--ignore-missing-imports",file_dir] 127 | result = subprocess.run(command,stdout=subprocess.PIPE).stdout.decode('utf-8')#(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) 128 | os.remove(file_dir) 129 | return result 130 | 131 | 132 | 133 | class ProcessDataset: 134 | """ 135 | Process Dataset for Async Augmentations 136 | """ 137 | def __init__(self,dataset_path:str="dataset/CodeReviewSE.json") -> None: 138 | self.dataset_path = dataset_path 139 | self.data = load_json_file(self.dataset_path) 140 | self.type_checker = TypeChecker() 141 | def __call__(self,index:int=10): 142 | """ 143 | Given a index, return the corresponding data entry. 144 | """ 145 | data_point = self.data[str(index)] 146 | code_blocks_data_point = parse_body_to_return(data_point["body"]) 147 | return code_blocks_data_point 148 | 149 | 150 | 151 | 152 | if __name__ == "__main__": 153 | dataset = ProcessDataset() 154 | dataset(1) -------------------------------------------------------------------------------- /data/extract_aligned_pairs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from os.path import exists 3 | from typing import Iterable, Sequence, Tuple 4 | import pylcs 5 | from difflib import SequenceMatcher 6 | import numpy as np 7 | from functools import reduce 8 | import json 9 | from tqdm import tqdm 10 | import re 11 | 12 | log_file = 'log.txt' 13 | with open(log_file, 'w') as f: 14 | f.write("LOG\n\n") 15 | 16 | def log(txt): 17 | with open(log_file, 'a+') as f: 18 | if type(txt) == dict: 19 | f.write(json.dumps(txt, indent=2)+'\n\n') 20 | elif type(txt) == list: 21 | for ele in txt: 22 | log(ele) 23 | elif type(txt) == Block: 24 | f.write(txt.text + '\n\n') 25 | else: 26 | f.write(str(txt)+'\n\n') 27 | 28 | 29 | 30 | #data = json.load(open('CodeReviewSE_clean.json')) 31 | data = json.load(open('temp_data.json', 'r')) 32 | def cache_temp_data(): 33 | with open('temp_data.json', 'w') as f: 34 | temp_keys = list(data.keys())[:5000] 35 | temp_dict = {k: data[k] for k in temp_keys} 36 | json.dump(temp_dict, f) 37 | #cache_temp_data() 38 | #exit() 39 | 40 | 41 | #log(first) 42 | #log(second) 43 | #log(posts[2]) 44 | #exit() 45 | 46 | # Compute lcs for first post 47 | # Currently only filtering out posts with no answers. Keepin posts with no accepted answer 48 | # TODO(dahoas): Perhaps filter out posts below certain upvote threshold 49 | def get_accepted_answer(post): 50 | try: 51 | accepted_id = post['meta_data']['AcceptedAnswerId'] 52 | except KeyError: 53 | if len(post['answers']) > 0: 54 | accepted_ans = reduce(lambda ans1, ans2 : ans1 55 | if int(ans1['meta_data']['Score']) > int(ans2['meta_data']['Score']) 56 | else ans2, post['answers']) 57 | return accepted_ans['body'] 58 | #print('acc_id', accepted_id) 59 | for answer in post['answers']: 60 | if answer['meta_data']['Id'] == accepted_id: 61 | return answer['body'] 62 | 63 | def get_lcs(body, accepted_answer): 64 | s = SequenceMatcher(None, body, accepted_answer) 65 | blocks = s.get_matching_blocks() 66 | max_len_id = np.argmax([block.size for block in blocks]) 67 | max_block = blocks[max_len_id] 68 | start, max_len = max_block.a, max_block.size 69 | return body[start : start + max_len], blocks 70 | 71 | @dataclass 72 | class Block: 73 | text : str 74 | start : int 75 | end : int 76 | type : str 77 | 78 | @dataclass 79 | class CodeBlock(Block): 80 | type = 'code' 81 | 82 | @dataclass 83 | class ReviewBlock: 84 | pre_blocks : Iterable[Block] 85 | code_block : CodeBlock 86 | post_blocks : Iterable[Block] 87 | 88 | def extract_code_blocks(text : str) -> Iterable[CodeBlock]: 89 | codeblock_pattern = r'(?s)((?!).)*<\/code>' 90 | code_block_matches = re.finditer(codeblock_pattern, text) 91 | code_blocks = [] 92 | for match in code_block_matches: 93 | start, end = match.span() 94 | code_block = text[start + 6 : end - 7] # Want to remove , tags 95 | code_blocks.append(CodeBlock(code_block, start, end, 'code')) 96 | return code_blocks 97 | 98 | # Assumes body is first text argument to SequenceMatcher 99 | def block_to_text(body, block): 100 | return body[block.a : block.a + block.size] 101 | 102 | @dataclass 103 | class Identifier: 104 | text : str 105 | start : int 106 | end : int 107 | 108 | @dataclass 109 | class CodeblockIdentifierEncoding: 110 | identifiers : Iterable[str] 111 | positions : Iterable[Tuple[int, int]] 112 | 113 | def code_to_identifiers(code): 114 | identifier_pattern = r"[a-zA-Z_][a-zA-Z0-9_]*" 115 | identifiers = re.finditer(identifier_pattern, code) 116 | identifier_list = [] 117 | identifier_positions = [] 118 | for identifier in identifiers: 119 | start = identifier.span()[0] 120 | end = identifier.span()[1] 121 | identifier_positions.append((start, end)) 122 | identifier_list.append(code[start : end]) 123 | return CodeblockIdentifierEncoding(identifier_list, identifier_positions) 124 | 125 | def code_blocks_to_identifiers(code_blocks : Iterable[CodeBlock]): 126 | return [code_to_identifiers(code_block.text) for code_block in code_blocks] 127 | 128 | 129 | def text_from_match(text1, text2, match): 130 | sub1 = text1[match.a : match.a + match.size] 131 | sub2 = text2[match.b : match.b + match.size] 132 | assert sub1 == sub2 133 | return sub1 134 | 135 | # Assumes body is first text argument to SequenceMatcher 136 | def block_to_tuple(body, answer, block, body_window=100, answer_window=200, threshold=10): 137 | if block.size < threshold: 138 | return None 139 | else: 140 | body_window = body[max(0, block.a - body_window) : block.a + block.size + body_window] 141 | answer_window = answer[max(0, block.b - answer_window) : block.b + block.size + answer_window] 142 | return (body_window, answer_window) 143 | 144 | 145 | @dataclass 146 | class CandidateRevs: 147 | pre_blocks : Iterable[Block] 148 | post_blocks : Iterable[Block] 149 | 150 | # TODO(dahoas): May also want to collect data on natural text questions asked by users? 151 | @dataclass 152 | class AlignedTriple: 153 | sub_start : int 154 | sub_end : int 155 | sub_text : str # Code being critiqued 156 | 157 | code_start : int 158 | code_end : int 159 | code_text : str # Improved code 160 | 161 | revs : CandidateRevs 162 | 163 | def choose_windows(match_size): 164 | sub_window = 1 * (match_size+5)**1.3 165 | rev_window = 2 * (match_size+8)**1.5 166 | return int(sub_window), int(rev_window) 167 | 168 | # Finds span containing two critiques of captured codeblock 169 | # This could break if code tag is very short reference, so should impose length threshold 170 | def find_maximal_rev_span(rev, rev_start, rev_end): 171 | # Search for text above and text below 172 | while rev_start > 0: 173 | tag = rev[rev_start : rev_start + 6] 174 | if tag == '': break 175 | rev_start -= 1 176 | while rev_end < len(rev): 177 | tag = rev[rev_end : rev_end + 6] 178 | if tag == '': break 179 | rev_end += 1 180 | return rev_start, rev_end 181 | 182 | def identifier_tuple_to_code_tuple( 183 | sub : str, 184 | sub_code_block : CodeBlock, 185 | sub_ident : CodeblockIdentifierEncoding, 186 | rev : str, 187 | rev_code_block : CodeBlock, 188 | rev_ident : CodeblockIdentifierEncoding, 189 | ident_match 190 | ) -> Tuple[str, str]: 191 | sub_matched_ident_start = sub_ident.positions[ident_match.a] 192 | sub_matched_ident_end = sub_ident.positions[ident_match.a + ident_match.size - 1] 193 | rev_matched_ident_start = rev_ident.positions[ident_match.b] 194 | rev_matched_ident_end = rev_ident.positions[ident_match.b + ident_match.size - 1] 195 | 196 | # Determine window sizes from relative lengths of match 197 | # Maybe best bet is to collect both natural language critiques that could possible apply and decide later 198 | match_size = ident_match.size 199 | sub_window, rev_window = choose_windows(match_size) 200 | 201 | sub_start = max(0, sub_code_block.start + sub_matched_ident_start[0] - sub_window) 202 | sub_end = min(len(sub), sub_code_block.start + sub_matched_ident_end[1] + sub_window) 203 | 204 | rev_start = rev_code_block.start + rev_matched_ident_start[0] 205 | rev_end = rev_code_block.start + rev_matched_ident_end[1] 206 | rev_start, rev_end = find_maximal_rev_span(rev, rev_start, rev_end) 207 | 208 | sub_match = sub[sub_start : sub_end] 209 | rev_match = rev[rev_start : rev_end] 210 | 211 | 212 | return AlignedTriple( 213 | sub_start=sub_start, 214 | sub_end=sub_end, 215 | sub_text=sub_match, 216 | 217 | code_start=rev_start, 218 | code_end=rev_end, 219 | code_text=rev_match, 220 | 221 | rev_start=None, 222 | rev_end=None, 223 | rev_text=None, 224 | ) 225 | 226 | ###################### 227 | 228 | 229 | def exp(): 230 | posts = list(data.values()) 231 | post = posts[4500] 232 | body = post['body'] 233 | accepted_answer = get_accepted_answer(post) 234 | 235 | 236 | log(body) 237 | log(accepted_answer) 238 | 239 | submitted_code_blocks : Iterable[CodeBlock] = extract_code_blocks(body) 240 | submitted_code_block_identifiers : Iterable[CodeblockIdentifierEncoding] = code_blocks_to_identifiers(submitted_code_blocks) 241 | 242 | reviewed_code_blocks : Iterable[CodeBlock] = extract_code_blocks(accepted_answer) 243 | reviewed_code_block_identifiers : Iterable[CodeblockIdentifierEncoding] = code_blocks_to_identifiers(reviewed_code_blocks) 244 | 245 | rev_index = 2 246 | matches = SequenceMatcher(None, submitted_code_block_identifiers[0].identifiers, reviewed_code_block_identifiers[rev_index].identifiers).get_matching_blocks() 247 | for match in matches: 248 | if match.size > 0: 249 | log(match) 250 | sub_match, rev_match = identifier_tuple_to_code_tuple( 251 | body, 252 | submitted_code_blocks[0], 253 | submitted_code_block_identifiers[0], 254 | accepted_answer, 255 | reviewed_code_blocks[rev_index], 256 | reviewed_code_block_identifiers[rev_index], 257 | match, 258 | ) 259 | log(sub_match) 260 | log(rev_match) 261 | 262 | 263 | #exp() 264 | 265 | #print(get_lcs(body, accepted_answer)) 266 | #lcs_seq_len = pylcs.lcs_sequence_length(body, accepted_answer) 267 | #print('lcs_seq_len', lcs_seq_len) 268 | #lcs_idx = pylcs.lcs_sequence_idx(body, accepted_answer) 269 | #print('lcs_idx', lcs_idx) 270 | 271 | @dataclass 272 | class AlignedQuadruple: 273 | sub_text : str 274 | pre_blocks : Iterable[Block] 275 | mid_blocks : Iterable[Block] 276 | post_blocks : Iterable[Block] 277 | 278 | def get_sub_window(size): 279 | sub_window = int(1 * (size)**1.3) 280 | return sub_window, sub_window 281 | 282 | def count_list(lst): 283 | ele_counts = {} 284 | for ele in lst: 285 | if ele_counts.get(ele) is None: 286 | ele_counts[ele] = 1 287 | else: 288 | ele_counts[ele] += 1 289 | return ele_counts 290 | 291 | def ident_sim_score(sub_ident : CodeblockIdentifierEncoding, rev_ident: CodeblockIdentifierEncoding): 292 | sub_count = count_list(sub_ident.identifiers) 293 | rev_count = count_list(rev_ident.identifiers) 294 | all_keys = set(sub_count.keys()) | set(rev_count.keys()) 295 | total_len = len(sub_ident.identifiers) + len(rev_ident.identifiers) 296 | if total_len == 0: 297 | return 1 298 | mass = 0 299 | for key in all_keys: 300 | sub_mass = sub_count.get(key) if sub_count.get(key) is not None else 0 301 | rev_mass = rev_count.get(key) if rev_count.get(key) is not None else 0 302 | mass += np.abs(sub_mass - rev_mass) 303 | mass /= total_len 304 | return mass 305 | 306 | def compute_sub_ident_chunks(sub_ident : CodeblockIdentifierEncoding, rev_ident: CodeblockIdentifierEncoding): 307 | sub_len = len(sub_ident.identifiers) 308 | rev_len = len(rev_ident.identifiers) 309 | # Chunk sub_code into overlapping blocks with overlaps of size rev_size(so we don't miss comp) 310 | chunks = [] 311 | for i in range(1, (sub_len // rev_len) + 1): 312 | center_index = rev_len * i 313 | start = max(0, center_index - rev_len) 314 | end = center_index + rev_len 315 | chunk = CodeblockIdentifierEncoding( 316 | sub_ident.identifiers[start : end], 317 | sub_ident.positions[start : end] 318 | ) 319 | #exit() 320 | chunks.append(chunk) 321 | return chunks 322 | 323 | @dataclass 324 | class MergedReviewBlock: 325 | pre_blocks : Iterable[Block] 326 | mid_blocks : Iterable[Block] 327 | post_blocks : Iterable[Block] 328 | sample_code_block : CodeBlock # Always chosen to be first code block which is usually copied from submission and revised later on 329 | 330 | MERGE_THRESH = 0.75 331 | 332 | def merge_review_blocks(block1 : MergedReviewBlock, block2 : ReviewBlock): 333 | code_block1 = code_blocks_to_identifiers([block1.sample_code_block])[0] 334 | code_block2 = code_blocks_to_identifiers([block2.code_block])[0] 335 | score = ident_sim_score(code_block1, code_block2) 336 | if score < MERGE_THRESH: 337 | block1.mid_blocks += block2.pre_blocks 338 | block1.mid_blocks += [block2.code_block] 339 | block1.post_blocks = block2.post_blocks 340 | return block1 341 | else: return None 342 | 343 | num_matches = 0 344 | num_posts = 0 345 | cum_match_len = 0 346 | cum_score = 0 347 | eval_steps = 100 348 | statistics = {"num_samples": len(data)} 349 | posts = list(data.values()) 350 | #posts = [posts[4500]] 351 | saved_data = [] 352 | for i, post in tqdm(enumerate(posts)): 353 | body = post['body'] 354 | try: 355 | accepted_answer = get_accepted_answer(post) 356 | except TypeError: 357 | continue 358 | if accepted_answer is None: 359 | continue 360 | 361 | num_posts += 1 362 | 363 | in_para = False 364 | start = 0 365 | index = 0 366 | blocks : Iterable[Block] = [] 367 | while index < len(accepted_answer): 368 | open_para_tag = accepted_answer[index : index + 3] 369 | close_para_tag = accepted_answer[index : index + 4] 370 | open_code_tag = accepted_answer[index : index + 6] 371 | closed_code_tag = accepted_answer[index : index + 7] 372 | if open_para_tag == '

': 373 | start = index 374 | in_para = True 375 | elif close_para_tag == '

': 376 | in_para = False 377 | index = index + 4 378 | para = accepted_answer[start + 3 : index - 4] 379 | blocks.append(Block(start=start, end=index, text=para, type='para')) 380 | elif not in_para and open_code_tag == '': 381 | start = index 382 | elif not in_para and closed_code_tag == '': 383 | index = index + 7 384 | code = accepted_answer[start + 6 : index - 7] 385 | blocks.append(Block(start=start, end=index, text=code, type='code')) 386 | index += 1 387 | 388 | review_blocks : Iterable[ReviewBlock] = [] 389 | block_type_list = [(i, block.type) for i, block in enumerate(blocks)] 390 | start = 0 391 | pre_review_block = ReviewBlock([], None, []) 392 | post_review_block = None 393 | for block in blocks: 394 | if block.type == 'code': 395 | pre_review_block.code_block = block 396 | 397 | # current post_review_block is done 398 | if post_review_block is not None: review_blocks.append(post_review_block) 399 | 400 | # pre_review_block becomes post_review_block 401 | post_review_block = pre_review_block 402 | 403 | # make new pre_review_block 404 | pre_review_block = ReviewBlock([], None, []) 405 | 406 | else: 407 | # Always add current para block to the start of the pre_review_block 408 | pre_review_block.pre_blocks.append(block) 409 | 410 | # Add currrent para block to end of post review block if active 411 | if post_review_block is not None: post_review_block.post_blocks.append(block) 412 | 413 | #Append last post_review_block 414 | review_blocks.append(post_review_block) 415 | review_blocks = [review_block for review_block in review_blocks if review_block is not None and len(review_block.code_block.text) > 0] 416 | 417 | # Preprocess to merge similar code review blocks 418 | merged_review_blocks = [] 419 | cur_merged_review_block = None 420 | for i in range(len(review_blocks)): 421 | cur_review_block = review_blocks[i] 422 | if i == 0: 423 | cur_merged_review_block = MergedReviewBlock( 424 | cur_review_block.pre_blocks, 425 | [cur_review_block.code_block], 426 | cur_review_block.post_blocks, 427 | cur_review_block.code_block 428 | ) 429 | else: 430 | new_merged_review_block = merge_review_blocks(cur_merged_review_block, cur_review_block) 431 | if new_merged_review_block is None: 432 | merged_review_blocks.append(cur_merged_review_block) 433 | cur_merged_review_block = MergedReviewBlock( 434 | cur_review_block.pre_blocks, 435 | [cur_review_block.code_block], 436 | cur_review_block.post_blocks, 437 | cur_review_block.code_block 438 | ) 439 | else: 440 | cur_merged_review_block = new_merged_review_block 441 | # Append last MergedReviewBlock 442 | merged_review_blocks.append(cur_merged_review_block) 443 | merged_review_blocks = [review_block for review_block in merged_review_blocks if review_block is not None] 444 | 445 | submitted_code_blocks : Iterable[CodeBlock] = extract_code_blocks(body) 446 | submitted_code_block_identifiers : Iterable[CodeblockIdentifierEncoding] = code_blocks_to_identifiers(submitted_code_blocks) 447 | 448 | for review_block in merged_review_blocks: 449 | rev_block = review_block.sample_code_block 450 | rev_ident : CodeblockIdentifierEncoding = code_blocks_to_identifiers([rev_block])[0] 451 | if len(rev_ident.identifiers) < 1: 452 | continue 453 | for sub_block, sub_ident in zip(submitted_code_blocks, submitted_code_block_identifiers): 454 | # Probably need a better matching mechanism than SequenceMatcher: need not be contiguous in reality 455 | sub_ident_chunks = compute_sub_ident_chunks(sub_ident, rev_ident) 456 | min_score = 1 457 | min_chunk = None 458 | for chunk in sub_ident_chunks: 459 | score = ident_sim_score(chunk, rev_ident) 460 | if score < min_score: 461 | min_score = score 462 | min_chunk = chunk 463 | 464 | THRESH = 0.5 465 | 466 | if min_chunk is not None and min_score < THRESH: 467 | num_matches += 1 468 | cum_score += min_score 469 | 470 | sub_start = min_chunk.positions[0][0] 471 | sub_end = min_chunk.positions[-1][-1] 472 | 473 | sub_text = sub_block.text[sub_start : sub_end] 474 | 475 | al = AlignedQuadruple( 476 | sub_text=sub_text, 477 | pre_blocks=review_block.pre_blocks, 478 | mid_blocks=review_block.mid_blocks, 479 | post_blocks=review_block.post_blocks, 480 | ) 481 | 482 | al_dict = { 483 | 'sub_text': al.sub_text, 484 | 'pre_blocks': [pre_block.text for pre_block in al.pre_blocks], 485 | 'mid_blocks': [mid_block.text for mid_block in al.mid_blocks], 486 | 'post_blocks': [post_block.text for post_block in al.post_blocks] 487 | } 488 | 489 | saved_data.append(al_dict) 490 | 491 | if num_matches % eval_steps == 0: 492 | log("SUBMITTED") 493 | log(al.sub_text) 494 | log("PRE") 495 | log(al.pre_blocks) 496 | log("MID") 497 | log(al.mid_blocks) 498 | log("POST") 499 | log(al.post_blocks) 500 | 501 | import json 502 | with open("aligned_data.json", 'w') as f: 503 | json.dump(saved_data, f) 504 | 505 | 506 | statistics['samples_with_answers'] = num_posts 507 | statistics['avg_num_matches'] = num_matches / statistics['samples_with_answers'] 508 | statistics['avg_match_len'] = cum_match_len / num_matches 509 | statistics['num_matches'] = num_matches 510 | statistics['avg_match_score'] = cum_score / num_matches 511 | print(json.dumps(statistics, indent=2)) 512 | 513 | 514 | 515 | '''matches = SequenceMatcher(None, sub_ident.identifiers, rev_ident.identifiers).get_matching_blocks() 516 | max_len_id = np.argmax([match.size for match in matches]) 517 | max_match = matches[max_len_id] 518 | 519 | THRESH = 5 520 | 521 | if max_match.size > THRESH: 522 | num_matches += 1 523 | cum_match_len += max_match.size 524 | 525 | sub_matched_ident_start = sub_ident.positions[max_match.a] 526 | sub_matched_ident_end = sub_ident.positions[max_match.a + max_match.size - 1] 527 | sub_start = sub_matched_ident_start[0] 528 | sub_end = sub_matched_ident_end[1] 529 | pre_window, post_window = get_sub_window(sub_end - sub_start) 530 | sub_start = max(0, sub_start - pre_window) 531 | sub_end = sub_end + post_window 532 | 533 | sub_text = sub_block.text[sub_start : sub_end] 534 | 535 | al = AlignedQuadruple( 536 | sub_text=sub_text, 537 | code_text=rev_block, 538 | pre_blocks=review_block.pre_blocks, 539 | post_blocks=review_block.post_blocks, 540 | ) 541 | 542 | if num_matches % eval_steps == 0: 543 | log("SUBMITTED") 544 | log(al.sub_text) 545 | log("PRE") 546 | log(al.pre_blocks) 547 | log("CODE") 548 | log(al.code_text) 549 | log("POST") 550 | log(al.post_blocks)''' 551 | 552 | 553 | '''lcs, blocks = get_lcs(body, accepted_answer) 554 | if i == 1493: 555 | log(post) 556 | for block in blocks: 557 | tuple_t = block_to_tuple(body, accepted_answer, block) 558 | if tuple_t is not None: 559 | log("Tuple") 560 | log("BODY\n" + tuple_t[0]) 561 | log("ANSWER\n" + tuple_t[1]) 562 | exit() 563 | num_matches = len(blocks) 564 | if num_matches > statistics['max_matches_in_post']: 565 | statistics['max_matches_in_post'] = num_matches 566 | statistics['max_matches_id'] = i 567 | statistics['avg_lcs_len'] += len(lcs) 568 | #print(lcs)''' 569 | 570 | '''def extract_blocks(rev : str, pattern : str, tag : str) -> Iterable[Block]: 571 | block_matches = re.finditer(pattern, rev) 572 | blocks = [] 573 | for match in block_matches: 574 | start, end = match.span() 575 | code_block = rev[start + len(tag) : end - (len(tag) + 1)] # Want to remove , tags 576 | blocks.append(Block(code_block, start, end)) 577 | return blocks 578 | 579 | 580 | # Chunks review into code and paragraph blocks 581 | def rev_to_pcblocks(rev): 582 | codeblock_pattern = r'(?s)((?!).)*<\/code>' 583 | para_pattern = r'

(?s)((?!

).)*<\/p>' 584 | 585 | code_blocks : Iterable[Block] = extract_blocks(rev, codeblock_pattern) 586 | para_blocks : Iterable[Block] = extract_blocks(rev, para_pattern) 587 | 588 | merged_blocks = [] 589 | i,j = 0, 0 590 | while i < len(code_blocks) and j < len(para_blocks): 591 | code_block = code_blocks[i] 592 | para_block = para_block[]''' -------------------------------------------------------------------------------- /data/helper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from pprint import pprint 4 | from re import L 5 | from bs4 import BeautifulSoup 6 | from lm_dataformat import Archive 7 | from tqdm import tqdm 8 | 9 | def load_json_file(file_path:str)-> dict: 10 | with open(file_path,"r") as f: 11 | return json.load(f) 12 | 13 | def dump_json_file(file_path:str, data:dict)->None: 14 | with open(file_path,"w") as f: 15 | json.dump(data,f) 16 | 17 | 18 | 19 | parse_body = lambda x: x 20 | 21 | def parse_html_to_str(body:str) -> list: 22 | 23 | html_parsed = BeautifulSoup(body, 'html.parser') 24 | strings = [] 25 | children = html_parsed.children 26 | 27 | for child in children: 28 | child_text = child.text 29 | if child.name == "pre": child_text = "" + child_text + "" 30 | strings.append(child_text) 31 | 32 | return strings 33 | 34 | 35 | 36 | def get_accepted_answer(code_review_data:dict): 37 | """ 38 | Provides the accepted answer of a question, if an accepted answer is available. 39 | """ 40 | accepted_answer_body = None 41 | if "AcceptedAnswerId" in code_review_data["meta_data"].keys(): 42 | accepted_answer_index = int(code_review_data["meta_data"]["AcceptedAnswerId"]) 43 | for code_review_answer in code_review_data["answers"]: 44 | if int(code_review_answer["meta_data"]["Id"]) == accepted_answer_index: 45 | accepted_answer_body = code_review_answer["body"] 46 | return parse_body(accepted_answer_body) 47 | 48 | def duplicate_data(data, is_ans=False, n=10): 49 | """ 50 | Duplicate the data by n times for subsequent augmentation. If is_ans is True, then key indicates that the answer is to be augmented, else the question is to be augmented. 51 | """ 52 | for question in list(data.keys()): 53 | for i in range(n): 54 | if is_ans: 55 | key_name = question + '_ans' + str(i) 56 | else: 57 | key_name = question + '_q' + str(i) 58 | data[key_name] = copy.deepcopy(data[question]) 59 | return data 60 | 61 | def iter_body(augs, body): 62 | body_strings = parse_html_to_str(body) 63 | for i in range(len(body_strings)): 64 | body_string = body_strings[i] 65 | if "" in body_string.split(): 66 | continue 67 | body_strings[i] = augs([body_string])[0] 68 | return ' '.join(body_strings) 69 | 70 | 71 | def create_dataset_for_QA(data, dest, parse_body=True): 72 | """ 73 | Create a JSON file for the dataset. To do this, for each answer, append to the list a dict with the title, question, and answer, then dump to JSON. 74 | """ 75 | 76 | dataset_list = [] 77 | for question in tqdm(list(data.keys())): 78 | for answer in data[question]["answers"]: 79 | if parse_body: 80 | body_strings = parse_html_to_str(data[question]["body"]) 81 | question_body = ' '.join(body_strings) 82 | body_strings = parse_html_to_str(answer["body"]) 83 | answer_body = ' '.join(body_strings) 84 | else: 85 | question_body = data[question]["body"] 86 | answer_body = answer["body"] 87 | 88 | dataset_list.append({ 89 | "title": data[question]["meta_data"]["Title"], 90 | "question": question_body, 91 | "answer": answer_body 92 | }) 93 | dump_json_file(dest, dataset_list) 94 | 95 | 96 | 97 | def create_dataset_for_20b(data, question_token = ' ', answer_token = ' ', archive_name = 'codereview_20b', parse_body = True): 98 | """ 99 | Create dataset for 20b training with lm_dataformat (`Archive`) 100 | """ 101 | ar = Archive('dataset') 102 | for question in tqdm(list(data.keys())): 103 | text = [] 104 | body = data[question]['body'] 105 | if parse_body: 106 | body_strings = parse_html_to_str(body) 107 | body = ' '.join(body_strings) 108 | text.append(body) 109 | for answer in data[question]['answers']: 110 | body = answer['body'] 111 | if parse_body: 112 | body_strings = parse_html_to_str(body) 113 | body = ' '.join(body_strings) 114 | text.append(body) 115 | text = answer_token.join(text) 116 | text = question_token + text 117 | ar.add_data(text) # do we need to add metadata? 118 | ar.commit(archive_name) # commit the archive 119 | 120 | if __name__ == "__main__": 121 | dataset = load_json_file("dataset/CodeReviewSE_clean.json") 122 | create_dataset_for_QA(dataset, "dataset/CodeReviewSE_clean_QA.json") 123 | # dataset = dataset[list(dataset.keys())[100]] 124 | #create_dataset_for_20b(dataset) 125 | # print(dataset["body"]) 126 | # print("#######") 127 | #pprint(dataset.keys()) 128 | #pprint(dataset['meta_data']) 129 | #pprint(get_accepted_answer(dataset)) 130 | -------------------------------------------------------------------------------- /data/pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | from data.helper import * 3 | 4 | #placeholder function to apply to each element of the dataset 5 | placeholder = lambda x: x 6 | 7 | 8 | preproc_fn_dict = { 9 | "placeholder": placeholder 10 | } 11 | 12 | class DataProcess: 13 | def __init__(self,path:str,config_path:str) -> None: 14 | """ 15 | path(str) : path to the json file. 16 | config_path(str) : path to the config file. 17 | """ 18 | self.dataset = load_json_file(path) 19 | self.config = load_json_file(config_path) 20 | 21 | def apply_config_preproc(self)->dict: 22 | """ 23 | Apply the preprocessing config to the dataset. 24 | """ 25 | for key in self.config: 26 | if key in preproc_fn_dict.keys(): 27 | dataset : dict = preproc_fn_dict[key](self.dataset) 28 | else: 29 | raise Exception("No such preprocessing function.") 30 | 31 | self.dataset = dataset 32 | 33 | def save_preproc_dataset(self)->None: 34 | """ 35 | Save the preprocessed dataset. 36 | """ 37 | dump_json_file(self.config["output_path"],self.dataset) 38 | 39 | 40 | 41 | if __name__ == "__main__": 42 | pipeline = DataProcess("dataset/CodeReviewSE.json","configs/preproc_config.json") 43 | pipeline.apply_config_preproc() 44 | pipeline.save_preproc_dataset() -------------------------------------------------------------------------------- /data_cleaning.py: -------------------------------------------------------------------------------- 1 | from data.helper import * 2 | from tqdm import tqdm 3 | 4 | 5 | def remove_empty_questions(data): 6 | """ 7 | Remove questions with no body 8 | """ 9 | for k,v in tqdm(data.copy().items()): 10 | if v['body'] == '': 11 | del data[k] 12 | return data 13 | 14 | def remove_questions_with_space(data): 15 | """ 16 | Remove question that has body with a space at the end. This usually indicates a tag. 17 | """ 18 | for question in tqdm(list(data.copy().keys())): 19 | body = data[question]["body"] 20 | body_strings = parse_html_to_str(body) 21 | if len(body_strings) > 0: 22 | if len(body_strings[-1]) > 0: 23 | if body_strings[-1][-1] == ' ': 24 | del data[question] 25 | else: 26 | del data[question] 27 | 28 | return data 29 | 30 | if __name__ == "__main__": 31 | data = load_json_file("dataset/CodeReviewSE.json") 32 | data = remove_empty_questions(data) 33 | data = remove_questions_with_space(data) 34 | dump_json_file("dataset/CodeReviewSE_clean.json", data) -------------------------------------------------------------------------------- /datadump.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | import multiprocessing 5 | 6 | # load posts 7 | with open("dataset/Posts.json") as json_file: 8 | posts = json.load(json_file) 9 | 10 | # load comments 11 | with open("dataset/Comments.json") as json_file: 12 | comments = json.load(json_file) 13 | 14 | """ 15 | question: 16 | question_meta data (user ID, upvotes, accepted answer, date posted, tags) 17 | question_comments 18 | question_comments_meta_data (user ID, upvotes, date posted) 19 | question_comment_body 20 | question_body 21 | 22 | question_answers 23 | question_answer_meta_data (user ID, upvotes, is_accepted, date posted) 24 | question 25 | question_answer_body 26 | """ 27 | 28 | def iterate_over_posts(file): 29 | """ 30 | Constructs a dictionary of the format above 31 | TODO: (excluding comments, which we will add later) 32 | """ 33 | output_dictionary = {} 34 | answers = defaultdict(list) 35 | 36 | for post in file['posts']['row']: 37 | post_dict = {} 38 | Id = post['@Id'] 39 | meta_data = [\ 40 | ('Id', Id), 41 | ('Score', post['@Score']), 42 | ('CreationDate', post['@CreationDate']), 43 | ('CommentCount', post['@CommentCount']), 44 | ('ContentLicense', post['@ContentLicense']) 45 | ] 46 | post_dict['body'] = post['@Body'] 47 | post_dict['comments'] = list() 48 | 49 | # is an answer 50 | if post['@PostTypeId'] == '2': 51 | answer_specific_meta_data = [\ 52 | ('ParentId', post['@ParentId']) 53 | ] 54 | meta_data += answer_specific_meta_data 55 | 56 | # is a question 57 | else: 58 | try: 59 | tags = post['@Tags'] 60 | tags = tags.split('><') 61 | tags = [tag.replace('<','').replace('>','') for tag in tags] 62 | question_specific_meta_data = [ 63 | ('Tags', tags), 64 | ] 65 | except: 66 | question_specific_meta_data = list() 67 | # if the post has a title 68 | if '@Title' in post.keys(): 69 | question_specific_meta_data += [\ 70 | ('Title', post['@Title']) 71 | ] 72 | # is there an accepted answer 73 | if '@AcceptedAnswerId' in post.keys(): 74 | question_specific_meta_data += [\ 75 | ('AcceptedAnswerId', post['@AcceptedAnswerId']) 76 | ] 77 | meta_data += question_specific_meta_data 78 | 79 | post_dict['answers'] = list() 80 | 81 | # copy tuples into dictionary 82 | meta_data_dict = {} 83 | for key,value in meta_data: 84 | meta_data_dict[key] = value 85 | 86 | post_dict['meta_data'] = meta_data_dict 87 | 88 | # if this is an answer 89 | if post['@PostTypeId'] == '2': 90 | #add it to the corresponding question 91 | parent_id = post['@ParentId'] 92 | answers[parent_id].append(post_dict) 93 | # if it is a question 94 | else: 95 | output_dictionary[Id] = post_dict 96 | 97 | # add answers 98 | for k,v in answers.items(): 99 | output_dictionary[k]['answers'] = v 100 | 101 | return output_dictionary 102 | 103 | 104 | def iterate_over_comments(file, post_dict): 105 | comments = defaultdict(list) 106 | for comment in file['comments']['row']: 107 | comment_output_dict = {} 108 | parent_id = comment['@PostId'] 109 | comment_output_dict['body'] = comment['@Text'] 110 | meta_data = [\ 111 | ('Id', comment['@Id']), 112 | ('Score', comment['@Score']), 113 | ('CreationDate', comment['@CreationDate']), 114 | ('ContentLicense', comment['@ContentLicense']) 115 | ] 116 | for k,v in meta_data: 117 | comment_output_dict[k] = v 118 | comments[parent_id].append(comment_output_dict) 119 | pool_obj = multiprocessing.Pool() 120 | 121 | for k,v in tqdm(comments.items()): 122 | try: 123 | post_dict[k]['comments'] = v 124 | except: 125 | # it isnt for a question, must be for a answer. 126 | # Iterate over all questions and all answers using multithreading 127 | for question_id, question_dict in post_dict.items(): 128 | for idx, answer in enumerate(question_dict['answers']): 129 | if answer['meta_data']['Id'] == k: 130 | post_dict[question_id]['answers'][idx]['comments'] = v 131 | break 132 | 133 | return post_dict 134 | 135 | 136 | post_dict = iterate_over_posts(posts) 137 | post_dict = iterate_over_comments(comments, post_dict) 138 | 139 | # Save dictionary to json file named CodeReviewSE.json 140 | with open('dataset/CodeReviewSE.json', 'w') as outfile: 141 | json.dump(post_dict, outfile) 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /instruct_augment_code_review/augment_code_review.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from logger import Logger 3 | import json 4 | import random 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | import time 9 | from datasets import load_dataset 10 | from utils import extract_max_code_block, filter_queried 11 | 12 | 13 | def query(prompt_batch, max_tokens): 14 | prompts = [] 15 | for sample in prompt_batch: 16 | body = sample["body"] 17 | answer = sample["answer"]["body"] 18 | prompt = "Question: {} \n\n Answer: {} \n\n This is a question and answer from a forum where users review and improve the code of other users. Please output the original code, a summary of the critique, and the revised code using the format ORIGINAL: [write original code here] CRITIQUE: [write critique here] REVISED: [write revision code here]. \n\n".format(body, answer) 19 | prompts.append(prompt) 20 | 21 | responses = openai.Completion.create(engine='text-davinci-003', prompt=prompts, max_tokens=max_tokens, temperature=0.1)["choices"] 22 | for prompt, response, sample in zip(prompts, responses, prompt_batch): 23 | text = response["text"] 24 | sample["prompt"] = prompt 25 | sample["response"] = text 26 | Logger.log([sample]) 27 | responses = [response["text"] for response in responses] 28 | return responses 29 | 30 | 31 | def augment_code_review(): 32 | code_review_dataset = load_dataset("Dahoas/2048_has_code_filtered_base_code_review")["train"] 33 | reformatted_dataset = [] 34 | for sample in code_review_dataset: 35 | for answer in sample["answers"]: 36 | reformatted_dataset.append({"body": sample["body"], "answer": answer, "comments": sample["comments"], "meta_data": sample["meta_data"], "question_id": sample["question_id"]}) 37 | length = len(reformatted_dataset) 38 | code_review_dataset = filter_queried(reformatted_dataset) 39 | new_length = len(code_review_dataset) 40 | print("Old len: {}, New len: {}".format(length, new_length)) 41 | 42 | prompts_per_query = 10 43 | batched_prompts = [code_review_dataset[i*prompts_per_query : (i+1)*prompts_per_query] for i in range((len(code_review_dataset) + prompts_per_query - 1) // prompts_per_query)] 44 | 45 | for prompt_batch in tqdm(batched_prompts): 46 | try: 47 | query(prompt_batch, 2048) 48 | except openai.error.RateLimitError: 49 | print("RATELIMIT ERROR") 50 | time.sleep(15) 51 | except openai.error.ServiceUnavailableError: 52 | print("SERVICE UNABAILABLE") 53 | time.sleep(15) 54 | except openai.error.Timeout: 55 | print("TIMEOUT") 56 | time.sleep(15) 57 | except: 58 | print("SOME OTHER EXCEPTION") 59 | time.sleep(30) 60 | time.sleep(10) # Sleep to prevent rate limiting 61 | 62 | def test(): 63 | code_review_dataset = load_dataset("Dahoas/2048_has_code_filtered_base_code_review")["train"] 64 | sample = code_review_dataset[25002] 65 | body = sample["body"] 66 | answer = sample["answers"][0]["body"] 67 | 68 | prompt = "Question: {} \n\n Answer: {} \n\n This is a question and answer from a forum where users review and improve the code of other users. Please output the original code, a summary of the critique, and the revised code using the format ORIGINAL: [write original code here] CRITIQUE: [write critique here] REVISED: [write revision code here]. \n\n".format(body, answer) 69 | query([prompt], 2048) 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--log_file") 74 | parser.add_argument("--oai_key") 75 | args = parser.parse_args() 76 | 77 | Logger.init(args.log_file) 78 | openai.api_key = args.oai_key 79 | 80 | query_instruct() 81 | -------------------------------------------------------------------------------- /instruct_augment_code_review/compute_stats.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from datasets import load_dataset 3 | from tqdm import tqdm 4 | 5 | 6 | def compute_stats(base_dataset): 7 | base_dataset = base_dataset.shuffle() 8 | tok = AutoTokenizer.from_pretrained("gpt2") 9 | 10 | stats = {"avg_body_len": 0, "avg_answer_len": 0, "max_body_len": 0, "max_answer_len": 0, "avg_num_answers": 0} 11 | cnt = 0 12 | for post in tqdm(base_dataset): 13 | if cnt > 5000: 14 | break 15 | cnt += 1 16 | body = post["body"] 17 | l = len(tok(body)["input_ids"]) 18 | stats["avg_body_len"] += l 19 | stats["max_body_len"] = max(l, stats["max_body_len"]) 20 | 21 | for answer in post["answers"]: 22 | l = len(tok(answer["body"])["input_ids"]) 23 | stats["avg_answer_len"] += l 24 | stats["max_answer_len"] = max(l, stats["max_answer_len"]) 25 | 26 | stats["avg_num_answers"] += len(post["answers"]) 27 | 28 | stats["avg_body_len"] = stats["avg_body_len"] / cnt 29 | stats["avg_answer_len"] = stats["avg_answer_len"] / stats["avg_num_answers"] 30 | stats["avg_num_answers"] = stats["avg_num_answers"] / cnt 31 | 32 | print(stats) 33 | 34 | if __name__ == "__main__": 35 | base_dataset = load_dataset("Dahoas/2048_has_code_filtered_base_code_review")["train"] 36 | compute_stats(base_dataset) -------------------------------------------------------------------------------- /instruct_augment_code_review/filter_data.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from datasets import load_dataset, Dataset 3 | from tqdm import tqdm 4 | import re 5 | import torch 6 | from utils import extract_max_code_block 7 | 8 | 9 | def filter_dataset(base_dataset): 10 | tok = AutoTokenizer.from_pretrained("gpt2") 11 | MAX_LENGTH = 2048 12 | CODE_BLOCK_THRESHOLD = 5 13 | 14 | new_dataset = {key: [] for key in base_dataset[0].keys()} 15 | cnt = 0 16 | for post in tqdm(base_dataset): 17 | cnt += 1 18 | body = post["body"] 19 | bl = len(tok(body)["input_ids"]) 20 | 21 | new_answers = [] 22 | for answer in post["answers"]: 23 | al = len(tok(answer["body"])["input_ids"]) 24 | if bl + al <= MAX_LENGTH: 25 | max_code_block = extract_max_code_block(answer["body"]) 26 | if max_code_block is not None and len(max_code_block.split(" ")) > CODE_BLOCK_THRESHOLD: 27 | new_answers.append(answer) 28 | 29 | if len(new_answers) > 0: 30 | new_dataset["body"].append(body) 31 | new_dataset["answers"].append(new_answers) 32 | new_dataset["comments"].append(post["comments"]) 33 | new_dataset["meta_data"].append(post["meta_data"]) 34 | new_dataset["question_id"].append(post["question_id"]) 35 | 36 | new_dataset = Dataset.from_dict(new_dataset) 37 | new_dataset.push_to_hub("Dahoas/2048_has_code_filtered_base_code_review") 38 | 39 | 40 | def reformat_by_question(base_dataset): 41 | new_dataset = {"body": [], "comments": [], "answer": [], "meta_data": [], "question_id": []} 42 | for sample in new_dataset: 43 | for answer in sample["answers"]: 44 | new_dataset["body"].append(sample["body"]) 45 | 46 | 47 | 48 | 49 | if __name__ == "__main__": 50 | #base_dataset = load_dataset("Dahoas/base_code_review")["train"] 51 | #filter_dataset(base_dataset) 52 | base_dataset = load_dataset("Dahoas/2048_has_code_filtered_base_code_review")["train"] 53 | reformat_by_question(base_dataset) -------------------------------------------------------------------------------- /instruct_augment_code_review/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class Logger: 4 | name: str 5 | 6 | @classmethod 7 | def init(cls, name): 8 | assert name is not None 9 | cls.name = name 10 | print(f"Logging in {cls.name}") 11 | 12 | @classmethod 13 | def log(cls, dicts): 14 | with open(f'{cls.name}.jsonl', 'a+') as f: 15 | for dict_t in dicts: 16 | json.dump(dict_t, f) 17 | f.write('\n') 18 | -------------------------------------------------------------------------------- /instruct_augment_code_review/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import torch 4 | from datasets import Dataset, load_dataset 5 | from tqdm import tqdm 6 | from copy import deepcopy 7 | 8 | def load_jsonl(filename): 9 | data = [] 10 | with open(filename, "r") as f: 11 | lines = f.readlines() 12 | for line in lines: 13 | response = json.loads(line) 14 | data.append(response) 15 | return data 16 | 17 | def write_jsonl(dataset, filename): 18 | with open(filename, "w") as f: 19 | for ele in dataset: 20 | json.dump(ele, f) 21 | f.write("\n") 22 | 23 | def inspect_output(): 24 | dataset = load_jsonl("augmentations.jsonl") 25 | sample = dataset[-1] 26 | with open("inspect.txt", "w") as f: 27 | f.write(sample["prompt"]) 28 | f.write("\n\n\n\n") 29 | f.write(sample["response"]) 30 | 31 | 32 | def extract_max_code_block(text : str): 33 | codeblock_pattern = r'(?s)((?!).)*<\/code>' 34 | code_block_matches = re.finditer(codeblock_pattern, text) 35 | code_blocks = [] 36 | for match in code_block_matches: 37 | start, end = match.span() 38 | code_block = text[start + 6 : end - 7] # Want to remove , tags 39 | code_blocks.append(code_block) 40 | lengths = torch.tensor([len(block) for block in code_blocks]) 41 | if len(code_blocks) == 0: 42 | return None 43 | argmax = torch.argmax(lengths) 44 | return code_blocks[argmax] 45 | 46 | def filter_queried(dataset): 47 | queried = load_jsonl("full_augmentations.jsonl") 48 | print(len(queried)) 49 | for query in tqdm(queried): 50 | flag = False 51 | QId = query["question_id"] 52 | Id = query["answer"]["meta_data"]["Id"] 53 | for i, sample in enumerate(dataset): 54 | cur_QId = sample["question_id"] 55 | cur_Id = sample["answer"]["meta_data"]["Id"] 56 | if QId == cur_QId and Id == cur_Id: 57 | flag = True 58 | dataset.pop(i) 59 | break 60 | #if not flag: 61 | #print(QId) 62 | #print(Id) 63 | #raise ValueError("Unsupported query") 64 | return dataset 65 | 66 | def filter_instruct_augments(): 67 | dataset = load_jsonl("filtered_full_augmentations.jsonl") 68 | print("dataset len", len(dataset)) 69 | removal_indices = [] 70 | cnt=0 71 | for i in tqdm(range(30000, len(dataset))): 72 | QId = dataset[i]["question_id"] 73 | Id = dataset[i]["answer"]["meta_data"]["Id"] 74 | for j in range(i+1, len(dataset)): 75 | ele_QId = dataset[j]["question_id"] 76 | ele_Id = dataset[j]["answer"]["meta_data"]["Id"] 77 | if QId == ele_QId and Id == ele_Id: 78 | removal_indices.append(j) 79 | cnt += 1 80 | print(cnt) 81 | break 82 | dataset = [ele for i, ele in enumerate(dataset) if i not in removal_indices] 83 | #write_jsonl(dataset, "filtered_full_augmentations.jsonl") 84 | print(cnt) 85 | 86 | def upload_dataset(): 87 | dataset = load_jsonl("filtered_full_augmentations.jsonl") 88 | dict_dataset = {key: [] for key in dataset[0].keys()} 89 | for ele in dataset: 90 | for key in dict_dataset: 91 | dict_dataset[key].append(ele[key]) 92 | hf_dataset = Dataset.from_dict(dict_dataset) 93 | hf_dataset.push_to_hub("Dahoas/code-review-instruct-critique-revision") 94 | 95 | def upload_python_subset(): 96 | dataset = load_dataset("Dahoas/code-review-instruct-critique-revision") 97 | sample = dataset["train"][0] 98 | print(sample["meta_data"]) 99 | python_dataset = [] 100 | for ele in tqdm(dataset["train"]): 101 | if "python" in ele["meta_data"]["Tags"]: 102 | python_dataset.append(ele) 103 | print(len(python_dataset)) 104 | python_dict = {key: [] for key in sample.keys()} 105 | for ele in python_dataset: 106 | for key in ele.keys(): 107 | python_dict[key].append(ele[key]) 108 | python_dataset = Dataset.from_dict(python_dict) 109 | python_dataset.push_to_hub("Dahoas/code-review-instruct-critique-revision-python") 110 | 111 | if __name__ == "__main__": 112 | #inspect_output() 113 | #upload_dataset() 114 | upload_python_subset() 115 | #filter_instruct_augments() 116 | -------------------------------------------------------------------------------- /models/backtranslation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 3 | 4 | 5 | class BackTranslationModel: 6 | def __init__(self, src_model_name, tgt_model_name, device='cpu', batch_size=32, max_length=300): 7 | self.src_model = AutoModelForSeq2SeqLM.from_pretrained(src_model_name) 8 | self.src_model.eval() 9 | self.src_model.to(device) 10 | self.tgt_model = AutoModelForSeq2SeqLM.from_pretrained(tgt_model_name) 11 | self.tgt_model.eval() 12 | self.tgt_model.to(device) 13 | self.src_tokenizer = AutoTokenizer.from_pretrained(src_model_name) 14 | self.tgt_tokenizer = AutoTokenizer.from_pretrained(tgt_model_name) 15 | 16 | self.batch_size = batch_size 17 | self.max_length = max_length 18 | 19 | def translate(self, src_text): 20 | src_ids = self.src_tokenizer.encode(src_text, return_tensors='pt') 21 | src_ids = src_ids.to(self.src_model.device) 22 | output = self.src_model.generate(src_ids, do_sample=True, max_length=self.max_length) 23 | output = output.to('cpu') 24 | output = output.tolist() 25 | output = output[0] 26 | output = self.tgt_tokenizer.decode(output) 27 | return output -------------------------------------------------------------------------------- /old/scrape.py: -------------------------------------------------------------------------------- 1 | from stackapi import StackAPI 2 | import json 3 | 4 | site = StackAPI("codereview") 5 | #questions = site.fetch("questions") 6 | 7 | # writes questions to a json file, saved at questions.json 8 | def write_questions(questions): 9 | with open("questions.json", "w") as f: 10 | json.dump(questions, f) 11 | 12 | #write_questions(questions) 13 | #print("Questions have been saved!") 14 | 15 | 16 | # loads questions from a json file named questions.json 17 | def load_questions(): 18 | with open("questions.json", "r") as f: 19 | questions = json.load(f) 20 | return questions 21 | 22 | questions = load_questions() 23 | 24 | #print(questions) 25 | #print(list(questions['items'][0].keys())) 26 | question_id = questions['items'][10]["question_id"] 27 | 28 | def get_answers(site, question_id): 29 | answers = site.fetch(f"questions/{question_id}/answers", filter="withbody") 30 | return answers 31 | # utilizes StackAPI to, given a question_id, fetch the body of all associated answers 32 | def fetch_answers_given_question_id(site, question_id): 33 | answers = get_answers(site, question_id) 34 | answers_list = answers['items'] 35 | answers_body = [] 36 | for answer in answers_list: 37 | answers_body.append(answer['body']) 38 | return answers_body 39 | 40 | def fetch_comments_given_id(site, id): 41 | comments = site.fetch(f"posts/{question_id}/comments", filter='withbody') 42 | comments_list = comments['items'] 43 | comments_body = [] 44 | for comment in comments_list: 45 | comments_body.append(comment['body']) 46 | return comments_body 47 | 48 | #print(get_answers(site, question_id)['items'][0].keys()) 49 | 50 | #answer = get_answers(site, question_id)['items'][0] 51 | 52 | print() 53 | 54 | answer_body = fetch_answers_given_question_id(site, question_id) 55 | print(answer_body) 56 | 57 | """ 58 | question: 59 | question_meta data (user ID, upvotes, accepted answer, date posted, tags) 60 | question_comments 61 | question_comments_meta_data (user ID, upvotes, date posted) 62 | question_comment_body 63 | question_body 64 | 65 | question_answers 66 | question_answer_meta_data (user ID, upvotes, is_accepted, date posted) 67 | question 68 | question_answer_body 69 | """ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | StackAPI 3 | xmltodict 4 | tree_sitter 5 | numpy 6 | torch 7 | beautifulsoup4 8 | transformers 9 | lm-dataformat 10 | datasets 11 | accelerate 12 | rouge-score -------------------------------------------------------------------------------- /stats/tags.json: -------------------------------------------------------------------------------- 1 | {"php": 4285, "mysql": 1091, "constructor": 141, "python": 14158, "optimization": 1065, "algorithm": 4793, "programming-challenge": 3134, "library": 308, "c#": 9739, "performance": 7251, "codeigniter": 103, "sql": 1447, "c++": 8638, "checksum": 71, "csv": 621, "strings": 2772, "elisp": 38, "unit-testing": 988, "cache": 265, "mvc": 548, "doctrine": 28, "vb6": 33, "random": 756, "shuffle": 90, "java": 10679, "computational-geometry": 235, "javascript": 9464, "jquery": 2401, "ajax": 375, "pagination": 126, "singleton": 179, "classes": 351, "postgresql": 176, "parsing": 1262, "url": 233, "multithreading": 1519, "objective-c": 370, "twitter": 85, "perl": 222, "subroutine": 13, ".net": 1169, "design-patterns": 1379, "security": 750, "regex": 843, "linq": 601, "file": 783, "fortran": 33, "mysqli": 254, "ruby": 1636, "ruby-on-rails": 599, "html": 1763, "css": 1008, "converting": 353, "game": 1793, "plugin": 168, "ext.js": 4, "asp.net": 452, "coordinate-system": 189, "winforms": 297, "synchronization": 44, "timeout": 38, "configuration": 126, "beginner": 7084, "django": 339, "xml": 443, "ms-word": 38, "authorization": 66, "pdo": 330, "object-oriented": 3635, "php5": 235, "array": 2002, "statistics": 368, "thread-safety": 455, "locking": 158, "lock-free": 90, "formatting": 552, "error-handling": 682, "atl": 2, "search": 367, "macros": 109, "cocoa": 34, "functional-programming": 981, "project-euler": 67, "clojure": 312, "fibonacci-sequence": 182, "c": 3544, "oracle": 79, "genetic-algorithm": 62, "hibernate": 74, "android": 827, "sqlite": 211, "comparative-review": 891, "finance": 255, "tic-tac-toe": 569, "console": 712, "datetime": 1009, "spl": 4, "wpf": 331, "file-system": 653, "sql-server": 573, "lookup": 51, "dice": 169, "bash": 756, "http": 398, "shell": 285, "django-template-language": 6, "casting": 97, "graph": 691, "homework": 524, "linked-list": 1008, "enum": 243, "controller": 166, "timer": 293, "proxy": 47, "stream": 303, "raii": 31, "image": 683, "iterator": 397, "sudoku": 181, "tree": 872, "f#": 359, "exception": 122, "sass": 94, "networking": 324, "event-handling": 448, "gui": 290, "queue": 384, "wordpress": 148, "validation": 706, "form": 383, "zend-framework": 26, "primes": 680, "sieve-of-eratosthenes": 176, "mpi": 15, "api": 540, "exception-handling": 100, "serialization": 275, "c++11": 1704, "pathfinding": 231, "integer": 351, "interview-questions": 928, "xslt": 23, "to-do-list": 93, "fizzbuzz": 203, "authentication": 381, "session": 153, "windows": 325, "server": 224, "signal-handling": 12, "jquery-ui": 78, "ienumerable": 16, "swing": 412, "html5": 412, "template": 477, "asynchronous": 518, "hash-map": 836, "extension-methods": 175, "bioinformatics": 140, "json": 846, "erlang": 44, "ip-address": 78, "soap": 15, "groovy": 68, "grails": 9, "recursion": 1065, "asp.net-mvc-3": 65, "database": 528, "vb.net": 342, "flask": 149, "sorting": 1109, "reflection": 229, "assembly": 318, "inheritance": 296, "polymorphism": 117, "python-2.x": 1240, "io": 464, "gwt": 6, "opengl": 141, "expression-trees": 41, "lisp": 195, "quick-sort": 166, "common-lisp": 120, "rss": 35, "combinatorics": 467, "callback": 173, "actionscript-3": 41, "actionscript": 9, "flex": 4, "oauth": 34, "jersey": 5, "entity-framework": 403, "haskell": 1058, "child-process": 96, "game-of-life": 202, "stl": 70, "circular-list": 141, "google-apps-script": 73, "google-maps": 83, "collections": 312, "cryptography": 443, "url-routing": 131, "memory-management": 594, "reinventing-the-wheel": 902, "jekyll": 7, "liquid": 5, "vbscript": 36, "asp-classic": 6, "jython": 12, "floating-point": 178, "asp.net-mvc-2": 6, "palindrome": 242, "ios": 395, "svg": 55, "parsec": 20, "mathematics": 570, "playing-cards": 336, "network-file-transfer": 103, "delegates": 66, "scheme": 176, "sicp": 89, "numerical-methods": 202, "generator": 137, "r": 381, "iteration": 248, "coldfusion": 39, "cfml": 31, "coffeescript": 99, "git": 134, "make": 30, "matrix": 648, "binary-search": 278, "grammar": 29, "factor-lang": 5, "numpy": 695, "bitwise": 311, "linux": 514, "trie": 108, "google-app-engine": 37, "easing": 10, "portability": 86, "utf-8": 37, "matlab": 171, "interval": 155, "clustering": 134, "machine-learning": 245, "tower-of-hanoi": 21, "pointers": 339, "c++-cli": 8, "interpreter": 142, "geospatial": 114, "sh": 85, "mvvm": 174, "collision": 101, "boost": 171, "jdbc": 84, "null": 159, "wolfram-mathematica": 11, "n-queens": 39, "jodatime": 11, "maven": 14, "jaxb": 5, "lambda": 211, "animation": 382, "servlets": 57, "xaml": 96, "symbolic-math": 23, "math-expression-eval": 176, "i18n": 75, "insertion-sort": 73, "compression": 185, "curl": 73, "rock-paper-scissors": 154, "powershell": 263, "razor": 39, "mergesort": 291, "calculator": 526, "closure": 53, "pig-latin": 39, "caesar-cipher": 174, "twisted": 6, "ai": 199, "raytracing": 11, "socket": 316, "brainfuck": 105, "memoization": 82, "poco-libraries": 1, "vba": 1257, "scala": 517, "curses": 40, "makefile": 83, "python-3.x": 4874, "web-scraping": 574, ".net-2.0": 9, "pascal": 25, "dom": 392, "graphics": 210, "pdf": 70, "t-sql": 213, "email": 281, "silverlight": 5, "crud": 54, "wxpython": 9, "chat": 97, "firefox": 8, "natural-language-processing": 113, "modules": 117, "lua": 98, "edit-distance": 111, "sfml": 94, "wcf": 43, "stackexchange": 166, "unicode": 74, "audio": 131, "entity-component-system": 39, "rational-numbers": 68, "opencl": 18, "node.js": 992, "canvas": 181, "matplotlib": 150, "guava": 45, "odbc": 8, "xpath": 35, "asp.net-mvc": 242, "simulation": 329, "scheduled-tasks": 77, "pyramid": 5, "repository": 146, "delphi": 80, "layout": 83, "user-interface": 133, "c++0x": 2, "qml": 7, ".htaccess": 46, "excel": 1070, "namespaces": 27, "static": 113, "processing": 25, "ascii-art": 148, "quiz": 184, "automapper": 6, "tkinter": 317, "google-contacts-api": 3, "video": 66, "lazy": 68, "weak-references": 22, "sql-injection": 82, "stack": 392, "knockout.js": 36, "touch": 8, "file-structure": 48, "jquery-datatables": 12, "snake-game": 161, "go": 568, "qt": 131, "join": 94, "facebook": 41, "state-machine": 133, "time-limit-exceeded": 995, "roman-numerals": 66, "logging": 310, "serial-port": 60, "openssl": 41, "haxe": 6, "tcp": 156, "web-services": 100, "aes": 112, "unix": 121, "sse": 28, "mootools": 3, "factory-method": 146, "number-guessing-game": 163, "pthreads": 76, "state": 65, "system.reactive": 38, "prototypal-class-design": 25, "raphael.js": 12, "dynamic-loading": 55, "neural-network": 114, "sed": 41, "fluent-interface": 44, "mixins": 31, "smart-pointers": 12, "dining-philosophers": 15, "ms-access": 63, "traveling-salesman": 28, "ocaml": 47, "captcha": 21, "bdd": 21, "phpunit": 22, "browser-storage": 17, "pygame": 210, "minesweeper": 78, "rspec": 66, "amazon-s3": 9, "markdown": 47, "xsd": 11, "vectors": 296, "openmp": 45, "jsp": 23, "turtle-graphics": 41, "template-meta-programming": 260, "signal-processing": 87, "operator-overloading": 8, "xna": 28, "minecraft": 33, "concurrency": 430, "hangman": 174, "generics": 467, "lxml": 21, "d": 16, "lex": 7, "yacc": 3, "microdata": 8, "abstract-factory": 34, "properties": 90, "union-find": 29, "plsql": 15, "constants": 43, "monads": 73, "complexity": 399, "role-playing-game": 128, "spring": 222, "active-record": 87, "installer": 69, "rest": 250, "tex": 38, "immutability": 82, "ksh": 5, "zeromq": 7, "bit-twiddling": 3, "promise": 259, "task-parallel-library": 148, "producer-consumer": 88, "library-design": 8, "symfony2": 33, "contest-problem": 10, "cakephp": 22, "breadth-first-search": 177, "c++03": 35, "sync": 8, "gtk": 13, "scrapy": 34, "e-commerce": 101, "sliding-tile-puzzle": 55, "eigen": 16, "actor": 14, "dependency-injection": 263, "awk": 54, "variadic": 76, "jsf": 14, "heap": 162, "wikipedia": 21, "sdl": 87, "dynamic-programming": 284, "client": 106, "racket": 64, "youtube": 39, "localization": 26, "balanced-delimiters": 71, "backbone.js": 71, "change-making-problem": 36, "chess": 174, "prolog": 40, "markov-chain": 38, "vigenere-cipher": 43, "etl": 7, "linq-to-sql": 35, "garbage-collection": 6, "rc": 1, "priority-queue": 69, "cross-browser": 19, "moq": 24, "core-data": 20, "dto": 16, "tetris": 43, "postscript": 7, "type-safety": 119, "atomic": 56, "bluetooth": 13, "autocomplete": 35, "ftp": 35, "sharepoint": 34, "joomla": 3, "instagram": 19, "hashcode": 120, "linkedin": 4, "beautifulsoup": 194, "titanium": 2, "visitor-pattern": 27, "junit": 78, "backtracking": 75, "status-monitoring": 53, "ant": 3, "drupal": 13, "mobile": 28, "phonegap": 2, "compiler": 66, "d3.js": 59, "reddit": 38, "import": 7, "numbers-to-words": 69, "container": 17, "nginx": 21, "tornado": 16, "winapi": 105, "twig": 10, "ado.net": 44, "eclipse": 9, "svn": 6, "physics": 129, "google-chrome": 31, "mongodb": 199, "erb": 10, "simd": 35, "gadt": 2, "base64": 48, "cgi": 10, "fancybox": 3, ".net-datatable": 55, "morse-code": 26, "variant-type": 36, "cellular-automata": 25, "haml": 6, "asp.net-mvc-4": 102, "scalaz": 8, "native-code": 31, "set": 132, "benchmarking": 99, "smarty": 4, "jpa": 35, "battle-simulation": 80, "tk": 11, "cocoa-touch": 17, "number-systems": 107, "meta-programming": 140, "adventure-game": 98, "depth-first-search": 178, "pyqt": 79, "underscore.js": 53, "mustache": 10, "opencv": 117, "struts2": 5, "meteor": 22, "nunit": 32, "async-await": 334, "json.net": 27, "require.js": 25, "wrapper": 140, "cuda": 40, "embedded": 114, "trampoline": 7, "pymongo": 15, "stored-procedure": 78, "monogame": 15, "arduino": 87, "assertions": 20, "vectorization": 96, "twitter-bootstrap": 112, "parallax": 7, "collatz-sequence": 67, "category": 5, "interface": 193, "taxicab-geometry": 7, "raspberry-pi": 77, "redis": 65, "reference": 38, "postsharp": 2, "aspect-oriented": 8, "a-star": 70, "knapsack-problem": 37, "scope": 31, "angular.js": 386, "poco": 3, "tdd": 23, "rhino": 2, "https": 30, "ssl": 17, "ecmascript-6": 658, "directory": 10, "observer-pattern": 97, "express.js": 190, "socket.io": 40, "webdriver": 38, "pawn": 1, "akka": 19, "mediator": 18, "salesforce-apex": 17, "google-drive": 12, "jms": 8, "protocol-buffers": 5, "checkers-draughts": 18, "hdl": 13, "vhdl": 14, "osx": 22, "battleship": 39, "jsf-2": 3, "signalr": 11, "radix-sort": 31, "batch": 61, "less-css": 8, "helper": 24, "jquery-mobile": 4, "location-services": 9, "cassandra": 9, "ember.js": 9, "mocks": 56, "connection-pool": 30, "com": 37, "cobol": 9, "javafx": 141, "covariance": 5, "query-selector": 10, "revealing-module-pattern": 28, "pyglet": 7, "mechatronics": 6, "overloading": 86, "bottle": 8, "snap-framework": 2, "music": 43, "basic-lang": 10, "awt": 29, "cursor": 11, "progress-4gl": 3, "laravel": 257, "hy": 2, "pong": 14, "lodash.js": 57, "elixir": 58, "connect-four": 55, "tcl": 9, "libgdx": 42, "pandas": 528, "iptables": 5, "propel": 4, "device-driver": 44, "c99": 56, "mocha": 19, "google-sheets": 53, "text-editor": 37, "data-importer": 7, "hadoop": 10, "emacs": 1, "zsh": 17, "guice": 4, "data-mining": 40, "cython": 56, "ssh": 55, "divide-and-conquer": 29, "elm": 16, "optional": 47, "vimscript": 18, "typescript": 491, "framework": 60, "kivy": 16, "userscript": 53, "bacon.js": 3, "fpga": 13, "outlook": 19, "dapper": 36, "multiprocessing": 85, "applescript": 12, "kernel": 21, "cyclomatic-complexity": 57, "dojo": 6, "ninject": 23, "ldap": 16, "sqlalchemy": 44, "jasmine": 23, "ebay": 5, "coldfusion-10": 2, "windows-phone": 14, "windows-phone-7": 4, "couchdb": 9, "webgl": 8, "forth": 9, "amp": 4, "visual-studio": 15, "peg.js": 2, "pascal-script": 3, "pubnub": 2, "object-pascal": 9, "community-challenge": 119, "weekend-challenge": 9, "passport": 9, "higher-order-functions": 13, "pokemon": 27, "unity3d": 211, "mapreduce": 27, "memory-optimization": 208, "threadx": 2, "cryptocurrency": 23, "bookmarklet": 10, "apache-spark": 31, "maya": 4, "pyside": 10, "asp.net-mvc-5": 56, "gradle": 8, "ti-basic": 14, "language-design": 29, "raknet": 2, "abap": 2, "scss": 2, "multiton": 1, "bitset": 54, "data-visualization": 151, "dsl": 23, "bloom-filter": 17, "websocket": 51, "ddd": 44, "amazon-web-services": 64, "exercism": 4, "mongoose": 61, "integration-testing": 42, "sandbox": 6, "java-8": 13, "cli": 9, "bitcoin": 12, "kotlin": 230, "yaml": 40, "object": 3, "grunt.js": 6, "c11": 23, "autofac": 18, "enigma-machine": 9, "active-directory": 52, "unit-conversion": 93, "rpython": 2, "fluent-assertions": 5, "asp.net-web-api": 138, "jwt": 29, "bem": 9, "adodb": 24, "mvp": 77, "c++14": 467, "solaris": 1, "roslyn": 17, "dart": 20, "tcsh": 2, "fixed-point": 31, "rtti": 4, "objective-c-runtime": 5, "rust": 621, "google-bigquery": 5, "maxscript": 1, "rabbitmq": 21, "constant-expression": 15, "lucene": 3, "jstl": 2, "gmp": 10, "react.js": 579, "sprite-kit": 20, "rebol": 14, "logo-lang": 4, "lua-table": 13, "gml": 2, "swift": 649, "foundation": 3, "verilog": 23, "escaping": 34, "windows-runtime": 7, "sidekiq": 4, "eloquent": 29, "databinding": 22, "j": 7, "stan": 1, "arcpy": 19, "rags-to-riches": 90, "yii": 6, "trait": 10, "elasticsearch": 25, "selenium": 149, "netty": 8, "polymer": 5, "nim": 6, "transactions": 19, "azure": 31, "css3": 20, "sympy": 11, "tds": 1, "sputnik": 1, "udp": 32, "99-bottles-of-beer": 11, "julia": 46, "async.js": 10, "marionette.js": 5, "grand-central-dispatch": 11, "lolcode": 7, "stylus": 2, "xamarin": 63, "gulp.js": 13, "paper.js": 5, "freetype": 4, "service-broker": 2, "crypto++": 1, "posix": 89, "steganography": 17, "cordova": 13, "scipy": 75, "astropy": 3, "steam": 7, "processing.js": 13, "tpl-dataflow": 16, "ada": 6, "hiveql": 5, "ssis": 5, "firebase": 64, "mithril.js": 1, "rubberduck": 98, "odoo": 7, "internet-explorer": 9, "uikit": 23, "levenshtein-distance": 3, "skip-list": 14, "protocols": 26, "simon-says": 25, "cherrypy": 4, "raku": 15, "sfinae": 17, "db2": 10, "lombok": 12, "gson": 17, "nosql": 12, "google-translate": 6, "susy": 1, "lexer": 26, "siebel-escript": 1, "hlsl": 5, "duck-typing": 7, "parse-platform": 19, "smalltalk": 3, "phalcon": 1, "solidworks": 2, "fractals": 79, "antlr": 18, "vkscript": 1, "cmake": 22, "complex-numbers": 13, "polyglot": 2, "screen-scraping": 3, "sas": 2, "m4": 3, "qunit": 2, "sinatra": 5, "delphi-xe": 4, "c++1z": 2, "asm.js": 1, "sql-dependency": 1, "reactive-cocoa": 5, "frp": 6, "c++17": 509, "rx-java": 33, "jsx": 192, "cql": 1, "purescript": 4, "lc-3": 4, "securestring": 10, "numba": 23, "casper.js": 4, "wildfly": 1, "n-tier": 8, "ebnf": 1, "glsl": 13, "jade": 3, "unity-container": 14, "jni": 9, "subset-sum-problem": 4, "informix": 4, "automation": 26, "zephir": 1, "tis-100": 4, "neo4j": 5, "cypher": 2, "phoenix-framework": 5, "plpgsql": 5, "bluemix": 2, "thundercats": 2, "google-cloud-platform": 6, "autohotkey": 4, "powerpoint": 12, "fltk": 8, "khronos": 3, "blockchain": 13, "io-lang": 1, "chuck": 1, "ramda.js": 12, "slim": 19, "telegram": 18, "babel.js": 8, "ipc": 10, "owin": 6, "rebol2": 4, "whitespace-lang": 1, "idris": 5, "spring-mvc": 46, "amd64": 7, "scratch": 3, "arnoldc": 1, "agda": 1, "k-sum": 46, "dogescript": 1, "typo3": 1, "reactive-programming": 26, "clojurescript": 19, "redux": 87, "dockerfile": 19, "lexical-analysis": 35, "cucumber": 5, "unrealscript": 1, "linq-expressions": 15, "angular-bootstrap": 6, "virtual-machine": 31, "macos": 34, "angular-2+": 163, "rake": 1, "fish-lang": 1, "octave": 9, "pyth": 2, "snobol4": 1, "simulink": 1, "automake": 1, "drools": 2, "uwp": 20, "memcache": 5, "heap-sort": 22, "rxjs": 59, "specflow": 2, "symfony3": 11, "swift3": 44, "haproxy": 1, "react-native": 40, "phpdocx": 1, "qasm": 1, "quantum-computing": 2, "haxeflixel": 2, "vue.js": 97, "2048": 25, "asp.net-core": 150, "vscode": 1, "gcc": 20, "adobe-illustrator": 2, "stratifiedjs": 1, "wren": 1, "orm": 20, "immutable.js": 4, "ansible": 13, "puppet": 1, "lwjgl": 3, "tensorflow": 41, "t4": 1, "raycasting": 8, "wmi-query": 3, "rx-swift": 11, "robotframework": 9, "mongodb-query": 2, "just-mock": 2, "eslint": 7, "protractor": 2, "coq": 4, "libgit2sharp": 1, "ebuild": 1, "pari-gp": 1, "ibm-rpg": 3, "php7": 9, "electron": 10, "function-block-diagram": 2, "entity-framework-core": 42, "docker": 19, "befunge": 2, "crystal": 1, "sml": 2, "solidity": 6, "apache-kafka": 9, "lilypond": 1, "s3": 3, "simple-injector": 3, ".net-core": 78, "asp.net-identity": 7, "phaser.io": 3, "x86": 64, "leaflet": 9, "axios": 18, "firefox-webextensions": 3, "c++98": 10, "c++20": 145, "netlogo": 1, "structured-text": 1, "mef": 2, "x11": 8, "asp.net-core-webapi": 12, "freezeflame": 2, "jison": 1, "october-cms": 1, "log4net": 2, "c89": 14, "pytorch": 18, "glib": 1, "rcpp": 16, "nasm": 18, "flutter": 12, "nuget": 1, "logback": 1, "azure-cosmosdb": 4, "ecmascript-8": 19, "thymeleaf": 3, "dxl": 1, "red-lang": 1, "interactive-data-language": 1, "allegro": 1, "xquery": 1, "autoit": 1, "xunit": 12, "ddt": 1, "grpc": 3, "visual-foxpro": 1, "mobx": 3, "railway-oriented": 1, "parquet": 3, "linny": 2, "circuit-python": 1, "razor-pages": 2, "doxygen": 1, "cors": 3, "indexeddb": 3, "typesetting": 1, "swiftui": 14, "bucket-sort": 3, "mmap": 2, "latex": 4, "symfony4": 6, "timsort": 1, "emulator": 3, "graphql": 8, "webassembly": 6, "shaders": 6, "2sum": 2, "asp.net-core-3.0": 2, "dax": 1, "m-code": 1, "nodatime": 2, "binary": 14, "c-preprocessor": 3, "binary-tree": 25, "bigint": 9, "fold": 3, "tail-recursion": 2, "apache-beam": 1, "dijkstra": 7, "encryption": 27, "sdl2": 6, "apl": 13, "covid-19": 7, "jsoup": 1, "strategy-pattern": 7, "binary-search-tree": 17, "topological-sorting": 6, "jupyter": 5, "keras": 6, "lstm": 2, "object-detection": 4, "fish-shell": 1, "alloy": 2, "vim": 4, "fxml": 1, "expect": 1, "color": 6, "blazor": 2, "handlebars": 2, "pattern-matching": 14, "nested": 9, "bootstrap-4": 5, "byte": 5, "hash": 10, "heuristic": 5, "encoding": 7, ".net-5": 8, "arm": 4, "argo-workflows": 1, "discord": 8, "pydantic": 2, "interpolation": 3, "nullable-reference-types": 1, "v-language": 1, "classification": 3, "openscad": 1, "asyncio": 5, "discord.py": 2, "spotipy": 1, "ncurses": 1, "listview": 1, "semaphore": 1, "google-api": 2, "kubernetes": 1, "peewee": 1, "vyxal": 2, "polly": 1, "webcomponent": 1, "xcb": 1, "widget": 1, "file-archive": 1, "bootstrap": 1, "postfix": 1, "linear-algebra": 1, "prbs": 1, "systemverilog": 2, "rtl": 1} -------------------------------------------------------------------------------- /t0_finetune.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # coding=utf-8 4 | # Copyright BigScience, The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | Fine-tuning T0 in PyTorch, optionally few-shot. 19 | 20 | This script is adapted from 21 | https://github.com/huggingface/transformers/blob/master/examples/pytorch/multiple-choice/run_swag_no_trainer.py 22 | as well as 23 | https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization_no_trainer.py 24 | """ 25 | 26 | import argparse 27 | import logging 28 | import os 29 | import random 30 | from dataclasses import dataclass 31 | from itertools import chain 32 | from typing import Optional, Union 33 | import csv 34 | import math 35 | 36 | import datasets 37 | import torch 38 | from datasets import load_dataset, load_metric 39 | from torch.utils.data import DataLoader 40 | from tqdm.auto import tqdm 41 | 42 | import transformers 43 | from accelerate import Accelerator 44 | from transformers import ( 45 | AutoConfig, 46 | AutoModelForSeq2SeqLM, 47 | AutoTokenizer, 48 | PreTrainedTokenizerBase, 49 | default_data_collator, 50 | DataCollatorForSeq2Seq, 51 | AdamW, 52 | SchedulerType, 53 | get_scheduler, 54 | set_seed, 55 | ) 56 | from transformers.file_utils import PaddingStrategy 57 | from promptsource.templates import DatasetTemplates 58 | 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | 63 | def parse_args(): 64 | parser = argparse.ArgumentParser(description="Fine-tuning T0 in PyTorch, optionally few-shot.") 65 | parser.add_argument( 66 | "-d", 67 | "--dataset_name", 68 | type=str, 69 | default=None, 70 | required=True, 71 | help="The name of the dataset to use (via the datasets library).", 72 | ) 73 | parser.add_argument( 74 | "-s", 75 | "--dataset_config_name", 76 | type=str, 77 | default=None, 78 | help="The configuration name (usually a subset) of the dataset to use (via the datasets library).", 79 | ) 80 | parser.add_argument( 81 | "-t", 82 | "--template_name", 83 | type=str, 84 | default=None, 85 | required=True, 86 | help="The template/prompt name in `promptsource`.", 87 | ) 88 | parser.add_argument( 89 | "-o", 90 | "--output_dir", 91 | type=str, 92 | default=None, 93 | required=True, 94 | help="Where to store the results CSV and (TODO) optionally the final model." 95 | ) 96 | parser.add_argument( 97 | "-m", 98 | "--model_name_or_path", 99 | type=str, 100 | required=True, 101 | help=( 102 | "Path to pretrained model or model identifier from huggingface.co/models. " 103 | "The list of T0 variants can be found on `https://huggingface.co/bigscience/T0_3B`" 104 | ), 105 | ) 106 | parser.add_argument( 107 | "-pa", 108 | "--parallelize", 109 | action="store_true", 110 | help=( 111 | "If passed, will call `model.parallelize` which splits the model on all GPUs available (model parallelism). " 112 | "Note that this feature is still experimental in HF Transformers." 113 | ), 114 | ) 115 | parser.add_argument( 116 | "-eb", 117 | "--per_device_eval_batch_size", 118 | type=int, 119 | default=8, 120 | help="Batch size (per device) for the evaluation dataloader. Will be multiplied by the number of answer choices.", 121 | ) 122 | parser.add_argument( 123 | "-tb", 124 | "--per_device_train_batch_size", 125 | type=int, 126 | default=4, 127 | help="Batch size (per device) for the training dataloader.", 128 | ) 129 | parser.add_argument( 130 | "-ns", 131 | "--num_shots", 132 | type=int, 133 | default=None, 134 | help="Number of training examples for few-shot learning. Default is None, which uses the entire train set.", 135 | ) 136 | parser.add_argument( 137 | "-lr", 138 | "--learning_rate", 139 | type=float, 140 | default=1e-4, 141 | help="Initial learning rate (after the potential warmup period) to use.", 142 | ) 143 | parser.add_argument( 144 | "-ep", 145 | "--num_train_epochs", 146 | type=int, 147 | default=10, 148 | help="Total number of training epochs to perform." 149 | ) 150 | parser.add_argument( 151 | "-ms", 152 | "--max_train_steps", 153 | type=int, 154 | default=None, 155 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 156 | ) 157 | parser.add_argument( 158 | "-ga", 159 | "--gradient_accumulation_steps", 160 | type=int, 161 | default=1, 162 | help="Number of updates steps to accumulate before performing a backward/update pass.", 163 | ) 164 | parser.add_argument( 165 | "-ie", 166 | "--input_eos", 167 | action="store_true", 168 | help=( 169 | "T0 was trained without EOS in its input sequences, which is the default in this script." 170 | "However, T5 was pretrained with EOS in its input sequences. See README for more info." 171 | ), 172 | ) 173 | parser.add_argument( 174 | "-db", 175 | "--debug", 176 | action="store_true", 177 | help="Activate debug mode and run training only with a subset of data.", 178 | ) 179 | parser.add_argument( 180 | "-wb", 181 | "--wandb_proj", 182 | type=str, 183 | default=None, 184 | help="Project name for Weights & Biases. By default, W&B is disabled.", 185 | ) 186 | parser.add_argument( 187 | "-sd", 188 | "--seed", 189 | type=int, 190 | default=42, 191 | help="Especially important for few-shot example sampling.", 192 | ) 193 | parser.add_argument( 194 | "-cf", 195 | "--config_name", 196 | type=str, 197 | default=None, 198 | help="Pretrained config name or path if not the same as model_name", 199 | ) 200 | parser.add_argument( 201 | "-tk", 202 | "--tokenizer_name", 203 | type=str, 204 | default=None, 205 | help="Pretrained tokenizer name or path if not the same as model_name", 206 | ) 207 | parser.add_argument( 208 | "-il", 209 | "--max_length", 210 | type=int, 211 | default=1024, 212 | help=( 213 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 214 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 215 | ), 216 | ) 217 | parser.add_argument( 218 | "-tl", 219 | "--target_max_length", 220 | type=int, 221 | default=256, 222 | help="Target max length. Sequences longer than this will be truncated." 223 | ) 224 | parser.add_argument( 225 | "-pml", 226 | "--pad_to_max_length", 227 | action="store_true", 228 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 229 | ) 230 | parser.add_argument( 231 | "-st", 232 | "--use_slow_tokenizer", 233 | action="store_true", 234 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 235 | ) 236 | parser.add_argument( 237 | "-wd", 238 | "--weight_decay", 239 | type=float, 240 | default=0.01, 241 | help="Weight decay for the AdamW optimizer." 242 | ) 243 | parser.add_argument( 244 | "-ls", 245 | "--lr_scheduler_type", 246 | type=SchedulerType, 247 | default="linear", 248 | help="The scheduler type to use.", 249 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 250 | ) 251 | parser.add_argument( 252 | "-ws", 253 | "--num_warmup_steps", 254 | type=int, 255 | default=0, 256 | help="Number of steps for the warmup in the lr scheduler." 257 | ) 258 | args = parser.parse_args() 259 | 260 | return args 261 | 262 | 263 | @dataclass 264 | class DataCollatorForMultipleChoice: 265 | """ 266 | Data collator that will dynamically pad the inputs for multiple choice received. 267 | 268 | Args: 269 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 270 | The tokenizer used for encoding the data. 271 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 272 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 273 | among: 274 | 275 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 276 | sequence if provided). 277 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 278 | maximum acceptable input length for the model if that argument is not provided. 279 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 280 | different lengths). 281 | max_length (:obj:`int`, `optional`): 282 | Maximum length of the returned list and optionally padding length (see above). 283 | pad_to_multiple_of (:obj:`int`, `optional`): 284 | If set will pad the sequence to a multiple of the provided value. 285 | 286 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 287 | 7.5 (Volta). 288 | Note that it's very NOT recommended to use fp16 to do any time of inference with T0 as the predictions will vastly differ from the predictions using fp32. 289 | """ 290 | 291 | tokenizer: PreTrainedTokenizerBase 292 | padding: Union[bool, str, PaddingStrategy] = True 293 | max_length: Optional[int] = None 294 | pad_to_multiple_of: Optional[int] = None 295 | 296 | def __call__(self, features): 297 | num_choices = len(features[0]["input_ids"]) 298 | flattened_features = [ 299 | [ 300 | { 301 | k: v[i] 302 | for k, v in feature.items() 303 | if k != "targets" 304 | } 305 | for i in range(num_choices) 306 | ] 307 | for feature in features 308 | ] 309 | flattened_features = list(chain(*flattened_features)) 310 | 311 | batch = self.tokenizer.pad( 312 | flattened_features, 313 | padding=self.padding, 314 | max_length=self.max_length, 315 | pad_to_multiple_of=self.pad_to_multiple_of, 316 | ) 317 | 318 | # Pad the labels because it's not padded automatically 319 | max_label_length = max([len(elem["labels"]) for elem in flattened_features]) 320 | batch["labels"] = [ 321 | l + [self.tokenizer.pad_token_id]*(max_label_length - len(l)) 322 | for l in [elem["labels"] for elem in flattened_features] 323 | ] 324 | batch["labels_attention_mask"] = [ 325 | m + [0]*(max_label_length - len(m)) 326 | for m in [elem["labels_attention_mask"] for elem in flattened_features] 327 | ] 328 | 329 | # Convert to tensors 330 | batch = { 331 | k: torch.tensor(v) 332 | for k, v in batch.items() 333 | } 334 | 335 | batch["targets"] = torch.tensor([f.pop("targets") for f in features]) 336 | return batch 337 | 338 | 339 | def main(): 340 | args = parse_args() 341 | set_seed(args.seed) 342 | 343 | # Initialize the accelerator. We will let the accelerator handle device placement for us. 344 | accelerator = Accelerator() 345 | # Make one log on every process with the configuration for debugging. 346 | logging.basicConfig( 347 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 348 | datefmt="%m/%d/%Y %H:%M:%S", 349 | level=logging.INFO, 350 | ) 351 | logger.info(accelerator.state) 352 | 353 | # Setup logging, we only want one process per machine to log things on the screen. 354 | # accelerator.is_local_main_process is only True for one process per machine. 355 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) 356 | if accelerator.is_local_main_process: 357 | datasets.utils.logging.set_verbosity_warning() 358 | transformers.utils.logging.set_verbosity_info() 359 | else: 360 | datasets.utils.logging.set_verbosity_error() 361 | transformers.utils.logging.set_verbosity_error() 362 | 363 | 364 | # Handle the output directory creation 365 | if accelerator.is_main_process: 366 | os.makedirs(args.output_dir, exist_ok=True) 367 | accelerator.wait_for_everyone() 368 | 369 | # In distributed evaluation, the load_dataset function guarantee that only one local process can concurrently 370 | # download the dataset. 371 | if args.dataset_name is not None: 372 | # Downloading and loading a dataset from the hub. 373 | if args.dataset_name == "anli": 374 | raw_train_dataset = load_dataset(args.dataset_name, split=f'train_{args.dataset_config_name}') # dataset_config_name = "r1", "r2", or "r3" 375 | raw_eval_dataset = load_dataset(args.dataset_name, split=f'dev_{args.dataset_config_name}') 376 | else: 377 | raw_train_dataset = load_dataset(args.dataset_name, args.dataset_config_name, split="train") 378 | raw_eval_dataset = load_dataset(args.dataset_name, args.dataset_config_name, split="validation") 379 | else: 380 | raise ValueError('Please specify `args.dataset_name` and `args.dataset_config_name` as appear in `promptsource`.') 381 | #TODO(Victor): enable loading pre-processed dataset from https://huggingface.co/datasets/bigscience/P3 382 | 383 | # Trim a number of evaluation examples 384 | if args.debug: 385 | raw_train_dataset = raw_train_dataset.select(range(min(100, len(raw_train_dataset)))) 386 | raw_eval_dataset = raw_eval_dataset.select(range(min(100, len(raw_eval_dataset)))) 387 | 388 | column_names = raw_eval_dataset.column_names 389 | 390 | 391 | # Load pretrained model and tokenizer 392 | # 393 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 394 | # download model & vocab. 395 | if args.config_name: 396 | config = AutoConfig.from_pretrained(args.config_name) 397 | elif args.model_name_or_path: 398 | config = AutoConfig.from_pretrained(args.model_name_or_path) 399 | else: 400 | raise ValueError( 401 | "Either `args.config_name` or `args.model_name_or_path` should be provided." 402 | ) 403 | 404 | if args.tokenizer_name: 405 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) 406 | elif args.model_name_or_path: 407 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) 408 | else: 409 | raise ValueError( 410 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 411 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 412 | ) 413 | 414 | if args.model_name_or_path: 415 | model = AutoModelForSeq2SeqLM.from_pretrained( 416 | args.model_name_or_path, 417 | from_tf=bool(".ckpt" in args.model_name_or_path), 418 | config=config, 419 | ) 420 | else: 421 | logger.info("Training new model from scratch") 422 | model = AutoModelForSeq2SeqLM.from_config(config) 423 | 424 | # Preprocessing the datasets. 425 | # First we tokenize all the texts. 426 | padding = "max_length" if args.pad_to_max_length else False 427 | 428 | # Get the prompt to apply and the possible targets. 429 | # TODO(Victor): If pulling from pre-processed data, remove this logic. 430 | if args.dataset_name == 'anli': 431 | prompts = DatasetTemplates('anli', None) 432 | else: 433 | prompts = DatasetTemplates( 434 | f"{args.dataset_name}" 435 | if args.dataset_config_name is None 436 | else f"{args.dataset_name}/{args.dataset_config_name}" 437 | ) 438 | template = prompts[args.template_name] 439 | 440 | def preprocess_train(examples): 441 | bs = len(examples[column_names[0]]) 442 | 443 | input_texts = [] 444 | target_texts = [] 445 | for i in range(bs): 446 | ex = { 447 | k: examples[k][i] 448 | for k in column_names 449 | } 450 | input, target = template.apply(ex) 451 | ex_answer_choices = template.get_answer_choices_list(ex) 452 | assert target in ex_answer_choices 453 | input_texts.append(input) 454 | target_texts.append(target) 455 | 456 | model_inputs = tokenizer( 457 | input_texts, 458 | padding=padding, 459 | max_length=args.max_length, 460 | truncation=True, 461 | add_special_tokens=args.input_eos, 462 | ) 463 | 464 | with tokenizer.as_target_tokenizer(): 465 | tokenized_targets = tokenizer( 466 | target_texts, 467 | padding=padding, 468 | max_length=args.target_max_length, 469 | truncation=True, 470 | add_special_tokens=False, 471 | ) 472 | model_inputs['labels'] = [ 473 | [(t if t != tokenizer.pad_token_id else -100) for t in targets] 474 | for targets in tokenized_targets["input_ids"] 475 | ] 476 | return model_inputs 477 | 478 | def preprocess_eval(examples): 479 | bs = len(examples[column_names[0]]) 480 | 481 | input_texts = [] 482 | target_texts = [] 483 | answer_choices_texts = [] 484 | for i in range(bs): 485 | ex = { 486 | k: examples[k][i] 487 | for k in column_names 488 | } 489 | input, target = template.apply(ex) 490 | ex_answer_choices = template.get_answer_choices_list(ex) 491 | assert target in ex_answer_choices 492 | input_texts.append(input) 493 | target_texts.append(target) 494 | answer_choices_texts.append(ex_answer_choices) 495 | 496 | tokenized_inputs = tokenizer( 497 | input_texts, 498 | padding=padding, 499 | max_length=args.max_length, 500 | truncation=True, 501 | add_special_tokens=False, 502 | ) 503 | tokenized_targets = [ 504 | tokenizer( 505 | ans_choi, 506 | padding=True, 507 | max_length=args.target_max_length, 508 | truncation=True, 509 | ) 510 | for ans_choi in answer_choices_texts 511 | ] 512 | 513 | features = { 514 | k: [ 515 | [elem for _ in range(len(tokenized_targets[idx]["input_ids"]))] 516 | for idx, elem in enumerate(v) 517 | ] 518 | for k, v in tokenized_inputs.items() 519 | } 520 | 521 | features["labels"] = [ 522 | tokenized_targets[idx]["input_ids"] 523 | for idx in range(bs) 524 | ] 525 | features["labels_attention_mask"] = [ 526 | tokenized_targets[idx]["attention_mask"] 527 | for idx in range(bs) 528 | ] 529 | features["targets"] = [ 530 | answer_choices_texts[idx].index(t) 531 | for idx, t in enumerate(target_texts) 532 | ] 533 | 534 | return features 535 | 536 | with accelerator.main_process_first(): 537 | eval_dataset = raw_eval_dataset.map(preprocess_eval, batched=True, remove_columns=column_names) 538 | 539 | if args.num_shots is not None: 540 | sample_indices = random.sample(range(0, len(raw_train_dataset)), k=args.num_shots) 541 | raw_train_dataset = raw_train_dataset.select(sample_indices) 542 | train_dataset = raw_train_dataset.map(preprocess_train, batched=True, remove_columns=column_names) 543 | 544 | # Log a few random examples: 545 | for index in random.sample(range(len(train_dataset)), 3): 546 | logger.debug(f"Sample {index} of the training set: {train_dataset[index]}.") 547 | for index in random.sample(range(len(eval_dataset)), 3): 548 | logger.debug(f"Sample {index} of the evaluation set: {eval_dataset[index]}.") 549 | 550 | # DataLoaders creation: 551 | train_collator = DataCollatorForSeq2Seq( 552 | tokenizer, 553 | model=model, 554 | label_pad_token_id=-100, 555 | pad_to_multiple_of=8 if accelerator.use_fp16 else None 556 | ) 557 | train_dataloader = DataLoader( 558 | train_dataset, 559 | shuffle=True, 560 | collate_fn=train_collator, 561 | batch_size=args.per_device_train_batch_size 562 | ) 563 | 564 | if args.pad_to_max_length: 565 | # If padding was already done ot max length, we use the default data collator that will just convert everything 566 | # to tensors. 567 | eval_collator = default_data_collator 568 | else: 569 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 570 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 571 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 572 | eval_collator = DataCollatorForMultipleChoice( 573 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None) 574 | ) 575 | eval_dataloader = DataLoader(eval_dataset, collate_fn=eval_collator, batch_size=args.per_device_eval_batch_size) 576 | 577 | # Optimizer 578 | # Split weights in two groups, one with weight decay and the other not. 579 | no_decay = ["bias", "LayerNorm.weight"] 580 | optimizer_grouped_parameters = [ 581 | { 582 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 583 | "weight_decay": args.weight_decay, 584 | }, 585 | { 586 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 587 | "weight_decay": 0.0, 588 | }, 589 | ] 590 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 591 | 592 | # Scheduler and math around the number of training steps. 593 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 594 | if args.max_train_steps is None: 595 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 596 | else: 597 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 598 | 599 | lr_scheduler = get_scheduler( 600 | name=args.lr_scheduler_type, 601 | optimizer=optimizer, 602 | num_warmup_steps=args.num_warmup_steps, 603 | num_training_steps=args.max_train_steps, 604 | ) 605 | 606 | if args.parallelize: 607 | num_gpus = torch.cuda.device_count() 608 | assert num_gpus > 1, "You need at least 2 GPUs to use `model.parallelize()`." 609 | model.parallelize() 610 | optimizer, train_dataloader, eval_dataloader = accelerator.prepare( 611 | optimizer, train_dataloader, eval_dataloader) 612 | else: 613 | model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( 614 | model, optimizer, train_dataloader, eval_dataloader) 615 | 616 | # Metrics 617 | metric = load_metric("accuracy") 618 | 619 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 620 | logger.info("***** Running training *****") 621 | logger.info(f" Num examples = {len(train_dataset)}") 622 | logger.info(f" Num Epochs = {args.num_train_epochs}") 623 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 624 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 625 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 626 | logger.info(f" Total optimization steps = {args.max_train_steps}") 627 | # Only show the progress bar once on each machine. 628 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 629 | global_steps = 0 630 | 631 | if args.wandb_proj and accelerator.is_main_process: 632 | import wandb 633 | extra_metadata = { 634 | 'template_jinja': template.jinja, 635 | 'template_answer_choices': template.answer_choices, 636 | 'template_reflects_original_task': template.metadata.original_task, 637 | 'template_choices_in_prompt': template.metadata.choices_in_prompt, 638 | 'template_comment': template.reference, 639 | } 640 | run_config = vars(args) 641 | run_config.update(extra_metadata) 642 | wandb.init( 643 | project=args.wandb_proj, 644 | config=run_config, 645 | # name=f'S{len(train_set)} {args.template_name} R{args.seed}', # uncomment to customize each run's name 646 | # reinit=True, # uncomment if running multiple runs in one script 647 | ) 648 | 649 | result_table = [] 650 | for epoch in range(1, args.num_train_epochs+1): 651 | model.train() 652 | for step, batch in enumerate(train_dataloader): 653 | outputs = model(**batch) 654 | loss = outputs.loss 655 | loss = loss / args.gradient_accumulation_steps 656 | accelerator.backward(loss) 657 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 658 | optimizer.step() 659 | lr_scheduler.step() 660 | optimizer.zero_grad() 661 | progress_bar.update(1) 662 | global_steps += 1 663 | loss = loss.item() 664 | if accelerator.is_main_process: 665 | tqdm.write(f"epoch = {epoch}, step = {global_steps}, loss = {loss}") 666 | if args.wandb_proj and accelerator.is_main_process: 667 | wandb.log({"loss": loss}, step=global_steps) 668 | 669 | if global_steps >= args.max_train_steps: 670 | break 671 | 672 | # Evaluate every epoch 673 | total_batch_size = args.per_device_eval_batch_size * accelerator.num_processes 674 | logger.info("***** Running evaluation *****") 675 | logger.info(f" Num examples = {len(eval_dataset)}") 676 | logger.info(f" Instantaneous batch size per device = {args.per_device_eval_batch_size}") 677 | logger.info(f" Total eval batch size (w. parallel, distributed) = {total_batch_size}") 678 | # Only show the progress bar once on each machine. # NOTE commented out to avoid nested pbar mess 679 | # progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process) 680 | 681 | model.eval() 682 | for batch in eval_dataloader: 683 | model_inputs = { 684 | k: batch[k] 685 | for k in ["input_ids", "attention_mask", "labels"] 686 | } 687 | with torch.no_grad(): 688 | logits = model(**model_inputs).logits 689 | masked_log_probs = batch["labels_attention_mask"].unsqueeze(-1) * torch.log_softmax(logits, dim=-1) 690 | seq_token_log_probs = torch.gather(masked_log_probs, -1, batch["labels"].unsqueeze(-1)) 691 | seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1) 692 | seq_log_prob = seq_log_prob.view(batch["targets"].size(0), -1) #TODO(Victor): this reshapes works based on the assumption that all examples have the same number of choices. the pre-processing doesn't make this assumption. 693 | predictions = seq_log_prob.argmax(dim=-1) 694 | 695 | metric.add_batch( 696 | predictions=accelerator.gather(predictions), 697 | references=accelerator.gather(batch["targets"]), 698 | ) 699 | 700 | # progress_bar.update(1) 701 | 702 | eval_metric = metric.compute() 703 | score = eval_metric["accuracy"] # TODO support other metrics; currently hardcoded at load_metric() anyway 704 | accelerator.print(f"Accuracy: {score}") 705 | result_table.append({ 706 | "dataset_name": args.dataset_name, 707 | "dataset_config_name": args.dataset_config_name, 708 | "template_name": args.template_name, 709 | "epoch": epoch, 710 | "step": global_steps, 711 | "metric": 'accuracy', 712 | "score": score, 713 | }) 714 | if args.wandb_proj and accelerator.is_main_process: 715 | wandb.log({"accuracy": score}, step=global_steps) 716 | # End training loop 717 | 718 | if accelerator.is_main_process: 719 | if args.output_dir is not None: 720 | with open(os.path.join(args.output_dir, "results.csv"), "w") as f: 721 | writer = csv.DictWriter(f, fieldnames=result_table[0].keys()) 722 | writer.writeheader() 723 | writer.writerows(result_table) 724 | 725 | if args.wandb_proj: 726 | wandb.finish() 727 | 728 | 729 | if __name__ == "__main__": 730 | main() -------------------------------------------------------------------------------- /trace_data/inject_locals.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import string 3 | 4 | def get_indentation(line: str) -> str: 5 | return line[:len(line) - len(line.lstrip())] 6 | 7 | def inject_locals_call(offset): 8 | return offset+'print(post_process_output(locals()))\n' 9 | 10 | def get_post_processing_function(): 11 | return """ 12 | def post_process_output(local_context): 13 | keys_to_ignore = [ 14 | '__name__', 15 | '__doc__', 16 | '__package__', 17 | '__loader__', 18 | '__spec__', 19 | '__annotations__', 20 | '__builtins__', 21 | '__file__', 22 | '__cached__' 23 | ] 24 | return { 25 | k:v for k,v in local_context.items() 26 | if k not in keys_to_ignore 27 | } 28 | 29 | """ 30 | 31 | def inject(file_name): 32 | post_processing_function = get_post_processing_function() 33 | new_python_code = [line+"\n" for line in post_processing_function.split('\n')] 34 | with open(file_name, "r") as f: 35 | python_code = f.readlines() 36 | for line in python_code: 37 | if line.lstrip() and not ( 38 | line.lstrip().startswith("def") or line.lstrip().startswith("class") 39 | or line.lstrip().startswith("elif") or line.lstrip().startswith("else") 40 | ): 41 | offset = get_indentation(line) 42 | new_line = inject_locals_call(offset) 43 | new_python_code.append(new_line) 44 | new_python_code.append(line) 45 | new_python_code.append("post_process_output(locals())\n") 46 | 47 | new_file_name = file_name.split(".")[0] + "_injected.py" 48 | with open(new_file_name, "w") as f: 49 | new_code = "".join(new_python_code) 50 | f.write(new_code) 51 | 52 | 53 | if __name__ == '__main__': 54 | file_name = sys.argv[1] 55 | inject(file_name) 56 | -------------------------------------------------------------------------------- /trace_data/process_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from inject_locals import inject 4 | 5 | folder = sys.argv[1] 6 | for root, dirs, files in os.walk(folder): 7 | for name in files: 8 | inject(os.path.join(root, name)) 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/CodeReviewSE/ae4f34d056084a235bab77340987e5c622fa9fac/utils/__init__.py -------------------------------------------------------------------------------- /utils/parser/build_parser.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from tree_sitter import Language, Parser 3 | import logging 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | def build_parser_in_tmp_dir(lang:str): 8 | subprocess.call("mkdir tmp", shell=True) 9 | subprocess.call("cd tmp", shell=True) 10 | subprocess.call(f"git clone https://github.com/tree-sitter/tree-sitter-{lang} tmp/tree-sitter-{lang}", shell=True) 11 | Language.build_library('build/my-languages.so',[f'./tmp/tree-sitter-{lang}']) 12 | logging.info("Successfully built the parser") 13 | subprocess.call("rm -rf tmp/", shell=True) 14 | 15 | 16 | 17 | def load_parser(lang:str): 18 | """ 19 | Function to load a parser given it's language identifier. 20 | """ 21 | language = Language(f'./build/my-languages.so',lang) 22 | parser = Parser() 23 | parser.set_language(language) 24 | return parser 25 | 26 | 27 | def check_parseability(parser:Parser, code:str): 28 | """ 29 | Function to check if a code snippet is parseable by a given parser. 30 | returns True if the code string is parsable, False otherwise. 31 | """ 32 | tree = parser.parse(bytes(code,"utf-8")) 33 | if tree.root_node.children[0].type == "ERROR": 34 | return False 35 | else: 36 | return True 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | build_parser_in_tmp_dir("javascript") 42 | # parser = load_parser("javascript") 43 | # print(check_parseability(parser,"var 1 = 1;")) -------------------------------------------------------------------------------- /utils/parser/lang.json: -------------------------------------------------------------------------------- 1 | { 2 | "lang" : [ 3 | "python", 4 | "java", 5 | "javascript" 6 | ] 7 | } -------------------------------------------------------------------------------- /xml_to_json.py: -------------------------------------------------------------------------------- 1 | import xmltodict 2 | import json 3 | 4 | src = ["Badges","Comments","PostHistory","PostLinks","Posts","Tags","Users","Votes"] 5 | 6 | for s in src: 7 | with open(s+".xml") as xml_file: 8 | data_dict = xmltodict.parse(xml_file.read()) 9 | 10 | json_data = json.dumps(data_dict) 11 | 12 | with open(s+".json", "w") as json_file: 13 | json_file.write(json_data) 14 | --------------------------------------------------------------------------------