├── .gitignore ├── LICENSE ├── README.md ├── arguments.py ├── models ├── __init__.py ├── data_utils │ ├── __init__.py │ └── data_utils.py ├── model.py ├── model_utils │ ├── __init__.py │ ├── logger.py │ └── supervisor.py └── modules │ ├── __init__.py │ └── mlp.py ├── plot_sample_extraction.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.DS_Store 3 | /.idea/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xinyun Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PlotCoder: Hierarchical Decoding for Synthesizing Visualization Code in Programmatic Context 2 | 3 | This repo provides the code to replicate the experiments in the [paper](https://aclanthology.org/2021.acl-long.169/) 4 | 5 | > Xinyun Chen, Linyuan Gong, Alvin Cheung, Dawn Song, PlotCoder: Hierarchical Decoding for Synthesizing Visualization Code in Programmatic Context, in ACL 2021. 6 | 7 | ## Prerequisites 8 | 9 | Download the JuiCe dataset [here](https://github.com/rajasagashe/juice). 10 | 11 | The code is runnable with Python 3, PyTorch 0.4.1. 12 | 13 | ## Data preprocessing 14 | 15 | To extract the plot generation samples from the entire JuiCe dataset, run ``plot_sample_extraction.py``. 16 | 17 | Note that for model training and evaluation, we may further filter out some samples from the datasets extracted here. But we would like to keep these copies that include the maximal number of plot samples we can extract, so that we no longer need to enumerate the entire JuiCe dataset afterwards. 18 | 19 | ### Key arguments 20 | 21 | In the following we list some important arguments for data preprocessing: 22 | * `--data_folder`: path to the directory that stores the data. 23 | * `--prep_train_data_name`: filename of the plot generation samples for training. Note that it does not include all plot samples in the original JuiCe training split: some of them are merged into the hard dev/test splits. To build the training set, make sure that this filename is not `None`. 24 | * `--prep_dev_data_name`, `--prep_test_data_name`: filename of the plot generation samples filtered from the original dev/test splits of JuiCe. To preprocess each split of the data, make sure that the corresponding filename is not `None`. 25 | * `--prep_dev_hard_data_name`, `--prep_test_hard_data_name`: filename of the homework or exam solutions extracted from the original training split of JuiCe. These are larger-scale sets for evaluation. 26 | * `--build_vocab`: set it to be `True` for building the vocabularies of natural language words and code tokens. 27 | 28 | ## Run experiments 29 | 30 | 1. To run the hierarchical model: 31 | 32 | `python run.py --nl --use_comments --code_context --nl_code_linking --copy_mechanism --hierarchy --target_code_transform` 33 | 34 | 2. To run the non-hierarchical model with the copy mechanism: 35 | 36 | `python run.py --nl --use_comments --code_context --nl_code_linking --copy_mechanism --target_code_transform` 37 | 38 | 3. To run the LSTM decoder without the copy mechanism, i.e., one-hot encoding for data items, but preserve the nl correspondence in the input code sequence: 39 | 40 | `python run.py --nl --use_comments --code_context --nl_code_linking --target_code_transform` 41 | 42 | 4. To run the LSTM decoder with the standard copy mechanism, do not preserve the nl correspondence: 43 | 44 | `python run.py --nl --use_comments --code_context --copy_mechanism --target_code_transform` 45 | 46 | 5. To run the LSTM decoder without the copy mechanism, i.e., one-hot encoding for data items as in prior work: 47 | 48 | `python run.py --nl --use_comments --code_context --target_code_transform` 49 | 50 | ### Key arguments 51 | In the following we list some important arguments for running neural models: 52 | * `--nl`: include the previous natural language cell as the model input. Note that the current code does not support including natural language from multiple cells, because it may not make sense to add NL instructions for previous code cells instead of the current one to confuse the model. 53 | * `--use_comments`: include the comments in the current code cell as the model input. 54 | * `--code_context`: include the code context as the model input. 55 | * `--target_code_transform`: standardize the target code sequence into a more canonical form. 56 | * `--max_num_code_cells`: the number of code cells included as the code context. Default: `2`. Note that setting it to `0` is not equivalent to not using the code context, because it still includes: (1) the code within the current code cell before the code snippet starting to generate the plots; and (2) the code context including the data frames and their attributes. 57 | * `--nl_code_linking`: if a code token appears in the nl, concatenate the code token embedding with the corresponding nl embedding. 58 | * `--copy_mechanism`: use the copy mechanism for the decoder. 59 | * `--hierarchy`: use the hierarchical decoder for code generation. 60 | * `--load_model`: path to the trained model (not required when training from scratch). 61 | * `--eval`: add this command during the test time, and remember to set `--load_model` for evaluation. 62 | 63 | ## Citation 64 | 65 | If you use the code in this repo, please cite the following paper: 66 | 67 | ``` 68 | @inproceedings{chen-2021-plotcoder, 69 | title={PlotCoder: Hierarchical Decoding for Synthesizing Visualization Code in Programmatic Context}, 70 | author={Chen, Xinyun and 71 | Gong, Linyuan and 72 | Cheung, Alvin and 73 | Song, Dawn}, 74 | booktitle={Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)}, 75 | year={2021} 76 | } 77 | ``` -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import sys 5 | 6 | def get_arg_parser(title): 7 | parser = argparse.ArgumentParser(description=title) 8 | parser.add_argument('--cpu', action='store_true', default=False) 9 | parser.add_argument('--eval', action='store_true') 10 | parser.add_argument('--model_dir', type=str, default='../checkpoints/model_0') 11 | parser.add_argument('--load_model', type=str, default=None) 12 | parser.add_argument('--num_LSTM_layers', type=int, default=2) 13 | parser.add_argument('--num_MLP_layers', type=int, default=1) 14 | parser.add_argument('--LSTM_hidden_size', type=int, default=512) 15 | parser.add_argument('--MLP_hidden_size', type=int, default=512) 16 | parser.add_argument('--embedding_size', type=int, default=512) 17 | 18 | parser.add_argument('--keep_last_n', type=int, default=None) 19 | parser.add_argument('--eval_every_n', type=int, default=1500) 20 | parser.add_argument('--log_interval', type=int, default=1500) 21 | parser.add_argument('--log_dir', type=str, default='../logs') 22 | parser.add_argument('--log_name', type=str, default='model_0.csv') 23 | 24 | parser.add_argument('--max_eval_size', type=int, default=1000) 25 | 26 | data_group = parser.add_argument_group('data') 27 | data_group.add_argument('--train_dataset', type=str, default='../data/train_plot.json') 28 | data_group.add_argument('--dev_dataset', type=str, default='../data/dev_plot_hard.json') 29 | data_group.add_argument('--test_dataset', type=str, default='../data/test_plot_hard.json') 30 | data_group.add_argument('--code_vocab', type=str, default='../data/code_vocab.json') 31 | data_group.add_argument('--word_vocab', type=str, default='../data/nl_vocab.json') 32 | data_group.add_argument('--word_vocab_size', type=int, default=None) 33 | data_group.add_argument('--code_vocab_size', type=int, default=None) 34 | data_group.add_argument('--num_plot_types', type=int, default=6) 35 | data_group.add_argument('--joint_plot_types', action='store_true', default=False) 36 | data_group.add_argument('--data_order_invariant', action='store_true', default=False) 37 | data_group.add_argument('--nl', action='store_true', default=False) 38 | data_group.add_argument('--use_comments', action='store_true', default=False) 39 | data_group.add_argument('--code_context', action='store_true', default=False) 40 | data_group.add_argument('--local_df_only', action='store_true', default=False) 41 | data_group.add_argument('--target_code_transform', action='store_true', default=False) 42 | data_group.add_argument('--max_num_code_cells', type=int, default=2) 43 | data_group.add_argument('--max_word_len', type=int, default=512) 44 | data_group.add_argument('--max_code_context_len', type=int, default=512) 45 | data_group.add_argument('--max_decode_len', type=int, default=200) 46 | 47 | model_group = parser.add_argument_group('model') 48 | model_group.add_argument('--hierarchy', action='store_true', default=False) 49 | model_group.add_argument('--copy_mechanism', action='store_true', default=False) 50 | model_group.add_argument('--nl_code_linking', action='store_true', default=False) 51 | 52 | train_group = parser.add_argument_group('train') 53 | train_group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd', 'rmsprop']) 54 | train_group.add_argument('--lr', type=float, default=1e-3) 55 | train_group.add_argument('--lr_decay_steps', type=int, default=6000) 56 | train_group.add_argument('--lr_decay_rate', type=float, default=0.9) 57 | train_group.add_argument('--dropout_rate', type=float, default=0.2) 58 | train_group.add_argument('--gradient_clip', type=float, default=5.0) 59 | train_group.add_argument('--num_epochs', type=int, default=50) 60 | train_group.add_argument('--batch_size', type=int, default=32) 61 | train_group.add_argument('--param_init', type=float, default=0.1) 62 | train_group.add_argument('--seed', type=int, default=None) 63 | 64 | return parser -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jungyhuk/plotcoder/4c5fe923dc69227c58d93f55b8a89fd8bb960703/models/__init__.py -------------------------------------------------------------------------------- /models/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jungyhuk/plotcoder/4c5fe923dc69227c58d93f55b8a89fd8bb960703/models/data_utils/__init__.py -------------------------------------------------------------------------------- /models/data_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """Data utils. 2 | """ 3 | 4 | import argparse 5 | import collections 6 | import json 7 | import numpy as np 8 | import os 9 | import re 10 | import string 11 | import sys 12 | import random 13 | import enum 14 | import six 15 | import copy 16 | from six.moves import map 17 | from six.moves import range 18 | from six.moves import zip 19 | import ast 20 | import ast2json 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | 25 | # Special vocabulary symbols 26 | _PAD = b"_PAD" 27 | _EOS = b"_EOS" 28 | _GO = b"_GO" 29 | _UNK = b"_UNK" 30 | _DF = b"_DF" 31 | _VAR = b"_VAR" 32 | _STR = b"_STR" 33 | _FUNC = b"_FUNC" 34 | _VALUE = b"_VALUE" 35 | _START_VOCAB = [_PAD, _EOS, _GO, _UNK, _DF, _VAR, _STR, _FUNC, _VALUE] 36 | 37 | PAD_ID = 0 38 | EOS_ID = 1 39 | GO_ID = 2 40 | UNK_ID = 3 41 | DF_ID = 4 42 | VAR_ID = 5 43 | STR_ID = 6 44 | FUNC_ID = 7 45 | VALUE_ID = 8 46 | 47 | def np_to_tensor(inp, output_type, cuda_flag): 48 | if output_type == 'float': 49 | inp_tensor = Variable(torch.FloatTensor(inp)) 50 | elif output_type == 'int': 51 | inp_tensor = Variable(torch.LongTensor(inp)) 52 | else: 53 | print('undefined tensor type') 54 | if cuda_flag: 55 | inp_tensor = inp_tensor.cuda() 56 | return inp_tensor 57 | 58 | class DataProcessor(object): 59 | def __init__(self, args): 60 | self.word_vocab = json.load(open(args.word_vocab, 'r')) 61 | self.code_vocab = json.load(open(args.code_vocab, 'r')) 62 | self.word_vocab_list = _START_VOCAB[:] 63 | self.code_vocab_list = _START_VOCAB[:] 64 | self.vocab_offset = len(_START_VOCAB) 65 | for word in self.word_vocab: 66 | while self.word_vocab[word] + self.vocab_offset >= len(self.word_vocab_list): 67 | self.word_vocab_list.append(word) 68 | self.word_vocab_list[self.word_vocab[word] + self.vocab_offset] = word 69 | for word in self.code_vocab: 70 | while self.code_vocab[word] + self.vocab_offset >= len(self.code_vocab_list): 71 | self.code_vocab_list.append(word) 72 | self.code_vocab_list[self.code_vocab[word] + self.vocab_offset] = word 73 | self.word_vocab_size = len(self.word_vocab) + self.vocab_offset 74 | self.code_vocab_size = len(self.code_vocab) + self.vocab_offset 75 | self.cuda_flag = args.cuda 76 | self.nl = args.nl 77 | self.code_context = args.code_context 78 | self.use_comments = args.use_comments 79 | self.local_df_only = args.local_df_only 80 | self.target_code_transform = args.target_code_transform 81 | self.hierarchy = args.hierarchy 82 | self.copy_mechanism = args.copy_mechanism 83 | self.max_num_code_cells = args.max_num_code_cells 84 | self.max_word_len = args.max_word_len 85 | self.max_code_context_len = args.max_code_context_len 86 | self.max_decode_len = args.max_decode_len 87 | self.scatter_word_list = ['scatter', 'scatterplot', "'scatter'", '"scatter"', 'scatter_kws', "'o'", "'bo'", "'r+'", '"o"', '"bo"', '"r+"'] 88 | self.hist_word_list = ['hist', "'hist'", '"hist"', 'bar', "'bar'", '"bar"', 'countplot', 'barplot'] 89 | self.pie_word_list = ['pie', "'pie'", '"pie"'] 90 | self.scatter_plot_word_list = ['lmplot', 'regplot'] 91 | self.hist_plot_word_list = ['distplot', 'kdeplot', 'contour'] 92 | self.normal_plot_word_list = ['plot'] 93 | self.reserved_words = ['plt', 'sns'] 94 | self.reserved_words += self.scatter_word_list + self.hist_word_list + self.pie_word_list + \ 95 | self.scatter_plot_word_list + self.hist_plot_word_list + self.normal_plot_word_list 96 | for word in self.code_vocab: 97 | if self.code_vocab[word] < 1000 and word != 'subplot' and word[-4:] == 'plot' and not ('_' in word) and not (word in self.reserved_words): 98 | self.reserved_words.append(word) 99 | self.default_program_mask = [0] * len(self.code_vocab) 100 | for word in self.reserved_words: 101 | if word in self.code_vocab: 102 | self.default_program_mask[self.code_vocab[word]] = 1 103 | 104 | def label_extraction(self, code_seq): 105 | label = -1 106 | for word in self.scatter_word_list: 107 | if word in code_seq: 108 | label = 1 109 | break 110 | for word in self.hist_word_list: 111 | if word in code_seq: 112 | if label != -1: 113 | return -1 114 | label = 2 115 | break 116 | for word in self.pie_word_list: 117 | if word in code_seq: 118 | if label != -1: 119 | return -1 120 | label = 3 121 | break 122 | for word in self.scatter_plot_word_list: 123 | if word in code_seq: 124 | label = 4 125 | break 126 | for word in self.hist_plot_word_list: 127 | if word in code_seq: 128 | label = 5 129 | break 130 | for word in self.normal_plot_word_list: 131 | if word in code_seq: 132 | if 'scatter' in code_seq: 133 | label = 4 134 | elif 'hist' in code_seq or 'bar' in code_seq or 'countplot' in code_seq or 'barplot' in code_seq: 135 | label = 5 136 | if label == -1: 137 | label = 0 138 | return label 139 | 140 | def data_extraction(self, target_code_seq, reserved_dfs, reserved_strs, reserved_vars): 141 | target_dfs = [] 142 | target_strs = [] 143 | target_vars = [] 144 | 145 | for i, tok in enumerate(target_code_seq): 146 | if i < len(target_code_seq) - 1 and target_code_seq[i + 1] == '=': 147 | continue 148 | if tok in reserved_dfs: 149 | target_dfs.append(tok) 150 | elif len(tok) > 2 and tok[0] in ["'", '"'] and tok[-1] in ["'", '"']: 151 | if tok[1:-1] in reserved_strs: 152 | target_strs.append(tok) 153 | elif i > 0 and i < len(target_code_seq) - 1 and target_code_seq[i - 1] == '[' and target_code_seq[i + 1] == ']': 154 | if i >= 3 and target_code_seq[i - 3] == '.': 155 | continue 156 | target_strs.append(tok) 157 | elif i >= 2 and target_code_seq[i - 1] == '=' and target_code_seq[i - 2] in ['x', 'y', 'data']: 158 | target_strs.append(tok) 159 | else: 160 | if tok in self.reserved_words: 161 | continue 162 | if tok in _START_VOCAB: 163 | continue 164 | if tok[0].isdigit() or tok[0] == '-' or '.' in tok: 165 | continue 166 | if tok in ['[', ']', '(', ')', '{', '}', 'ax', '=', '\n']: 167 | continue 168 | if i >= 2 and target_code_seq[i - 1] == '=' and target_code_seq[i - 2] not in ['x', 'y', 'data']: 169 | continue 170 | if i < len(target_code_seq) - 1 and target_code_seq[i + 1] == '[': 171 | if i > 0 and target_code_seq[i - 1] == '.': 172 | continue 173 | target_dfs.append(tok) 174 | reserved_dfs.append(tok) 175 | continue 176 | if i == len(target_code_seq) - 1 or target_code_seq[i + 1] not in ['.', ',', ')', ']']: 177 | continue 178 | if i < len(target_code_seq) - 2 and target_code_seq[i + 1] == '.' and target_code_seq[i + 2] in self.reserved_words and target_code_seq[i + 2] not in ['hist', 'pie']: 179 | continue 180 | target_vars.append(tok) 181 | return target_dfs, target_strs, target_vars, reserved_dfs 182 | 183 | 184 | def ids_to_prog(self, sample, ids): 185 | reserved_word_list = sample['reserved_dfs'] + sample['reserved_vars'] 186 | for tok in sample['reserved_strs']: 187 | reserved_word_list.append("'" + tok + "'") 188 | prog = [] 189 | for i in ids: 190 | if i < self.code_vocab_size: 191 | prog += [self.code_vocab_list[i]] 192 | else: 193 | prog += [reserved_word_list[i - self.code_vocab_size]] 194 | if i == EOS_ID: 195 | break 196 | return prog 197 | 198 | 199 | def get_joint_plot_type(self, init_label): 200 | if init_label in [0, 3]: 201 | return init_label 202 | elif init_label in [1, 4]: 203 | return 1 204 | else: 205 | return 2 206 | 207 | 208 | def load_data(self, filename): 209 | init_samples = json.load(open(filename, 'r')) 210 | samples = [] 211 | for sample in init_samples: 212 | code_seq = sample['code_tokens'] 213 | label = self.label_extraction(code_seq) 214 | samples.append(sample) 215 | return samples 216 | 217 | 218 | def ast_to_seq(self, ast): 219 | seq = [] 220 | if ast['_type'] == 'Str': 221 | seq.append("'" + ast['s'] + "'") 222 | return seq 223 | if ast['_type'] == 'Name': 224 | seq.append(ast['id']) 225 | return seq 226 | if ast['_type'] == 'Attribute': 227 | if 'id' not in ast['value']: 228 | return seq 229 | df = ast['value']['id'] 230 | attr = ast['attr'] 231 | seq.append(df) 232 | seq.append('.') 233 | seq.append(attr) 234 | return seq 235 | if ast['_type'] == 'Subscript': 236 | if 'id' not in ast['value']: 237 | return seq 238 | df = ast['value']['id'] 239 | if 'value' not in ast['slice']: 240 | seq.append(df) 241 | return seq 242 | attr = ast['slice']['value'] 243 | attr = self.ast_to_seq(attr) 244 | if len(attr) == 0: 245 | return seq 246 | seq.append(df) 247 | seq.append('[') 248 | seq += attr 249 | seq.append(']') 250 | return seq 251 | return [] 252 | 253 | def var_extraction(self, code, reserved_vars, reserved_dfs): 254 | cur_code = [] 255 | for tok_idx, tok in enumerate(code): 256 | cur_code.append(tok) 257 | if tok == '\n': 258 | parse_error = False 259 | try: 260 | ast_tree = ast2json.ast2json(ast.parse(''.join(cur_code))) 261 | except: 262 | parse_error = True 263 | if parse_error: 264 | continue 265 | if len(ast_tree['body']) == 0: 266 | cur_code = [] 267 | continue 268 | ast_tree = ast_tree['body'][0] 269 | if ast_tree['_type'] != 'Assign': 270 | cur_code = [] 271 | continue 272 | var_list = ast_tree['targets'] 273 | for var in var_list: 274 | if var['_type'] != 'Name': 275 | continue 276 | var_name = var['id'] 277 | if var_name not in reserved_vars + reserved_dfs: 278 | reserved_vars.append(var_name) 279 | cur_code = [] 280 | return reserved_vars 281 | 282 | def code_seq_transform(self, init_code_seq, reserved_dfs): 283 | code_seq = [] 284 | st = 0 285 | 286 | while st < len(init_code_seq): 287 | ed = st 288 | cur_code_seq = [] 289 | ast_tree = None 290 | while ed < len(init_code_seq) and init_code_seq[ed] != '\n': 291 | cur_code_seq.append(init_code_seq[ed]) 292 | ed += 1 293 | while ed <= len(init_code_seq): 294 | if ed < len(init_code_seq): 295 | cur_code_seq.append(init_code_seq[ed]) 296 | ed += 1 297 | parse_error = False 298 | try: 299 | ast_tree = ast2json.ast2json(ast.parse(''.join(cur_code_seq))) 300 | except: 301 | parse_error = True 302 | if not parse_error: 303 | break 304 | if ed == len(init_code_seq): 305 | break 306 | if ast_tree is None: 307 | st = ed 308 | continue 309 | if len(ast_tree['body']) == 0: 310 | st = ed 311 | continue 312 | ast_tree = ast_tree['body'][0]['value'] 313 | if ast_tree is None or 'func' not in ast_tree: 314 | st = ed 315 | continue 316 | func = ast_tree['func'] 317 | plot_type = self.label_extraction(cur_code_seq) 318 | if 'value' not in func or 'id' not in func['value']: 319 | st = ed 320 | continue 321 | func_name = func['value']['id'] 322 | if func_name == 'sns': 323 | code_seq.append('sns') 324 | else: 325 | code_seq.append('plt') 326 | code_seq.append('.') 327 | if func['attr'] != 'plot': 328 | code_seq.append(func['attr']) 329 | else: 330 | if plot_type in [1, 4]: 331 | code_seq.append('scatter') 332 | elif plot_type in [2, 5]: 333 | code_seq.append('hist') 334 | elif plot_type == 3: 335 | code_seq.append('pie') 336 | else: 337 | code_seq.append('plot') 338 | code_seq.append('(') 339 | 340 | data_value = None 341 | if 'keywords' in ast_tree: 342 | kvs = ast_tree['keywords'] 343 | for i in range(len(kvs)): 344 | kv = kvs[i] 345 | if kv['arg'] == 'data': 346 | data_value = kv['value'] 347 | break 348 | 349 | if data_value is None: 350 | if func_name == 'sns' and len(ast_tree['args']) > 2: 351 | data_value = ast_tree['args'][2] 352 | 353 | if data_value is not None: 354 | data_value = self.ast_to_seq(data_value) 355 | else: 356 | if func_name in reserved_dfs: 357 | data_value = [func_name] 358 | 359 | x_value = None 360 | if 'keywords' in ast_tree: 361 | kvs = ast_tree['keywords'] 362 | for i in range(len(kvs)): 363 | kv = kvs[i] 364 | if kv['arg'] == 'x': 365 | x_value = kv['value'] 366 | break 367 | 368 | if x_value is None: 369 | if len(ast_tree['args']) > 0: 370 | x_value = ast_tree['args'][0] 371 | 372 | if x_value is not None: 373 | x_value = self.ast_to_seq(x_value) 374 | if len(x_value) == 0 or len(x_value) == 1 and x_value[0][0] in ["'", '"']: 375 | x_value = None 376 | else: 377 | if data_value is not None: 378 | code_seq += data_value 379 | code_seq.append('[') 380 | code_seq += x_value 381 | if data_value is not None: 382 | code_seq.append(']') 383 | 384 | y_value = None 385 | if 'keywords' in ast_tree: 386 | kvs = ast_tree['keywords'] 387 | for i in range(len(kvs)): 388 | kv = kvs[i] 389 | if kv['arg'] == 'y': 390 | y_value = kv['value'] 391 | break 392 | 393 | if y_value is None: 394 | if plot_type in [0, 1, 4] and len(ast_tree['args']) > 1: 395 | y_value = ast_tree['args'][1] 396 | 397 | if y_value is not None: 398 | y_value = self.ast_to_seq(y_value) 399 | if len(y_value) == 0 or len(y_value) == 1 and y_value[0][0] in ["'", '"']: 400 | y_value = None 401 | else: 402 | if code_seq[-1] not in ['(', ',']: 403 | code_seq.append(',') 404 | if data_value is not None: 405 | code_seq += data_value 406 | code_seq.append('[') 407 | code_seq += y_value 408 | if data_value is not None: 409 | code_seq.append(']') 410 | 411 | if code_seq[-1] == '(' or x_value is None or plot_type in [0, 1, 4] and y_value is None: 412 | while len(code_seq) > 0 and code_seq[-1] != '\n': 413 | code_seq = code_seq[:-1] 414 | if len(code_seq) > 0 and code_seq[-1] != '\n': 415 | code_seq.append(')') 416 | code_seq.append('\n') 417 | 418 | st = ed 419 | return code_seq 420 | 421 | def preprocess(self, samples): 422 | data = [] 423 | indices = [] 424 | cnt_word = 0 425 | cnt_code = 0 426 | max_target_code_seq_len = 0 427 | min_target_code_seq_len = 512 428 | for sample_idx, sample in enumerate(samples): 429 | init_code_seq = sample['code_tokens'] 430 | api_seq = sample['api_sequence'] 431 | 432 | code_seq = [] 433 | for tok in init_code_seq: 434 | if len(tok) == 0 or tok[0] == '#': 435 | continue 436 | code_seq.append(tok) 437 | 438 | reserved_df_size = 0 439 | reserved_dfs = [] 440 | reserved_df_attr_list = [] 441 | reserved_str_size = 0 442 | reserved_strs = [] 443 | reserved_vars = [] 444 | reserved_var_size = 0 445 | 446 | code_context_cell_cnt = 0 447 | if self.local_df_only: 448 | max_num_code_cells = self.max_num_code_cells 449 | else: 450 | max_num_code_cells = len(sample['context']) 451 | for ctx_idx in range(len(sample['context'])): 452 | if code_context_cell_cnt == max_num_code_cells: 453 | break 454 | if not 'code_tokens' in sample['context'][ctx_idx]: 455 | continue 456 | cur_code_context = sample['context'][ctx_idx]['code_tokens'] 457 | if type(cur_code_context) != list: 458 | continue 459 | code_context_cell_cnt += 1 460 | for i in range(len(cur_code_context)): 461 | if cur_code_context[i] in self.reserved_words: 462 | continue 463 | if i > 0 and i < len(cur_code_context) - 2 and cur_code_context[i] == '[' and cur_code_context[i + 1][0] in ["'", '"'] and cur_code_context[i + 2] == ']': 464 | if cur_code_context[i - 1] in ['[', ']', '(', ')', '=', ',']: 465 | continue 466 | if cur_code_context[i - 1] not in reserved_dfs: 467 | reserved_dfs.append(cur_code_context[i - 1]) 468 | reserved_df_attr_list.append([]) 469 | if i >= 4 and cur_code_context[i] == 'read_csv' and cur_code_context[i - 1] == '.' and cur_code_context[i - 2] == 'pd' and cur_code_context[i - 3] == '=': 470 | if cur_code_context[i - 4] in ['[', ']', '(', ')', '=', ',']: 471 | continue 472 | if not (cur_code_context[i - 4] in reserved_dfs): 473 | reserved_dfs.append(cur_code_context[i - 4]) 474 | reserved_df_attr_list.append([]) 475 | if i >= 4 and cur_code_context[i] == 'DataFrame' and cur_code_context[i - 1] == '.' and cur_code_context[i - 2] == 'pd' and cur_code_context[i - 3] == '=': 476 | if cur_code_context[i - 4] in ['[', ']', '(', ')', '=', ',']: 477 | continue 478 | if not (cur_code_context[i - 4] in reserved_dfs): 479 | reserved_dfs.append(cur_code_context[i - 4]) 480 | reserved_df_attr_list.append([]) 481 | if i >= 4 and cur_code_context[i] == 'DataReader' and cur_code_context[i - 1] == '.' and cur_code_context[i - 2] == 'data' and cur_code_context[i - 3] == '=': 482 | if cur_code_context[i - 4] in ['[', ']', '(', ')', '=', ',']: 483 | continue 484 | if not (cur_code_context[i - 4] in reserved_dfs): 485 | reserved_dfs.append(cur_code_context[i - 4]) 486 | reserved_df_attr_list.append([]) 487 | if i >= 2 and i < len(cur_code_context) - 2 and cur_code_context[i] == 'head' and cur_code_context[i - 1] == '.' \ 488 | and cur_code_context[i + 1] == '(' and cur_code_context[i + 2] == ')': 489 | if cur_code_context[i - 2] in ['[', ']', '(', ')', '=', ',']: 490 | continue 491 | if not (cur_code_context[i - 2] in reserved_dfs): 492 | reserved_dfs.append(cur_code_context[i - 2]) 493 | reserved_df_attr_list.append([]) 494 | if i >= 4 and cur_code_context[i] == 'load' and cur_code_context[i - 1] == '.' and cur_code_context[i - 2] == 'np' and cur_code_context[i - 3] == '=': 495 | if cur_code_context[i - 4] in ['[', ']', '(', ')', '=', ',']: 496 | continue 497 | if not (cur_code_context[i - 4] in reserved_dfs): 498 | reserved_dfs.append(cur_code_context[i - 4]) 499 | reserved_df_attr_list.append([]) 500 | 501 | code_context = [] 502 | code_context_cell_cnt = 0 503 | for ctx_idx in range(len(sample['context'])): 504 | if code_context_cell_cnt == max_num_code_cells: 505 | break 506 | if not 'code_tokens' in sample['context'][ctx_idx]: 507 | continue 508 | init_cur_code_context = sample['context'][ctx_idx]['code_tokens'] 509 | if type(init_cur_code_context) != list: 510 | continue 511 | cur_code_context = [] 512 | for tok in init_cur_code_context: 513 | if len(tok) == 0 or tok[0] == '#': 514 | continue 515 | cur_code_context.append(tok) 516 | selected_code_context = [] 517 | i = 0 518 | while i < len(cur_code_context): 519 | if cur_code_context[i] in reserved_dfs: 520 | df_idx = reserved_dfs.index(cur_code_context[i]) 521 | selected = False 522 | csv_reading = False 523 | st = i - 1 524 | while st >= 0 and cur_code_context[st] != '\n': 525 | if cur_code_context[st] == 'read_csv': 526 | csv_reading = True 527 | st -= 1 528 | ed = i + 1 529 | while ed < len(cur_code_context) and cur_code_context[ed] != '\n': 530 | if cur_code_context[ed] == 'read_csv': 531 | csv_reading = True 532 | ed += 1 533 | while ed < len(cur_code_context): 534 | if cur_code_context[ed] == 'read_csv': 535 | csv_reading = True 536 | 537 | parse_error = False 538 | try: 539 | ast_tree = ast2json.ast2json(ast.parse(''.join(cur_code_context[st + 1:ed + 1]))) 540 | except: 541 | parse_error = True 542 | if not parse_error: 543 | break 544 | ed += 1 545 | if csv_reading: 546 | i = ed + 1 547 | continue 548 | for tok_idx in range(st + 1, ed): 549 | if cur_code_context[tok_idx] in self.reserved_words: 550 | continue 551 | if len(cur_code_context[tok_idx]) > 2 and cur_code_context[tok_idx][0] in ["'", '"'] and cur_code_context[tok_idx][-1] in ["'", '"'] and not (cur_code_context[tok_idx][1:-1] in reserved_df_attr_list[df_idx]): 552 | if not ('.csv' in cur_code_context[tok_idx] or cur_code_context[tok_idx - 1] == '='): 553 | reserved_df_attr_list[df_idx].append(cur_code_context[tok_idx][1:-1]) 554 | if cur_code_context[tok_idx][1:-1] not in reserved_strs: 555 | reserved_strs.append(cur_code_context[tok_idx][1:-1]) 556 | reserved_str_size += 1 557 | selected = True 558 | if selected: 559 | if len(selected_code_context) > 0 and selected_code_context[-1] != '\n': 560 | selected_code_context += ['\n'] 561 | selected_code_context = selected_code_context + cur_code_context[st + 1: ed + 1] 562 | i = ed + 1 563 | continue 564 | 565 | if cur_code_context[i] == 'savez': 566 | st = i - 1 567 | while st >= 0 and cur_code_context[st] != '\n': 568 | st -= 1 569 | ed = i + 1 570 | while ed < len(cur_code_context) and cur_code_context[ed] != '\n': 571 | ed += 1 572 | while ed < len(cur_code_context): 573 | parse_error = False 574 | try: 575 | ast_tree = ast2json.ast2json(ast.parse(''.join(cur_code_context[st + 1:ed + 1]))) 576 | except: 577 | parse_error = True 578 | if not parse_error: 579 | break 580 | ed += 1 581 | 582 | if 'body' not in ast_tree: 583 | i += 1 584 | continue 585 | 586 | ast_tree = ast_tree['body'][0]['value'] 587 | if 'keywords' in ast_tree: 588 | kvs = ast_tree['keywords'] 589 | for i in range(len(kvs)): 590 | kv = kvs[i] 591 | if kv['arg'] not in reserved_strs: 592 | reserved_strs.append(kv['arg']) 593 | reserved_str_size += 1 594 | selected = True 595 | 596 | if selected: 597 | if len(selected_code_context) > 0 and selected_code_context[-1] != '\n': 598 | selected_code_context += ['\n'] 599 | selected_code_context = selected_code_context + cur_code_context[st + 1: ed + 1] 600 | i = ed + 1 601 | continue 602 | i += 1 603 | if code_context_cell_cnt < self.max_num_code_cells: 604 | if len(code_context) > 0 and len(cur_code_context) > 0 and cur_code_context[-1] != '\n': 605 | code_context = ['\n'] + code_context 606 | reserved_vars = self.var_extraction(cur_code_context, reserved_vars, reserved_dfs) 607 | code_context = cur_code_context + code_context 608 | code_context_cell_cnt += 1 609 | else: 610 | if len(code_context) > 0 and len(selected_code_context) > 0 and selected_code_context[-1] != '\n': 611 | code_context = ['\n'] + code_context 612 | code_context = selected_code_context + code_context 613 | reserved_vars = self.var_extraction(selected_code_context, reserved_vars, reserved_dfs) 614 | 615 | keyword_pos = code_seq.index('plt') 616 | while code_seq[keyword_pos + 1] != '.': 617 | keyword_pos = keyword_pos + 1 + code_seq[keyword_pos + 1:].index('plt') 618 | if 'sns' in code_seq: 619 | keyword_pos = min(keyword_pos, code_seq.index('sns')) 620 | if len(code_context) > 0 and code_context[-1] != '\n': 621 | code_context += ['\n'] 622 | code_context += code_seq[:keyword_pos] 623 | 624 | i = 0 625 | while i < keyword_pos: 626 | if code_seq[i] in reserved_dfs: 627 | df_idx = reserved_dfs.index(code_seq[i]) 628 | csv_reading = False 629 | st = i - 1 630 | while st >= 0 and code_seq[st] != '\n': 631 | if code_seq[st] == 'read_csv': 632 | csv_reading = True 633 | st -= 1 634 | ed = i + 1 635 | while ed < keyword_pos and code_seq[ed] != '\n': 636 | if code_seq[ed] == 'read_csv': 637 | csv_reading = True 638 | ed += 1 639 | while ed < keyword_pos: 640 | if code_seq[ed] == 'read_csv': 641 | csv_reading = True 642 | 643 | parse_error = False 644 | try: 645 | ast_tree = ast2json.ast2json(ast.parse(''.join(code_seq[st + 1:ed + 1]))) 646 | except: 647 | parse_error = True 648 | if not parse_error: 649 | break 650 | ed += 1 651 | if csv_reading: 652 | i = ed + 1 653 | continue 654 | for tok_idx in range(st + 1, ed): 655 | if code_seq[tok_idx] in self.reserved_words: 656 | continue 657 | if len(code_seq[tok_idx]) > 2 and code_seq[tok_idx][0] in ["'", '"'] and code_seq[tok_idx][-1] in ["'", '"'] and not (code_seq[tok_idx][1:-1] in reserved_df_attr_list[df_idx]): 658 | if not ('.csv' in code_seq[tok_idx] or code_seq[tok_idx - 1] == '='): 659 | reserved_df_attr_list[df_idx].append(code_seq[tok_idx][1:-1]) 660 | if code_seq[tok_idx][1:-1] not in reserved_strs: 661 | reserved_strs.append(code_seq[tok_idx][1:-1]) 662 | reserved_str_size += 1 663 | i = ed + 1 664 | continue 665 | 666 | if code_seq[i] == 'savez': 667 | st = i - 1 668 | while st >= 0 and code_seq[st] != '\n': 669 | st -= 1 670 | ed = i + 1 671 | while ed < keyword_pos and code_seq[ed] != '\n': 672 | ed += 1 673 | while ed < keyword_pos: 674 | parse_error = False 675 | try: 676 | ast_tree = ast2json.ast2json(ast.parse(''.join(code_seq[st + 1:ed + 1]))) 677 | except: 678 | parse_error = True 679 | if not parse_error: 680 | break 681 | ed += 1 682 | 683 | if 'body' not in ast_tree: 684 | i += 1 685 | continue 686 | 687 | ast_tree = ast_tree['body'][0]['value'] 688 | 689 | if 'keywords' in ast_tree: 690 | kvs = ast_tree['keywords'] 691 | for i in range(len(kvs)): 692 | kv = kvs[i] 693 | if kv['arg'] not in reserved_strs: 694 | reserved_strs.append(kv['arg']) 695 | reserved_str_size += 1 696 | i = ed + 1 697 | continue 698 | i += 1 699 | 700 | code_seq = code_seq[keyword_pos:] 701 | target_code_seq = [] 702 | selected_code_idx = 0 703 | code_idx = 0 704 | while code_idx < len(code_seq): 705 | tok = code_seq[code_idx] 706 | if not (tok in self.reserved_words): 707 | code_idx += 1 708 | continue 709 | if code_idx == len(code_seq) - 1 or code_seq[code_idx + 1] != '(': 710 | code_idx += 1 711 | continue 712 | st_idx = code_idx - 1 713 | while st_idx >= 0 and code_seq[st_idx] != '\n': 714 | st_idx -= 1 715 | ed_idx = code_idx + 2 716 | include_function_calls = False 717 | while ed_idx < len(code_seq) and code_seq[ed_idx] != ')': 718 | if code_seq[ed_idx] == '(': 719 | include_function_calls = True 720 | break 721 | ed_idx += 1 722 | if include_function_calls: 723 | code_idx += 1 724 | continue 725 | while ed_idx < len(code_seq) and code_seq[ed_idx] != '\n': 726 | ed_idx += 1 727 | target_code_seq += code_seq[st_idx + 1: ed_idx + 1] 728 | code_context += code_seq[selected_code_idx:st_idx + 1] 729 | 730 | i = selected_code_idx 731 | while i <= st_idx: 732 | if code_seq[i] in reserved_dfs: 733 | df_idx = reserved_dfs.index(code_seq[i]) 734 | csv_reading = False 735 | st = i - 1 736 | while st >= selected_code_idx and code_seq[st] != '\n': 737 | if code_seq[st] == 'read_csv': 738 | csv_reading = True 739 | st -= 1 740 | ed = i + 1 741 | while ed <= st_idx and code_seq[ed] != '\n': 742 | if code_seq[ed] == 'read_csv': 743 | csv_reading = True 744 | ed += 1 745 | while ed <= st_idx: 746 | if code_seq[ed] == 'read_csv': 747 | csv_reading = True 748 | 749 | parse_error = False 750 | try: 751 | ast_tree = ast2json.ast2json(ast.parse(''.join(code_seq[st + 1:ed + 1]))) 752 | except: 753 | parse_error = True 754 | if not parse_error: 755 | break 756 | ed += 1 757 | if csv_reading: 758 | i = ed + 1 759 | continue 760 | for tok_idx in range(st + 1, ed): 761 | if code_seq[tok_idx] in self.reserved_words: 762 | continue 763 | if len(code_seq[tok_idx]) > 2 and code_seq[tok_idx][0] in ["'", '"'] and code_seq[tok_idx][-1] in ["'", '"'] and not (code_seq[tok_idx][1:-1] in reserved_df_attr_list[df_idx]): 764 | if not ('.csv' in code_seq[tok_idx] or code_seq[tok_idx - 1] == '='): 765 | reserved_df_attr_list[df_idx].append(code_seq[tok_idx][1:-1]) 766 | if code_seq[tok_idx][1:-1] not in reserved_strs: 767 | reserved_strs.append(code_seq[tok_idx][1:-1]) 768 | reserved_str_size += 1 769 | i = ed + 1 770 | continue 771 | if code_seq[i] == 'savez': 772 | st = i - 1 773 | while st >= selected_code_idx and code_seq[st] != '\n': 774 | st -= 1 775 | ed = i + 1 776 | while ed <= st_idx and code_seq[ed] != '\n': 777 | ed += 1 778 | while ed <= st_idx: 779 | parse_error = False 780 | try: 781 | ast_tree = ast2json.ast2json(ast.parse(''.join(code_seq[st + 1:ed + 1]))) 782 | except: 783 | parse_error = True 784 | if not parse_error: 785 | break 786 | ed += 1 787 | 788 | if 'body' not in ast_tree: 789 | i += 1 790 | continue 791 | ast_tree = ast_tree['body'][0]['value'] 792 | if 'keywords' in ast_tree: 793 | kvs = ast_tree['keywords'] 794 | for i in range(len(kvs)): 795 | kv = kvs[i] 796 | if kv['arg'] not in reserved_strs: 797 | reserved_strs.append(kv['arg']) 798 | reserved_str_size += 1 799 | i = ed + 1 800 | continue 801 | i += 1 802 | 803 | selected_code_idx = ed_idx + 1 804 | code_idx = ed_idx + 1 805 | 806 | init_target_code_seq = target_code_seq[:] 807 | target_code_seq = self.code_seq_transform(target_code_seq, reserved_dfs) 808 | 809 | label = self.label_extraction(target_code_seq) 810 | if label == -1: 811 | continue 812 | 813 | if len(target_code_seq) <= 5: 814 | continue 815 | 816 | if not self.target_code_transform: 817 | target_code_seq = init_target_code_seq[:] 818 | 819 | max_target_code_seq_len = max(max_target_code_seq_len, len(target_code_seq)) 820 | min_target_code_seq_len = min(min_target_code_seq_len, len(target_code_seq)) 821 | 822 | input_word_seq = [] 823 | 824 | if self.nl and not self.use_comments: 825 | nl = sample['nl'] 826 | nl = nl[:self.max_word_len - 1] 827 | elif self.use_comments and not self.nl: 828 | nl = sample['comments'] 829 | nl = nl[:self.max_word_len - 1] 830 | elif not self.nl and not self.use_comments: 831 | nl = [] 832 | else: 833 | nl = sample['nl'] + sample['comments'] 834 | if len(nl) > self.max_word_len - 1: 835 | if len(sample['comments']) <= self.max_word_len // 2: 836 | nl = sample['nl'][:self.max_word_len - 1 - len(sample['comments'])] + sample['comments'] 837 | elif len(sample['nl']) <= self.max_word_len // 2: 838 | nl = sample['nl'] + sample['comments'][:self.max_word_len - 1 - len(sample['nl'])] 839 | else: 840 | nl = sample['nl'][:self.max_word_len // 2 - 1] + sample['comments'][:self.max_word_len // 2] 841 | 842 | if not self.code_context: 843 | code_context = [] 844 | if len(code_context) > self.max_code_context_len - 1: 845 | code_context = code_context[1 - self.max_code_context_len:] 846 | 847 | target_dfs, target_strs, target_vars, reserved_dfs = self.data_extraction(target_code_seq, reserved_dfs, reserved_strs, reserved_vars) 848 | 849 | init_reserved_dfs = list(reserved_dfs) 850 | for tok in init_reserved_dfs: 851 | if not (tok in code_context): 852 | reserved_dfs.remove(tok) 853 | reserved_df_size = len(reserved_dfs) 854 | 855 | init_reserved_strs = list(reserved_strs) 856 | for tok in init_reserved_strs: 857 | if not ('"' + tok + '"' in code_context or "'" + tok + "'" in code_context or tok in code_context): 858 | reserved_strs.remove(tok) 859 | reserved_str_size = len(reserved_strs) 860 | 861 | init_reserved_vars = list(reserved_vars) 862 | for tok in init_reserved_vars: 863 | if not (tok in code_context): 864 | reserved_vars.remove(tok) 865 | reserved_var_size = len(reserved_vars) 866 | 867 | input_code_seq = [] 868 | input_code_nl_indices = [] 869 | input_code_df_seq = [] 870 | input_code_var_seq = [] 871 | input_code_str_seq = [] 872 | 873 | for i in range(len(code_context)): 874 | tok = code_context[i] 875 | input_code_nl_indices.append([]) 876 | 877 | if tok in _START_VOCAB: 878 | input_code_seq.append(_START_VOCAB.index(tok)) 879 | input_code_df_seq.append(-1) 880 | input_code_var_seq.append(-1) 881 | input_code_str_seq.append(-1) 882 | continue 883 | 884 | nl_lower = [tok.lower() for tok in nl] 885 | if tok.lower() in nl_lower: 886 | input_code_nl_indices[-1].append(nl_lower.index(tok.lower())) 887 | elif ("'" + tok.lower() + "'") in nl_lower: 888 | input_code_nl_indices[-1].append(nl_lower.index("'" + tok.lower() + "'")) 889 | elif ('"' + tok.lower() + '"') in nl_lower: 890 | input_code_nl_indices[-1].append(nl_lower.index('"' + tok.lower() + '"')) 891 | elif tok[0] in ["'", '"'] and tok[-1] in ["'", '"'] and tok[1:-1].lower() in nl_lower: 892 | input_code_nl_indices[-1].append(nl_lower.index(tok[1:-1].lower())) 893 | elif '_' in tok.lower(): 894 | if tok[0] in ["'", '"'] and tok[-1] in ["'", '"']: 895 | tok_list = tok[1:-1].split('_') 896 | else: 897 | tok_list = tok.split('_') 898 | for sub_tok in tok_list: 899 | if sub_tok.lower() in nl_lower: 900 | input_code_nl_indices[-1].append(nl_lower.index(sub_tok.lower())) 901 | if len(input_code_nl_indices[-1]) > 2: 902 | input_code_nl_indices[-1] = input_code_nl_indices[-1][:2] 903 | elif len(input_code_nl_indices[-1]) < 2: 904 | input_code_nl_indices[-1] = input_code_nl_indices[-1] + [len(nl)] * (2 - len(input_code_nl_indices[-1])) 905 | 906 | if tok in self.code_vocab: 907 | input_code_seq.append(self.code_vocab[tok] + self.vocab_offset) 908 | 909 | if tok in reserved_dfs: 910 | input_code_df_seq.append(reserved_dfs.index(tok)) 911 | else: 912 | input_code_df_seq.append(-1) 913 | if tok in reserved_vars: 914 | input_code_var_seq.append(reserved_vars.index(tok)) 915 | else: 916 | input_code_var_seq.append(-1) 917 | if len(tok) > 2 and tok[0] in ["'", '"'] and tok[-1] in ["'", '"'] and tok[1:-1] in reserved_strs: 918 | input_code_str_seq.append(reserved_strs.index(tok[1:-1])) 919 | else: 920 | input_code_str_seq.append(-1) 921 | elif tok in reserved_dfs: 922 | input_code_seq.append(DF_ID) 923 | input_code_df_seq.append(reserved_dfs.index(tok)) 924 | input_code_var_seq.append(-1) 925 | input_code_str_seq.append(-1) 926 | elif tok in reserved_vars: 927 | input_code_seq.append(VAR_ID) 928 | input_code_df_seq.append(-1) 929 | input_code_var_seq.append(reserved_vars.index(tok)) 930 | input_code_str_seq.append(-1) 931 | elif tok[-1] in ["'", '"']: 932 | input_code_seq.append(STR_ID) 933 | if len(tok) > 2 and tok[0] in ["'", '"'] and tok[-1] in ["'", '"'] and tok[1:-1] in reserved_strs: 934 | input_code_str_seq.append(reserved_strs.index(tok[1:-1])) 935 | else: 936 | input_code_str_seq.append(-1) 937 | input_code_df_seq.append(-1) 938 | input_code_var_seq.append(-1) 939 | elif tok[0].isdigit() or tok[0] == '-' or '.' in tok: 940 | input_code_seq.append(VALUE_ID) 941 | input_code_df_seq.append(-1) 942 | input_code_var_seq.append(-1) 943 | input_code_str_seq.append(-1) 944 | elif i < len(code_context) - 1 and code_context[i + 1] == '(': 945 | input_code_seq.append(FUNC_ID) 946 | input_code_df_seq.append(-1) 947 | input_code_var_seq.append(-1) 948 | input_code_str_seq.append(-1) 949 | else: 950 | reserved_vars.append(tok) 951 | reserved_var_size += 1 952 | input_code_seq.append(VAR_ID) 953 | input_code_df_seq.append(-1) 954 | input_code_var_seq.append(reserved_vars.index(tok)) 955 | input_code_str_seq.append(-1) 956 | 957 | if not self.copy_mechanism: 958 | for i in range(len(input_code_seq)): 959 | if input_code_seq[i] == DF_ID and input_code_df_seq[i] != -1: 960 | input_code_seq[i] = self.code_vocab_size + input_code_df_seq[i] 961 | elif input_code_seq[i] == VAR_ID and input_code_var_seq[i] != -1: 962 | input_code_seq[i] = self.code_vocab_size + reserved_df_size + input_code_var_seq[i] 963 | elif input_code_seq[i] == STR_ID and input_code_str_seq[i] != -1: 964 | input_code_seq[i] = self.code_vocab_size + reserved_df_size + reserved_var_size + input_code_str_seq[i] 965 | 966 | for word in nl: 967 | if word in self.word_vocab: 968 | input_word_seq.append(self.word_vocab[word] + self.vocab_offset) 969 | elif word in _START_VOCAB: 970 | input_word_seq.append(_START_VOCAB.index(word)) 971 | elif word in reserved_vars: 972 | input_word_seq.append(VAR_ID) 973 | elif word in reserved_dfs: 974 | input_word_seq.append(DF_ID) 975 | elif word in reserved_strs: 976 | str_idx = reserved_strs.index(word) 977 | input_word_seq.append(STR_ID) 978 | elif word[0] in ["'", '"'] and word[-1] in ["'", '"']: 979 | tok = word[1:-1] 980 | if tok in reserved_vars: 981 | input_word_seq.append(VAR_ID) 982 | elif tok in reserved_dfs: 983 | input_word_seq.append(DF_ID) 984 | else: 985 | input_word_seq.append(UNK_ID) 986 | else: 987 | input_word_seq.append(UNK_ID) 988 | 989 | output_code_seq = [] 990 | output_code_df_seq = [] 991 | output_code_var_seq = [] 992 | output_code_str_seq = [] 993 | output_gt = [] 994 | 995 | for i, tok in enumerate(target_code_seq): 996 | if tok in self.code_vocab: 997 | output_code_seq.append(self.code_vocab[tok] + self.vocab_offset) 998 | if tok in reserved_dfs: 999 | if self.hierarchy: 1000 | output_code_seq[-1] = DF_ID 1001 | output_code_df_seq.append(self.code_vocab_size + reserved_dfs.index(tok)) 1002 | else: 1003 | output_code_df_seq.append(-1) 1004 | 1005 | if tok in target_vars and tok in code_context: 1006 | if self.hierarchy: 1007 | output_code_seq[-1] = VAR_ID 1008 | output_code_var_seq.append(self.code_vocab[tok] + self.vocab_offset) 1009 | else: 1010 | output_code_var_seq.append(-1) 1011 | 1012 | if len(tok) > 2 and tok[0] in ["'", '"'] and tok[-1] in ["'", '"'] and tok[1:-1] in reserved_strs: 1013 | if self.hierarchy: 1014 | output_code_seq[-1] = STR_ID 1015 | output_code_str_seq.append(self.code_vocab_size + reserved_df_size + reserved_var_size + reserved_strs.index(tok[1:-1])) 1016 | elif tok in code_context and not (tok in reserved_dfs + reserved_vars + reserved_strs + self.reserved_words + sample['imports']) and tok[-1] in ['"', '"']: 1017 | output_code_str_seq.append(self.code_vocab[tok] + self.vocab_offset) 1018 | else: 1019 | output_code_str_seq.append(-1) 1020 | elif tok in _START_VOCAB: 1021 | output_code_seq.append(_START_VOCAB.index(tok)) 1022 | output_code_df_seq.append(-1) 1023 | output_code_var_seq.append(-1) 1024 | output_code_str_seq.append(-1) 1025 | elif tok in reserved_dfs: 1026 | df_idx = reserved_dfs.index(tok) 1027 | if self.hierarchy: 1028 | output_code_seq.append(DF_ID) 1029 | else: 1030 | output_code_seq.append(self.code_vocab_size + df_idx) 1031 | output_code_df_seq.append(self.code_vocab_size + df_idx) 1032 | output_code_var_seq.append(-1) 1033 | output_code_str_seq.append(-1) 1034 | elif tok in reserved_vars: 1035 | var_idx = reserved_vars.index(tok) 1036 | if self.hierarchy: 1037 | output_code_seq.append(VAR_ID) 1038 | else: 1039 | output_code_seq.append(self.code_vocab_size + reserved_df_size + var_idx) 1040 | output_code_df_seq.append(-1) 1041 | output_code_var_seq.append(self.code_vocab_size + reserved_df_size + var_idx) 1042 | output_code_str_seq.append(-1) 1043 | elif len(tok) > 2 and tok[0] in ["'", '"'] and tok[-1] in ["'", '"'] and tok[1:-1] in reserved_strs: 1044 | str_idx = reserved_strs.index(tok[1:-1]) 1045 | if self.hierarchy: 1046 | output_code_seq.append(STR_ID) 1047 | else: 1048 | output_code_seq.append(self.code_vocab_size + reserved_df_size + reserved_var_size + str_idx) 1049 | output_code_df_seq.append(-1) 1050 | output_code_var_seq.append(-1) 1051 | output_code_str_seq.append(self.code_vocab_size + reserved_df_size + reserved_var_size + str_idx) 1052 | elif tok[-1] in ["'", '"']: 1053 | output_code_seq.append(PAD_ID) 1054 | output_code_df_seq.append(-1) 1055 | output_code_var_seq.append(-1) 1056 | output_code_str_seq.append(-1) 1057 | elif tok[0].isdigit() or tok[0] == '-' or '.' in tok: 1058 | output_code_seq.append(PAD_ID) 1059 | output_code_df_seq.append(-1) 1060 | output_code_var_seq.append(-1) 1061 | output_code_str_seq.append(-1) 1062 | elif i < len(target_code_seq) - 1 and target_code_seq[i + 1] in ['(', '=']: 1063 | output_code_seq.append(PAD_ID) 1064 | output_code_df_seq.append(-1) 1065 | output_code_var_seq.append(-1) 1066 | output_code_str_seq.append(-1) 1067 | else: 1068 | output_code_seq.append(PAD_ID) 1069 | output_code_df_seq.append(-1) 1070 | output_code_var_seq.append(-1) 1071 | output_code_str_seq.append(-1) 1072 | if output_code_seq[-1] == DF_ID: 1073 | output_gt.append(output_code_df_seq[-1]) 1074 | elif output_code_seq[-1] == VAR_ID: 1075 | output_gt.append(output_code_var_seq[-1]) 1076 | elif output_code_seq[-1] == STR_ID: 1077 | output_gt.append(output_code_str_seq[-1]) 1078 | else: 1079 | output_gt.append(output_code_seq[-1]) 1080 | 1081 | input_word_seq += [EOS_ID] 1082 | input_code_seq += [EOS_ID] 1083 | output_code_seq += [EOS_ID] 1084 | output_gt += [EOS_ID] 1085 | output_code_df_seq += [-1] 1086 | output_code_var_seq += [-1] 1087 | output_code_str_seq += [-1] 1088 | output_code_mask = [1] * 3 + [0] * (self.vocab_offset - 3) 1089 | output_code_mask += list(self.default_program_mask) 1090 | 1091 | for tok in input_code_seq: 1092 | if tok < self.code_vocab_size: 1093 | output_code_mask[tok] = 1 1094 | 1095 | for tok in output_code_seq: 1096 | if tok < self.code_vocab_size: 1097 | output_code_mask[tok] = 1 1098 | 1099 | if not self.hierarchy: 1100 | output_code_mask += [1] * (reserved_df_size + reserved_var_size + reserved_str_size) 1101 | else: 1102 | output_code_mask += [0] * (reserved_df_size + reserved_var_size + reserved_str_size) 1103 | 1104 | output_df_mask = [0] * (self.code_vocab_size + reserved_df_size + reserved_var_size + reserved_str_size) 1105 | for df_idx in range(reserved_df_size): 1106 | output_df_mask[self.code_vocab_size + df_idx] = 1 1107 | if reserved_dfs[df_idx] in self.code_vocab: 1108 | output_df_mask[self.code_vocab[reserved_dfs[df_idx]] + self.vocab_offset] = 1 1109 | 1110 | output_var_mask = [0] * (self.code_vocab_size + reserved_df_size + reserved_var_size + reserved_str_size) 1111 | for var_idx in range(reserved_var_size): 1112 | output_var_mask[self.code_vocab_size + reserved_df_size + var_idx] = 1 1113 | for tok in code_context: 1114 | if tok in self.code_vocab and tok in target_vars: 1115 | output_var_mask[self.code_vocab[tok] + self.vocab_offset] = 1 1116 | if tok in self.code_vocab and not (tok in reserved_dfs + reserved_vars + reserved_strs + self.reserved_words + sample['imports']) and not (tok[-1] in ['"', '"']) and not (tok[0].isdigit() or tok[0] == '-' or '.' in tok) and (i == len(target_code_seq) - 1 or target_code_seq[i + 1] != '('): 1117 | output_var_mask[self.code_vocab[tok] + self.vocab_offset] = 1 1118 | 1119 | output_str_mask = [0] * (self.code_vocab_size + reserved_df_size + reserved_var_size + reserved_str_size) 1120 | for str_idx in range(reserved_str_size): 1121 | output_str_mask[self.code_vocab_size + reserved_df_size + reserved_var_size + str_idx] = 1 1122 | for tok in code_context: 1123 | if tok in self.code_vocab and not (tok in reserved_dfs + reserved_vars + reserved_strs + self.reserved_words + sample['imports']) and tok[-1] in ['"', '"']: 1124 | output_str_mask[self.code_vocab[tok] + self.vocab_offset] = 1 1125 | 1126 | for i in range(3, self.vocab_offset): 1127 | output_df_mask[i] = 0 1128 | output_var_mask[i] = 0 1129 | output_str_mask[i] = 0 1130 | if not self.hierarchy: 1131 | output_code_mask[i] = 0 1132 | 1133 | output_code_indices = [] 1134 | output_code_ctx_indices = [] 1135 | for tok in self.code_vocab_list: 1136 | if tok in code_context: 1137 | output_code_ctx_indices.append(code_context.index(tok)) 1138 | else: 1139 | output_code_ctx_indices.append(len(code_context)) 1140 | 1141 | output_code_nl_indices = [] 1142 | for tok in self.code_vocab_list: 1143 | output_code_nl_indices.append([]) 1144 | if tok in _START_VOCAB: 1145 | output_code_nl_indices[-1] += [len(nl), len(nl)] 1146 | elif tok in self.scatter_word_list: 1147 | nl_lower = [tok.lower() for tok in nl] 1148 | if 'scatter' in nl_lower: 1149 | output_code_nl_indices[-1] += [nl_lower.index('scatter'), len(nl)] 1150 | elif 'scatterplot' in nl_lower: 1151 | output_code_nl_indices[-1] += [nl_lower.index('scatterplot'), len(nl)] 1152 | else: 1153 | output_code_nl_indices[-1] += [len(nl), len(nl)] 1154 | elif tok in self.hist_word_list: 1155 | nl_lower = [tok.lower() for tok in nl] 1156 | if 'histogram' in nl_lower: 1157 | output_code_nl_indices[-1] += [nl_lower.index('histogram'), len(nl)] 1158 | elif 'histograms' in nl_lower: 1159 | output_code_nl_indices[-1] += [nl_lower.index('histograms'), len(nl)] 1160 | else: 1161 | output_code_nl_indices[-1] += [len(nl), len(nl)] 1162 | elif tok in self.pie_word_list: 1163 | nl_lower = [tok.lower() for tok in nl] 1164 | if 'pie' in nl_lower: 1165 | output_code_nl_indices[-1] += [nl_lower.index('pie'), len(nl)] 1166 | else: 1167 | output_code_nl_indices[-1] += [len(nl), len(nl)] 1168 | elif tok in self.scatter_plot_word_list: 1169 | nl_lower = [tok.lower() for tok in nl] 1170 | if 'scatter' in nl_lower: 1171 | output_code_nl_indices[-1] += [nl_lower.index('scatter')] 1172 | if 'line' in nl_lower: 1173 | output_code_nl_indices[-1] += [nl_lower.index('line')] 1174 | elif 'linear' in nl_lower: 1175 | output_code_nl_indices[-1] += [nl_lower.index('linear')] 1176 | if len(output_code_nl_indices[-1]) < 2: 1177 | output_code_nl_indices[-1] += [len(nl)] * (2 - len(output_code_nl_indices[-1])) 1178 | elif tok in self.hist_plot_word_list: 1179 | nl_lower = [tok.lower() for tok in nl] 1180 | if 'distribution' in nl_lower: 1181 | output_code_nl_indices[-1] += [nl_lower.index('distribution'), len(nl)] 1182 | else: 1183 | output_code_nl_indices[-1] += [len(nl), len(nl)] 1184 | elif tok in code_context: 1185 | output_code_nl_indices[-1] += input_code_nl_indices[code_context.index(tok)] 1186 | else: 1187 | output_code_nl_indices[-1] += [len(nl), len(nl)] 1188 | 1189 | 1190 | for tok in reserved_dfs + reserved_vars: 1191 | output_code_indices.append(code_context.index(tok)) 1192 | for tok in reserved_strs: 1193 | if "'" + tok + "'" in code_context: 1194 | output_code_indices.append(code_context.index("'" + tok + "'")) 1195 | elif '"' + tok + '"' in code_context: 1196 | output_code_indices.append(code_context.index('"' + tok + '"')) 1197 | else: 1198 | output_code_indices.append(code_context.index(tok)) 1199 | 1200 | cur_data = {} 1201 | cur_data['reserved_dfs'] = reserved_dfs 1202 | cur_data['reserved_vars'] = reserved_vars 1203 | cur_data['reserved_strs'] = reserved_strs 1204 | cur_data['target_dfs'] = target_dfs 1205 | cur_data['target_strs'] = target_strs 1206 | cur_data['target_vars'] = target_vars 1207 | cur_data['input_word_seq'] = input_word_seq 1208 | cur_data['input_code_seq'] = input_code_seq 1209 | cur_data['input_code_df_seq'] = input_code_df_seq 1210 | cur_data['input_code_var_seq'] = input_code_var_seq 1211 | cur_data['input_code_str_seq'] = input_code_str_seq 1212 | cur_data['input_code_nl_indices'] = input_code_nl_indices 1213 | cur_data['output_gt'] = output_gt 1214 | cur_data['output_code_seq'] = output_code_seq 1215 | cur_data['output_code_df_seq'] = output_code_df_seq 1216 | cur_data['output_code_var_seq'] = output_code_var_seq 1217 | cur_data['output_code_str_seq'] = output_code_str_seq 1218 | cur_data['output_code_mask'] = output_code_mask 1219 | cur_data['output_df_mask'] = output_df_mask 1220 | cur_data['output_var_mask'] = output_var_mask 1221 | cur_data['output_str_mask'] = output_str_mask 1222 | cur_data['output_code_nl_indices'] = output_code_nl_indices 1223 | cur_data['output_code_ctx_indices'] = output_code_ctx_indices 1224 | cur_data['output_code_indices'] = output_code_indices 1225 | cur_data['label'] = label 1226 | data.append(cur_data) 1227 | indices.append(sample_idx) 1228 | print('Number of samples (before preprocessing): ', len(samples)) 1229 | print('Number of samples (after filtering): ', len(data)) 1230 | print('code seq len: min: ', min_target_code_seq_len, 'max: ', max_target_code_seq_len) 1231 | return data, indices 1232 | 1233 | def get_batch(self, data, batch_size, start_idx): 1234 | data_size = len(data) 1235 | batch_vectors = [] 1236 | batch_labels = [] 1237 | batch_word_input = [] 1238 | batch_code_input = [] 1239 | batch_output_code_mask = [] 1240 | batch_output_df_mask = [] 1241 | batch_output_var_mask = [] 1242 | batch_output_str_mask = [] 1243 | batch_code_output = [] 1244 | batch_df_output = [] 1245 | batch_var_output = [] 1246 | batch_str_output = [] 1247 | batch_gt = [] 1248 | batch_input_code_nl_indices = [] 1249 | batch_output_code_nl_indices = [] 1250 | batch_output_code_ctx_indices = [] 1251 | input_dict = {} 1252 | max_word_len = 0 1253 | max_input_code_len = 0 1254 | max_output_code_len = 0 1255 | max_output_code_mask_len = 0 1256 | if not self.copy_mechanism: 1257 | max_output_code_mask_len = self.code_vocab_size + self.max_code_context_len 1258 | for idx in range(start_idx, min(start_idx + batch_size, data_size)): 1259 | cur_sample = data[idx] 1260 | 1261 | batch_word_input.append(cur_sample['input_word_seq']) 1262 | max_word_len = max(max_word_len, len(cur_sample['input_word_seq'])) 1263 | batch_code_input.append(cur_sample['input_code_seq']) 1264 | max_input_code_len = max(max_input_code_len, len(cur_sample['input_code_seq'])) 1265 | batch_output_code_mask.append(cur_sample['output_code_mask']) 1266 | batch_output_df_mask.append(cur_sample['output_df_mask']) 1267 | batch_output_var_mask.append(cur_sample['output_var_mask']) 1268 | batch_output_str_mask.append(cur_sample['output_str_mask']) 1269 | max_output_code_mask_len = max(max_output_code_mask_len, len(cur_sample['output_code_mask'])) 1270 | batch_gt.append(cur_sample['output_gt']) 1271 | batch_code_output.append(cur_sample['output_code_seq']) 1272 | max_output_code_len = max(max_output_code_len, len(cur_sample['output_code_seq'])) 1273 | batch_df_output.append(cur_sample['output_code_df_seq']) 1274 | batch_var_output.append(cur_sample['output_code_var_seq']) 1275 | batch_str_output.append(cur_sample['output_code_str_seq']) 1276 | batch_input_code_nl_indices.append(cur_sample['input_code_nl_indices']) 1277 | batch_output_code_nl_indices.append(cur_sample['output_code_nl_indices']) 1278 | batch_output_code_ctx_indices.append(cur_sample['output_code_ctx_indices']) 1279 | batch_labels.append(cur_sample['label']) 1280 | 1281 | batch_labels = np.array(batch_labels) 1282 | batch_labels = np_to_tensor(batch_labels, 'int', self.cuda_flag) 1283 | 1284 | for idx in range(len(batch_word_input)): 1285 | if len(batch_word_input[idx]) < max_word_len: 1286 | batch_word_input[idx] = batch_word_input[idx] + [PAD_ID] * (max_word_len - len(batch_word_input[idx])) 1287 | batch_word_input = np.array(batch_word_input) 1288 | batch_word_input = np_to_tensor(batch_word_input, 'int', self.cuda_flag) 1289 | input_dict['nl'] = batch_word_input 1290 | 1291 | for idx in range(len(batch_code_input)): 1292 | if len(batch_code_input[idx]) < max_input_code_len: 1293 | batch_code_input[idx] = batch_code_input[idx] + [PAD_ID] * (max_input_code_len - len(batch_code_input[idx])) 1294 | batch_code_input = np.array(batch_code_input) 1295 | batch_code_input = np_to_tensor(batch_code_input, 'int', self.cuda_flag) 1296 | input_dict['code_context'] = batch_code_input 1297 | 1298 | for idx in range(len(batch_code_output)): 1299 | if len(batch_code_output[idx]) < max_output_code_len: 1300 | batch_code_output[idx] = batch_code_output[idx] + [PAD_ID] * (max_output_code_len - len(batch_code_output[idx])) 1301 | batch_gt[idx] = batch_gt[idx] + [PAD_ID] * (max_output_code_len - len(batch_gt[idx])) 1302 | batch_df_output[idx] = batch_df_output[idx] + [-1] * (max_output_code_len - len(batch_df_output[idx])) 1303 | batch_var_output[idx] = batch_var_output[idx] + [-1] * (max_output_code_len - len(batch_var_output[idx])) 1304 | batch_str_output[idx] = batch_str_output[idx] + [-1] * (max_output_code_len - len(batch_str_output[idx])) 1305 | for idx in range(len(batch_output_code_mask)): 1306 | if len(batch_output_code_mask[idx]) < max_output_code_mask_len: 1307 | batch_output_code_mask[idx] = batch_output_code_mask[idx] + [0] * (max_output_code_mask_len - len(batch_output_code_mask[idx])) 1308 | batch_output_df_mask[idx] = batch_output_df_mask[idx] + [0] * (max_output_code_mask_len - len(batch_output_df_mask[idx])) 1309 | batch_output_var_mask[idx] = batch_output_var_mask[idx] + [0] * (max_output_code_mask_len - len(batch_output_var_mask[idx])) 1310 | batch_output_str_mask[idx] = batch_output_str_mask[idx] + [0] * (max_output_code_mask_len - len(batch_output_str_mask[idx])) 1311 | for idx in range(len(batch_input_code_nl_indices)): 1312 | if len(batch_input_code_nl_indices[idx]) < max_input_code_len: 1313 | batch_input_code_nl_indices[idx] = batch_input_code_nl_indices[idx] + [[max_word_len - 1, max_word_len - 1] for _ in range(max_input_code_len - len(batch_input_code_nl_indices[idx]))] 1314 | 1315 | batch_gt = np.array(batch_gt) 1316 | batch_gt = np_to_tensor(batch_gt, 'int', self.cuda_flag) 1317 | input_dict['gt'] = batch_gt 1318 | 1319 | batch_code_output = np.array(batch_code_output) 1320 | batch_code_output = np_to_tensor(batch_code_output, 'int', self.cuda_flag) 1321 | input_dict['code_output'] = batch_code_output 1322 | 1323 | batch_df_output = np.array(batch_df_output) 1324 | batch_df_output = np_to_tensor(batch_df_output, 'int', self.cuda_flag) 1325 | input_dict['df_output'] = batch_df_output 1326 | 1327 | batch_var_output = np.array(batch_var_output) 1328 | batch_var_output = np_to_tensor(batch_var_output, 'int', self.cuda_flag) 1329 | input_dict['var_output'] = batch_var_output 1330 | 1331 | batch_str_output = np.array(batch_str_output) 1332 | batch_str_output = np_to_tensor(batch_str_output, 'int', self.cuda_flag) 1333 | input_dict['str_output'] = batch_str_output 1334 | 1335 | batch_output_code_mask = np.array(batch_output_code_mask) 1336 | batch_output_code_mask = np_to_tensor(batch_output_code_mask, 'float', self.cuda_flag) 1337 | input_dict['code_output_mask'] = batch_output_code_mask 1338 | 1339 | batch_output_df_mask = np.array(batch_output_df_mask) 1340 | batch_output_df_mask = np_to_tensor(batch_output_df_mask, 'float', self.cuda_flag) 1341 | input_dict['output_df_mask'] = batch_output_df_mask 1342 | 1343 | batch_output_var_mask = np.array(batch_output_var_mask) 1344 | batch_output_var_mask = np_to_tensor(batch_output_var_mask, 'float', self.cuda_flag) 1345 | input_dict['output_var_mask'] = batch_output_var_mask 1346 | 1347 | batch_output_str_mask = np.array(batch_output_str_mask) 1348 | batch_output_str_mask = np_to_tensor(batch_output_str_mask, 'float', self.cuda_flag) 1349 | input_dict['output_str_mask'] = batch_output_str_mask 1350 | 1351 | batch_input_code_nl_indices = np.array(batch_input_code_nl_indices) 1352 | batch_input_code_nl_indices = np_to_tensor(batch_input_code_nl_indices, 'int', self.cuda_flag) 1353 | input_dict['input_code_nl_indices'] = batch_input_code_nl_indices 1354 | batch_output_code_nl_indices = np.array(batch_output_code_nl_indices) 1355 | batch_output_code_nl_indices = np_to_tensor(batch_output_code_nl_indices, 'int', self.cuda_flag) 1356 | input_dict['output_code_nl_indices'] = batch_output_code_nl_indices 1357 | batch_output_code_ctx_indices = np.array(batch_output_code_ctx_indices) 1358 | batch_output_code_ctx_indices = np_to_tensor(batch_output_code_ctx_indices, 'int', self.cuda_flag) 1359 | input_dict['output_code_ctx_indices'] = batch_output_code_ctx_indices 1360 | input_dict['init_data'] = data[start_idx: start_idx + batch_size] 1361 | return input_dict, batch_labels 1362 | 1363 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch import cuda 5 | import torch.optim as optim 6 | from torch.nn.utils import clip_grad_norm 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | from .data_utils import data_utils 11 | from .modules import mlp 12 | 13 | class PlotCodeGenerator(nn.Module): 14 | def __init__(self, args, word_vocab, code_vocab): 15 | super(PlotCodeGenerator, self).__init__() 16 | self.cuda_flag = args.cuda 17 | self.word_vocab_size = args.word_vocab_size 18 | self.code_vocab_size = args.code_vocab_size 19 | self.num_plot_types = args.num_plot_types 20 | self.word_vocab = word_vocab 21 | self.code_vocab = code_vocab 22 | self.batch_size = args.batch_size 23 | self.embedding_size = args.embedding_size 24 | self.LSTM_hidden_size = args.LSTM_hidden_size 25 | self.MLP_hidden_size = args.MLP_hidden_size 26 | self.num_LSTM_layers = args.num_LSTM_layers 27 | self.num_MLP_layers = args.num_MLP_layers 28 | self.gradient_clip = args.gradient_clip 29 | self.lr = args.lr 30 | self.dropout_rate = args.dropout_rate 31 | self.nl = args.nl 32 | self.use_comments = args.use_comments 33 | self.code_context = args.code_context 34 | self.hierarchy = args.hierarchy 35 | self.copy_mechanism = args.copy_mechanism 36 | self.nl_code_linking = args.nl_code_linking 37 | self.max_word_len = args.max_word_len 38 | self.max_code_context_len = args.max_code_context_len 39 | self.max_decode_len = args.max_decode_len 40 | self.dropout = nn.Dropout(p=self.dropout_rate) 41 | 42 | self.word_embedding = nn.Embedding(self.word_vocab_size, self.embedding_size) 43 | if self.copy_mechanism: 44 | self.code_embedding = nn.Embedding(self.code_vocab_size, self.embedding_size) 45 | else: 46 | self.code_embedding = nn.Embedding(self.code_vocab_size + self.max_code_context_len, self.embedding_size) 47 | self.code_predictor = nn.Linear(self.embedding_size, self.code_vocab_size + self.max_code_context_len) 48 | self.copy_predictor = nn.Linear(self.embedding_size, self.code_vocab_size + self.max_code_context_len) 49 | self.input_nl_encoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.LSTM_hidden_size, num_layers=self.num_LSTM_layers, dropout=self.dropout_rate, 50 | batch_first=True, bidirectional=True) 51 | self.input_code_encoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.LSTM_hidden_size, num_layers=self.num_LSTM_layers, dropout=self.dropout_rate, 52 | batch_first=True, bidirectional=True) 53 | if self.hierarchy: 54 | self.decoder = nn.LSTM(input_size=self.embedding_size * 2, hidden_size=self.LSTM_hidden_size, num_layers=self.num_LSTM_layers, dropout=self.dropout_rate, 55 | batch_first=True, bidirectional=True) 56 | else: 57 | self.decoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.LSTM_hidden_size, num_layers=self.num_LSTM_layers, dropout=self.dropout_rate, 58 | batch_first=True, bidirectional=True) 59 | self.word_attention = nn.Linear(self.LSTM_hidden_size * 2, self.LSTM_hidden_size * 2) 60 | if not self.nl_code_linking: 61 | self.code_ctx_linear = nn.Linear(self.LSTM_hidden_size * 2 + self.embedding_size, self.embedding_size) 62 | else: 63 | self.code_ctx_word_linear = nn.Linear(self.LSTM_hidden_size * 4 + self.embedding_size, self.embedding_size) 64 | self.code_word_linear = nn.Linear(self.LSTM_hidden_size * 2 + self.embedding_size, self.embedding_size) 65 | self.encoder_code_attention_linear = nn.Linear(self.LSTM_hidden_size * 2, self.embedding_size) 66 | self.decoder_code_attention_linear = nn.Linear(self.LSTM_hidden_size * 2, self.embedding_size) 67 | self.decoder_copy_attention_linear = nn.Linear(self.LSTM_hidden_size * 2, self.embedding_size) 68 | self.encoder_copy_attention_linear = nn.Linear(self.LSTM_hidden_size * 2, self.embedding_size) 69 | self.target_embedding_linear = nn.Linear(self.LSTM_hidden_size * 2, self.embedding_size) 70 | 71 | # training 72 | self.loss = nn.CrossEntropyLoss() 73 | 74 | if args.optimizer == 'adam': 75 | self.optimizer = optim.Adam(self.parameters(), lr=self.lr) 76 | elif args.optimizer == 'sgd': 77 | self.optimizer = optim.SGD(self.parameters(), lr=self.lr) 78 | elif args.optimizer == 'rmsprop': 79 | self.optimizer = optim.RMSprop(self.parameters(), lr=self.lr) 80 | else: 81 | raise ValueError('optimizer undefined: ', args.optimizer) 82 | 83 | def init_weights(self, param_init): 84 | for param in self.parameters(): 85 | nn.init.uniform_(param, -param_init, param_init) 86 | 87 | def lr_decay(self, lr_decay_rate): 88 | self.lr *= lr_decay_rate 89 | for param_group in self.optimizer.param_groups: 90 | param_group['lr'] = self.lr 91 | 92 | def train_step(self): 93 | if self.gradient_clip > 0: 94 | clip_grad_norm(self.parameters(), self.gradient_clip) 95 | self.optimizer.step() 96 | 97 | 98 | def forward(self, batch_input, batch_labels, eval_flag=False): 99 | batch_size = batch_labels.size()[0] 100 | batch_init_data = batch_input['init_data'] 101 | batch_nl_input = batch_input['nl'] 102 | batch_nl_embedding = self.word_embedding(batch_nl_input) 103 | encoder_word_mask = (batch_nl_input == data_utils.PAD_ID).float() 104 | encoder_word_mask = torch.max(encoder_word_mask, (batch_nl_input == data_utils.UNK_ID).float()) 105 | encoder_word_mask = torch.max(encoder_word_mask, (batch_nl_input == data_utils.EOS_ID).float()) 106 | if self.cuda_flag: 107 | encoder_word_mask = encoder_word_mask.cuda() 108 | nl_encoder_output, nl_hidden_state = self.input_nl_encoder(batch_nl_embedding) 109 | decoder_hidden_state = nl_hidden_state 110 | 111 | batch_code_context_input = batch_input['code_context'] 112 | batch_code_context_embedding = self.code_embedding(batch_code_context_input) 113 | batch_code_nl_embedding = [] 114 | batch_input_code_nl_indices = batch_input['input_code_nl_indices'] 115 | max_code_len = batch_code_context_input.size()[1] 116 | max_word_len = batch_nl_input.size()[1] 117 | 118 | if self.nl_code_linking: 119 | for batch_idx in range(batch_size): 120 | input_code_nl_indices = batch_input_code_nl_indices[batch_idx, :, :] 121 | cur_code_nl_embedding_0 = nl_encoder_output[batch_idx, input_code_nl_indices[:, 0]] 122 | cur_code_nl_embedding_1 = nl_encoder_output[batch_idx, input_code_nl_indices[:, 1]] 123 | cur_code_nl_embedding = cur_code_nl_embedding_0 + cur_code_nl_embedding_1 124 | batch_code_nl_embedding.append(cur_code_nl_embedding) 125 | batch_code_nl_embedding = torch.stack(batch_code_nl_embedding, dim=0) 126 | code_encoder_input = torch.cat([batch_code_context_embedding, batch_code_nl_embedding], dim=-1) 127 | code_encoder_input = self.code_word_linear(code_encoder_input) 128 | else: 129 | code_encoder_input = batch_code_context_embedding 130 | 131 | encoder_code_mask = (batch_code_context_input == data_utils.PAD_ID).float() 132 | encoder_code_mask = torch.max(encoder_code_mask, (batch_code_context_input == data_utils.UNK_ID).float()) 133 | encoder_code_mask = torch.max(encoder_code_mask, (batch_code_context_input == data_utils.EOS_ID).float()) 134 | if self.cuda_flag: 135 | encoder_code_mask = encoder_code_mask.cuda() 136 | code_encoder_output, code_hidden_state = self.input_code_encoder(code_encoder_input) 137 | decoder_hidden_state = code_hidden_state 138 | 139 | gt_output = batch_input['gt'] 140 | target_code_output = batch_input['code_output'] 141 | target_df_output = batch_input['df_output'] 142 | target_var_output = batch_input['var_output'] 143 | target_str_output = batch_input['str_output'] 144 | code_output_mask = batch_input['code_output_mask'] 145 | output_df_mask = batch_input['output_df_mask'] 146 | output_var_mask = batch_input['output_var_mask'] 147 | output_str_mask = batch_input['output_str_mask'] 148 | 149 | gt_decode_length = target_code_output.size()[1] 150 | if not eval_flag: 151 | decode_length = gt_decode_length 152 | else: 153 | decode_length = self.max_decode_len 154 | 155 | decoder_input_sketch = torch.ones(batch_size, 1, dtype=torch.int64) * data_utils.GO_ID 156 | if self.cuda_flag: 157 | decoder_input_sketch = decoder_input_sketch.cuda() 158 | decoder_input_sketch_embedding = self.code_embedding(decoder_input_sketch) 159 | decoder_input = torch.ones(batch_size, 1, dtype=torch.int64) * data_utils.GO_ID 160 | if self.cuda_flag: 161 | decoder_input = decoder_input.cuda() 162 | decoder_input_embedding = self.code_embedding(decoder_input) 163 | 164 | finished = torch.zeros(batch_size, 1, dtype=torch.int64) 165 | 166 | max_code_mask_len = code_output_mask.size()[1] 167 | 168 | pad_mask = torch.zeros(max_code_mask_len) 169 | pad_mask[data_utils.PAD_ID] = 1e9 170 | pad_mask = torch.stack([pad_mask] * batch_size, dim=0) 171 | if self.cuda_flag: 172 | finished = finished.cuda() 173 | pad_mask = pad_mask.cuda() 174 | 175 | batch_code_output_indices = data_utils.np_to_tensor(np.array(list(range(self.code_vocab_size))), 'int', self.cuda_flag) 176 | batch_code_output_embedding = self.code_embedding(batch_code_output_indices) 177 | batch_code_output_embedding = torch.stack([batch_code_output_embedding] * batch_size, dim=0) 178 | 179 | batch_output_code_ctx_embedding = [] 180 | batch_output_code_ctx_indices = batch_input['output_code_ctx_indices'] 181 | for batch_idx in range(batch_size): 182 | output_code_ctx_indices = batch_output_code_ctx_indices[batch_idx] 183 | cur_output_code_ctx_embedding = code_encoder_output[batch_idx, output_code_ctx_indices] 184 | batch_output_code_ctx_embedding.append(cur_output_code_ctx_embedding) 185 | batch_output_code_ctx_embedding = torch.stack(batch_output_code_ctx_embedding, dim=0) 186 | 187 | if self.nl_code_linking: 188 | batch_output_code_nl_embedding = [] 189 | batch_output_code_nl_indices = batch_input['output_code_nl_indices'] 190 | for batch_idx in range(batch_size): 191 | output_code_nl_indices = batch_output_code_nl_indices[batch_idx, :, :] 192 | cur_output_code_nl_embedding_0 = nl_encoder_output[batch_idx, output_code_nl_indices[:, 0]] 193 | cur_output_code_nl_embedding_1 = nl_encoder_output[batch_idx, output_code_nl_indices[:, 1]] 194 | cur_output_code_nl_embedding = cur_output_code_nl_embedding_0 + cur_output_code_nl_embedding_1 195 | batch_output_code_nl_embedding.append(cur_output_code_nl_embedding) 196 | batch_output_code_nl_embedding = torch.stack(batch_output_code_nl_embedding, dim=0) 197 | batch_code_output_embedding = torch.cat([batch_code_output_embedding, batch_output_code_ctx_embedding, batch_output_code_nl_embedding], dim=-1) 198 | batch_code_output_embedding = self.code_ctx_word_linear(batch_code_output_embedding) 199 | else: 200 | batch_code_output_embedding = torch.cat([batch_code_output_embedding, batch_output_code_ctx_embedding], dim=-1) 201 | batch_code_output_embedding = self.code_ctx_linear(batch_code_output_embedding) 202 | 203 | if self.code_context: 204 | batch_code_output_context_embedding = [] 205 | 206 | for batch_idx in range(batch_size): 207 | output_code_indices = batch_init_data[batch_idx]['output_code_indices'] 208 | cur_code_output_context_embedding = [] 209 | for code_idx in output_code_indices: 210 | cur_code_output_context_embedding.append(code_encoder_output[batch_idx, code_idx, :]) 211 | if len(cur_code_output_context_embedding) < max_code_mask_len - self.code_vocab_size: 212 | cur_code_output_context_embedding += [data_utils.np_to_tensor(np.zeros(self.LSTM_hidden_size * 2), 'float', self.cuda_flag)] * (max_code_mask_len - self.code_vocab_size - len(cur_code_output_context_embedding)) 213 | cur_code_output_context_embedding = torch.stack(cur_code_output_context_embedding, dim=0) 214 | batch_code_output_context_embedding.append(cur_code_output_context_embedding) 215 | batch_code_output_context_embedding = torch.stack(batch_code_output_context_embedding, dim=0) 216 | batch_code_output_context_embedding = self.target_embedding_linear(batch_code_output_context_embedding) 217 | batch_code_output_embedding = torch.cat([batch_code_output_embedding, batch_code_output_context_embedding], dim=1) 218 | 219 | code_pred_logits = [] 220 | code_predictions = [] 221 | df_pred_logits = [] 222 | df_predictions = [] 223 | var_pred_logits = [] 224 | var_predictions = [] 225 | str_pred_logits = [] 226 | str_predictions = [] 227 | predictions = [] 228 | 229 | for step in range(decode_length): 230 | if self.hierarchy: 231 | decoder_output, decoder_hidden_state = self.decoder( 232 | torch.cat([decoder_input_sketch_embedding, decoder_input_embedding], dim=-1), decoder_hidden_state) 233 | else: 234 | decoder_output, decoder_hidden_state = self.decoder(decoder_input_embedding, decoder_hidden_state) 235 | decoder_output = decoder_output.squeeze(1) 236 | 237 | decoder_nl_attention = self.word_attention(decoder_output) 238 | attention_logits = torch.bmm(nl_encoder_output, decoder_nl_attention.unsqueeze(2)) 239 | attention_logits = attention_logits.squeeze(-1) 240 | attention_logits = attention_logits - encoder_word_mask * 1e9 241 | attention_weights = nn.Softmax(dim=-1)(attention_logits) 242 | attention_weights = self.dropout(attention_weights) 243 | nl_attention_vector = torch.bmm(torch.transpose(nl_encoder_output, 1, 2), attention_weights.unsqueeze(2)) 244 | nl_attention_vector = nl_attention_vector.squeeze(-1) 245 | 246 | input_code_encoding = self.encoder_code_attention_linear(nl_attention_vector) 247 | if self.hierarchy: 248 | input_copy_encoding = self.encoder_copy_attention_linear(nl_attention_vector) 249 | 250 | decoder_code_output = self.decoder_code_attention_linear(decoder_output) 251 | if self.hierarchy: 252 | decoder_copy_output = self.decoder_copy_attention_linear(decoder_output) 253 | 254 | decoder_code_output = decoder_code_output + input_code_encoding 255 | if self.hierarchy: 256 | decoder_copy_output = decoder_copy_output + input_copy_encoding 257 | 258 | if self.copy_mechanism: 259 | cur_code_pred_logits = torch.bmm(batch_code_output_embedding, decoder_code_output.unsqueeze(2)) 260 | cur_code_pred_logits = cur_code_pred_logits.squeeze(-1) 261 | else: 262 | cur_code_pred_logits = self.code_predictor(decoder_code_output) 263 | cur_code_pred_logits = cur_code_pred_logits + finished.float() * pad_mask 264 | cur_code_pred_logits = cur_code_pred_logits - (1.0 - code_output_mask) * 1e9 265 | cur_code_predictions = cur_code_pred_logits.max(1)[1] 266 | 267 | if eval_flag: 268 | sketch_predictions = cur_code_predictions 269 | else: 270 | sketch_predictions = target_code_output[:, step] 271 | 272 | if self.hierarchy: 273 | if self.copy_mechanism: 274 | cur_copy_pred_logits = torch.bmm(batch_code_output_embedding, decoder_copy_output.unsqueeze(2)) 275 | cur_copy_pred_logits = cur_copy_pred_logits.squeeze(-1) 276 | else: 277 | cur_copy_pred_logits = self.copy_predictor(decoder_copy_output) 278 | cur_df_pred_logits = cur_copy_pred_logits - (1.0 - output_df_mask) * 1e9 279 | cur_df_predictions = cur_df_pred_logits.max(1)[1] * ((sketch_predictions == data_utils.DF_ID).long()) 280 | 281 | cur_var_pred_logits = cur_copy_pred_logits - (1.0 - output_var_mask) * 1e9 282 | cur_var_predictions = cur_var_pred_logits.max(1)[1] * ((sketch_predictions == data_utils.VAR_ID).long()) 283 | 284 | cur_str_pred_logits = cur_copy_pred_logits - (1.0 - output_str_mask) * 1e9 285 | cur_str_predictions = cur_str_pred_logits.max(1)[1] * ((sketch_predictions == data_utils.STR_ID).long()) 286 | 287 | if eval_flag: 288 | decoder_input_sketch = cur_code_predictions 289 | decoder_input = cur_code_predictions 290 | if self.hierarchy: 291 | decoder_input = torch.max(decoder_input, cur_df_predictions) 292 | decoder_input = torch.max(decoder_input, cur_var_predictions) 293 | decoder_input = torch.max(decoder_input, cur_str_predictions) 294 | else: 295 | decoder_input_sketch = target_code_output[:, step] 296 | decoder_input = gt_output[:, step] 297 | if self.copy_mechanism: 298 | decoder_input_sketch_embedding = [] 299 | for batch_idx in range(batch_size): 300 | decoder_input_sketch_embedding.append(batch_code_output_embedding[batch_idx, decoder_input_sketch[batch_idx], :]) 301 | decoder_input_sketch_embedding = torch.stack(decoder_input_sketch_embedding, dim=0) 302 | 303 | decoder_input_embedding = [] 304 | for batch_idx in range(batch_size): 305 | decoder_input_embedding.append(batch_code_output_embedding[batch_idx, decoder_input[batch_idx], :]) 306 | decoder_input_embedding = torch.stack(decoder_input_embedding, dim=0) 307 | else: 308 | decoder_input_sketch_embedding = self.code_embedding(decoder_input_sketch) 309 | decoder_input_embedding = self.code_embedding(decoder_input) 310 | decoder_input_sketch_embedding = decoder_input_sketch_embedding.unsqueeze(1) 311 | decoder_input_embedding = decoder_input_embedding.unsqueeze(1) 312 | if step < gt_decode_length: 313 | code_pred_logits.append(cur_code_pred_logits) 314 | code_predictions.append(cur_code_predictions) 315 | cur_predictions = cur_code_predictions 316 | if self.hierarchy: 317 | if step < gt_decode_length: 318 | df_pred_logits.append(cur_df_pred_logits) 319 | var_pred_logits.append(cur_var_pred_logits) 320 | str_pred_logits.append(cur_str_pred_logits) 321 | df_predictions.append(cur_df_predictions) 322 | var_predictions.append(cur_var_predictions) 323 | str_predictions.append(cur_str_predictions) 324 | cur_predictions = torch.max(cur_predictions, cur_df_predictions) 325 | cur_predictions = torch.max(cur_predictions, cur_var_predictions) 326 | cur_predictions = torch.max(cur_predictions, cur_str_predictions) 327 | predictions.append(cur_predictions) 328 | 329 | cur_finished = (decoder_input == data_utils.EOS_ID).long().unsqueeze(1) 330 | finished = torch.max(finished, cur_finished) 331 | if torch.sum(finished) == batch_size and step >= gt_decode_length - 1: 332 | break 333 | 334 | total_loss = 0.0 335 | code_pred_logits = torch.stack(code_pred_logits, dim=0) 336 | code_pred_logits = code_pred_logits.permute(1, 2, 0) 337 | code_predictions = torch.stack(code_predictions, dim=0) 338 | code_predictions = code_predictions.permute(1, 0) 339 | 340 | total_loss += F.cross_entropy(code_pred_logits, target_code_output, ignore_index=data_utils.PAD_ID) 341 | 342 | if self.hierarchy: 343 | df_pred_logits = torch.stack(df_pred_logits, dim=0) 344 | df_pred_logits = df_pred_logits.permute(1, 2, 0) 345 | df_predictions = torch.stack(df_predictions, dim=0) 346 | df_predictions = df_predictions.permute(1, 0) 347 | df_loss = F.cross_entropy(df_pred_logits, target_df_output, ignore_index=-1) 348 | 349 | var_pred_logits = torch.stack(var_pred_logits, dim=0) 350 | var_pred_logits = var_pred_logits.permute(1, 2, 0) 351 | var_predictions = torch.stack(var_predictions, dim=0) 352 | var_predictions = var_predictions.permute(1, 0) 353 | var_loss = F.cross_entropy(var_pred_logits, target_var_output, ignore_index=-1) 354 | 355 | str_pred_logits = torch.stack(str_pred_logits, dim=0) 356 | str_pred_logits = str_pred_logits.permute(1, 2, 0) 357 | str_predictions = torch.stack(str_predictions, dim=0) 358 | str_predictions = str_predictions.permute(1, 0) 359 | str_loss = F.cross_entropy(str_pred_logits, target_str_output, ignore_index=-1) 360 | total_loss += (df_loss + var_loss + str_loss) / 3.0 361 | 362 | predictions = torch.stack(predictions, dim=0) 363 | predictions = predictions.permute(1, 0) 364 | return total_loss, code_pred_logits, predictions 365 | 366 | -------------------------------------------------------------------------------- /models/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .supervisor import * -------------------------------------------------------------------------------- /models/model_utils/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | import os 5 | import re 6 | import json 7 | import pandas as pd 8 | 9 | class Logger(object): 10 | """ 11 | The class for recording the training process. 12 | """ 13 | def __init__(self, args): 14 | self.log_interval = args.log_interval 15 | self.log_dir = args.log_dir 16 | self.log_name = os.path.join(args.log_dir, args.log_name) 17 | self.best_eval_acc = 0 18 | self.records = [] 19 | 20 | def write_summary(self, summary): 21 | print("global-step: %(global_step)d, train-acc: %(train_acc).3f, train-loss: %(train_loss).3f, eval-label-acc: %(eval_label_acc).3f, eval-data-acc: %(eval_data_acc).3f, eval-acc: %(eval_acc).3f, eval-loss: %(eval_loss).3f" % summary) 22 | self.records.append(summary) 23 | df = pd.DataFrame(self.records) 24 | if not os.path.exists(self.log_dir): 25 | os.makedirs(self.log_dir) 26 | df.to_csv(self.log_name, index=False) 27 | self.best_eval_acc = max(self.best_eval_acc, summary['eval_acc']) -------------------------------------------------------------------------------- /models/model_utils/supervisor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | import os 5 | import torch 6 | import re 7 | import json 8 | import time 9 | 10 | from torch.nn.utils import clip_grad_norm 11 | 12 | from ..data_utils import data_utils 13 | 14 | CKPT_PATTERN = re.compile('^ckpt-(\d+)$') 15 | 16 | 17 | class Supervisor(object): 18 | """ 19 | The base class to manage the high-level model execution processes. The concrete classes for different applications are derived from it. 20 | """ 21 | def __init__(self, model, args): 22 | self.data_processor = data_utils.DataProcessor(args) 23 | self.model = model 24 | self.keep_last_n = args.keep_last_n 25 | self.global_step = 0 26 | self.batch_size = args.batch_size 27 | self.model_dir = args.model_dir 28 | 29 | 30 | def load_pretrained(self, load_model): 31 | print("Read model parameters from %s." % load_model) 32 | checkpoint = torch.load(load_model) 33 | self.model.load_state_dict(checkpoint) 34 | 35 | 36 | def save_model(self): 37 | if not os.path.exists(self.model_dir): 38 | os.makedirs(self.model_dir) 39 | global_step_padded = format(self.global_step, '08d') 40 | ckpt_name = 'ckpt-' + global_step_padded 41 | path = os.path.join(self.model_dir, ckpt_name) 42 | ckpt = self.model.state_dict() 43 | torch.save(ckpt, path) 44 | 45 | if self.keep_last_n is not None: 46 | ckpts = [] 47 | for file_name in os.listdir(self.model_dir): 48 | matched_name = CKPT_PATTERN.match(file_name) 49 | if matched_name is None or matched_name == ckpt_name: 50 | continue 51 | step = int(matched_name.group(1)) 52 | ckpts.append((step, file_name)) 53 | if len(ckpts) > self.keep_last_n: 54 | ckpts.sort() 55 | os.unlink(os.path.join(self.model_dir, ckpts[0][1])) 56 | 57 | 58 | def train(self, batch_input, batch_labels): 59 | self.model.optimizer.zero_grad() 60 | cur_loss, pred_logits, predictions = self.model(batch_input, batch_labels) 61 | gt_output = batch_input['gt'] 62 | pred_acc = torch.sum(predictions == gt_output) 63 | pred_acc = pred_acc.item() * 1.0 / (gt_output.size()[0] * gt_output.size()[1]) 64 | 65 | self.global_step += 1 66 | cur_loss.backward() 67 | self.model.train_step() 68 | return cur_loss.item(), pred_acc 69 | 70 | def eval(self, data, data_order_invariant=False, max_eval_size=None): 71 | self.model.eval() 72 | data_size = len(data) 73 | if max_eval_size is not None: 74 | data_size = min(data_size, max_eval_size) 75 | eval_data = data[:data_size] 76 | test_loss = 0.0 77 | test_label_acc = 0 78 | test_data_acc = 0 79 | test_acc = 0 80 | 81 | predictions = [] 82 | for batch_idx in range(0, data_size, self.batch_size): 83 | batch_input, batch_labels = self.data_processor.get_batch(eval_data, self.batch_size, batch_idx) 84 | cur_loss, cur_pred_logits, cur_predictions = self.model(batch_input, batch_labels, eval_flag=True) 85 | test_loss += cur_loss.item() * batch_labels.size()[0] 86 | cur_predictions = cur_predictions.data.cpu().numpy().tolist() 87 | for i, sample in enumerate(batch_input['init_data']): 88 | gt_prog = self.data_processor.ids_to_prog(sample, sample['output_gt']) 89 | pred_prog = self.data_processor.ids_to_prog(sample, cur_predictions[i]) 90 | gt_label = sample['label'] 91 | pred_label = self.data_processor.label_extraction(pred_prog) 92 | if gt_label == pred_label: 93 | cur_test_label_acc = 1 94 | else: 95 | cur_test_label_acc = 0 96 | target_dfs, target_strs, target_vars = sample['target_dfs'], sample['target_strs'], sample['target_vars'] 97 | pred_dfs, pred_strs, pred_vars, _ = self.data_processor.data_extraction(pred_prog, 98 | sample['reserved_dfs'], sample['reserved_strs'], sample['reserved_vars']) 99 | 100 | if data_order_invariant: 101 | if (set(target_dfs + target_strs + target_vars) == set(pred_dfs + pred_strs + pred_vars) and 102 | len(target_dfs + target_strs + target_vars) == len(pred_dfs + pred_strs + pred_vars)): 103 | cur_test_data_acc = 1 104 | else: 105 | cur_test_data_acc = 0 106 | else: 107 | if target_dfs + target_strs + target_vars == pred_dfs + pred_strs + pred_vars: 108 | cur_test_data_acc = 1 109 | else: 110 | cur_test_data_acc = 0 111 | cur_test_acc = min(cur_test_label_acc, cur_test_data_acc) 112 | test_label_acc += cur_test_label_acc 113 | test_data_acc += cur_test_data_acc 114 | test_acc += cur_test_acc 115 | print('batch_idx: ', batch_idx, 'test_label_acc: ', test_label_acc, 'test_data_acc', test_data_acc, 'test_acc', test_acc) 116 | predictions += cur_predictions 117 | 118 | test_loss /= data_size 119 | test_label_acc = test_label_acc * 1.0 / data_size 120 | test_data_acc = test_data_acc * 1.0 / data_size 121 | test_acc = test_acc * 1.0 / data_size 122 | self.model.train() 123 | return test_loss, test_label_acc, test_data_acc, test_acc, predictions 124 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jungyhuk/plotcoder/4c5fe923dc69227c58d93f55b8a89fd8bb960703/models/modules/__init__.py -------------------------------------------------------------------------------- /models/modules/mlp.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | 10 | class MLPModel(nn.Module): 11 | """ 12 | Multi-layer perception module. 13 | """ 14 | def __init__(self, num_layers, input_size, hidden_size, output_size, dropout_rate, cuda_flag, activation=None): 15 | super(MLPModel, self).__init__() 16 | self.num_layers = num_layers 17 | self.input_size = input_size 18 | self.hidden_size = hidden_size 19 | self.output_size = output_size 20 | self.cuda_flag = cuda_flag 21 | self.dropout_rate = dropout_rate 22 | self.model = nn.Sequential( 23 | nn.Linear(self.input_size, self.hidden_size), 24 | nn.Dropout(p=self.dropout_rate), 25 | nn.ReLU()) 26 | for _ in range(self.num_layers): 27 | self.model = nn.Sequential( 28 | self.model, 29 | nn.Linear(self.hidden_size, self.hidden_size), 30 | nn.Dropout(p=self.dropout_rate), 31 | nn.ReLU()) 32 | self.model = nn.Sequential( 33 | self.model, 34 | nn.Linear(self.hidden_size, self.output_size)) 35 | if activation is not None: 36 | self.model = nn.Sequential( 37 | self.model, 38 | activation 39 | ) 40 | 41 | def forward(self, inputs): 42 | return self.model(inputs) -------------------------------------------------------------------------------- /plot_sample_extraction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import sys 5 | import os 6 | import json 7 | import numpy as np 8 | import time 9 | import operator 10 | 11 | 12 | scatter_word_list = ['scatter', "'scatter'", '"scatter"', 'scatter_kws', "'o'", "'bo'", "'r+'", '"o"', '"bo"', '"r+"'] 13 | hist_word_list = ['hist', "'hist'", '"hist"', 'bar', "'bar'", '"bar"', 'countplot', 'barplot'] 14 | pie_word_list = ['pie', "'pie'", '"pie"'] 15 | scatter_plot_word_list = ['lmplot', 'regplot'] 16 | hist_plot_word_list = ['distplot', 'kdeplot', 'contour'] 17 | normal_plot_word_list = ['plot'] 18 | 19 | reserved_words = scatter_word_list + hist_word_list + pie_word_list + scatter_plot_word_list + hist_plot_word_list + normal_plot_word_list 20 | 21 | 22 | arg_parser = argparse.ArgumentParser(description='JuiCe plot data extraction') 23 | arg_parser.add_argument('--data_folder', type=str, default='../data', 24 | help="the folder where the datasets downloaded from the original JuiCe repo are stored. We will retrieve 'train.jsonl', 'dev.jsonl' and 'test.jsonl' here.") 25 | arg_parser.add_argument('--init_train_data_name', type=str, default='train.jsonl', 26 | help="the filename of the original training data.") 27 | arg_parser.add_argument('--init_dev_data_name', type=str, default='dev.jsonl', 28 | help="the filename of the original dev data.") 29 | arg_parser.add_argument('--init_test_data_name', type=str, default='test.jsonl', 30 | help="the filename of the original test data.") 31 | arg_parser.add_argument('--prep_train_data_name', type=str, default='train_plot.json', 32 | help="the filename of the preprocessed training data. When set to None, it means that the training data is not preprocessed (this file is the most time-consuming for preprocessing).") 33 | arg_parser.add_argument('--prep_dev_data_name', type=str, default='dev_plot.json', 34 | help="the filename of the preprocessed dev data. When set to None, it means that the dev data is not preprocessed.") 35 | arg_parser.add_argument('--prep_test_data_name', type=str, default='test_plot.json', 36 | help="the filename of the preprocessed test data. When set to None, it means that the test data is not preprocessed.") 37 | arg_parser.add_argument('--prep_dev_hard_data_name', type=str, default='dev_plot_hard.json', 38 | help="the filename of the preprocessed hard split of the dev data. When set to None, it means that the dev data is not preprocessed.") 39 | arg_parser.add_argument('--prep_test_hard_data_name', type=str, default='test_plot_hard.json', 40 | help="the filename of the preprocessed hard split of the test data. When set to None, it means that the test data is not preprocessed.") 41 | arg_parser.add_argument('--build_vocab', action='store_true', default=True, 42 | help="set the flag to be true, so as to build the natural language word and code vocabs from the training set.") 43 | arg_parser.add_argument('--nl_freq_file', type=str, default='nl_freq.json', 44 | help='the file that stores the frequency of each natural language word.') 45 | arg_parser.add_argument('--code_freq_file', type=str, default='code_freq.json', 46 | help='the file that stores the frequency of each code token.') 47 | arg_parser.add_argument('--nl_vocab', type=str, default='nl_vocab.json', 48 | help='the file that stores the natural language word vocabulary.') 49 | arg_parser.add_argument('--code_vocab', type=str, default='code_vocab.json', 50 | help='the file that stores the code token vocabulary.') 51 | arg_parser.add_argument('--min_nl_freq', type=int, default=15, 52 | help='Words with a smaller number of occurrences in the training data than this threshold are excluded from the nl word vocab.') 53 | arg_parser.add_argument('--min_code_freq', type=int, default=1000, 54 | help='Code tokens with a smaller number of occurrences in the training data than this threshold are excluded from the code token vocab.') 55 | 56 | args = arg_parser.parse_args() 57 | 58 | 59 | def preprocess(data_folder, init_data_name, prep_data_name, prep_hard_data_name=None, additional_samples=[], is_train=True): 60 | plot_samples = [] 61 | clean_samples = [] 62 | init_data_name = os.path.join(data_folder, init_data_name) 63 | with open(init_data_name) as fin: 64 | for i, line in enumerate(fin): 65 | sample = json.loads(line) 66 | 67 | # extract code sequence without comments and empty strings 68 | init_code_seq = sample['code_tokens'] 69 | code_seq = [] 70 | for tok in init_code_seq: 71 | if len(tok) == 0 or tok[0] == '#': 72 | continue 73 | code_seq.append(tok) 74 | 75 | # filter out samples where 'plt' is not used 76 | while 'plt' in code_seq: 77 | pos = code_seq.index('plt') 78 | if pos < len(code_seq) - 1 and code_seq[pos + 1] == '.': 79 | break 80 | code_seq = code_seq[pos + 1:] 81 | if not ('plt' in code_seq): 82 | continue 83 | 84 | plot_calls = [] 85 | api_seq = sample['api_sequence'] 86 | for api in api_seq: 87 | if api == 'subplot': 88 | continue 89 | if api[-4:] == 'plot' and not ('_' in api): 90 | plot_calls.append(api) 91 | 92 | exist_plot_calls = False 93 | for code_idx, tok in enumerate(code_seq): 94 | if not (tok in reserved_words + plot_calls): 95 | continue 96 | if code_idx == len(code_seq) - 1 or code_seq[code_idx + 1] != '(': 97 | continue 98 | exist_plot_calls = True 99 | break 100 | if not exist_plot_calls: 101 | continue 102 | 103 | url = sample['metadata']['path'] 104 | if 'solution' in url.lower() or 'assignment' in url.lower(): 105 | clean_samples.append(sample) 106 | if not is_train: 107 | plot_samples.append(sample) 108 | else: 109 | plot_samples.append(sample) 110 | 111 | print('number of samples in the original partition: ', len(plot_samples)) 112 | print('number of course-related samples in the partition: ', len(clean_samples)) 113 | json.dump(plot_samples, open(os.path.join(data_folder, prep_data_name), 'w')) 114 | if len(additional_samples) > 0: 115 | print('number of samples in the hard partition: ', len(additional_samples)) 116 | json.dump(additional_samples, open(os.path.join(data_folder, prep_hard_data_name), 'w')) 117 | return plot_samples, clean_samples 118 | 119 | 120 | def add_token_to_dict(seq, vocab_dict, is_code=False): 121 | for tok in seq: 122 | if len(tok) == 0: 123 | continue 124 | if is_code and tok[0] == '#': 125 | continue 126 | if tok in vocab_dict: 127 | vocab_dict[tok] += 1 128 | else: 129 | vocab_dict[tok] = 1 130 | return vocab_dict 131 | 132 | def build_vocab(samples): 133 | 134 | # Compute the frequency of each nl and code token 135 | code_dict = {} 136 | word_dict = {} 137 | 138 | for sample in samples: 139 | context = sample['context'] 140 | for cell in context: 141 | if not 'code_tokens' in cell: 142 | continue 143 | code_context = cell['code_tokens'] 144 | if type(code_context) != list: 145 | continue 146 | code_dict = add_token_to_dict(code_context, code_dict, is_code=True) 147 | code_dict = add_token_to_dict(sample['code_tokens'], code_dict, is_code=True) 148 | word_dict = add_token_to_dict(sample['nl'] + sample['comments'], word_dict, is_code=False) 149 | 150 | sorted_word_list = sorted(word_dict.items(), key=operator.itemgetter(1), reverse=True) 151 | sorted_code_list = sorted(code_dict.items(), key=operator.itemgetter(1), reverse=True) 152 | print('Total number of nl tokens (before filtering): ', len(sorted_word_list)) 153 | print('Total number of code tokens (before filtering): ', len(sorted_code_list)) 154 | json.dump(sorted_word_list, open(os.path.join(args.data_folder, args.nl_freq_file), 'w')) 155 | json.dump(sorted_code_list, open(os.path.join(args.data_folder, args.code_freq_file), 'w')) 156 | 157 | # filter out rare tokens 158 | code_vocab = {} 159 | word_vocab = {} 160 | 161 | for i, item in enumerate(sorted_word_list): 162 | if item[1] < args.min_nl_freq: 163 | break 164 | word_vocab[item[0]] = i 165 | 166 | for i, item in enumerate(sorted_code_list): 167 | if item[1] < args.min_code_freq: 168 | break 169 | code_vocab[item[0]] = i 170 | 171 | print('Total number of nl tokens (after filtering): ', len(word_vocab)) 172 | print('Total number of code tokens (after filtering): ', len(code_vocab)) 173 | json.dump(word_vocab, open(os.path.join(args.data_folder, args.nl_vocab), 'w')) 174 | json.dump(code_vocab, open(os.path.join(args.data_folder, args.code_vocab), 'w')) 175 | 176 | 177 | if not os.path.exists(args.data_folder): 178 | os.makedirs(args.data_folder) 179 | 180 | # data preprocessing 181 | if args.prep_train_data_name: 182 | print('preprocessing training data:') 183 | train_plot_samples, train_plot_clean_samples = preprocess(args.data_folder, args.init_train_data_name, args.prep_train_data_name, is_train=True) 184 | cnt_train_clean_samples = len(train_plot_clean_samples) 185 | 186 | if args.prep_dev_data_name: 187 | print('preprocessing dev data:') 188 | dev_plot_samples, dev_plot_clean_samples = preprocess(args.data_folder, args.init_dev_data_name, args.prep_dev_data_name, 189 | prep_hard_data_name=args.prep_dev_hard_data_name, additional_samples=train_plot_clean_samples[:cnt_train_clean_samples // 2], is_train=False) 190 | 191 | if args.prep_test_data_name: 192 | print('preprocessing test data:') 193 | test_plot_samples, test_plot_clean_samples = preprocess(args.data_folder, args.init_test_data_name, args.prep_test_data_name, 194 | prep_hard_data_name=args.prep_test_hard_data_name, additional_samples=train_plot_clean_samples[cnt_train_clean_samples // 2:], is_train=False) 195 | 196 | # build natural language word and code vocabularies 197 | if args.build_vocab: 198 | assert args.init_train_data_name is not None 199 | build_vocab(train_plot_samples) -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import sys 5 | import os 6 | import json 7 | import numpy as np 8 | import time 9 | 10 | import torch 11 | 12 | import arguments 13 | import models 14 | import models.data_utils.data_utils as data_utils 15 | import models.model_utils as model_utils 16 | from models.model import PlotCodeGenerator 17 | 18 | def create_model(args, word_vocab, code_vocab): 19 | model = PlotCodeGenerator(args, word_vocab, code_vocab) 20 | if model.cuda_flag: 21 | model = model.cuda() 22 | model_supervisor = model_utils.Supervisor(model, args) 23 | if args.load_model: 24 | model_supervisor.load_pretrained(args.load_model) 25 | else: 26 | print('Created model with fresh parameters.') 27 | model_supervisor.model.init_weights(args.param_init) 28 | return model_supervisor 29 | 30 | 31 | def train(args): 32 | print('Training:') 33 | 34 | data_processor = data_utils.DataProcessor(args) 35 | train_data = data_processor.load_data(args.train_dataset) 36 | train_data, train_indices = data_processor.preprocess(train_data) 37 | dev_data = data_processor.load_data(args.dev_dataset) 38 | dev_data, dev_indices = data_processor.preprocess(dev_data) 39 | 40 | train_data_size = len(train_data) 41 | args.word_vocab_size = data_processor.word_vocab_size 42 | args.code_vocab_size = data_processor.code_vocab_size 43 | model_supervisor = create_model(args, data_processor.word_vocab, data_processor.code_vocab) 44 | 45 | logger = model_utils.Logger(args) 46 | 47 | for epoch in range(args.num_epochs): 48 | random.shuffle(train_data) 49 | for batch_idx in range(0, train_data_size, args.batch_size): 50 | print(epoch, batch_idx) 51 | batch_input, batch_labels = data_processor.get_batch(train_data, args.batch_size, batch_idx) 52 | train_loss, train_acc = model_supervisor.train(batch_input, batch_labels) 53 | print('train loss: %.4f train acc: %.4f' % (train_loss, train_acc)) 54 | 55 | if model_supervisor.global_step % args.eval_every_n == 0: 56 | model_supervisor.save_model() 57 | eval_loss, eval_label_acc, eval_data_acc, eval_acc, pred_labels = model_supervisor.eval(dev_data, args.data_order_invariant, args.max_eval_size) 58 | val_summary = {'train_loss': train_loss, 'train_acc': train_acc, 'eval_loss': eval_loss, 59 | 'eval_label_acc': eval_label_acc, 'eval_data_acc': eval_data_acc, 'eval_acc': eval_acc} 60 | val_summary['global_step'] = model_supervisor.global_step 61 | logger.write_summary(val_summary) 62 | 63 | if args.lr_decay_steps is not None and model_supervisor.global_step % args.lr_decay_steps == 0: 64 | model_supervisor.model.lr_decay(args.lr_decay_rate) 65 | 66 | 67 | def evaluate(args): 68 | print('Evaluation') 69 | data_processor = data_utils.DataProcessor(args) 70 | init_test_data = data_processor.load_data(args.test_dataset) 71 | test_data, test_indices = data_processor.preprocess(init_test_data) 72 | 73 | args.word_vocab_size = data_processor.word_vocab_size 74 | args.code_vocab_size = data_processor.code_vocab_size 75 | model_supervisor = create_model(args, data_processor.word_vocab, data_processor.code_vocab) 76 | test_loss, test_label_acc, test_data_acc, test_acc, predictions = model_supervisor.eval(test_data, args.data_order_invariant) 77 | 78 | label_acc_per_category = [0] * args.num_plot_types 79 | data_acc_per_category = [0] * args.num_plot_types 80 | acc_per_category = [0] * args.num_plot_types 81 | cnt_per_category = [0] * args.num_plot_types 82 | 83 | cnt_unpredictable = 0 84 | for i, item in enumerate(test_data): 85 | gt_label = item['label'] 86 | if args.joint_plot_types: 87 | gt_label = data_processor.get_joint_plot_type(gt_label) 88 | cnt_per_category[gt_label] += 1 89 | 90 | gt_prog = data_processor.ids_to_prog(item, item['output_gt']) 91 | 92 | if data_utils._PAD in gt_prog: 93 | cnt_unpredictable += 1 94 | 95 | pred_prog = data_processor.ids_to_prog(item, predictions[i]) 96 | 97 | pred_label = data_processor.label_extraction(pred_prog) 98 | if args.joint_plot_types: 99 | pred_label = data_processor.get_joint_plot_type(pred_label) 100 | if pred_label == gt_label: 101 | label_acc_per_category[gt_label] += 1 102 | 103 | target_dfs, target_strs, target_vars = item['target_dfs'], item['target_strs'], item['target_vars'] 104 | pred_dfs, pred_strs, pred_vars, _ = data_processor.data_extraction(pred_prog, 105 | item['reserved_dfs'], item['reserved_strs'], item['reserved_vars']) 106 | 107 | if args.data_order_invariant: 108 | if (set(target_dfs + target_strs + target_vars) == set(pred_dfs + pred_strs + pred_vars) and 109 | len(target_dfs + target_strs + target_vars) == len(pred_dfs + pred_strs + pred_vars)): 110 | cur_data_acc = 1 111 | else: 112 | cur_data_acc = 0 113 | else: 114 | if target_dfs + target_strs + target_vars == pred_dfs + pred_strs + pred_vars: 115 | cur_data_acc = 1 116 | else: 117 | cur_data_acc = 0 118 | if cur_data_acc == 1: 119 | data_acc_per_category[gt_label] += 1 120 | if pred_label == gt_label: 121 | acc_per_category[gt_label] += 1 122 | 123 | print('test label acc: %.4f test data acc: %.4f test acc: %.4f ' % (test_label_acc, test_data_acc, test_acc)) 124 | print('Unpredictable samples: %d %.4f' % (cnt_unpredictable, cnt_unpredictable * 1.0 / len(test_data))) 125 | print('Upper bound: %.4f' % (1 - cnt_unpredictable * 1.0 / len(test_data))) 126 | for i in range(args.num_plot_types): 127 | print('cnt per category: ', i, cnt_per_category[i]) 128 | if cnt_per_category[i] == 0: 129 | continue 130 | print('label acc per category: ', i, label_acc_per_category[i], label_acc_per_category[i] * 1.0 / cnt_per_category[i]) 131 | print('data acc per category: ', i, data_acc_per_category[i], data_acc_per_category[i] * 1.0 / cnt_per_category[i]) 132 | print('acc per category: ', i, acc_per_category[i], acc_per_category[i] * 1.0 / cnt_per_category[i]) 133 | 134 | 135 | if __name__ == "__main__": 136 | arg_parser = arguments.get_arg_parser('juice') 137 | args = arg_parser.parse_args() 138 | args.cuda = not args.cpu and torch.cuda.is_available() 139 | random.seed(args.seed) 140 | np.random.seed(args.seed) 141 | if args.eval: 142 | evaluate(args) 143 | else: 144 | train(args) 145 | --------------------------------------------------------------------------------