├── .gitignore ├── README.md ├── dataset ├── database_information.csv ├── dataset_final │ ├── dev.csv │ ├── test.csv │ └── train.csv ├── db_tables_columns.json ├── db_tables_columns_types.json ├── dev.csv ├── test.csv └── train.csv ├── img ├── example.png ├── example0.png ├── inoutput.png └── teaser.png ├── model ├── AttentionForcing.py ├── Decoder.py ├── Encoder.py ├── Model.py ├── SubLayers.py └── VisAwareTranslation.py ├── ncNet-VIS21.pdf ├── ncNet.ipynb ├── ncNet.py ├── preprocessing ├── build_vocab.py └── process_dataset.py ├── requirements.txt ├── save_models └── trained_model.pt ├── test.py ├── test_ncNet.ipynb ├── train.py └── utilities └── vis_rendering.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.idea/ 3 | **/.DS_Store 4 | *.DS_Store 5 | *.idea 6 | /dataset/database/* 7 | *.pyc 8 | .idea 9 | .idea/ 10 | .ipynb_checkpoints 11 | .ipynb_checkpoints/ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ncNet 2 | 3 | Supporting the translation from natural language (NL) query to visualization (NL2VIS) can simplify the creation of data visualizations because if successful, anyone can generate visualizations by their natural language from the tabular data. 4 | 5 | 6 | 7 | We present ncNet, a Transformer-based model for supporting NL2VIS, with several novel visualization-aware optimizations, including using attention-forcing to optimize the learning process, and visualization-aware rendering to produce better visualization results. 8 | 9 | ## Input and Output 10 | 11 | 12 | Input: 13 | * a tabular dataset (csv, json, or sqlite3) 14 | * a natural language query used for NL2VIS 15 | * an optional chart template 16 | 17 | Output: 18 | * [Vega-Zero](https://github.com/Thanksyy/Vega-Zero): a sequence-based grammar for model-friendly, by simplifying Vega-Lite 19 | 20 | 21 | Please refer to our [paper](https://github.com/Thanksyy/Vega-Zero/blob/main/ncNet-VIS21.pdf) at IEEE VIS 2021 for more details. 22 | 23 | 24 | # Environment Setup 25 | 26 | * `Python3.6+` 27 | * `PyTorch 1.7` 28 | * `torchtext 0.8` 29 | * `ipyvega` 30 | 31 | Install Python dependency via `pip install -r requirements.txt` when the environment of Python and Pytorch is setup. 32 | 33 | 34 | # Running Code 35 | 36 | ## Data preparation 37 | 38 | 39 | 40 | * [Must] Download the Spider data [here](https://drive.google.com/drive/folders/1wmJTcC9R6ah0jBo_ONaZW3ykx5iGMx9j?usp=sharing) and unzip under `./dataset/` directory 41 | 42 | * [Optional] **_Only if_** you change the `train/dev/test.csv` under the `./dataset/` folder, you need to run `process_dataset.py` under the `preprocessing` foler. 43 | 44 | ## Runing Example 45 | 46 | Open the `ncNet.ipynb` to try the running example. 47 | 48 | 49 | 50 | 51 | ## Training 52 | 53 | Run `train.py` to train ncNet. 54 | 55 | 56 | ## Testing 57 | 58 | Run `test.py` to eval ncNet. 59 | 60 | 61 | # Citing ncNet 62 | 63 | ```bibTeX 64 | @ARTICLE{ncnet, 65 | author={Luo, Yuyu and Tang, Nan and Li, Guoliang and Tang, Jiawei and Chai, Chengliang and Qin, Xuedi}, 66 | journal={IEEE Transactions on Visualization and Computer Graphics}, 67 | title={Natural Language to Visualization by Neural Machine Translation}, 68 | year={2021}, 69 | volume={}, 70 | number={}, 71 | pages={1-1}, doi={10.1109/TVCG.2021.3114848}} 72 | ``` 73 | 74 | # License 75 | The project is available under the [MIT License](https://github.com/Thanksyy/Vega-Zero/blob/main/README.md). 76 | 77 | # Contact 78 | If you have any questions, feel free to contact Yuyu Luo (yuyuluo [AT] hkust-gz.edu.cn). 79 | -------------------------------------------------------------------------------- /img/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/ncNet/415d9477d424296bc5414e0f6624af23643372d7/img/example.png -------------------------------------------------------------------------------- /img/example0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/ncNet/415d9477d424296bc5414e0f6624af23643372d7/img/example0.png -------------------------------------------------------------------------------- /img/inoutput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/ncNet/415d9477d424296bc5414e0f6624af23643372d7/img/inoutput.png -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/ncNet/415d9477d424296bc5414e0f6624af23643372d7/img/teaser.png -------------------------------------------------------------------------------- /model/AttentionForcing.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | import numpy as np 4 | 5 | 6 | def create_visibility_matrix(SRC, each_src): 7 | each_src = np.array(each_src.to('cpu')) 8 | 9 | # find related index 10 | nl_beg_index = np.where(each_src == SRC.vocab[''])[0][0] 11 | nl_end_index = np.where(each_src == SRC.vocab[''])[0][0] 12 | template_beg_index = np.where(each_src == SRC.vocab[''])[0][0] 13 | template_end_index = np.where(each_src == SRC.vocab[''])[0][0] 14 | col_beg_index = np.where(each_src == SRC.vocab[''])[0][0] 15 | col_end_index = np.where(each_src == SRC.vocab[''])[0][0] 16 | value_beg_index = np.where(each_src == SRC.vocab[''])[0][0] 17 | value_end_index = np.where(each_src == SRC.vocab[''])[0][0] 18 | table_name_beg_index = np.where(each_src == SRC.vocab[''])[0][0] 19 | table_name_end_index = np.where(each_src == SRC.vocab[''])[0][0] 20 | 21 | if SRC.vocab['[d]'] in each_src: 22 | table_index = np.where(each_src == SRC.vocab['[d]'])[0][0] 23 | else: 24 | table_index = -1 25 | 26 | if SRC.vocab['[x]'] in each_src: 27 | x_index = np.where(each_src == SRC.vocab['[x]'])[0][0] 28 | else: 29 | # print('x') 30 | x_index = -1 31 | 32 | if SRC.vocab['[y]'] in each_src: 33 | y_index = np.where(each_src == SRC.vocab['[y]'])[0][0] 34 | else: 35 | # print('y') 36 | y_index = -1 37 | 38 | if SRC.vocab['[z]'] in each_src: 39 | color_index = np.where(each_src == SRC.vocab['[z]'])[0][0] 40 | else: 41 | # print('y') 42 | color_index = -1 43 | 44 | if SRC.vocab['[aggfunction]'] in each_src: 45 | agg_y_index = np.where(each_src == SRC.vocab['[aggfunction]'])[0][0] 46 | else: 47 | agg_y_index = -1 48 | # print('agg') 49 | 50 | if SRC.vocab['[g]'] in each_src: 51 | group_index = np.where(each_src == SRC.vocab['[g]'])[0][0] 52 | else: 53 | group_index = -1 54 | # print('xy') 55 | 56 | if SRC.vocab['[b]'] in each_src: 57 | bin_index = np.where(each_src == SRC.vocab['[b]'])[0][0] 58 | else: 59 | bin_index = -1 60 | # print('xy') 61 | 62 | if SRC.vocab['[s]'] in each_src: 63 | sort_index = np.where(each_src == SRC.vocab['[s]'])[0][0] 64 | else: 65 | sort_index = -1 66 | 67 | if SRC.vocab['[f]'] in each_src: 68 | where_index = np.where(each_src == SRC.vocab['[f]'])[0][0] 69 | else: 70 | where_index = -1 71 | # print('w') 72 | 73 | if SRC.vocab['[o]'] in each_src: 74 | other_index = np.where(each_src == SRC.vocab['[o]'])[0][0] 75 | else: 76 | other_index = -1 77 | # print('o') 78 | 79 | if SRC.vocab['[k]'] in each_src: 80 | topk_index = np.where(each_src == SRC.vocab['[k]'])[0][0] 81 | else: 82 | topk_index = -1 83 | # print('o') 84 | 85 | # init the visibility matrix 86 | v_matrix = np.zeros(each_src.shape * 2, dtype=int) 87 | 88 | # assign 1 to related cells 89 | 90 | # nl - (nl, template, col, value) self-attention 91 | v_matrix[nl_beg_index:nl_end_index, :] = 1 92 | v_matrix[:, nl_beg_index:nl_end_index] = 1 93 | 94 | # col-value self-attention 95 | v_matrix[col_beg_index:value_end_index, col_beg_index:value_end_index] = 1 96 | 97 | # template self-attention 98 | v_matrix[template_beg_index:template_end_index, 99 | template_beg_index:template_end_index] = 1 100 | 101 | # template - col/value self-attention 102 | # [x]/[y]/[agg(y)]/[o]/[w] <---> col 103 | # [w] <---> value 104 | # [c]/[o](order_type)/[i] <---> NL 105 | if table_index != -1: 106 | v_matrix[table_index, table_name_beg_index:table_name_end_index] = 1 107 | v_matrix[table_name_beg_index:table_name_end_index, table_index] = 1 108 | 109 | if x_index != -1: 110 | v_matrix[x_index, col_beg_index:col_end_index] = 1 111 | v_matrix[col_beg_index:col_end_index, x_index] = 1 112 | if y_index != -1: 113 | v_matrix[y_index, col_beg_index:col_end_index] = 1 114 | v_matrix[col_beg_index:col_end_index, y_index] = 1 115 | if color_index != -1: 116 | v_matrix[color_index, col_beg_index:col_end_index] = 1 117 | v_matrix[col_beg_index:col_end_index, color_index] = 1 118 | 119 | if agg_y_index != -1: 120 | v_matrix[agg_y_index, nl_beg_index:nl_end_index] = 1 121 | v_matrix[nl_beg_index:nl_end_index, agg_y_index] = 1 122 | 123 | if other_index != -1: 124 | v_matrix[other_index, col_beg_index:col_end_index] = 1 125 | v_matrix[col_beg_index:col_end_index, other_index] = 1 126 | 127 | if where_index != -1: 128 | v_matrix[where_index, col_beg_index:col_end_index] = 1 129 | v_matrix[where_index, value_beg_index:value_end_index] = 1 130 | 131 | v_matrix[col_beg_index:col_end_index, where_index] = 1 132 | v_matrix[value_beg_index:value_end_index, where_index] = 1 133 | 134 | if group_index != -1: 135 | v_matrix[group_index, col_beg_index:col_end_index] = 1 136 | v_matrix[col_beg_index:col_end_index, group_index] = 1 137 | 138 | v_matrix[group_index, nl_beg_index:nl_end_index] = 1 139 | v_matrix[nl_beg_index:nl_end_index, group_index] = 1 140 | 141 | if bin_index != -1: 142 | v_matrix[bin_index, col_beg_index:col_end_index] = 1 143 | v_matrix[col_beg_index:col_end_index, bin_index] = 1 144 | 145 | v_matrix[bin_index, nl_beg_index:nl_end_index] = 1 146 | v_matrix[nl_beg_index:nl_end_index, bin_index] = 1 147 | 148 | if sort_index != -1: 149 | v_matrix[sort_index, col_beg_index:col_end_index] = 1 150 | v_matrix[col_beg_index:col_end_index, sort_index] = 1 151 | if topk_index != -1: 152 | v_matrix[topk_index, nl_beg_index:nl_end_index] = 1 153 | v_matrix[nl_beg_index:nl_end_index, topk_index] = 1 154 | 155 | return v_matrix -------------------------------------------------------------------------------- /model/Decoder.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | ''' 4 | Define the decoder of the model 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | from model.SubLayers import MultiHeadAttentionLayer, PositionwiseFeedforwardLayer 10 | 11 | class Decoder(nn.Module): 12 | def __init__(self, 13 | output_dim, 14 | hid_dim, 15 | n_layers, 16 | n_heads, 17 | pf_dim, 18 | dropout, 19 | device, 20 | max_length=128): 21 | super().__init__() 22 | 23 | self.device = device 24 | 25 | self.tok_embedding = nn.Embedding(output_dim, hid_dim) 26 | self.pos_embedding = nn.Embedding(max_length, hid_dim) 27 | 28 | self.layers = nn.ModuleList([DecoderLayer(hid_dim, 29 | n_heads, 30 | pf_dim, 31 | dropout, 32 | device) 33 | for _ in range(n_layers)]) 34 | 35 | self.fc_out = nn.Linear(hid_dim, output_dim) 36 | 37 | self.dropout = nn.Dropout(dropout) 38 | 39 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) 40 | 41 | def forward(self, trg, enc_src, trg_mask, src_mask): 42 | # trg = [batch size, trg len] 43 | # enc_src = [batch size, src len, hid dim] 44 | # trg_mask = [batch size, trg len] 45 | # src_mask = [batch size, src len] 46 | 47 | batch_size = trg.shape[0] 48 | trg_len = trg.shape[1] 49 | 50 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) 51 | 52 | # pos = [batch size, trg len] 53 | 54 | # this_seg_list = np.resize(np.array(self.segment_id), (batch_size, trg_len)).tolist() 55 | # seg = torch.tensor(this_seg_list).to(self.device) 56 | 57 | trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos)) 58 | 59 | # trg = [batch size, trg len, hid dim] 60 | 61 | for layer in self.layers: 62 | trg, attention = layer(trg, enc_src, trg_mask, src_mask) 63 | 64 | # trg = [batch size, trg len, hid dim] 65 | # attention = [batch size, n heads, trg len, src len] 66 | 67 | output = self.fc_out(trg) 68 | 69 | # output = [batch size, trg len, output dim] 70 | 71 | return output, attention 72 | 73 | 74 | class DecoderLayer(nn.Module): 75 | def __init__(self, 76 | hid_dim, 77 | n_heads, 78 | pf_dim, 79 | dropout, 80 | device): 81 | super().__init__() 82 | 83 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 84 | self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) 85 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 86 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 87 | self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 88 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 89 | pf_dim, 90 | dropout) 91 | self.dropout = nn.Dropout(dropout) 92 | 93 | def forward(self, trg, enc_src, trg_mask, src_mask): 94 | # trg = [batch size, trg len, hid dim] 95 | # enc_src = [batch size, src len, hid dim] 96 | # trg_mask = [batch size, trg len] 97 | # src_mask = [batch size, src len] 98 | 99 | # self attention 100 | _trg, _ = self.self_attention(trg, trg, trg, trg_mask) 101 | 102 | # dropout, residual connection and layer norm 103 | trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) 104 | 105 | # trg = [batch size, trg len, hid dim] 106 | 107 | # encoder attention 108 | _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) 109 | 110 | # dropout, residual connection and layer norm 111 | trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) 112 | 113 | # trg = [batch size, trg len, hid dim] 114 | 115 | # positionwise feedforward 116 | _trg = self.positionwise_feedforward(trg) 117 | 118 | # dropout, residual and layer norm 119 | trg = self.ff_layer_norm(trg + self.dropout(_trg)) 120 | 121 | # trg = [batch size, trg len, hid dim] 122 | # attention = [batch size, n heads, trg len, src len] 123 | 124 | return trg, attention 125 | 126 | 127 | -------------------------------------------------------------------------------- /model/Encoder.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | ''' 4 | Define the Encoder of the model 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | from model.SubLayers import MultiHeadAttentionLayer, PositionwiseFeedforwardLayer 10 | 11 | class Encoder(nn.Module): 12 | def __init__(self, 13 | input_dim, 14 | hid_dim, # == d_model 15 | n_layers, 16 | n_heads, 17 | pf_dim, 18 | dropout, 19 | device, 20 | TOK_TYPES, 21 | max_length=128): 22 | super().__init__() 23 | 24 | self.device = device 25 | ''' 26 | nn.Embedding: 27 | 28 | A simple lookup table that stores embeddings of a fixed dictionary and size. 29 | This module is often used to store word embeddings and retrieve them using indices. 30 | The input to the module is a list of indices, and the output is the corresponding word embeddings. 31 | - num_embeddings (int) – size of the dictionary of embeddings 32 | - embedding_dim (int) – the size of each embedding vector 33 | ''' 34 | self.tok_embedding = nn.Embedding(input_dim, hid_dim) # 初始化Embedding 35 | 36 | self.pos_embedding = nn.Embedding(max_length, hid_dim) 37 | 38 | tok_types_num = len(TOK_TYPES.vocab.itos) 39 | self.tok_types_embedding = nn.Embedding(tok_types_num, hid_dim) 40 | 41 | self.layers = nn.ModuleList([EncoderLayer(hid_dim, 42 | n_heads, 43 | pf_dim, 44 | dropout, 45 | device) 46 | for _ in range(n_layers)]) 47 | 48 | self.dropout = nn.Dropout(dropout) 49 | 50 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) 51 | 52 | def forward(self, src, src_mask, tok_types, batch_matrix): 53 | 54 | batch_size = src.shape[0] 55 | src_len = src.shape[1] 56 | 57 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) 58 | 59 | src = self.dropout((self.tok_embedding(src) * self.scale) + self.tok_types_embedding(tok_types) + self.pos_embedding(pos)) 60 | 61 | # src = [batch size, src len, hid dim] 62 | 63 | for layer in self.layers: 64 | src, enc_attention = layer(src, src_mask, batch_matrix) 65 | 66 | # src = [batch size, src len, hid dim] 67 | return src, enc_attention 68 | 69 | 70 | class EncoderLayer(nn.Module): 71 | def __init__(self, 72 | hid_dim, 73 | n_heads, 74 | pf_dim, 75 | dropout, 76 | device): 77 | super().__init__() 78 | 79 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 80 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 81 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 82 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 83 | pf_dim, 84 | dropout) 85 | self.dropout = nn.Dropout(dropout) 86 | 87 | def forward(self, src, src_mask, batch_matrix): 88 | # src = [batch size, src len, hid dim] 89 | # src_mask = [batch size, src len] 90 | 91 | # self attention 92 | _src, _attention = self.self_attention(src, src, src, src_mask, batch_matrix) 93 | 94 | # dropout, residual connection and layer norm 95 | src = self.self_attn_layer_norm(src + self.dropout(_src)) 96 | 97 | # src = [batch size, src len, hid dim] 98 | 99 | # position-wise feedforward 100 | _src = self.positionwise_feedforward(src) 101 | 102 | # dropout, residual and layer norm 103 | src = self.ff_layer_norm(src + self.dropout(_src)) 104 | 105 | # src = [batch size, src len, hid dim] 106 | # print('EncoderLayer->forward:', src) 107 | return src, _attention 108 | -------------------------------------------------------------------------------- /model/Model.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from model.AttentionForcing import create_visibility_matrix 7 | 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | class Seq2Seq(nn.Module): 12 | ''' 13 | A transformer-based Seq2Seq model. 14 | ''' 15 | def __init__(self, 16 | encoder, 17 | decoder, 18 | SRC, 19 | src_pad_idx, 20 | trg_pad_idx, 21 | device): 22 | super().__init__() 23 | 24 | self.encoder = encoder 25 | self.decoder = decoder 26 | self.src_pad_idx = src_pad_idx 27 | self.trg_pad_idx = trg_pad_idx 28 | self.device = device 29 | 30 | ''' 31 | The source mask is created by checking where the source sequence is not equal to a token. 32 | It is 1 where the token is not a token and 0 when it is. 33 | It is then unsqueezed so it can be correctly broadcast when applying the mask to the energy, 34 | which of shape [batch size, n heads, seq len, seq len]. 35 | ''' 36 | 37 | def make_visibility_matrix(self, src, SRC): 38 | ''' 39 | building the visibility matrix here 40 | ''' 41 | # src = [batch size, src len] 42 | batch_matrix = [] 43 | for each_src in src: 44 | v_matrix = create_visibility_matrix(SRC, each_src) 45 | n_heads_matrix = [v_matrix] * 8 # TODO: 8 is the number of heads ... 46 | batch_matrix.append(np.array(n_heads_matrix)) 47 | batch_matrix = np.array(batch_matrix) 48 | 49 | # batch_matrix = [batch size, n_heads, src_len, src_len] 50 | return torch.tensor(batch_matrix).to(device) 51 | 52 | def make_src_mask(self, src): 53 | # src = [batch size, src len] 54 | src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) 55 | 56 | # src_mask = [batch size, 1, 1, src len] 57 | 58 | return src_mask 59 | 60 | def make_trg_mask(self, trg): 61 | # trg = [batch size, trg len] 62 | 63 | trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) 64 | 65 | # trg_pad_mask = [batch size, 1, 1, trg len] 66 | 67 | trg_len = trg.shape[1] 68 | 69 | trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool() 70 | 71 | # trg_sub_mask = [trg len, trg len] 72 | 73 | trg_mask = trg_pad_mask & trg_sub_mask 74 | 75 | # trg_mask = [batch size, 1, trg len, trg len] 76 | 77 | return trg_mask 78 | 79 | def forward(self, src, trg, tok_types, SRC): 80 | # src = [batch size, src len] 81 | # trg = [batch size, trg len] 82 | 83 | src_mask = self.make_src_mask(src) 84 | trg_mask = self.make_trg_mask(trg) 85 | 86 | batch_visibility_matrix = self.make_visibility_matrix(src, SRC) 87 | 88 | # src_mask = [batch size, 1, 1, src len] 89 | # trg_mask = [batch size, 1, trg len, trg len] 90 | 91 | enc_src, enc_attention = self.encoder(src, src_mask, tok_types, batch_visibility_matrix) 92 | 93 | # enc_src = [batch size, src len, hid dim] 94 | 95 | output, attention = self.decoder(trg, enc_src, trg_mask, src_mask) 96 | 97 | # output = [batch size, trg len, output dim] 98 | # attention = [batch size, n heads, trg len, src len] 99 | 100 | return output, attention -------------------------------------------------------------------------------- /model/SubLayers.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | 4 | ''' Define the sublayers in encoder/decoder layer ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class MultiHeadAttentionLayer(nn.Module): 10 | def __init__(self, hid_dim, n_heads, dropout, device): 11 | super().__init__() 12 | 13 | assert hid_dim % n_heads == 0 14 | 15 | self.hid_dim = hid_dim 16 | self.n_heads = n_heads 17 | self.head_dim = hid_dim // n_heads 18 | 19 | self.fc_q = nn.Linear(hid_dim, hid_dim) 20 | self.fc_k = nn.Linear(hid_dim, hid_dim) 21 | self.fc_v = nn.Linear(hid_dim, hid_dim) 22 | 23 | self.fc_o = nn.Linear(hid_dim, hid_dim) 24 | 25 | self.dropout = nn.Dropout(dropout) 26 | 27 | # print(torch.sqrt(torch.FloatTensor([self.head_dim]))) 28 | 29 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 30 | 31 | def forward(self, query, key, value, mask=None, batch_matrix=None): 32 | 33 | batch_size = query.shape[0] 34 | 35 | # query = [batch size, query len, hid dim] 36 | # key = [batch size, key len, hid dim] 37 | # value = [batch size, value len, hid dim] 38 | 39 | Q = self.fc_q(query) 40 | K = self.fc_k(key) 41 | V = self.fc_v(value) 42 | 43 | # Q = [batch size, query len, hid dim] 44 | # K = [batch size, key len, hid dim] 45 | # V = [batch size, value len, hid dim] 46 | 47 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 48 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 49 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 50 | 51 | # Q = [batch size, n heads, query len, head dim] 52 | # K = [batch size, n heads, key len, head dim] 53 | # V = [batch size, n heads, value len, head dim] 54 | 55 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 56 | 57 | # energy = [batch size, n heads, query len, key len] 58 | # batch_matrix = [batch size, n heads, query len, key len] 59 | if mask is not None: 60 | 61 | # mask = [batch size, 1, 1, query len] 62 | 63 | energy = energy.masked_fill(mask == 0, -1e10) 64 | 65 | if batch_matrix is not None: 66 | ''' 67 | apply the visibility matrix here 68 | ''' 69 | energy = energy.masked_fill(batch_matrix == 0, -1e10) 70 | 71 | attention = torch.softmax(energy, dim=-1) 72 | 73 | # attention = [batch size, n heads, query len, key len] 74 | 75 | x = torch.matmul(self.dropout(attention), V) 76 | 77 | # x = [batch size, n heads, query len, head dim] 78 | 79 | x = x.permute(0, 2, 1, 3).contiguous() 80 | 81 | # x = [batch size, query len, n heads, head dim] 82 | 83 | x = x.view(batch_size, -1, self.hid_dim) 84 | 85 | # x = [batch size, query len, hid dim] 86 | 87 | x = self.fc_o(x) 88 | 89 | # x = [batch size, query len, hid dim] 90 | 91 | return x, attention 92 | 93 | 94 | class PositionwiseFeedforwardLayer(nn.Module): 95 | def __init__(self, hid_dim, pf_dim, dropout): 96 | super().__init__() 97 | 98 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 99 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 100 | 101 | self.dropout = nn.Dropout(dropout) 102 | 103 | def forward(self, x): 104 | # x = [batch size, seq len, hid dim] 105 | 106 | x = self.dropout(torch.relu(self.fc_1(x))) 107 | 108 | # x = [batch size, seq len, pf dim] 109 | 110 | x = self.fc_2(x) 111 | 112 | # x = [batch size, seq len, hid dim] 113 | 114 | return x 115 | 116 | -------------------------------------------------------------------------------- /model/VisAwareTranslation.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | import re 4 | import json 5 | 6 | import torch 7 | 8 | def get_candidate_columns(src): 9 | col_list = re.findall('.*', src)[0].lower().split(' ') 10 | return col_list[1:-1] # remove 11 | 12 | def get_template(src): 13 | col_list = re.findall('.*', src)[0].lower().split(' ') 14 | return col_list[1:-1] # remove 15 | 16 | 17 | def get_all_table_columns(data_file): 18 | with open(data_file, 'r') as fp: 19 | data = json.load(fp) 20 | ''' 21 | return: 22 | {'chinook_1': {'Album': ['AlbumId', 'Title', 'ArtistId'], 23 | 'Artist': ['ArtistId', 'Name'], 24 | 'Customer': ['CustomerId', 25 | 'FirstName', 26 | ''' 27 | return data 28 | 29 | 30 | def get_chart_type(pred_tokens_list): 31 | return pred_tokens_list[pred_tokens_list.index('mark') + 1] 32 | 33 | 34 | def get_agg_func(pred_tokens_list): 35 | return pred_tokens_list[pred_tokens_list.index('aggregate') + 1] 36 | 37 | 38 | def get_x(pred_tokens_list): 39 | return pred_tokens_list[pred_tokens_list.index('x') + 1] 40 | 41 | 42 | def get_y(pred_tokens_list): 43 | return pred_tokens_list[pred_tokens_list.index('aggregate') + 2] 44 | 45 | 46 | def guide_decoder_by_candidates(db_id, table_id, trg_field, input_source, table_columns, db_tables_columns_types, topk_ids, topk_tokens, 47 | current_token_type, pred_tokens_list): 48 | ''' 49 | get the current token types (X, Y,...), 50 | we use the topk tokens from the decoder and the candidate columns to inference the "best" pred_token. 51 | table_columns: all columns in this table. 52 | topk_tokens: the top-k candidate predicted tokens 53 | current_token_type = x|y|groupby-axis|bin x| if_template:[orderby-axis, order-type, chart_type] 54 | pred_tokens_list: the predicted tokens list 55 | ''' 56 | # candidate columns mentioned by the NL query 57 | candidate_columns = get_candidate_columns(input_source) 58 | 59 | best_token = topk_tokens[0] 60 | best_id = topk_ids[0] 61 | 62 | if current_token_type == 'x_axis': 63 | mark_type = get_chart_type(pred_tokens_list) 64 | 65 | if best_token not in table_columns and '(' not in best_token: 66 | is_in_topk = False 67 | for tok in topk_tokens: 68 | if tok in candidate_columns and tok in table_columns: 69 | # get column's type 70 | if mark_type in ['bar', 'line'] and db_tables_columns_types!=None and db_tables_columns_types[db_id][table_id][tok] != 'numeric': 71 | best_token = tok 72 | best_id = trg_field.vocab.stoi[best_token] 73 | is_in_topk = True 74 | break 75 | if mark_type == 'point' and db_tables_columns_types!=None and db_tables_columns_types[db_id][table_id][tok] == 'numeric': 76 | best_token = tok 77 | best_id = trg_field.vocab.stoi[best_token] 78 | is_in_topk = True 79 | break 80 | if mark_type == 'arc' and db_tables_columns_types!=None and db_tables_columns_types[db_id][table_id][tok] != 'numeric': 81 | best_token = tok 82 | best_id = trg_field.vocab.stoi[best_token] 83 | is_in_topk = True 84 | break 85 | 86 | if is_in_topk == False and len(candidate_columns) > 0: 87 | for col in candidate_columns: 88 | if col != '': 89 | if mark_type in ['bar', 'line'] and db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][col] != 'numeric': 90 | best_token = col 91 | best_id = trg_field.vocab.stoi[best_token] 92 | break 93 | 94 | if mark_type == 'point' and db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][col] == 'numeric': 95 | best_token = col 96 | best_id = trg_field.vocab.stoi[best_token] 97 | break 98 | if mark_type == 'arc' and db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][col] != 'numeric': 99 | best_token = col 100 | best_id = trg_field.vocab.stoi[best_token] 101 | break 102 | 103 | if current_token_type == 'y_axis': 104 | mark_type = get_chart_type(pred_tokens_list) 105 | agg_function = get_agg_func(pred_tokens_list) 106 | selected_x = get_x(pred_tokens_list) 107 | 108 | y = best_token 109 | 110 | if y not in table_columns and y != 'distinct': 111 | is_in_topk = False 112 | for tok in topk_tokens: 113 | if tok in candidate_columns and tok in table_columns: 114 | if mark_type in ['bar', 'arc', 'line'] and agg_function == 'count': 115 | best_token = tok 116 | best_id = trg_field.vocab.stoi[best_token] 117 | is_in_topk = True 118 | break 119 | if mark_type in ['bar', 'arc', 'line'] and agg_function != 'count' and \ 120 | db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][tok] == 'numeric': 121 | best_token = tok 122 | best_id = trg_field.vocab.stoi[best_token] 123 | is_in_topk = True 124 | break 125 | if mark_type == 'point' and tok != selected_x: 126 | best_token = tok 127 | best_id = trg_field.vocab.stoi[best_token] 128 | break 129 | 130 | if is_in_topk == False and len(candidate_columns) > 0: 131 | for col in candidate_columns: 132 | if col != '': 133 | if mark_type in ['bar', 'arc', 'line'] and agg_function == 'count': 134 | best_token = col 135 | best_id = trg_field.vocab.stoi[best_token] 136 | is_in_topk = True 137 | break 138 | if mark_type in ['bar', 'arc', 'line'] and agg_function != 'count' and \ 139 | db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][col] == 'numeric': 140 | best_token = col 141 | best_id = trg_field.vocab.stoi[best_token] 142 | break 143 | if mark_type == 'point' and col != selected_x: 144 | best_token = col 145 | best_id = trg_field.vocab.stoi[best_token] 146 | break 147 | 148 | # TODO! 149 | if (y in table_columns and y not in candidate_columns) and ('(' not in y): 150 | for tok in topk_tokens: 151 | if tok in candidate_columns and tok in table_columns: 152 | best_token = tok 153 | best_id = trg_field.vocab.stoi[best_token] 154 | is_in_topk = True 155 | break 156 | 157 | if current_token_type == 'z_axis': 158 | selected_x = get_x(pred_tokens_list) 159 | selected_y = get_y(pred_tokens_list) 160 | 161 | if best_token not in table_columns or best_token == selected_x or best_token == selected_y: 162 | is_in_topk = False 163 | for tok in topk_tokens: 164 | if tok in candidate_columns and tok in table_columns: 165 | # get column's type 166 | if selected_x != tok and selected_y != tok and db_tables_columns_types !=None and db_tables_columns_types[db_id][table_id][tok] == 'categorical': 167 | best_token = tok 168 | best_id = trg_field.vocab.stoi[best_token] 169 | is_in_topk = True 170 | break 171 | 172 | if is_in_topk == False and len(candidate_columns) > 0: 173 | for col in candidate_columns: 174 | if col != selected_x and col != selected_y and db_tables_columns_types!=None and db_tables_columns_types[db_id][table_id][ 175 | col] == 'categorical': 176 | best_token = col 177 | best_id = trg_field.vocab.stoi[best_token] 178 | break 179 | 180 | if selected_x == best_token or selected_y == best_token or db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][ 181 | best_token] != 'categorical': 182 | for tok in topk_tokens: 183 | if tok in candidate_columns and tok in table_columns: 184 | # get column's type 185 | if selected_x != tok and selected_y != tok and db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][ 186 | tok] == 'categorical': 187 | best_token = tok 188 | best_id = trg_field.vocab.stoi[best_token] 189 | break 190 | 191 | if current_token_type == 'topk': # bin [x] by .. 192 | is_in_topk = False 193 | if best_token.isdigit() == False: 194 | for tok in topk_tokens: 195 | if tok.isdigit(): 196 | best_token = tok 197 | is_in_topk = True 198 | if is_in_topk == False: 199 | best_token = '3' # default 200 | best_id = trg_field.vocab.stoi[best_token] 201 | 202 | if current_token_type == 'groupby_axis': 203 | if best_token != 'x': 204 | if best_token not in table_columns or db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][best_token] == 'numeric': 205 | is_in_topk = False 206 | for tok in topk_tokens: 207 | if tok in candidate_columns and tok in table_columns: 208 | # get column's type 209 | if db_tables_columns_types != None and db_tables_columns_types[db_id][table_id][tok] == 'categorical': 210 | best_token = tok 211 | best_id = trg_field.vocab.stoi[best_token] 212 | is_in_topk = True 213 | break 214 | 215 | if is_in_topk == False: 216 | best_token = get_x(pred_tokens_list) 217 | best_id = trg_field.vocab.stoi[best_token] 218 | 219 | if current_token_type == 'bin_axis': # bin [x] by .. 220 | best_token = 'x' 221 | best_id = trg_field.vocab.stoi[best_token] 222 | 223 | template_list = get_template(input_source) 224 | 225 | if '[t]' not in template_list: # have the chart template 226 | if current_token_type == 'chart_type': 227 | best_token = template_list[template_list.index('mark') + 1] 228 | best_id = trg_field.vocab.stoi[best_token] 229 | 230 | if current_token_type == 'orderby_axis': 231 | # print('Case-3') 232 | if template_list[template_list.index('sort') + 1] == '[x]': 233 | best_token = 'x' 234 | best_id = trg_field.vocab.stoi[best_token] 235 | 236 | elif template_list[template_list.index('sort') + 1] == '[y]': 237 | best_token = 'y' 238 | best_id = trg_field.vocab.stoi[best_token] 239 | else: 240 | pass 241 | # print('Let me know this issue!') 242 | 243 | if current_token_type == 'orderby_type': 244 | best_token = template_list[template_list.index('sort') + 2] 245 | best_id = trg_field.vocab.stoi[best_token] 246 | 247 | return best_id, best_token 248 | 249 | 250 | def translate_sentence(sentence, src_field, trg_field, TOK_TYPES, tok_types, model, device, max_len=128): 251 | model.eval() 252 | 253 | # process the tok_type 254 | if isinstance(tok_types, str): 255 | tok_types_ids = tok_types.lower().split(' ') 256 | else: 257 | tok_types_ids = [tok_type.lower() for tok_type in tok_types] 258 | tok_types_ids = [TOK_TYPES.init_token] + tok_types_ids + [TOK_TYPES.eos_token] 259 | tok_types_ids_indexes = [TOK_TYPES.vocab.stoi[tok_types_id] for tok_types_id in tok_types_ids] 260 | tok_types_tensor = torch.LongTensor(tok_types_ids_indexes).unsqueeze(0).to(device) 261 | 262 | if isinstance(sentence, str): 263 | tokens = sentence.lower().split(' ') 264 | else: 265 | tokens = [token.lower() for token in sentence] 266 | 267 | tokens = [src_field.init_token] + tokens + [src_field.eos_token] 268 | 269 | src_indexes = [src_field.vocab.stoi[token] for token in tokens] 270 | 271 | src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device) 272 | 273 | src_mask = model.make_src_mask(src_tensor) 274 | 275 | # visibility matrix 276 | batch_visibility_matrix = model.make_visibility_matrix(src_tensor, src_field) 277 | 278 | with torch.no_grad(): 279 | enc_src, enc_attention = model.encoder(src_tensor, src_mask, tok_types_tensor, batch_visibility_matrix) 280 | 281 | trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]] 282 | 283 | for i in range(max_len): 284 | 285 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) 286 | 287 | trg_mask = model.make_trg_mask(trg_tensor) 288 | 289 | with torch.no_grad(): 290 | output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask) 291 | 292 | pred_token = output.argmax(2)[:, -1].item() 293 | 294 | trg_indexes.append(pred_token) 295 | 296 | if pred_token == trg_field.vocab.stoi[trg_field.eos_token]: 297 | break 298 | 299 | trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes] 300 | 301 | return trg_tokens[1:], attention, enc_attention 302 | 303 | 304 | def translate_sentence_with_guidance(db_id, table_id, sentence, src_field, trg_field, TOK_TYPES, tok_types, SRC, model, 305 | db_tables_columns, db_tables_columns_types, device, max_len=128, show_progress = False): 306 | model.eval() 307 | # process the tok_type 308 | if isinstance(tok_types, str): 309 | tok_types_ids = tok_types.lower().split(' ') 310 | else: 311 | tok_types_ids = [tok_type.lower() for tok_type in tok_types] 312 | tok_types_ids = [TOK_TYPES.init_token] + \ 313 | tok_types_ids + [TOK_TYPES.eos_token] 314 | tok_types_ids_indexes = [TOK_TYPES.vocab.stoi[tok_types_id] 315 | for tok_types_id in tok_types_ids] 316 | tok_types_tensor = torch.LongTensor( 317 | tok_types_ids_indexes).unsqueeze(0).to(device) 318 | 319 | if isinstance(sentence, str): 320 | tokens = sentence.lower().split(' ') 321 | else: 322 | tokens = [token.lower() for token in sentence] 323 | 324 | tokens = [src_field.init_token] + tokens + [src_field.eos_token] 325 | 326 | src_indexes = [src_field.vocab.stoi[token] for token in tokens] 327 | 328 | src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device) 329 | 330 | src_mask = model.make_src_mask(src_tensor) 331 | 332 | # visibility matrix 333 | batch_visibility_matrix = model.make_visibility_matrix(src_tensor, SRC) 334 | 335 | with torch.no_grad(): 336 | enc_src, enc_attention = model.encoder(src_tensor, src_mask, 337 | tok_types_tensor, batch_visibility_matrix) 338 | 339 | trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]] 340 | trg_tokens = [] 341 | 342 | current_token_type = None 343 | if show_progress == True: 344 | print('Show the details in each tokens:') 345 | 346 | for i in range(max_len): 347 | 348 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) 349 | 350 | trg_mask = model.make_trg_mask(trg_tensor) 351 | 352 | with torch.no_grad(): 353 | output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask) 354 | 355 | table_columns = [] 356 | try: # get all columns in a table 357 | table_columns = db_tables_columns[db_id][table_id] 358 | except: 359 | print('[Fail] get all columns in a table') 360 | table_columns = [] 361 | 362 | if current_token_type == 'table_name': 363 | ''' 364 | only for single table !!! 365 | ''' 366 | pred_token = table_id 367 | pred_id = trg_field.vocab.stoi[pred_token] 368 | if show_progress == True: 369 | print('-------------------\nCurrent Token Type: Table Name , top-3 tokens: [{}]'.format( 370 | current_token_type, pred_token)) 371 | 372 | else: 373 | topk_ids = torch.topk(output, k=5, dim=2, sorted=True).indices[:, -1, :].tolist()[0] 374 | topk_tokens = [trg_field.vocab.itos[tok_id] for tok_id in topk_ids] 375 | 376 | ''' 377 | apply guide_decoder_by_candidates 378 | ''' 379 | pred_id, pred_token = guide_decoder_by_candidates( 380 | db_id, table_id, trg_field, sentence, table_columns, db_tables_columns_types, topk_ids, 381 | topk_tokens, current_token_type, trg_tokens 382 | ) 383 | if show_progress == True: 384 | if current_token_type == None: 385 | print('-------------------\nCurrent Token Type: Query Sketch Part , top-3 tokens: [{}]'.format(', '.join(topk_tokens))) 386 | else: 387 | print('-------------------\nCurrent Token Type: {} , original top-3 tokens: [{}] , the final tokens by VisAwareTranslation: {}'.format(current_token_type, ', '.join(topk_tokens), pred_token)) 388 | 389 | current_token_type = None 390 | 391 | trg_indexes.append(pred_id) 392 | trg_tokens.append(pred_token) 393 | 394 | # update the current_token_type and pred_aix here 395 | # mark bar data apartments encoding x apt_type_code y aggregate count apt_type_code transform group x sort y desc 396 | if i == 0: 397 | current_token_type = 'chart_type' 398 | 399 | if i > 1: 400 | if trg_tokens[-1] == 'data' and trg_tokens[-2] in ['bar', 'arc', 'line', 'point']: 401 | current_token_type = 'table_name' 402 | 403 | if i > 2: 404 | if trg_tokens[-1] == 'x' and trg_tokens[-2] == 'encoding': 405 | current_token_type = 'x_axis' 406 | 407 | if trg_tokens[-1] == 'aggregate' and trg_tokens[-2] == 'y': 408 | current_token_type = 'aggFunction' 409 | 410 | if trg_tokens[-2] == 'aggregate' and trg_tokens[-1] in ['count', 'sum', 'mean', 'avg', 'max', 'min']: 411 | current_token_type = 'y_axis' 412 | 413 | if trg_tokens[-3] == 'aggregate' and trg_tokens[-2] in ['count', 'sum', 'mean', 'avg', 'max', 'min'] and \ 414 | trg_tokens[-1] == 'distinct': 415 | current_token_type = 'y_axis' 416 | 417 | # mark [T] data photos encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K] 418 | if trg_tokens[-1] == 'color' and trg_tokens[-4] == 'aggregate': 419 | current_token_type = 'z_axis' 420 | 421 | if trg_tokens[-1] == 'bin': 422 | current_token_type = 'bin_axis' 423 | 424 | if trg_tokens[-1] == 'group': 425 | current_token_type = 'groupby_axis' 426 | 427 | if trg_tokens[-1] == 'sort': 428 | current_token_type = 'orderby_axis' 429 | 430 | if trg_tokens[-2] == 'sort' and trg_tokens[-1] in ['x', 'y']: 431 | current_token_type = 'orderby_type' 432 | 433 | if trg_tokens[-1] == 'topk': 434 | current_token_type = 'topk' 435 | 436 | if pred_id == trg_field.vocab.stoi[trg_field.eos_token]: 437 | break 438 | 439 | return trg_tokens, attention, enc_attention 440 | 441 | 442 | def postprocessing_group(gold_q_tok, pred_q_tok): 443 | # 2. checking (and correct) group-by 444 | 445 | # rule: if other part is the same, and only add group-by part, the result should be the same 446 | if 'group' not in gold_q_tok and 'group' in pred_q_tok: 447 | groupby_x = pred_q_tok[pred_q_tok.index('group') + 1] 448 | if ' '.join(pred_q_tok).replace('group ' + groupby_x, '') == ' '.join(gold_q_tok): 449 | pred_q_tok = gold_q_tok 450 | 451 | return pred_q_tok 452 | 453 | 454 | def postprocessing(gold_query, pred_query, if_template, src_input): 455 | try: 456 | # get the template: 457 | chart_template = re.findall('.*', src_input)[0] 458 | chart_template_tok = chart_template.lower().split(' ') 459 | 460 | gold_q_tok = gold_query.lower().split(' ') 461 | pred_q_tok = pred_query.lower().split(' ') 462 | 463 | # 0. visualize type. if we have the template, the visualization type must be matched. 464 | if if_template: 465 | pred_q_tok[pred_q_tok.index('mark') + 1] = gold_q_tok[gold_q_tok.index('mark') + 1] 466 | 467 | # 1. Table Checking. If we focus on single table, must match!!! 468 | if 'data' in pred_q_tok and 'data' in gold_q_tok: 469 | pred_q_tok[pred_q_tok.index('data') + 1] = gold_q_tok[gold_q_tok.index('data') + 1] 470 | 471 | pred_q_tok = postprocessing_group(gold_q_tok, pred_q_tok) 472 | 473 | # 3. Order-by. if we have the template, we can checking (and correct) the predicting order-by 474 | # rule 1: if have the template, order by [x]/[y], trust to the select [x]/[y] 475 | if 'sort' in gold_q_tok and 'sort' in pred_q_tok and if_template: 476 | order_by_which_axis = chart_template_tok[chart_template_tok.index('sort') + 1] # [x], [y], or [o] 477 | if order_by_which_axis == '[x]': 478 | pred_q_tok[pred_q_tok.index('sort') + 1] = 'x' 479 | elif order_by_which_axis == '[y]': 480 | pred_q_tok[pred_q_tok.index('sort') + 1] = 'y' 481 | else: 482 | pass 483 | 484 | elif 'sort' in gold_q_tok and 'sort' not in pred_q_tok and if_template: 485 | order_by_which_axis = chart_template_tok[chart_template_tok.index('sort') + 1] # [x], [y], or [o] 486 | order_type = chart_template_tok[chart_template_tok.index('sort') + 2] 487 | 488 | if 'x' == gold_q_tok[gold_q_tok.index('sort') + 1] or 'y' == gold_q_tok[gold_q_tok.index('sort') + 1]: 489 | pred_q_tok += ['sort', gold_q_tok[gold_q_tok.index('sort') + 1]] 490 | if gold_q_tok.index('sort') + 2 < len(gold_q_tok): 491 | pred_q_tok += [gold_q_tok[gold_q_tok.index('sort') + 2]] 492 | else: 493 | pass 494 | 495 | else: 496 | pass 497 | 498 | pred_q_tok = postprocessing_group(gold_q_tok, pred_q_tok) 499 | 500 | # 4. checking (and correct) bining 501 | # rule 1: [interval] bin 502 | # rule 2: bin by [x] 503 | if 'bin' in gold_q_tok and 'bin' in pred_q_tok: 504 | # rule 1 505 | if_bin_gold, if_bin_pred = False, False 506 | 507 | for binn in ['by time', 'by year', 'by weekday', 'by month']: 508 | if binn in gold_query: 509 | if_bin_gold = binn 510 | if binn in pred_query: 511 | if_bin_pred = binn 512 | 513 | if (if_bin_gold != False and if_bin_pred != False) and (if_bin_gold != if_bin_pred): 514 | pred_q_tok[pred_q_tok.index('bin') + 3] = if_bin_gold.replace('by ', '') 515 | 516 | if 'bin' in gold_q_tok and 'bin' not in pred_q_tok and 'group' in pred_q_tok: 517 | # rule 3: group-by x and bin x by time in the bar chart should be the same. 518 | bin_x = gold_q_tok[gold_q_tok.index('bin') + 1] 519 | group_x = pred_q_tok[pred_q_tok.index('group') + 1] 520 | if bin_x == group_x: 521 | if ''.join(pred_q_tok).replace('group ' + group_x, '') == ''.join(gold_q_tok).replace( 522 | 'bin ' + bin_x + ' by time', ''): 523 | pred_q_tok = gold_q_tok 524 | 525 | # group x | bin x ... count A == count B 526 | if 'count' in gold_q_tok and 'count' in pred_q_tok: 527 | if ('group' in gold_q_tok and 'group' in pred_q_tok) or ('bin' in gold_q_tok and 'bin' in pred_q_tok): 528 | pred_count = pred_q_tok[pred_q_tok.index('count') + 1] 529 | gold_count = gold_q_tok[gold_q_tok.index('count') + 1] 530 | if ' '.join(pred_q_tok).replace('count ' + pred_count, 'count ' + gold_count) == ' '.join(gold_q_tok): 531 | pred_q_tok = gold_q_tok 532 | 533 | except: 534 | print('error at post processing') 535 | return ' '.join(pred_q_tok) -------------------------------------------------------------------------------- /ncNet-VIS21.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/ncNet/415d9477d424296bc5414e0f6624af23643372d7/ncNet-VIS21.pdf -------------------------------------------------------------------------------- /ncNet.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | import pandas as pd 4 | import sqlite3 5 | import re 6 | import os 7 | 8 | import torch 9 | from model.VisAwareTranslation import translate_sentence_with_guidance, translate_sentence, postprocessing 10 | from model.Model import Seq2Seq 11 | from model.Encoder import Encoder 12 | from model.Decoder import Decoder 13 | from preprocessing.build_vocab import build_vocab 14 | 15 | 16 | from utilities.vis_rendering import VegaZero2VegaLite 17 | from preprocessing.process_dataset import ProcessData4Training 18 | from vega import VegaLite 19 | 20 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 21 | 22 | class ncNet(object): 23 | def __init__(self, trained_model): 24 | self.data = None 25 | self.db_id = '' 26 | self.table_id = '' 27 | self.db_tables_columns = None 28 | self.db_tables_columns_types = None 29 | self.trained_model = trained_model 30 | 31 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | 33 | self.SRC, self.TRG, self.TOK_TYPES, BATCH_SIZE, train_iterator, valid_iterator, test_iterator, self.my_max_length = build_vocab( 34 | data_dir='./dataset/dataset_final/', 35 | db_info='./dataset/database_information.csv', 36 | batch_size=128, 37 | max_input_length=128 38 | ) 39 | 40 | INPUT_DIM = len(self.SRC.vocab) 41 | OUTPUT_DIM = len(self.TRG.vocab) 42 | HID_DIM = 256 # it equals to embedding dimension 43 | ENC_LAYERS = 3 44 | DEC_LAYERS = 3 45 | ENC_HEADS = 8 46 | DEC_HEADS = 8 47 | ENC_PF_DIM = 512 48 | DEC_PF_DIM = 512 49 | ENC_DROPOUT = 0.1 50 | DEC_DROPOUT = 0.1 51 | 52 | enc = Encoder(INPUT_DIM, 53 | HID_DIM, 54 | ENC_LAYERS, 55 | ENC_HEADS, 56 | ENC_PF_DIM, 57 | ENC_DROPOUT, 58 | self.device, 59 | self.TOK_TYPES, 60 | self.my_max_length 61 | ) 62 | 63 | dec = Decoder(OUTPUT_DIM, 64 | HID_DIM, 65 | DEC_LAYERS, 66 | DEC_HEADS, 67 | DEC_PF_DIM, 68 | DEC_DROPOUT, 69 | self.device, 70 | self.my_max_length 71 | ) 72 | 73 | SRC_PAD_IDX = self.SRC.vocab.stoi[self.SRC.pad_token] 74 | TRG_PAD_IDX = self.TRG.vocab.stoi[self.TRG.pad_token] 75 | 76 | self.ncNet = Seq2Seq(enc, dec, self.SRC, SRC_PAD_IDX, TRG_PAD_IDX, self.device).to(self.device) # define the transformer-based ncNet 77 | self.ncNet.load_state_dict(torch.load(trained_model, map_location=self.device)) 78 | 79 | 80 | def specify_dataset(self, data_type, db_url = None, table_name = None, data = None, data_url = None): 81 | ''' 82 | :param data_type: sqlite3, csv, json 83 | :param db_url: db path for sqlite3 database, e.g., './dataset/database/flight/flight.sqlite' 84 | :param table_name: the table name in a sqlite3 85 | :param data: DataFrame for csv 86 | :param data_url: data path for csv or json 87 | :return: save the DataFrame in the self.data 88 | ''' 89 | self.db_id = 'temp_' + table_name 90 | self.table_id = table_name 91 | 92 | if data_type == 'csv': 93 | if data != None and data_url == None: 94 | self.data = data 95 | elif data == None and data_url != None: 96 | self.data = pd.read_csv(data_url) 97 | else: 98 | raise ValueError('Please only specify one of the data or data_url') 99 | elif data_type == 'json': 100 | if data == None and data_url != None: 101 | self.data = pd.read_json(data_url) 102 | else: 103 | raise ValueError('Read JSON from the json file, please only specify the "data_type" or "data_url"') 104 | 105 | elif data_type == 'sqlite3': 106 | # Create your connection. 107 | try: 108 | cnx = sqlite3.connect(db_url) 109 | self.data = pd.read_sql_query("SELECT * FROM " + table_name, cnx) 110 | except: 111 | raise ValueError('Errors in read table from sqlite3 database. \ndb_url: {0}\n table_name : {1} '.format(data_url, table_name)) 112 | 113 | else: 114 | if data != None and type(data) == pd.core.frame.DataFrame: 115 | self.data = data 116 | else: 117 | raise ValueError('The data type must be one of the csv, json, sqlite3, or a DataFrame object.') 118 | 119 | self.db_tables_columns_types = dict() 120 | self.db_tables_columns_types[self.db_id] = dict() 121 | self.db_tables_columns_types[self.db_id][table_name] = dict() 122 | for col, _type in self.data.dtypes.items(): 123 | # print(col, _type) 124 | if 'int' in str(_type).lower() or 'float' in str(_type).lower(): 125 | _type = 'numeric' 126 | else: 127 | _type = 'categorical' 128 | self.db_tables_columns_types[self.db_id][table_name][col.lower()] = _type 129 | 130 | # print(self.db_tables_columns_types) 131 | 132 | self.data.columns = self.data.columns.str.lower() # to lowercase 133 | 134 | self.db_tables_columns = { 135 | self.db_id:{ 136 | self.table_id: list(self.data.columns) 137 | } 138 | } 139 | 140 | if data_type == 'json' or data_type == 'sqlite3': 141 | # write to sqlite3 database 142 | if not os.path.exists('./dataset/database/'+self.db_id): 143 | os.makedirs('./dataset/database/'+self.db_id) 144 | 145 | conn = sqlite3.connect('./dataset/database/'+self.db_id+'/'+self.db_id+'.sqlite') 146 | 147 | self.data.to_sql(self.table_id, conn, if_exists='replace', index=False) 148 | 149 | self.DataProcesser = ProcessData4Training(db_url='./dataset/database') 150 | self.db_table_col_val_map = dict() 151 | table_cols = self.DataProcesser.get_table_columns(self.db_id) 152 | self.db_table_col_val_map[self.db_id] = dict() 153 | for table, cols in table_cols.items(): 154 | col_val_map = self.DataProcesser.get_values_in_columns(self.db_id, table, cols, conditions='remove') 155 | self.db_table_col_val_map[self.db_id][table] = col_val_map 156 | 157 | def show_dataset(self, top_rows=5): 158 | return self.data[:top_rows] 159 | 160 | 161 | def nl2vis(self, nl_question, chart_template=None, show_progress=None, visualization_aware_translation=True): 162 | # process and the nl_question and the chart template as input. 163 | # call the model to perform prediction 164 | # render the predicted query 165 | query2vl = VegaZero2VegaLite() 166 | 167 | input_src, token_types = self.process_input(nl_question, chart_template) 168 | 169 | if visualization_aware_translation == True: 170 | # print("\nGenerate the visualization by visualization-aware translation:\n") 171 | 172 | pred_query, attention, enc_attention = translate_sentence_with_guidance( 173 | self.db_id, self.table_id, input_src, self.SRC, self.TRG, self.TOK_TYPES, token_types, 174 | self.SRC, self.ncNet, self.db_tables_columns, self.db_tables_columns_types, self.device, self.my_max_length, show_progress 175 | ) 176 | 177 | pred_query = ' '.join(pred_query).replace(' ', '').lower() 178 | if chart_template != None: 179 | pred_query = postprocessing(pred_query, pred_query, True, input_src) 180 | else: 181 | pred_query = postprocessing(pred_query, pred_query, False, input_src) 182 | 183 | pred_query = ' '.join(pred_query.replace('"', "'").split()) 184 | 185 | print('[NL Question]:', nl_question) 186 | print('[Chart Template]:', chart_template) 187 | print('[Predicted VIS Query]:', pred_query) 188 | 189 | # print('[The Predicted VIS Result]:') 190 | return VegaLite(query2vl.to_VegaLite(pred_query, self.data)), query2vl.to_VegaLite(pred_query, self.data) 191 | # print('\n') 192 | 193 | else: 194 | # print("\nGenerate the visualization by greedy decoding:\n") 195 | 196 | pred_query, attention, enc_attention = translate_sentence( 197 | input_src, self.SRC, self.TRG, self.TOK_TYPES, token_types, self.ncNet, self.device, self.my_max_length 198 | ) 199 | 200 | pred_query = ' '.join(pred_query).replace(' ', '').lower() 201 | if chart_template != None: 202 | pred_query = postprocessing(pred_query, pred_query, True, input_src) 203 | else: 204 | pred_query = postprocessing(pred_query, pred_query, False, input_src) 205 | 206 | pred_query = ' '.join(pred_query.replace('"', "'").split()) 207 | 208 | print('[NL Question]:', nl_question) 209 | print('[Chart Template]:', chart_template) 210 | print('[Predicted VIS Query]:', pred_query) 211 | 212 | # print('[The Predicted VIS Result]:') 213 | return VegaLite(query2vl.to_VegaLite(pred_query, self.data)), query2vl.to_VegaLite(pred_query, self.data) 214 | 215 | 216 | def process_input(self, nl_question, chart_template): 217 | 218 | def get_token_types(input_source): 219 | # print('input_source:', input_src) 220 | 221 | token_types = '' 222 | 223 | for ele in re.findall('.*', input_source)[0].split(' '): 224 | token_types += ' nl' 225 | 226 | for ele in re.findall('.*', input_source)[0].split(' '): 227 | token_types += ' template' 228 | 229 | token_types += ' table table' 230 | 231 | for ele in re.findall('.*', input_source)[0].split(' '): 232 | token_types += ' col' 233 | 234 | for ele in re.findall('.*', input_source)[0].split(' '): 235 | token_types += ' value' 236 | 237 | token_types += ' table' 238 | 239 | token_types = token_types.strip() 240 | return token_types 241 | 242 | def fix_chart_template(chart_template = None): 243 | query_template = 'mark [T] data [D] encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K]' 244 | 245 | if chart_template != None: 246 | try: 247 | query_template = query_template.replace('[T]', chart_template['chart']) 248 | except: 249 | raise ValueError('Error at settings of chart type!') 250 | 251 | try: 252 | if 'sorting_options' in chart_template and chart_template['sorting_options'] != None: 253 | order_xy = '[O]' 254 | if 'axis' in chart_template['sorting_options']: 255 | if chart_template['sorting_options']['axis'].lower() == 'x': 256 | order_xy = '[X]' 257 | elif chart_template['sorting_options']['axis'].lower() == 'y': 258 | order_xy = '[Y]' 259 | else: 260 | order_xy = '[O]' 261 | 262 | order_type = 'ASC' 263 | if 'type' in chart_template['sorting_options']: 264 | if chart_template['sorting_options']['type'].lower() == 'desc': 265 | order_type = 'DESC' 266 | elif chart_template['sorting_options']['type'].lower() == 'asc': 267 | order_type = 'ASC' 268 | else: 269 | raise ValueError('Unknown order by settings, the order-type must be "desc", or "asc"') 270 | query_template = query_template.replace('sort [S]', 'sort '+order_xy+' '+order_type) 271 | except: 272 | raise ValueError('Error at settings of sorting!') 273 | 274 | return query_template 275 | else: 276 | return query_template 277 | 278 | query_template = fix_chart_template(chart_template) 279 | # get a list of mentioned values in the NL question 280 | col_names, value_names = self.DataProcesser.get_mentioned_values_in_NL_question( 281 | self.db_id, self.table_id, nl_question, db_table_col_val_map=self.db_table_col_val_map 282 | ) 283 | col_names = ' '.join(str(e) for e in col_names) 284 | value_names = ' '.join(str(e) for e in value_names) 285 | input_src = " {} {} {} {} {} ".format(nl_question, query_template, self.table_id, col_names, value_names).lower() 286 | token_types = get_token_types(input_src) 287 | 288 | return input_src, token_types 289 | 290 | 291 | if __name__ == '__main__': 292 | ncNet = ncNet( 293 | trained_model='./save_models/trained_model.pt' 294 | ) 295 | ncNet.specify_dataset( 296 | data_type='sqlite3', 297 | db_url='./dataset/database/car_1/car_1.sqlite', 298 | table_name='cars_data' 299 | ) 300 | ncNet.nl2vis( 301 | nl_question='What is the average weight and year for each year. Plot them as line chart.', 302 | chart_template=None 303 | ) 304 | -------------------------------------------------------------------------------- /preprocessing/build_vocab.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | import torch 4 | from torchtext.data import Field, TabularDataset, BucketIterator 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | def build_vocab(data_dir, db_info, batch_size, max_input_length): 9 | 10 | def tokenizer(text): 11 | return text.split(' ') 12 | 13 | # def tokenizer_src(text): 14 | # return text.split(' ') 15 | 16 | SRC = Field(tokenize=tokenizer, 17 | init_token='', 18 | eos_token='', 19 | lower=True, 20 | batch_first=True) 21 | 22 | TOK_TYPES = Field(tokenize=tokenizer, 23 | init_token='', 24 | eos_token='', 25 | lower=True, 26 | batch_first=True) 27 | 28 | # TODO data_dir = './Code/dataset/vega_zero/dataset_final/' 29 | train_data, valid_data, test_data = TabularDataset.splits( 30 | path=data_dir, format='csv', skip_header=True, 31 | train='train.csv', validation='dev.csv', test='test.csv', 32 | fields=[ 33 | ('tvBench_id', None), 34 | ('db_id', None), 35 | ('chart', None), 36 | ('hardness', None), 37 | ('query', None), 38 | ('question', None), 39 | ('vega_zero', None), 40 | ('mentioned_columns', None), 41 | ('mentioned_values', None), 42 | ('query_template', None), 43 | ('src', SRC), 44 | ('trg', SRC), 45 | ('tok_types', TOK_TYPES) 46 | ]) 47 | 48 | # TODO db_info = './Code/dataset/database_information.csv', 49 | db_information = TabularDataset( 50 | path=db_info, 51 | format='csv', 52 | skip_header=True, 53 | fields=[ 54 | ('table', SRC), 55 | ('column', SRC), 56 | ('value', SRC) 57 | ] 58 | ) 59 | 60 | SRC.build_vocab(train_data, valid_data, test_data, db_information, min_freq=2) 61 | TRG = SRC 62 | TOK_TYPES.build_vocab(train_data, valid_data, test_data, db_information, min_freq=2) 63 | 64 | train_iterator, valid_iterator, test_iterator = BucketIterator.splits( 65 | (train_data, valid_data, test_data), sort=False, 66 | batch_size=batch_size, 67 | device=device) 68 | 69 | return SRC, TRG, TOK_TYPES, batch_size, train_iterator, valid_iterator, test_iterator, max_input_length -------------------------------------------------------------------------------- /preprocessing/process_dataset.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | 4 | import pandas as pd 5 | import sqlite3 6 | from dateutil.parser import parse 7 | import json 8 | import py_stringsimjoin as ssj 9 | import re 10 | import os 11 | import time 12 | 13 | 14 | class ProcessData4Training(object): 15 | def __init__(self, db_url): 16 | self.db_url = db_url 17 | 18 | # def is_date(string, fuzzy=False): 19 | # """ 20 | # Return whether the string can be interpreted as a date. 21 | # 22 | # :param string: str, string to check for date 23 | # :param fuzzy: bool, ignore unknown tokens in string if True 24 | # """ 25 | # try: 26 | # parse(string, fuzzy=fuzzy) 27 | # return True 28 | # 29 | # except ValueError: 30 | # return False 31 | # 32 | # def levenshteinSimilarity(s1, s2): 33 | # # Edit Similarity 34 | # s1, s2 = s1.lower(), s2.lower() 35 | # if len(s1) > len(s2): 36 | # s1, s2 = s2, s1 37 | # 38 | # distances = range(len(s1) + 1) 39 | # for i2, c2 in enumerate(s2): 40 | # distances_ = [i2 + 1] 41 | # for i1, c1 in enumerate(s1): 42 | # if c1 == c2: 43 | # distances_.append(distances[i1]) 44 | # else: 45 | # distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 46 | # distances = distances_ 47 | # return 1 - distances[-1] / max(len(s1) + 1, len(s2) + 1) # [0,1] 48 | 49 | def get_table_columns(self, db_id): 50 | table_columns = dict() 51 | ''' 52 | get a list of column names of the tables. 53 | ''' 54 | try: 55 | connection = sqlite3.connect(self.db_url + '/' + db_id + '/' + db_id + '.sqlite') 56 | cursor = connection.execute("SELECT name FROM sqlite_master WHERE type='table';") 57 | 58 | for each_table in cursor.fetchall(): 59 | try: 60 | cursor = connection.execute('select * from ' + each_table[0]) 61 | columns_list = list(map(lambda x: x[0].lower(), cursor.description)) # a list of column names 62 | table_columns[each_table[0].lower()] = columns_list 63 | except: 64 | print('table error: ', each_table[0]) 65 | table_columns[each_table[0].lower()] = [] 66 | except: 67 | print('db error: ', db_id) 68 | 69 | return table_columns 70 | 71 | def get_values_in_columns(self, db_id, table_id, columns_list, conditions=None): 72 | ''' 73 | get values in the column 74 | 75 | arg: 76 | conditions: { 77 | 'numeric_col': 'remove', 78 | 'string_col': { 79 | 'remove': '50' 80 | } 81 | } 82 | ''' 83 | values_in_columns = dict() 84 | 85 | connection = sqlite3.connect(self.db_url + '/' + db_id + '/' + db_id + '.sqlite') 86 | cursor = connection.cursor() # get a cursor 87 | for col in columns_list: 88 | try: 89 | if conditions == None: 90 | values_in_columns[col] = list( 91 | set([values[0] for values in cursor.execute("select " + col + " from " + table_id)])) 92 | else: 93 | my_list = list( 94 | set([values[0] for values in cursor.execute("select " + col + " from " + table_id)])) 95 | if all(isinstance(item, int) or isinstance(item, float) or str(item) == '' or str( 96 | item) == 'None' for item in my_list) == False: 97 | # dont consider numeric col 98 | if all(len(str(item)) <= 50 for item in my_list) == True: 99 | # remove string column with value length > 50 100 | values_in_columns[col] = my_list 101 | else: 102 | values_in_columns[col] = [] 103 | else: 104 | values_in_columns[col] = [] 105 | 106 | except: 107 | print('error.') 108 | 109 | return values_in_columns 110 | ''' 111 | {'Team_ID': [1, 2, 3, 4], 112 | 'School_ID': [1, 2, 4, 5] 113 | } 114 | ''' 115 | 116 | def get_mentioned_values_in_NL_question(self, db_id, table_id, NL_question, db_table_col_val_map): 117 | ''' 118 | high recall: to find a set of possible columns/vables mentioned in NL_question 119 | ''' 120 | columns_list = list(db_table_col_val_map[db_id][table_id].keys()) 121 | values = db_table_col_val_map[db_id][table_id] 122 | 123 | # we now only consider 1-gram, 2-gram, and 3-gram of the NL_question. 124 | NL_tokens = NL_question.split(' ') # 1-gram 125 | two_grams, three_grams = [], [] 126 | # 2-gram 127 | for i in range(len(NL_tokens) - 1): 128 | two_grams.append(NL_tokens[i] + ' ' + NL_tokens[i + 1]) 129 | # 3-gram 130 | for i in range(len(NL_tokens) - 2): 131 | three_grams.append(NL_tokens[i] + ' ' + NL_tokens[i + 1] + ' ' + NL_tokens[i + 2]) 132 | NL_tokens += two_grams 133 | NL_tokens += three_grams 134 | 135 | A = pd.DataFrame(data=NL_tokens, columns=['name']) 136 | A['id'] = list(range(len(A))) 137 | C = pd.DataFrame(data=columns_list, columns=['name']) 138 | C['id'] = list(range(len(C))) 139 | cand_col = ssj.edit_distance_join( 140 | A, C, 'id', 'id', 'name', 'name', 2, l_out_attrs=['name'], r_out_attrs=['name'], show_progress=False 141 | ) 142 | cand_col = cand_col.sort_values(by=['_sim_score']) 143 | cand_col = list(cand_col['r_name']) 144 | candidate_mentioned_col = [] 145 | for i in range(len(cand_col)): 146 | if cand_col[i] not in candidate_mentioned_col: 147 | candidate_mentioned_col.append(cand_col[i]) 148 | if len(candidate_mentioned_col) > 10: 149 | break 150 | 151 | B_value = [] 152 | for k, v in values.items(): 153 | for each_v in v: 154 | B_value.append([k, each_v]) 155 | B = pd.DataFrame(data=B_value, columns=['col', 'name']) 156 | B['id'] = list(range(len(B))) 157 | output_pairs = ssj.edit_distance_join( 158 | A, B, 'id', 'id', 'name', 'name', 2, l_out_attrs=['name'], r_out_attrs=['name', 'col'], show_progress=False 159 | ) 160 | output_pairs = output_pairs.sort_values(by=['_sim_score']) 161 | cand_val = list(zip(output_pairs['r_name'], output_pairs['r_col'])) 162 | candidate_mentioned_val = [] 163 | for i in range(len(cand_val)): 164 | if cand_val[i][0] not in candidate_mentioned_val: 165 | candidate_mentioned_val.append(cand_val[i][0]) 166 | if cand_val[i][1] not in candidate_mentioned_col: 167 | candidate_mentioned_col.append(cand_val[i][1]) 168 | if len(candidate_mentioned_val) > 10: 169 | break 170 | 171 | return candidate_mentioned_col, candidate_mentioned_val 172 | 173 | def fill_in_query_template_by_chart_template(self, query): 174 | ''' 175 | mark = {bar, pie, line, scatter} 176 | order = {by: x|y, type: desc|asc} 177 | ''' 178 | 179 | query_template = 'mark [T] data [D] encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K]' 180 | query_chart_template = query_template 181 | 182 | query_list = query.lower().split(' ') 183 | 184 | chart_type = query_list[query_list.index('mark') + 1] 185 | table_name = query_list[query_list.index('data') + 1] 186 | 187 | if 'sort' in query_list: 188 | # ORDER by X or BY Y? 189 | xy_axis = query_list[query_list.index('sort') + 1] 190 | order_xy = '[O]' 191 | if xy_axis == 'y': 192 | order_xy = '[Y]' 193 | elif xy_axis == 'x': 194 | order_xy = '[X]' 195 | else: 196 | order_xy = '[O]' # other 197 | 198 | if query_list.index('sort') + 2 < len(query_list): 199 | order_type = query_list[query_list.index('sort') + 2] # asc / desc 200 | query_chart_template = query_chart_template.replace('[S]', order_xy + ' ' + order_type) 201 | 202 | else: 203 | query_chart_template = query_chart_template.replace('[S]', order_xy) 204 | 205 | query_chart_template = query_chart_template.replace('[D]', table_name) 206 | 207 | query_chart_template = query_chart_template.replace('[T]', chart_type) 208 | query_template = query_template.replace('[D]', table_name) 209 | 210 | return query_template, query_chart_template 211 | 212 | def get_token_types(self, input_source): 213 | ''' 214 | get token type id (Segment ID) 215 | ''' 216 | ''' 217 | Draw a bar chart about the distribution of ACC_Road and the average of Team_ID , and group by attribute ACC_Road, and order in asc by the X-axis. 218 | 219 | Team_ID School_ID Team_Name ACC_Regular_Season ACC_Percent ACC_Home ACC_Road All_Games All_Games_Percent All_Home All_Road All_Neutral 220 | 0 for nl 221 | 1 for template 222 | 2 for col 223 | 3 for val 224 | ''' 225 | token_types = '' 226 | 227 | for ele in re.findall('.*', input_source)[0].split(' '): 228 | token_types += ' nl' 229 | 230 | for ele in re.findall('.*', input_source)[0].split(' '): 231 | token_types += ' template' 232 | 233 | token_types += ' table table' 234 | 235 | for ele in re.findall('.*', input_source)[0].split(' '): 236 | token_types += ' col' 237 | 238 | for ele in re.findall('.*', input_source)[0].split(' '): 239 | token_types += ' value' 240 | 241 | token_types += ' table' 242 | 243 | token_types = token_types.strip() 244 | return token_types 245 | 246 | def process4training(self): 247 | # process for template 248 | for each in ['train.csv', 'dev.csv', 'test.csv']: 249 | df = pd.read_csv('./dataset/' + each) 250 | data = list() 251 | 252 | for index, row in df.iterrows(): 253 | 254 | if str(row['question']) != 'nan': 255 | 256 | new_row1 = list(row) 257 | new_row2 = list(row) 258 | 259 | query_list = row['vega_zero'].lower().split(' ') 260 | table_name = query_list[query_list.index('data') + 1] 261 | 262 | query_template, query_chart_template = self.fill_in_query_template_by_chart_template(row['vega_zero']) 263 | 264 | # get a list of mentioned values in the NL question 265 | 266 | col_names, value_names = self.get_mentioned_values_in_NL_question( 267 | row['db_id'], table_name, row['question'], db_table_col_val_map=finding_map 268 | ) 269 | col_names = ' '.join(str(e) for e in col_names) 270 | value_names = ' '.join(str(e) for e in value_names) 271 | new_row1.append(col_names) 272 | new_row1.append(value_names) 273 | new_row2.append(col_names) 274 | new_row2.append(value_names) 275 | 276 | new_row1.append(query_template) 277 | new_row2.append(query_chart_template) 278 | 279 | input_source1 = ' ' + row[ 280 | 'question'] + ' ' + ' ' + query_template + ' ' + ' ' + table_name + ' ' + col_names + ' ' + ' ' + value_names + ' ' 281 | input_source1 = ' '.join(input_source1.split()) # Replace multiple spaces with single space 282 | 283 | input_source2 = ' ' + row[ 284 | 'question'] + ' ' + ' ' + query_chart_template + ' ' + ' ' + table_name + ' ' + col_names + ' ' + ' ' + value_names + ' ' 285 | input_source2 = ' '.join(input_source2.split()) # Replace multiple spaces with single space 286 | 287 | new_row1.append(input_source1) 288 | new_row1.append(row['vega_zero']) 289 | 290 | new_row2.append(input_source2) 291 | new_row2.append(row['vega_zero']) 292 | 293 | token_types1 = self.get_token_types(input_source1) 294 | token_types2 = self.get_token_types(input_source2) 295 | new_row1.append(token_types1) 296 | new_row2.append(token_types2) 297 | 298 | data.append(new_row1) 299 | data.append(new_row2) 300 | else: 301 | print('nan at ', index) 302 | 303 | if index % 500 == 0: 304 | print(round(index / len(df) * 100, 2)) 305 | 306 | df_template = pd.DataFrame(data=data, columns=list(df.columns) + ['mentioned_columns', 'mentioned_values', 307 | 'query_template', 'source', 'labels', 308 | 'token_types']) 309 | df_template.to_csv('../dataset/dataset_final/' + each, index=False) 310 | 311 | ### Part 2 312 | 313 | def extract_db_information(self): 314 | 315 | def get_values_in_columns(db_id, table_id, columns_list): 316 | ''' 317 | get values in the column 318 | ''' 319 | values_in_columns = dict() 320 | 321 | connection = sqlite3.connect(self.db_url + '/' + db_id + '/' + db_id + '.sqlite') 322 | cursor = connection.cursor() # get a cursor 323 | for col in columns_list: 324 | try: 325 | values_in_columns[col] = list( 326 | set([values[0] for values in cursor.execute("select " + col + " from " + table_id)])) 327 | except: 328 | print('error on {0}'.format(db_id)) 329 | 330 | return values_in_columns 331 | ''' 332 | {'Team_ID': [1, 2, 3, 4], 333 | 'School_ID': [1, 2, 4, 5], 334 | 'Team_Name': ['Duke', 'Virginia Tech', 'Clemson', 'North Carolina'], 335 | } 336 | ''' 337 | 338 | with open('../dataset/db_tables_columns.json') as f: 339 | data = json.load(f) 340 | 341 | result = [] 342 | 343 | for db, tables_data in data.items(): 344 | for table, cols in tables_data.items(): 345 | try: 346 | col_val_dict = get_values_in_columns(db, table, cols) 347 | except Exception as ex: 348 | template = "An exception of database -- {0} error occurred." 349 | message = template.format(db) 350 | print(message) 351 | 352 | for c, v in col_val_dict.items(): 353 | if len(v) <= 20: 354 | for each_v in v: 355 | result.append([table, c, each_v]) 356 | else: 357 | result.append([table, c, '']) 358 | 359 | df = pd.DataFrame(data=result, columns=['table', 'column', 'value']) 360 | 361 | df.to_csv('../dataset/database_information.csv', index=False) 362 | 363 | if __name__ == "__main__": 364 | start_time = time.time() 365 | 366 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 367 | 368 | print('It needs about 6 minutes for processing the benchmark datasets') 369 | time.sleep(2) 370 | 371 | DataProcesser = ProcessData4Training(db_url='../dataset/database') 372 | 373 | # build db-table-column-distinctValue dictionary 374 | print('build db-table-column-distinctValue dictionary start ... ...') 375 | finding_map = dict() 376 | 377 | db_list = os.listdir('../dataset/database/') 378 | 379 | for db in db_list: 380 | table_cols = DataProcesser.get_table_columns(db) 381 | finding_map[db] = dict() 382 | for table, cols in table_cols.items(): 383 | col_val_map = DataProcesser.get_values_in_columns(db, table, cols, conditions='remove') 384 | finding_map[db][table] = col_val_map 385 | 386 | print('build db-table-column-distinctValue dictionary end ... ...') 387 | 388 | # process the benchmark dataset for training&testing 389 | print('process the benchmark dataset for training&testing start ... ...') 390 | DataProcesser.process4training() 391 | print('process the benchmark dataset for training&testing end ... ...') 392 | 393 | # build 'database_information.csv' 394 | print("build 'database_information.csv' start ... ...") 395 | DataProcesser.extract_db_information() 396 | print("build 'database_information.csv' end ... ...") 397 | 398 | print("\n {0} minutes for processing the dataset.".format(round((time.time()-start_time)/60,2))) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | py_stringsimjoin 2 | python-dateutil 3 | torchtext==0.8 4 | torch==1.7 5 | vega -------------------------------------------------------------------------------- /save_models/trained_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/ncNet/415d9477d424296bc5414e0f6624af23643372d7/save_models/trained_model.pt -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | ''' 4 | This script handles the testing process. 5 | We evaluate the ncNet on the benchmark dataset. 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from model.VisAwareTranslation import translate_sentence_with_guidance, postprocessing, get_all_table_columns 12 | from model.Model import Seq2Seq 13 | from model.Encoder import Encoder 14 | from model.Decoder import Decoder 15 | from preprocessing.build_vocab import build_vocab 16 | 17 | import random 18 | import numpy as np 19 | import pandas as pd 20 | from tqdm import tqdm 21 | import math 22 | import matplotlib.pyplot as plt 23 | 24 | import argparse 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser(description='test.py') 28 | 29 | parser.add_argument('-model', required=False, default='./save_models/trained_model.pt', 30 | help='Path to model weight file') 31 | parser.add_argument('-data_dir', required=False, default='./dataset/dataset_final/', 32 | help='Path to dataset for building vocab') 33 | parser.add_argument('-db_info', required=False, default='./dataset/database_information.csv', 34 | help='Path to database tables/columns information, for building vocab') 35 | parser.add_argument('-test_data', required=False, default='./dataset/dataset_final/test.csv', 36 | help='Path to testing dataset, formatting as csv') 37 | parser.add_argument('-db_schema', required=False, default='./dataset/db_tables_columns.json', 38 | help='Path to database schema file, formatting as json') 39 | parser.add_argument('-db_tables_columns_types', required=False, default='./dataset/db_tables_columns_types.json', 40 | help='Path to database schema file, formatting as json') 41 | 42 | parser.add_argument('-batch_size', type=int, default=128) 43 | parser.add_argument('-max_input_length', type=int, default=128) 44 | parser.add_argument('-show_progress', required=False, default=False, help='True to show details during decoding') 45 | opt = parser.parse_args() 46 | print("the input parameters: ", opt) 47 | 48 | SEED = 1234 49 | 50 | random.seed(SEED) 51 | np.random.seed(SEED) 52 | torch.manual_seed(SEED) 53 | torch.cuda.manual_seed(SEED) 54 | torch.backends.cudnn.deterministic = True 55 | 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | 58 | print("------------------------------\n| Build vocab start ... | \n------------------------------") 59 | SRC, TRG, TOK_TYPES, BATCH_SIZE, train_iterator, valid_iterator, test_iterator, my_max_length = build_vocab( 60 | data_dir=opt.data_dir, 61 | db_info=opt.db_info, 62 | batch_size=opt.batch_size, 63 | max_input_length=opt.max_input_length 64 | ) 65 | print("------------------------------\n| Build vocab end ... | \n------------------------------") 66 | 67 | INPUT_DIM = len(SRC.vocab) 68 | OUTPUT_DIM = len(TRG.vocab) 69 | HID_DIM = 256 # it equals to embedding dimension 70 | ENC_LAYERS = 3 71 | DEC_LAYERS = 3 72 | ENC_HEADS = 8 73 | DEC_HEADS = 8 74 | ENC_PF_DIM = 512 75 | DEC_PF_DIM = 512 76 | ENC_DROPOUT = 0.1 77 | DEC_DROPOUT = 0.1 78 | 79 | print("------------------------------\n| Build encoder of the ncNet ... | \n------------------------------") 80 | enc = Encoder(INPUT_DIM, 81 | HID_DIM, 82 | ENC_LAYERS, 83 | ENC_HEADS, 84 | ENC_PF_DIM, 85 | ENC_DROPOUT, 86 | device, 87 | TOK_TYPES, 88 | my_max_length 89 | ) 90 | 91 | print("------------------------------\n| Build decoder of the ncNet ... | \n------------------------------") 92 | dec = Decoder(OUTPUT_DIM, 93 | HID_DIM, 94 | DEC_LAYERS, 95 | DEC_HEADS, 96 | DEC_PF_DIM, 97 | DEC_DROPOUT, 98 | device, 99 | my_max_length 100 | ) 101 | 102 | SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token] 103 | TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token] 104 | 105 | print("------------------------------\n| Build the ncNet structure... | \n------------------------------") 106 | ncNet = Seq2Seq(enc, dec, SRC, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device) # define the transformer-based ncNet 107 | 108 | print("------------------------------\n| Load the trained ncNet ... | \n------------------------------") 109 | ncNet.load_state_dict(torch.load(opt.model, map_location=device)) 110 | 111 | 112 | print("------------------------------\n| Testing ... | \n------------------------------") 113 | 114 | 115 | db_tables_columns = get_all_table_columns(opt.db_schema) 116 | db_tables_columns_types = get_all_table_columns(opt.db_tables_columns_types) 117 | 118 | only_nl_cnt = 0 119 | only_nl_match = 0 120 | nl_template_cnt = 0 121 | nl_template_match = 0 122 | 123 | test_df = pd.read_csv(opt.test_data) 124 | 125 | for index, row in tqdm(test_df.iterrows()): 126 | 127 | try: 128 | gold_query = row['labels'].lower() 129 | 130 | src = row['source'].lower() 131 | 132 | tok_types = row['token_types'] 133 | table_name = gold_query.split(' ')[gold_query.split(' ').index('data') + 1] 134 | translation, attention, enc_attention = translate_sentence_with_guidance( 135 | row['db_id'], table_name, src, SRC, TRG, TOK_TYPES, tok_types, SRC, 136 | ncNet, db_tables_columns, db_tables_columns_types, device, my_max_length, show_progress=opt.show_progress 137 | ) 138 | 139 | pred_query = ' '.join(translation).replace(' ', '').lower() 140 | old_pred_query = pred_query 141 | 142 | if '[t]' not in src: 143 | # with template 144 | pred_query = postprocessing(gold_query, pred_query, True, src) 145 | 146 | nl_template_cnt += 1 147 | 148 | if ' '.join(gold_query.replace('"', "'").split()) == ' '.join(pred_query.replace('"', "'").split()): 149 | nl_template_match += 1 150 | else: 151 | pass 152 | 153 | 154 | if '[t]' in src: 155 | # without template 156 | pred_query = postprocessing(gold_query, pred_query, False, src) 157 | 158 | only_nl_cnt += 1 159 | if ' '.join(gold_query.replace('"', "'").split()) == ' '.join(pred_query.replace('"', "'").split()): 160 | only_nl_match += 1 161 | 162 | else: 163 | pass 164 | 165 | except: 166 | print('error') 167 | 168 | # if index > 100: 169 | # break 170 | 171 | print("========================================================") 172 | print('ncNet w/o chart template:', only_nl_match / only_nl_cnt) 173 | print('ncNet with chart template:', nl_template_match / nl_template_cnt) 174 | print('ncNet overall:', (only_nl_match + nl_template_match) / (only_nl_cnt + nl_template_cnt)) 175 | -------------------------------------------------------------------------------- /test_ncNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Import the necessary modules" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 16, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2021-10-27T03:06:08.325335Z", 16 | "start_time": "2021-10-27T03:05:51.509562Z" 17 | } 18 | }, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "Found existing installation: torch 1.7.0\n", 25 | "Uninstalling torch-1.7.0:\n", 26 | " Successfully uninstalled torch-1.7.0\n", 27 | "Collecting torch==1.7\n", 28 | " Using cached torch-1.7.0-cp38-none-macosx_10_9_x86_64.whl (108.1 MB)\n", 29 | "Requirement already satisfied: future in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch==1.7) (0.18.2)\n", 30 | "Requirement already satisfied: numpy in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch==1.7) (1.18.5)\n", 31 | "Requirement already satisfied: typing-extensions in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch==1.7) (3.7.4.2)\n", 32 | "Requirement already satisfied: dataclasses in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch==1.7) (0.6)\n", 33 | "Installing collected packages: torch\n", 34 | "Successfully installed torch-1.7.0\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "! pip3 uninstall torch -y\n", 40 | "! pip3 install torch==1.7" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 17, 46 | "metadata": { 47 | "ExecuteTime": { 48 | "end_time": "2021-10-27T03:06:11.856577Z", 49 | "start_time": "2021-10-27T03:06:08.327795Z" 50 | } 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "Found existing installation: torchtext 0.8.0\n", 58 | "Uninstalling torchtext-0.8.0:\n", 59 | " Successfully uninstalled torchtext-0.8.0\n", 60 | "Collecting torchtext==0.8\n", 61 | " Using cached torchtext-0.8.0-cp38-cp38-macosx_10_9_x86_64.whl (1.5 MB)\n", 62 | "Requirement already satisfied, skipping upgrade: tqdm in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torchtext==0.8) (4.47.0)\n", 63 | "Requirement already satisfied, skipping upgrade: numpy in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torchtext==0.8) (1.18.5)\n", 64 | "Requirement already satisfied, skipping upgrade: torch in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torchtext==0.8) (1.7.0)\n", 65 | "Requirement already satisfied, skipping upgrade: requests in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torchtext==0.8) (2.24.0)\n", 66 | "Requirement already satisfied, skipping upgrade: dataclasses in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch->torchtext==0.8) (0.6)\n", 67 | "Requirement already satisfied, skipping upgrade: future in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch->torchtext==0.8) (0.18.2)\n", 68 | "Requirement already satisfied, skipping upgrade: typing-extensions in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from torch->torchtext==0.8) (3.7.4.2)\n", 69 | "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from requests->torchtext==0.8) (1.25.9)\n", 70 | "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from requests->torchtext==0.8) (2020.6.20)\n", 71 | "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from requests->torchtext==0.8) (2.10)\n", 72 | "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /Users/yuyu/anaconda3/lib/python3.8/site-packages (from requests->torchtext==0.8) (3.0.4)\n", 73 | "Installing collected packages: torchtext\n", 74 | "Successfully installed torchtext-0.8.0\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "!pip uninstall torchtext -y\n", 80 | "!pip install -U torchtext==0.8" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 42, 86 | "metadata": { 87 | "ExecuteTime": { 88 | "end_time": "2021-10-27T03:39:27.925145Z", 89 | "start_time": "2021-10-27T03:39:27.921224Z" 90 | } 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "import torch\n", 95 | "import torch.nn as nn\n", 96 | "\n", 97 | "from model.Model import Seq2Seq\n", 98 | "from model.Encoder import Encoder\n", 99 | "from model.Decoder import Decoder\n", 100 | "from model.VisAwareTranslation import translate_sentence_with_guidance, postprocessing, get_all_table_columns\n", 101 | "from preprocessing.build_vocab import build_vocab\n", 102 | "from utilities.vis_rendering import VegaZero2VegaLite\n", 103 | "\n", 104 | "from vega import VegaLite\n", 105 | "\n", 106 | "import random\n", 107 | "import numpy as np\n", 108 | "import pandas as pd\n", 109 | "import math\n", 110 | "import sqlite3\n", 111 | "from pprint import pprint " 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 43, 117 | "metadata": { 118 | "ExecuteTime": { 119 | "end_time": "2021-10-27T03:39:28.557807Z", 120 | "start_time": "2021-10-27T03:39:28.551151Z" 121 | } 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "SEED = 1234\n", 126 | "\n", 127 | "random.seed(SEED)\n", 128 | "np.random.seed(SEED)\n", 129 | "torch.manual_seed(SEED)\n", 130 | "torch.cuda.manual_seed(SEED)\n", 131 | "torch.backends.cudnn.deterministic = True" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 44, 137 | "metadata": { 138 | "ExecuteTime": { 139 | "end_time": "2021-10-27T03:39:28.744923Z", 140 | "start_time": "2021-10-27T03:39:28.742876Z" 141 | } 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 146 | "\n", 147 | "device = torch.device('cpu') # cpu or gpu? depend on your computational environment." 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 4, 153 | "metadata": { 154 | "ExecuteTime": { 155 | "end_time": "2021-10-27T03:12:11.982399Z", 156 | "start_time": "2021-10-27T03:12:11.973923Z" 157 | } 158 | }, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "device(type='cpu')" 164 | ] 165 | }, 166 | "execution_count": 4, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "device" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "# Build vocab" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 5, 185 | "metadata": { 186 | "ExecuteTime": { 187 | "end_time": "2021-10-27T03:12:15.602447Z", 188 | "start_time": "2021-10-27T03:12:11.984549Z" 189 | } 190 | }, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "------------------------------\n", 197 | "| Build vocab start ... | \n", 198 | "------------------------------\n" 199 | ] 200 | }, 201 | { 202 | "name": "stderr", 203 | "output_type": "stream", 204 | "text": [ 205 | "/Users/yuyu/anaconda3/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 206 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n", 207 | "/Users/yuyu/anaconda3/lib/python3.8/site-packages/torchtext/data/example.py:68: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 208 | " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n", 209 | "/Users/yuyu/anaconda3/lib/python3.8/site-packages/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 210 | " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n" 211 | ] 212 | }, 213 | { 214 | "name": "stdout", 215 | "output_type": "stream", 216 | "text": [ 217 | "------------------------------\n", 218 | "| Build vocab end ... | \n", 219 | "------------------------------\n" 220 | ] 221 | }, 222 | { 223 | "name": "stderr", 224 | "output_type": "stream", 225 | "text": [ 226 | "/Users/yuyu/anaconda3/lib/python3.8/site-packages/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 227 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "print(\"------------------------------\\n| Build vocab start ... | \\n------------------------------\")\n", 233 | "SRC, TRG, TOK_TYPES, BATCH_SIZE, train_iterator, valid_iterator, test_iterator, my_max_length = build_vocab(\n", 234 | " data_dir='./dataset/dataset_final/',\n", 235 | " db_info='./dataset/database_information.csv',\n", 236 | " batch_size=128,\n", 237 | " max_input_length=128\n", 238 | ")\n", 239 | "print(\"------------------------------\\n| Build vocab end ... | \\n------------------------------\")" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "# Construct ncNet model" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 6, 252 | "metadata": { 253 | "ExecuteTime": { 254 | "end_time": "2021-10-27T03:12:15.679017Z", 255 | "start_time": "2021-10-27T03:12:15.604445Z" 256 | } 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "INPUT_DIM = len(SRC.vocab)\n", 261 | "OUTPUT_DIM = len(TRG.vocab)\n", 262 | "HID_DIM = 256 # it equals to embedding dimension \n", 263 | "ENC_LAYERS = 3 \n", 264 | "DEC_LAYERS = 3 \n", 265 | "ENC_HEADS = 8\n", 266 | "DEC_HEADS = 8\n", 267 | "ENC_PF_DIM = 512\n", 268 | "DEC_PF_DIM = 512\n", 269 | "ENC_DROPOUT = 0.1\n", 270 | "DEC_DROPOUT = 0.1\n", 271 | "\n", 272 | "enc = Encoder(INPUT_DIM,\n", 273 | " HID_DIM,\n", 274 | " ENC_LAYERS,\n", 275 | " ENC_HEADS,\n", 276 | " ENC_PF_DIM,\n", 277 | " ENC_DROPOUT,\n", 278 | " device,\n", 279 | " TOK_TYPES,\n", 280 | " my_max_length\n", 281 | " )\n", 282 | "\n", 283 | "dec = Decoder(OUTPUT_DIM,\n", 284 | " HID_DIM,\n", 285 | " DEC_LAYERS,\n", 286 | " DEC_HEADS,\n", 287 | " DEC_PF_DIM,\n", 288 | " DEC_DROPOUT,\n", 289 | " device,\n", 290 | " my_max_length\n", 291 | " )\n", 292 | "\n", 293 | "SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]\n", 294 | "TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]\n", 295 | "\n", 296 | "model = Seq2Seq(enc, dec, SRC, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device) # define the transformer-based ncNet" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "# Load the trained ncNet model" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 7, 309 | "metadata": { 310 | "ExecuteTime": { 311 | "end_time": "2021-10-27T03:12:15.741183Z", 312 | "start_time": "2021-10-27T03:12:15.681053Z" 313 | } 314 | }, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "" 320 | ] 321 | }, 322 | "execution_count": 7, 323 | "metadata": {}, 324 | "output_type": "execute_result" 325 | } 326 | ], 327 | "source": [ 328 | "model.load_state_dict(torch.load('./save_models/trained_model.pt', map_location=device))\n" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "# Testing" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "## read the testing dataset" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 8, 348 | "metadata": { 349 | "ExecuteTime": { 350 | "end_time": "2021-10-27T03:12:15.801061Z", 351 | "start_time": "2021-10-27T03:12:15.742537Z" 352 | } 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "db_tables_columns = get_all_table_columns('./dataset/db_tables_columns.json')\n", 357 | "db_tables_columns_types = get_all_table_columns('./dataset/db_tables_columns_types.json')\n", 358 | "\n", 359 | "test_df = pd.read_csv('./dataset/dataset_final/test.csv')\n", 360 | "\n", 361 | "# shuffle your dataframe in-place and reset the index\n", 362 | "test_df = test_df.sample(frac=1).reset_index(drop=True)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 9, 368 | "metadata": { 369 | "ExecuteTime": { 370 | "end_time": "2021-10-27T03:12:15.815438Z", 371 | "start_time": "2021-10-27T03:12:15.802714Z" 372 | } 373 | }, 374 | "outputs": [ 375 | { 376 | "data": { 377 | "text/html": [ 378 | "
\n", 379 | "\n", 392 | "\n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | "
tvBench_iddb_idcharthardnessqueryquestionvega_zeromentioned_columnsmentioned_valuesquery_templatesourcelabelstoken_types
02914@y_name@DESCswimmingBarEasyVisualize BAR SELECT name , ID FROM swimmer OR...Draw a bar chart about the distribution of nam...mark bar data swimmer encoding x name y aggreg...id name timeNaNmark bar data swimmer encoding x [X] y aggrega...<N> Draw a bar chart about the distribution of...mark bar data swimmer encoding x name y aggreg...nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
1586college_1BarEasyVisualize BAR SELECT DEPT_CODE , sum(crs_credi...Bar chart of sum crs credit from each dept codemark bar data course encoding x dept_code y ag...dept_code crs_creditNaNmark [T] data course encoding x [X] y aggregat...<N> Bar chart of sum crs credit from each dept...mark bar data course encoding x dept_code y ag...nl nl nl nl nl nl nl nl nl nl nl nl template t...
22798@x_name@ASCsoccer_2BarMediumVisualize BAR SELECT cName , min(enr) FROM col...Return a bar graph for the name of the school ...mark bar data college encoding x cname y aggre...state cname enrNaNmark [T] data college encoding x [X] y aggrega...<N> Return a bar graph for the name of the sch...mark bar data college encoding x cname y aggre...nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
33051train_stationPieEasyVisualize PIE SELECT Location , sum(number_of_...Show the proportion of the total number of pla...mark arc data station encoding x location y ag...location number_of_platformsNaNmark [T] data station encoding x [X] y aggrega...<N> Show the proportion of the total number of...mark arc data station encoding x location y ag...nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
473apartment_rentalsPieEasyVisualize PIE SELECT booking_status_code , COU...How many bookings does each booking status hav...mark arc data apartment_bookings encoding x bo...booking_status_codeNaNmark [T] data apartment_bookings encoding x [X...<N> How many bookings does each booking status...mark arc data apartment_bookings encoding x bo...nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
\n", 494 | "
" 495 | ], 496 | "text/plain": [ 497 | " tvBench_id db_id chart hardness \\\n", 498 | "0 2914@y_name@DESC swimming Bar Easy \n", 499 | "1 586 college_1 Bar Easy \n", 500 | "2 2798@x_name@ASC soccer_2 Bar Medium \n", 501 | "3 3051 train_station Pie Easy \n", 502 | "4 73 apartment_rentals Pie Easy \n", 503 | "\n", 504 | " query \\\n", 505 | "0 Visualize BAR SELECT name , ID FROM swimmer OR... \n", 506 | "1 Visualize BAR SELECT DEPT_CODE , sum(crs_credi... \n", 507 | "2 Visualize BAR SELECT cName , min(enr) FROM col... \n", 508 | "3 Visualize PIE SELECT Location , sum(number_of_... \n", 509 | "4 Visualize PIE SELECT booking_status_code , COU... \n", 510 | "\n", 511 | " question \\\n", 512 | "0 Draw a bar chart about the distribution of nam... \n", 513 | "1 Bar chart of sum crs credit from each dept code \n", 514 | "2 Return a bar graph for the name of the school ... \n", 515 | "3 Show the proportion of the total number of pla... \n", 516 | "4 How many bookings does each booking status hav... \n", 517 | "\n", 518 | " vega_zero \\\n", 519 | "0 mark bar data swimmer encoding x name y aggreg... \n", 520 | "1 mark bar data course encoding x dept_code y ag... \n", 521 | "2 mark bar data college encoding x cname y aggre... \n", 522 | "3 mark arc data station encoding x location y ag... \n", 523 | "4 mark arc data apartment_bookings encoding x bo... \n", 524 | "\n", 525 | " mentioned_columns mentioned_values \\\n", 526 | "0 id name time NaN \n", 527 | "1 dept_code crs_credit NaN \n", 528 | "2 state cname enr NaN \n", 529 | "3 location number_of_platforms NaN \n", 530 | "4 booking_status_code NaN \n", 531 | "\n", 532 | " query_template \\\n", 533 | "0 mark bar data swimmer encoding x [X] y aggrega... \n", 534 | "1 mark [T] data course encoding x [X] y aggregat... \n", 535 | "2 mark [T] data college encoding x [X] y aggrega... \n", 536 | "3 mark [T] data station encoding x [X] y aggrega... \n", 537 | "4 mark [T] data apartment_bookings encoding x [X... \n", 538 | "\n", 539 | " source \\\n", 540 | "0 Draw a bar chart about the distribution of... \n", 541 | "1 Bar chart of sum crs credit from each dept... \n", 542 | "2 Return a bar graph for the name of the sch... \n", 543 | "3 Show the proportion of the total number of... \n", 544 | "4 How many bookings does each booking status... \n", 545 | "\n", 546 | " labels \\\n", 547 | "0 mark bar data swimmer encoding x name y aggreg... \n", 548 | "1 mark bar data course encoding x dept_code y ag... \n", 549 | "2 mark bar data college encoding x cname y aggre... \n", 550 | "3 mark arc data station encoding x location y ag... \n", 551 | "4 mark arc data apartment_bookings encoding x bo... \n", 552 | "\n", 553 | " token_types \n", 554 | "0 nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n... \n", 555 | "1 nl nl nl nl nl nl nl nl nl nl nl nl template t... \n", 556 | "2 nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n... \n", 557 | "3 nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n... \n", 558 | "4 nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n... " 559 | ] 560 | }, 561 | "execution_count": 9, 562 | "metadata": {}, 563 | "output_type": "execute_result" 564 | } 565 | ], 566 | "source": [ 567 | "test_df.head()" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "## testing and rendering the result using vega-lite" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 10, 580 | "metadata": { 581 | "ExecuteTime": { 582 | "end_time": "2021-10-27T03:12:15.822085Z", 583 | "start_time": "2021-10-27T03:12:15.818726Z" 584 | } 585 | }, 586 | "outputs": [ 587 | { 588 | "name": "stdout", 589 | "output_type": "stream", 590 | "text": [ 591 | "cpu\n" 592 | ] 593 | } 594 | ], 595 | "source": [ 596 | "print(device)" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 50, 602 | "metadata": { 603 | "ExecuteTime": { 604 | "end_time": "2021-10-27T03:40:38.183696Z", 605 | "start_time": "2021-10-27T03:40:37.841123Z" 606 | } 607 | }, 608 | "outputs": [ 609 | { 610 | "name": "stdout", 611 | "output_type": "stream", 612 | "text": [ 613 | "=========================================================\n", 614 | "\n", 615 | "[Database]:\n", 616 | " swimming\n", 617 | "[NL Question]:\n", 618 | " Draw a bar chart about the distribution of name and ID , I want to display from high to low by the Y-axis .\n", 619 | "[Chart Template]:\n", 620 | " mark bar data swimmer encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [Y] desc topk [K]\n", 621 | "[Predicted VIS Query]:\n", 622 | " mark bar data swimmer encoding x name y aggregate none id transform sort y desc\n", 623 | "[The Ground Truth VIS Query]:\n", 624 | " mark bar data swimmer encoding x name y aggregate none id transform sort y desc\n", 625 | "The Predicted VIS:\n", 626 | "\n", 627 | "{\"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"name\", \"type\": \"nominal\", \"sort\": \"-y\"}, \"y\": {\"field\": \"id\", \"type\": \"quantitative\"}}, \"data\": {\"values\": [{\"id\": 7, \"name\": \"Przemys\\u0142aw Sta\\u0144czyk\", \"nationality\": \"Poland\", \"meter_100\": 57.31, \"meter_200\": \"1:57.10\", \"meter_300\": \"2:56.02\", \"meter_400\": \"3:55.36\", \"meter_500\": \"4:54.21\", \"meter_600\": \"5:52.59\", \"meter_700\": \"6:50.91\", \"time\": \"7:47.91\"}, {\"id\": 4, \"name\": \"Craig Stevens\", \"nationality\": \"Australia\", \"meter_100\": 57.35, \"meter_200\": \"1:56.34\", \"meter_300\": \"2:55.90\", \"meter_400\": \"3:55.72\", \"meter_500\": \"4:55.08\", \"meter_600\": \"5:54.45\", \"meter_700\": \"6:52.69\", \"time\": \"7:48.67\"}, {\"id\": 5, \"name\": \"Federico Colbertaldo\", \"nationality\": \"Italy\", \"meter_100\": 57.66, \"meter_200\": \"1:56.77\", \"meter_300\": \"2:56.04\", \"meter_400\": \"3:55.37\", \"meter_500\": \"4:54.48\", \"meter_600\": \"5:53.53\", \"meter_700\": \"6:52.58\", \"time\": \"7:49.98\"}, {\"id\": 8, \"name\": \"S\\u00e9bastien Rouault\", \"nationality\": \"France\", \"meter_100\": 55.67, \"meter_200\": \"1:54.40\", \"meter_300\": \"2:53.46\", \"meter_400\": \"3:52.93\", \"meter_500\": \"4:52.85\", \"meter_600\": \"5:53.03\", \"meter_700\": \"6:53.34\", \"time\": \"7:52.04\"}, {\"id\": 1, \"name\": \"Sergiy Fesenko\", \"nationality\": \"Ukraine\", \"meter_100\": 57.34, \"meter_200\": \"1:57.26\", \"meter_300\": \"2:57.10\", \"meter_400\": \"3:57.12\", \"meter_500\": \"4:57.03\", \"meter_600\": \"5:56.31\", \"meter_700\": \"6:55.07\", \"time\": \"7:53.43\"}, {\"id\": 2, \"name\": \"Grant Hackett\", \"nationality\": \"Australia\", \"meter_100\": 57.34, \"meter_200\": \"1:57.21\", \"meter_300\": \"2:56.95\", \"meter_400\": \"3:57.00\", \"meter_500\": \"4:56.96\", \"meter_600\": \"5:57.10\", \"meter_700\": \"6:57.44\", \"time\": \"7:55.39\"}, {\"id\": 6, \"name\": \"Ryan Cochrane\", \"nationality\": \"Canada\", \"meter_100\": 57.84, \"meter_200\": \"1:57.26\", \"meter_300\": \"2:56.64\", \"meter_400\": \"3:56.34\", \"meter_500\": \"4:56.15\", \"meter_600\": \"5:56.99\", \"meter_700\": \"6:57.69\", \"time\": \"7:56.56\"}, {\"id\": 3, \"name\": \"Oussama Mellouli\", \"nationality\": \"Tunisia\", \"meter_100\": 57.31, \"meter_200\": \"1:56.44\", \"meter_300\": \"2:55.94\", \"meter_400\": \"3:55.49\", \"meter_500\": \"4:54.19\", \"meter_600\": \"5:52.92\", \"meter_700\": \"6:50.80\", \"time\": \"7:46.95\"}]}}\n", 628 | "\n", 629 | "\n", 630 | "=========================================================\n", 631 | "\n", 632 | "[Database]:\n", 633 | " college_1\n", 634 | "[NL Question]:\n", 635 | " Bar chart of sum crs credit from each dept code\n", 636 | "[Chart Template]:\n", 637 | " mark [T] data course encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K]\n", 638 | "[Predicted VIS Query]:\n", 639 | " mark bar data course encoding x dept_code y aggregate sum crs_credit transform group x\n", 640 | "[The Ground Truth VIS Query]:\n", 641 | " mark bar data course encoding x dept_code y aggregate sum crs_credit transform group x\n", 642 | "The Predicted VIS:\n", 643 | "\n", 644 | "{\"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"dept_code\", \"type\": \"nominal\"}, \"y\": {\"field\": \"crs_credit\", \"type\": \"quantitative\", \"aggregate\": \"sum\"}}, \"data\": {\"values\": [{\"crs_code\": \"ACCT-211\", \"dept_code\": \"ACCT\", \"crs_description\": \"Accounting I\", \"crs_credit\": 3.0}, {\"crs_code\": \"ACCT-212\", \"dept_code\": \"ACCT\", \"crs_description\": \"Accounting II\", \"crs_credit\": 3.0}, {\"crs_code\": \"CIS-220\", \"dept_code\": \"CIS\", \"crs_description\": \"Intro. to Microcomputing\", \"crs_credit\": 3.0}, {\"crs_code\": \"CIS-420\", \"dept_code\": \"CIS\", \"crs_description\": \"Database Design and Implementation\", \"crs_credit\": 4.0}, {\"crs_code\": \"QM-261\", \"dept_code\": \"CIS\", \"crs_description\": \"Intro. to Statistics\", \"crs_credit\": 3.0}, {\"crs_code\": \"QM-362\", \"dept_code\": \"CIS\", \"crs_description\": \"Statistical Applications\", \"crs_credit\": 4.0}]}}\n", 645 | "\n", 646 | "\n", 647 | "=========================================================\n", 648 | "\n", 649 | "[Database]:\n", 650 | " soccer_2\n", 651 | "[NL Question]:\n", 652 | " Return a bar graph for the name of the school that has the smallest enrollment in each state , could you order by the x axis in asc please ?\n", 653 | "[Chart Template]:\n", 654 | " mark [T] data college encoding x [X] y aggregate [AggFunction] [Y] color [Z] transform filter [F] group [G] bin [B] sort [S] topk [K]\n", 655 | "[Predicted VIS Query]:\n", 656 | " mark bar data college encoding x cname y aggregate min enr transform group state sort x asc\n", 657 | "[The Ground Truth VIS Query]:\n", 658 | " mark bar data college encoding x cname y aggregate min enr transform group state sort x asc\n", 659 | "The Predicted VIS:\n", 660 | "\n", 661 | "{\"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"cname\", \"type\": \"nominal\"}, \"y\": {\"field\": \"enr\", \"type\": \"quantitative\", \"aggregate\": \"min\", \"sort\": \"x\"}}, \"data\": {\"values\": [{\"cname\": \"LSU\", \"state\": \"LA\", \"enr\": 18000}, {\"cname\": \"ASU\", \"state\": \"AZ\", \"enr\": 12000}, {\"cname\": \"OU\", \"state\": \"OK\", \"enr\": 22000}, {\"cname\": \"FSU\", \"state\": \"FL\", \"enr\": 19000}]}}\n", 662 | "\n", 663 | "\n", 664 | "========================================================\n", 665 | "ncNet w/o chart template: 1.0\n", 666 | "ncNet with chart template: 1.0\n", 667 | "ncNet overall: 1.0\n" 668 | ] 669 | } 670 | ], 671 | "source": [ 672 | "only_nl_cnt = 0\n", 673 | "only_nl_match = 0\n", 674 | "\n", 675 | "nl_template_cnt = 0\n", 676 | "nl_template_match = 0\n", 677 | "\n", 678 | "query2vl = VegaZero2VegaLite()\n", 679 | "\n", 680 | "for index, row in test_df.iterrows():\n", 681 | " \n", 682 | " gold_query = row['labels'].lower()\n", 683 | " src = row['source'].lower()\n", 684 | " \n", 685 | "\n", 686 | " tok_types = row['token_types']\n", 687 | " table_name = gold_query.split(' ')[gold_query.split(' ').index('data')+1]\n", 688 | " \n", 689 | " translation, attention, enc_attention = translate_sentence_with_guidance(\n", 690 | " row['db_id'], table_name, src, SRC, TRG, TOK_TYPES, tok_types, SRC, \n", 691 | " model, db_tables_columns, db_tables_columns_types, device, my_max_length\n", 692 | " )\n", 693 | "\n", 694 | " pred_query = ' '.join(translation).replace(' ', '').lower()\n", 695 | " old_pred_query = pred_query\n", 696 | "\n", 697 | " if '[t]' not in src:\n", 698 | " # with template\n", 699 | "\n", 700 | " nl_template_cnt += 1\n", 701 | "\n", 702 | " if ' '.join(gold_query.replace('\"', \"'\").split()) == ' '.join(pred_query.replace('\"', \"'\").split()):\n", 703 | " nl_template_match += 1\n", 704 | " \n", 705 | " predict_query = ' '.join(pred_query.replace('\"', \"'\").split()) \n", 706 | " print('=========================================================\\n')\n", 707 | " print('[Database]:\\n', row['db_id'])\n", 708 | " print('[NL Question]:\\n', row['question'])\n", 709 | " print('[Chart Template]:\\n', row['query_template'])\n", 710 | " print('[Predicted VIS Query]:\\n', predict_query)\n", 711 | " print('[The Ground Truth VIS Query]:\\n', gold_query)\n", 712 | " ############ Query the VIS Result ############\n", 713 | " print('The Predicted VIS:')\n", 714 | " cnx = sqlite3.connect('./dataset/database/'+row['db_id']+'/'+row['db_id']+'.sqlite')\n", 715 | " table4vis = pd.read_sql_query(\"SELECT * FROM \" + table_name, cnx)\n", 716 | " table4vis.columns = table4vis.columns.str.lower() # to lowercase\n", 717 | " pprint(VegaLite(query2vl.to_VegaLite(predict_query, table4vis)))\n", 718 | " print(json.dumps(query2vl.to_VegaLite(predict_query, table4vis))) # print the vega-lite spec\n", 719 | " print('\\n')\n", 720 | " \n", 721 | " else:\n", 722 | " pass\n", 723 | "\n", 724 | " if '[t]' in src:\n", 725 | " # without template\n", 726 | "\n", 727 | " only_nl_cnt += 1\n", 728 | " if ' '.join(gold_query.replace('\"', \"'\").split()) == ' '.join(pred_query.replace('\"', \"'\").split()): \n", 729 | " only_nl_match += 1\n", 730 | " \n", 731 | " predict_query = ' '.join(pred_query.replace('\"', \"'\").split()) \n", 732 | " print('=========================================================\\n')\n", 733 | " print('[Database]:\\n', row['db_id'])\n", 734 | " print('[NL Question]:\\n', row['question'])\n", 735 | " print('[Chart Template]:\\n', row['query_template'])\n", 736 | " print('[Predicted VIS Query]:\\n', predict_query)\n", 737 | " print('[The Ground Truth VIS Query]:\\n', gold_query)\n", 738 | " ############ Query the VIS Result ############\n", 739 | " print('The Predicted VIS:')\n", 740 | " cnx = sqlite3.connect('./dataset/database/'+row['db_id']+'/'+row['db_id']+'.sqlite')\n", 741 | " table4vis = pd.read_sql_query(\"SELECT * FROM \" + table_name, cnx)\n", 742 | " table4vis.columns = table4vis.columns.str.lower() # to lowercase\n", 743 | " print(VegaLite(query2vl.to_VegaLite(predict_query, table4vis)))\n", 744 | " print(json.dumps(query2vl.to_VegaLite(predict_query, table4vis))) # print the vega-lite spec\n", 745 | " print('\\n')\n", 746 | " \n", 747 | " else:\n", 748 | " pass\n", 749 | " \n", 750 | " \n", 751 | " #show top-X testing cases.\n", 752 | " if index > 1:\n", 753 | " break\n", 754 | "\n", 755 | "\n", 756 | "print(\"========================================================\")\n", 757 | "print('ncNet w/o chart template:', only_nl_match/only_nl_cnt)\n", 758 | "print('ncNet with chart template:', nl_template_match/nl_template_cnt)\n", 759 | "print('ncNet overall:',(only_nl_match+nl_template_match) / (only_nl_cnt+nl_template_cnt))\n", 760 | "\n" 761 | ] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "execution_count": 51, 766 | "metadata": { 767 | "ExecuteTime": { 768 | "end_time": "2021-10-27T03:40:48.934363Z", 769 | "start_time": "2021-10-27T03:40:48.930755Z" 770 | } 771 | }, 772 | "outputs": [ 773 | { 774 | "data": { 775 | "application/javascript": [ 776 | "const spec = {\"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"cname\", \"type\": \"nominal\"}, \"y\": {\"field\": \"enr\", \"type\": \"quantitative\", \"aggregate\": \"min\", \"sort\": \"x\"}}, \"data\": {\"values\": [{\"cname\": \"LSU\", \"state\": \"LA\", \"enr\": 18000}, {\"cname\": \"ASU\", \"state\": \"AZ\", \"enr\": 12000}, {\"cname\": \"OU\", \"state\": \"OK\", \"enr\": 22000}, {\"cname\": \"FSU\", \"state\": \"FL\", \"enr\": 19000}]}};\n", 777 | "const opt = {};\n", 778 | "const type = \"vega-lite\";\n", 779 | "const id = \"a8a997d3-53d7-474e-8ed4-3edf93c68558\";\n", 780 | "\n", 781 | "const output_area = this;\n", 782 | "\n", 783 | "require([\"nbextensions/jupyter-vega/index\"], function(vega) {\n", 784 | " const target = document.createElement(\"div\");\n", 785 | " target.id = id;\n", 786 | " target.className = \"vega-embed\";\n", 787 | "\n", 788 | " const style = document.createElement(\"style\");\n", 789 | " style.textContent = [\n", 790 | " \".vega-embed .error p {\",\n", 791 | " \" color: firebrick;\",\n", 792 | " \" font-size: 14px;\",\n", 793 | " \"}\",\n", 794 | " ].join(\"\\\\n\");\n", 795 | "\n", 796 | " // element is a jQuery wrapped DOM element inside the output area\n", 797 | " // see http://ipython.readthedocs.io/en/stable/api/generated/\\\n", 798 | " // IPython.display.html#IPython.display.Javascript.__init__\n", 799 | " element[0].appendChild(target);\n", 800 | " element[0].appendChild(style);\n", 801 | "\n", 802 | " vega.render(\"#\" + id, spec, type, opt, output_area);\n", 803 | "}, function (err) {\n", 804 | " if (err.requireType !== \"scripterror\") {\n", 805 | " throw(err);\n", 806 | " }\n", 807 | "});\n" 808 | ], 809 | "text/plain": [ 810 | "" 811 | ] 812 | }, 813 | "execution_count": 51, 814 | "metadata": { 815 | "jupyter-vega": "#a8a997d3-53d7-474e-8ed4-3edf93c68558" 816 | }, 817 | "output_type": "execute_result" 818 | }, 819 | { 820 | "data": { 821 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJEAAAD9CAYAAAC87pUCAAAAAXNSR0IArs4c6QAAFfFJREFUeF7tnQm0jdX7x59LySyztLiZVlwkFX8ZMoRMUeJimaeVXEIyNiwSVsvqliySKFG4puSapcEqKSFlXEuDocxkzJD81/dxz/2de9zDffd+7znnPb57rbuue7x7v/v97s959n73fvZ+Yq5du3ZNmKiAhQIxhMhCPWZVBQgRQbBWgBBZS8gCCBEZsFbgtoNo2bJlEhcXZy2clwoY8P4mV6o7sXc1Lad06dJpyrvtIHrjjTdk2LBhrojqlUKeHJHkSlWTx7eT3377jRARInOeCFGKdoSIEJkrQIistaMlIkSEyFoBQmQtIS0RISJE1goQImsJaYkIESGyVoAQWUtIS0SICJG1AoTIWkJaIkLkTYjOnj0rOXLkkDvuuCPNAxw/flzuvvvuGz73XQSP3SNHjkixYsXS5HP6eXqqcdnDnKWQWqKrV6/KwIEDZf369fLvv//K4MGDpUePHnL58mVp3769HDx4UA4dOiSzZs2S+vXrp3mqP//8Uxo3biwFCxaUEydOyJo1a+Tee+8Vp58Hk4oQeQSib775Rt5++21JSkqSS5cuSeXKlWX37t2yfPlySU5OlhkzZsiePXskPj5etm3bluapXnjhBalatap07txZZs+eLVu3bpXExERx+jkh+p8CnnQFmTt3rhQuXFgaNmwoFy9elHLlysn27dtl/PjxUqdOHWnevLk+YfXq1dXSoGvzpYcffljWrVunn50+fVoaNWokP/zwgzj9nBB5HCJf9f/55x+1IIUKFZIxY8ZoNzVx4kSpUKGCXgJr06dPH6lVq5b+ja4PXRfGQ75UtGhR2bdvn8TGxmb4c3R9GIfBIn777bc38NS2bVtzm+7BnJ71bER31a1bN+nZs6f+xMTESP/+/eWZZ56RevXqaVNgPDRnzhy55557Upvm/vvv1y4ue/bsasVggXbs2CFOP6cl8rglwqAZoCxZsiTV6uCRFixYIFu2bNFu7fDhw9qtbd68WU6dOqVWCF1gQkKCtGjRQpo2bSorV64U+ERPnjzZ8eeEyOMQjR07VvAWVKpUqdQnWbx4sXZVeDs7efKkdk2TJk3SLm706NE66MbYB11Xy5YttQv8+++/BfnQlTn9nBB5HKJbDRtgefLkyZM6T7R3714dKwEqX8JcEkAKTE4/D8zPV/xbtU7w/w/pPJHTai5dulQH15gbyuxEiMwVjmiIzB/LeU5C5FwzXw5ClKIEISJE5goQImvtaIkyEaJXP/jauoFQwGs96rpSTmAhnlz2yBQlXCo0M7ozNxvJpcdMU4yb9eM2ahGdv3J7L76bjUSIMkMBl8skROaCckyUiWMiWqLb7OBPWiJaInMFaImstWN3RogIkbUChMhaQloiD0GUOP976wZHAS/E/5+W4+bAn/NEHpkncrPRCZEr38e0hXjh7YwQZULDu1kkITJXk2MiD42JaInMQQ9JTloic5lpiWiJzOlJyUmICBEhslaAEFlLSEtEiAiRtQKEyFpCWiJCRIisFSBE1hLSEhEiQmStACGylpCWiBARImsFCJG1hLREhIgQWStAiKwlpCUiRITIWgFCZC0hLREhIkTWChAiawnDaomwUxvnWCMF7tr2fe7/hE4DwQS7Pj3V6NlozlJYILpw4YK88sorsmvXLlmxYoWGYGjTpo2UKFEi9UkQrsEfJKeBYIJdH0wqQuQxiB577DEpX7687N+/X1atWqVRhc6cOSP9+vUL+iROA8EEu54QicByILnp+B/yzYs4zByWAqGqABGsAM6vxg9CMeDU/cBYaE4DwQS7nhBFCURoSMTl8EEEq4HQU/iN8Axr167VIC6+5HaAGI6JosASBUIESLJmzZo6Bqpdu7ZMmzZN4uLiUtvbaSCYYNejwFBFGXI7ik+kl1e6dOk038+Ya5l8yJW/JUIMj0qVKmmkIQy6EUxv48aN2qV5OUCMm2MOt8cwbpcX8jGRzxINGTJE386OHj0qnTp1kitXrsj58+e1m2vXrp3nA8QQoky2ROmNSzDg9o+26PUAMYQoDBAFguX1ADGEKAIgMp/+cp6Tk43ONfPlCMuMtXl1My8nITLXlhClaEeICJG5AoTIWjsrSzR8+HCd38HrudcTLZF5C1pBhIXU++67TxdQvZ4IkXkLWkE0dOhQmTBhgoYZr1GjRmotYKGyZctmXqsw5CRE5qJbQdS3b1/5+usbA8MhBHmuXLnMaxWGnITIXHQriMxvG3k5CZF5m1hBtG/fPpkyZYr6BF28eDG1Fps3b5bcuXOb1yoMOQmRuehWEKE7W7RokXTs2FEKFCiQWguMlTgmctdzEOJG8jKK8Sp+xYoVBSvx3bp1M8c4QnLSEpk3hJUlGjlypOzcuVPeeecdKVy4cGotcuTIYV6jMOUkRObCW0HUuHFjdWUNTKdPn5a8efOa1yoMOQmRuehWEG3YsEGOHTuW5u7Y5tOsWbMbHO3NqxianITIXGcriHDbTZs26b6x4sWL6xta165dJWfOnOY1ClNOQmQuvBVEeLXHbHWRIkWke/fu2rWVKVNG5s+fb16jMOUkRObCW0FUvXp16dy5s/pFw7U1ISFBSpYsKadOnUrj5mpevdDlJETmWltBVKVKFd21euLECYXoueee0wXZQ4cOSbFixcxrFYachMhcdCuIxo4dKy+//LJgv1H27Nl1K3TNmjUlKSnJvEZhykmIzIW3ggh7wpYsWSJz5syRgwcPSpcuXXSrj/+ckXnVQpuTEJnrbQWR+W0jLycgejZhgCsVuzt3di0nkpcp3K6f8bKHK4pHSCGA6Ju/73OlNplx6obbje52eYQoJaQ5ITL7Dhl1ZzjBo2zZsnL27FnJly+f5xzQ0pOKlsgMIOQygqho0aKyYMECGTdunLqBYK7I64kQmbegEUT169eXkydPyu+//y7lypXTWWr/NHPmTM8tfRCiEEOE8xRxfhDmg2JjYxUk/zR9+nTxmjsIIQoxRL7bxcfHS+vWraV9+/bmNYiQnITIvCGMujP/223btk3mzp0rBw4ckA4dOgh8jLzmGovnIURhgujLL7+UBg0aSLVq1XQlf/ny5dKqVSudxfZaIkTmLWZliVq2bCmlSpWSiRMnag3WrVsnDRs2lL/++ktPgfVSIkTmrWUFERz1e/XqJYMGDdIa4GQzDLKxlQguIV5KhMi8tawgwnZp7DsbOHCgFCxYUPfk41R8dmfuH/GLJo7ktTjjZQ+c9IotQwAJB5n37NlTnn/+ee3ivJZoicxbzMoS+W6Lk/muXr3qOed8f9kIUZghMr29f5QhlHH8+HF1qw0MyeAP65EjR27wmnQafSi9+hIi01Y0XDszv931nIFRhi5fvqwTlnBsg2stxlZYWvFPoYgyxFV8s5Z1pTtzeuvAKEMYiCcnJ8uMGTNkz549gplwTGL6p1BEGSJETlvy+vWuQITdHZcuXUqtAVb50wt657sgMMoQ3vLq1KkjzZs310uwiwR72fwPRg9FlCFCFAaIEGhlwIABGhnIP2VkG7V/bA8slWDCskKFCloMXEv69OkjtWrV0r/djjIULECMWxBN7F1N6x3pAV3crp9RgBgsviIuB17z8+fPn8oRGh9Rg26W/CHq37+/BodBnDMkjIfg/O8/6+1mlKH06sWBtZkVsu7O0MVgrALHNKfJHyI4uMGajR8/Xg4fPqzdGg7KQjcZyihDblki+lhfpyFDoapGjBghixcvllGjRukCrC/VrVv3lnNGgMgXZQh7+PF2Bkc3vMJPmjRJvQEQwgqLujgDEkspWKsrVKiQbpTEfeHL5PTzYLDTEjk1A/+73mpg3aRJE1m9evUNd8/ImCi9KsPyYObbN08U6ihDtERmIFlBBFgQo8w/4a0M62hupFBHGSJEZq1mBBHGQRj8njt3Tv2sAyFC0DuvOaaxOzMDyGhg7XvdxjlEGMPAMS0wYaLQi6fH0hKZgWRkibBM4TVLcyt5aIlupVDw/zeCCG9JsELBEmabvXZaGiEKMURwPMNiKRLmdAK9GBMTE/WoGS8lQmTeWkaWCG9lH3/8sQAWeLRhmQIHo/sHiTGvUnhyEiJz3Y0g8t0Ojmg4txGTg5gvwq4PrKVh61CWLFnMaxWGnITIXHQriPxvO3XqVD1uD8l0stH8MexzEiJzDa0ggvvHwoULtVvD2lejRo3UxxrjpJu5gphXN/NyEiJzbY0gwiIp9tujG8MqPg7/hBWKi4szr0mYcxIi8wYwgghOZ4AHCdYncJkDHop8xY/sLT5oOze3IDneMoQDPrGSHizNmzePELncSG43utvlOYbI3PBFbk52Z+ZtY9Sdmd8ucnMSIvO2IUQp2hEiQmSuACGy1o6WiBARImsFCJG1hLREhIgQWStAiKwlpCUiRITIWgFCZC0hLREhIkTWChAiawlpiQgRIbJWgBBZS0hLRIgIkbUChMhaQloiQkSIrBUgRNYS0hIRIkJkrQAhspaQlogQESJrBQiRtYS0RIQoeiBCkBf/lN42bKeBYIJdn55qdNQ3ZykiLBEOxWrTpo0G3PMlhE33B4kBYtzdsQqdw7oD1pzZ9HMiqtCZM2d0T3+wxAAx7jZ61EGErgTnV+MHoRgQniEw5hkDxBCimxovWJmtW7dqiAccUbN27VpBEBdfcjtADMdE7seoDftefECCgDK+MVDt2rVl2rRpaY6qcTNADKMMZU4UJKMoQ26NjRDDo1KlShppCFEZK1euLBs3btQujQFirgelc3sM43Z5YbdEOOuoU6dOGuLh/PnzghP527VrxwAxKd9SQuTAXOHMI/9oiwwQc108QuQAosBLGSCGEFngE/qsnLE21zwiZqzNq+9eTkJkriUhStGOEBEicwUIkbV2tESEiBBZK0CIrCWkJSJEhMhaAUJkLSEtESEiRNYKECJrCWmJCBEhslaAEFlLSEtEiAiRtQKEyFpCWiJCRIisFSBE1hLSEhEiQmStACGylpCWiBARImsFCJG1hLREhIgQWStAiKwlpCUiRITIWgFCZC0hLREhIkTWChAiawlpiQgRIbJWgBBZS0hLRIgIkbUChMhaQloiQkSIrBUgRNYS0hIRIkJkrQAhspaQlogQESJrBQiRtYRRbYkYZSh9PjLjNNqwn2Nt/VVIp4BgUYmC3YvH7Zm3QtRaomBRiQhR5pyLHZWWKFhUIkJEiDJkc4NFJUIXh3gh6QWIyZUrl4aEYDJToGLFitKiRYs0mWOuBcbTNCs7bLmCRSWyrRDGTsOGDbMtJjV/NJfneYgSEhL0m9G0aVNZuXKlLFu2TCZPnmzd+NHc6OmJY/O8nodo37590rJlSylUqJAg8MzixYslNjaWEDlU4LaGyKfV8ePHFSS3ko2obn/TI708z1sit6AJLAeDckSGdCtFc3mEyC1KbuNyCNFt3PhuPTohEpFff/1VunbtmkbTIkWKSOvWreXJJ5+UfPnyOdL7diuPEIloEOP9+/enAeXYsWPy6aefCn7PmDHDEUS3W3mE6CZ4YB72wQcflO+++05y5szpCKT0Lo7W8gjRLdDAGxrmntC9uZGisTxCFISMS5cuyZo1a3T2e9WqVdb8RHN5hEhEdu/eLXXr1r0BFAyssaxSqVIlRxDdbuURIhHBWOXKlSs3gILP77rrLkcA4WIvlffff//JqVOntLvOmjWr42dFBkKUItu6desEDm6bN2+WHTt2SJ06daRw4cIyduxYad++vWNxI728w4cPy9ChQ2X27NkK0NGjR2XQoEHy2muvSe7cuR09LyESUSuELuujjz6SGjVqSLNmzeTZZ5+VmjVrSsOGDWXbtm2ORI308mB98Jxt27bVL0iJEiVkz549Mm3aNDl58qR8+OGHjp6XEInIxo0bZdasWTJlyhSdF2rSpIlaJCS4meDzkiVLZljYSC9v69atamEXLlyY5pl8UxAbNmwQOO9lNBEiEf0WTpgwQaZPn64/R44ckZdeekk1hCffli1bHI2NIr28zz//XNauXSvwVAhMeJkYN26clC9fPqMMcUzkUwqwPProozJ//nzBNzVLliwycuRI/f3JJ59kWFAvlHfgwAF56qmn5Pvvv1c3Yl86c+aMVK1aVXbu3OnoS0NLlKIg/LLhFVm9enUVcv369QKx4fAG0w6YnKRg5XXo0MFxWbhvsPIwrsmWLZuTqum1gwcPlp9++kk6deok2OwAVxUs7+BvDLCdJEIURC2MiT744AN9e8GbjJNlD3Rnc+bMkdGjR+umgPfee0/f/DIjPf744/LZZ585fqPC4BpWF7Px+LKULl1aF6EbN27suJqEyE8yuNcmJSXpQPrnn3/WxsfbS968eR0J+8svv2gZ7777rpw4cUI6d+4sK1ascFRGRi82hSij5WfkOkIkoq+1Q4YMka+++ko6duyoPwMGDNC3F6dzJhCdEGUEvSi75o8//pAHHnhAMF7p1auXVKtWTV/zCVHGGpqWKEUnLJBiYP3+++/rGAGO/5jvKVWqVMaU9Ltq+/btUrly5dSVf8wG+3sBYCuyk3kYjLEw6ZleOnjwoJw9e9bIYjp+sCAZCFE6wqCRMfmItxVYpXnz5jl6Awq2dua7ldO3KbfLcwseXzmE6CaKwkNx9erV+sZy5513uq191JRHiKKmKcP3IIQofNpHzZ0JUdQ0ZfgehBCFT/uouTMhipqmDN+DEKLwaR81dyZEUdOU4XsQQuSy9pighFPXuXPn1M32zTff1JV8LKPAzWTRokW6IXLUqFHSoEED/b9+/fqlLtBid8mrr76qK/P4jcO7sJhbr1493daN09uKFSum92jevLnmxzofVuSx3Rt+0/3793f5qW5eHCFyUW4sl8CNFg39xBNP6DpcfHy8/i5btqyCMGLECHn99dcFrhjw4QFMWO3HOt3evXulR48e6tMNXx+4ZmBh+KGHHtJ1PSydzJw5U3/gwooDvpAXzvX47NChQ9KzZ09dAHa6zclGBkJko15AXmx0BCTwCoDHINa18G+skwEiuKRiDQyeknD+unjxol6HNTq4rMJvaerUqfqTI0cOhQiWBr5MOJsSlicxMVG++OILgQsIyn/66afl9OnTqYdxYu0PUA0cONDFJ6MlCpmY8FlGIwOGmJgYXciF4z+6GUC0adMmeeSRR2Tp0qXSqlUrBQR50OjDhw/XxV7sMkH3BfAAke9c1ipVqqiX5ZgxY9QK1apVS8sHTAUKFNCuzpcAHK4NVaIlclHpXbt2SVxcnG65wXinTZs2un+tb9++QSFCF1e8eHH1hETX9OKLLzqCCFYNW32WL1+u4zBYK2w66NKli4tPRksUMjFxI7jBvvXWW3pPgIRGvnDhgpQpU0Z+/PFH9WdOTk5WSwFLhH/37t1b3TngMYBxDsZJcIYDCD5LhHwYqPtbInRn6A6xMwUD+jx58kj37t11MO/vgJ/ZAtASZYLCgAb+Sfnz589Q6ZcvX1aIChYsmKHr07sIVghbvsPhbUCIjJuNGX0KECKyYK0AIbKWkAUQIjJgrQAhspaQBRAiMmCtwP8DrHrjkL5G7ecAAAAASUVORK5CYII=" 822 | }, 823 | "metadata": { 824 | "jupyter-vega": "#a8a997d3-53d7-474e-8ed4-3edf93c68558" 825 | }, 826 | "output_type": "display_data" 827 | } 828 | ], 829 | "source": [ 830 | "VegaLite(query2vl.to_VegaLite(predict_query, table4vis))" 831 | ] 832 | } 833 | ], 834 | "metadata": { 835 | "kernelspec": { 836 | "display_name": "Python 3", 837 | "language": "python", 838 | "name": "python3" 839 | }, 840 | "language_info": { 841 | "codemirror_mode": { 842 | "name": "ipython", 843 | "version": 3 844 | }, 845 | "file_extension": ".py", 846 | "mimetype": "text/x-python", 847 | "name": "python", 848 | "nbconvert_exporter": "python", 849 | "pygments_lexer": "ipython3", 850 | "version": "3.8.3" 851 | }, 852 | "latex_envs": { 853 | "LaTeX_envs_menu_present": true, 854 | "autoclose": false, 855 | "autocomplete": true, 856 | "bibliofile": "biblio.bib", 857 | "cite_by": "apalike", 858 | "current_citInitial": 1, 859 | "eqLabelWithNumbers": true, 860 | "eqNumInitial": 1, 861 | "hotkeys": { 862 | "equation": "Ctrl-E", 863 | "itemize": "Ctrl-I" 864 | }, 865 | "labels_anchors": false, 866 | "latex_user_defs": false, 867 | "report_style_numbering": false, 868 | "user_envs_cfg": false 869 | }, 870 | "toc": { 871 | "base_numbering": 1, 872 | "nav_menu": {}, 873 | "number_sections": true, 874 | "sideBar": true, 875 | "skip_h1_title": false, 876 | "title_cell": "Table of Contents", 877 | "title_sidebar": "Contents", 878 | "toc_cell": false, 879 | "toc_position": {}, 880 | "toc_section_display": true, 881 | "toc_window_display": false 882 | } 883 | }, 884 | "nbformat": 4, 885 | "nbformat_minor": 4 886 | } 887 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | __author__ = "Yuyu Luo" 2 | 3 | ''' 4 | This script handles the training process. 5 | ''' 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from model.Model import Seq2Seq 12 | from model.Encoder import Encoder 13 | from model.Decoder import Decoder 14 | from preprocessing.build_vocab import build_vocab 15 | 16 | import numpy as np 17 | import random 18 | import time 19 | import math 20 | import matplotlib.pyplot as plt 21 | 22 | import argparse 23 | 24 | 25 | def count_parameters(model): 26 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 27 | 28 | def initialize_weights(m): 29 | if hasattr(m, 'weight') and m.weight.dim() > 1: 30 | nn.init.xavier_uniform_(m.weight.data) 31 | 32 | 33 | def train(model, iterator, optimizer, criterion, clip): 34 | model.train() 35 | 36 | epoch_loss = 0 37 | 38 | for i, batch in enumerate(iterator): 39 | src = batch.src 40 | trg = batch.trg 41 | tok_types = batch.tok_types 42 | 43 | optimizer.zero_grad() 44 | 45 | output, _ = model(src, trg[:, :-1], tok_types, SRC) 46 | 47 | # output = [batch size, trg len - 1, output dim] 48 | # trg = [batch size, trg len] 49 | 50 | output_dim = output.shape[-1] 51 | 52 | output = output.contiguous().view(-1, output_dim) 53 | trg = trg[:, 1:].contiguous().view(-1) 54 | 55 | # output = [batch size * trg len - 1, output dim] 56 | # trg = [batch size * trg len - 1] 57 | 58 | loss = criterion(output, trg) 59 | 60 | loss.backward() 61 | 62 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 63 | 64 | optimizer.step() 65 | 66 | epoch_loss += loss.item() 67 | 68 | return epoch_loss / len(iterator) 69 | 70 | 71 | def evaluate(model, iterator, criterion): 72 | model.eval() 73 | 74 | epoch_loss = 0 75 | 76 | with torch.no_grad(): 77 | for i, batch in enumerate(iterator): 78 | src = batch.src 79 | trg = batch.trg 80 | tok_types = batch.tok_types 81 | 82 | output, _ = model(src, trg[:, :-1], tok_types, SRC) 83 | 84 | # output = [batch size, trg len - 1, output dim] 85 | # trg = [batch size, trg len] 86 | 87 | output_dim = output.shape[-1] 88 | 89 | output = output.contiguous().view(-1, output_dim) 90 | trg = trg[:, 1:].contiguous().view(-1) 91 | 92 | # output = [batch size * trg len - 1, output dim] 93 | # trg = [batch size * trg len - 1] 94 | 95 | loss = criterion(output, trg) 96 | 97 | epoch_loss += loss.item() 98 | 99 | return epoch_loss / len(iterator) 100 | 101 | 102 | def epoch_time(start_time, end_time): 103 | elapsed_time = end_time - start_time 104 | elapsed_mins = int(elapsed_time / 60) 105 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 106 | return elapsed_mins, elapsed_secs 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser(description='train.py') 112 | 113 | parser.add_argument('-data_dir', required=False, default='./dataset/dataset_final/', 114 | help='Path to dataset for building vocab') 115 | parser.add_argument('-db_info', required=False, default='./dataset/database_information.csv', 116 | help='Path to database tables/columns information, for building vocab') 117 | parser.add_argument('-output_dir', type=str, default='./save_models/') 118 | 119 | parser.add_argument('-epoch', type=int, default=100, 120 | help='the number of epoch for training') 121 | parser.add_argument('-learning_rate', type=float, default=0.0005) 122 | parser.add_argument('-batch_size', type=int, default=128) 123 | parser.add_argument('-max_input_length', type=int, default=128) 124 | 125 | # parser.add_argument('-n_head', type=int, default=8) 126 | # parser.add_argument('-dropout', type=float, default=0.1) 127 | opt = parser.parse_args() 128 | 129 | ################################### 130 | 131 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 132 | 133 | SEED = 1234 134 | 135 | random.seed(SEED) 136 | np.random.seed(SEED) 137 | torch.manual_seed(SEED) 138 | torch.cuda.manual_seed(SEED) 139 | torch.backends.cudnn.deterministic = True 140 | 141 | 142 | print("------------------------------\n| Build vocab start ... | \n------------------------------") 143 | SRC, TRG, TOK_TYPES, BATCH_SIZE, train_iterator, valid_iterator, test_iterator, my_max_length = build_vocab( 144 | data_dir=opt.data_dir, 145 | db_info=opt.db_info, 146 | batch_size=opt.batch_size, 147 | max_input_length=opt.max_input_length 148 | ) 149 | print("------------------------------\n| Build vocab end ... | \n------------------------------") 150 | 151 | INPUT_DIM = len(SRC.vocab) 152 | OUTPUT_DIM = len(TRG.vocab) 153 | HID_DIM = 256 # it equals to embedding dimension 154 | ENC_LAYERS = 3 155 | DEC_LAYERS = 3 156 | ENC_HEADS = 8 157 | DEC_HEADS = 8 158 | ENC_PF_DIM = 512 159 | DEC_PF_DIM = 512 160 | ENC_DROPOUT = 0.1 161 | DEC_DROPOUT = 0.1 162 | 163 | print("------------------------------\n| Build encoder of the ncNet ... | \n------------------------------") 164 | enc = Encoder(INPUT_DIM, 165 | HID_DIM, 166 | ENC_LAYERS, 167 | ENC_HEADS, 168 | ENC_PF_DIM, 169 | ENC_DROPOUT, 170 | device, 171 | TOK_TYPES, 172 | my_max_length 173 | ) 174 | print("------------------------------\n| Build decoder of the ncNet ... | \n------------------------------") 175 | dec = Decoder(OUTPUT_DIM, 176 | HID_DIM, 177 | DEC_LAYERS, 178 | DEC_HEADS, 179 | DEC_PF_DIM, 180 | DEC_DROPOUT, 181 | device, 182 | my_max_length 183 | ) 184 | 185 | SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token] 186 | TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token] 187 | 188 | print("------------------------------\n| Build the ncNet structure... | \n------------------------------") 189 | ncNet = Seq2Seq(enc, dec, SRC, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device) # define the transformer-based ncNet 190 | 191 | print("------------------------------\n| Init for training ... | \n------------------------------") 192 | ncNet.apply(initialize_weights) 193 | 194 | LEARNING_RATE = opt.learning_rate 195 | 196 | optimizer = torch.optim.Adam(ncNet.parameters(), lr=LEARNING_RATE) 197 | 198 | criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX) 199 | 200 | N_EPOCHS = opt.epoch 201 | CLIP = 1 202 | 203 | train_loss_list, valid_loss_list = list(), list() 204 | 205 | best_valid_loss = float('inf') 206 | 207 | print("------------------------------\n| Training start ... | \n------------------------------") 208 | 209 | for epoch in range(N_EPOCHS): 210 | 211 | start_time = time.time() 212 | 213 | train_loss = train(ncNet, train_iterator, optimizer, criterion, CLIP) 214 | valid_loss = evaluate(ncNet, valid_iterator, criterion) 215 | 216 | end_time = time.time() 217 | 218 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 219 | 220 | # save the best trained model 221 | if valid_loss < best_valid_loss: 222 | print('△○△○△○△○△○△○△○△○\nSave the BEST model!\n△○△○△○△○△○△○△○△○△○') 223 | best_valid_loss = valid_loss 224 | torch.save(ncNet.state_dict(), opt.output_dir + 'model_best.pt') 225 | 226 | # save model on each epoch 227 | print('△○△○△○△○△○△○△○△○\nSave ncNet!\n△○△○△○△○△○△○△○△○△○') 228 | torch.save(ncNet.state_dict(), opt.output_dir + 'model_' + str(epoch + 1) + '.pt') 229 | 230 | train_loss_list.append(train_loss) 231 | valid_loss_list.append(valid_loss) 232 | 233 | print(f'Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s') 234 | print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}') 235 | print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}') 236 | plt.plot(train_loss_list) 237 | plt.plot(valid_loss_list) 238 | plt.show() 239 | -------------------------------------------------------------------------------- /utilities/vis_rendering.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is for convert vis query to the Vega-Lite object 3 | ''' 4 | 5 | 6 | __author__ = "Yuyu Luo" 7 | 8 | import json 9 | import pandas 10 | 11 | class VegaZero2VegaLite(object): 12 | def __init__(self): 13 | pass 14 | 15 | def parse_vegaZero(self, vega_zero): 16 | self.parsed_vegaZero = { 17 | 'mark': '', 18 | 'data': '', 19 | 'encoding': { 20 | 'x': '', 21 | 'y': { 22 | 'aggregate': '', 23 | 'y': '' 24 | }, 25 | 'color': { 26 | 'z': '' 27 | } 28 | }, 29 | 'transform': { 30 | 'filter': '', 31 | 'group': '', 32 | 'bin': { 33 | 'axis': '', 34 | 'type': '' 35 | }, 36 | 'sort': { 37 | 'axis': '', 38 | 'type': '' 39 | }, 40 | 'topk': '' 41 | } 42 | } 43 | vega_zero_keywords = vega_zero.split(' ') 44 | 45 | self.parsed_vegaZero['mark'] = vega_zero_keywords[vega_zero_keywords.index('mark') + 1] 46 | self.parsed_vegaZero['data'] = vega_zero_keywords[vega_zero_keywords.index('data') + 1] 47 | self.parsed_vegaZero['encoding']['x'] = vega_zero_keywords[vega_zero_keywords.index('x') + 1] 48 | self.parsed_vegaZero['encoding']['y']['y'] = vega_zero_keywords[vega_zero_keywords.index('aggregate') + 2] 49 | self.parsed_vegaZero['encoding']['y']['aggregate'] = vega_zero_keywords[vega_zero_keywords.index('aggregate') + 1] 50 | if 'color' in vega_zero_keywords: 51 | self.parsed_vegaZero['encoding']['color']['z'] = vega_zero_keywords[vega_zero_keywords.index('color') + 1] 52 | 53 | if 'topk' in vega_zero_keywords: 54 | self.parsed_vegaZero['transform']['topk'] = vega_zero_keywords[vega_zero_keywords.index('topk') + 1] 55 | 56 | if 'sort' in vega_zero_keywords: 57 | self.parsed_vegaZero['transform']['sort']['axis'] = vega_zero_keywords[vega_zero_keywords.index('sort') + 1] 58 | self.parsed_vegaZero['transform']['sort']['type'] = vega_zero_keywords[vega_zero_keywords.index('sort') + 2] 59 | 60 | if 'group' in vega_zero_keywords: 61 | self.parsed_vegaZero['transform']['group'] = vega_zero_keywords[vega_zero_keywords.index('group') + 1] 62 | 63 | if 'bin' in vega_zero_keywords: 64 | self.parsed_vegaZero['transform']['bin']['axis'] = vega_zero_keywords[vega_zero_keywords.index('bin') + 1] 65 | self.parsed_vegaZero['transform']['bin']['type'] = vega_zero_keywords[vega_zero_keywords.index('bin') + 3] 66 | 67 | if 'filter' in vega_zero_keywords: 68 | 69 | filter_part_token = [] 70 | for each in vega_zero_keywords[vega_zero_keywords.index('filter') + 1:]: 71 | if each not in ['group', 'bin', 'sort', 'topk']: 72 | filter_part_token.append(each) 73 | else: 74 | break 75 | 76 | if 'between' in filter_part_token: 77 | filter_part_token[filter_part_token.index('between') + 2] = 'and ' + filter_part_token[ 78 | filter_part_token.index('between') - 1] + ' <=' 79 | filter_part_token[filter_part_token.index('between')] = '>=' 80 | 81 | # replace 'and' -- 'or' 82 | filter_part_token = ' '.join(filter_part_token).split() 83 | filter_part_token = ['&' if x == 'and' else x for x in filter_part_token] 84 | filter_part_token = ['|' if x == 'or' else x for x in filter_part_token] 85 | 86 | if '&' in filter_part_token or '|' in filter_part_token: 87 | final_filter_part = '' 88 | each_conditions = [] 89 | for i in range(len(filter_part_token)): 90 | each = filter_part_token[i] 91 | if each != '&' and each != '|': 92 | # ’=‘ in SQL --to--> ’==‘ in Vega-Lite 93 | if each == '=': 94 | each = '==' 95 | each_conditions.append(each) 96 | if each == '&' or each == '|' or i == len(filter_part_token) - 1: 97 | # each = '&' or '|' 98 | if 'like' == each_conditions[1]: 99 | # only consider this case: '%a%' 100 | if each_conditions[2][1] == '%' and each_conditions[2][len(each_conditions[2]) - 2] == '%': 101 | final_filter_part += 'indexof(' + 'datum.' + each_conditions[0] + ',"' + \ 102 | each_conditions[2][2:len(each_conditions[2]) - 2] + '") != -1' 103 | elif 'like' == each_conditions[2] and 'not' == each_conditions[1]: 104 | 105 | if each_conditions[3][1] == '%' and each_conditions[3][len(each_conditions[3]) - 2] == '%': 106 | final_filter_part += 'indexof(' + 'datum.' + each_conditions[0] + ',"' + \ 107 | each_conditions[3][2:len(each_conditions[3]) - 2] + '") == -1' 108 | else: 109 | final_filter_part += 'datum.' + ' '.join(each_conditions) 110 | 111 | if i != len(filter_part_token) - 1: 112 | final_filter_part += ' ' + each + ' ' 113 | each_conditions = [] 114 | 115 | self.parsed_vegaZero['transform']['filter'] = final_filter_part 116 | 117 | else: 118 | # only single filter condition 119 | self.parsed_vegaZero['transform']['filter'] = 'datum.' + ' '.join(filter_part_token).strip() 120 | 121 | return self.parsed_vegaZero 122 | 123 | def to_VegaLite(self, vega_zero, dataframe=None): 124 | self.VegaLiteSpec = { 125 | 'bar': { 126 | "mark": "bar", 127 | "encoding": { 128 | "x": {"field": "x", "type": "nominal"}, 129 | "y": {"field": "y", "type": "quantitative"} 130 | } 131 | }, 132 | 'arc': { 133 | "mark": "arc", 134 | "encoding": { 135 | "color": {"field": "x", "type": "nominal"}, 136 | "theta": {"field": "y", "type": "quantitative"} 137 | } 138 | }, 139 | 'line': { 140 | "mark": "line", 141 | "encoding": { 142 | "x": {"field": "x", "type": "nominal"}, 143 | "y": {"field": "y", "type": "quantitative"} 144 | } 145 | }, 146 | 'point': { 147 | "mark": "point", 148 | "encoding": { 149 | "x": {"field": "x", "type": "quantitative"}, 150 | "y": {"field": "y", "type": "quantitative"} 151 | } 152 | } 153 | } 154 | 155 | VegaZero = self.parse_vegaZero(vega_zero) 156 | 157 | # assign some vega-zero keywords to the VegaLiteSpec object 158 | if isinstance(dataframe, pandas.core.frame.DataFrame): 159 | self.VegaLiteSpec[VegaZero['mark']]['data'] = dict() 160 | self.VegaLiteSpec[VegaZero['mark']]['data']['values'] = json.loads(dataframe.to_json(orient='records')) 161 | 162 | if VegaZero['mark'] != 'arc': 163 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['x']['field'] = VegaZero['encoding']['x'] 164 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['y']['field'] = VegaZero['encoding']['y']['y'] 165 | if VegaZero['encoding']['y']['aggregate'] != '' and VegaZero['encoding']['y']['aggregate'] != 'none': 166 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['y']['aggregate'] = VegaZero['encoding']['y']['aggregate'] 167 | else: 168 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['color']['field'] = VegaZero['encoding']['x'] 169 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['theta']['field'] = VegaZero['encoding']['y']['y'] 170 | if VegaZero['encoding']['y']['aggregate'] != '' and VegaZero['encoding']['y']['aggregate'] != 'none': 171 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['theta']['aggregate'] = VegaZero['encoding']['y'][ 172 | 'aggregate'] 173 | 174 | if VegaZero['encoding']['color']['z'] != '': 175 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['color'] = { 176 | 'field': VegaZero['encoding']['color']['z'], 'type': 'nominal' 177 | } 178 | 179 | # it seems that the group will be performed by VegaLite defaultly, in our cases. 180 | if VegaZero['transform']['group'] != '': 181 | pass 182 | 183 | if VegaZero['transform']['bin']['axis'] != '': 184 | if VegaZero['transform']['bin']['axis'] == 'x': 185 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['x']['type'] = 'temporal' 186 | if VegaZero['transform']['bin']['type'] in ['date', 'year', 'week', 'month']: 187 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['x']['timeUnit'] = VegaZero['transform']['bin']['type'] 188 | elif VegaZero['transform']['bin']['type'] == 'weekday': 189 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['x']['timeUnit'] = 'week' 190 | else: 191 | print('Unknown binning step.') 192 | 193 | if VegaZero['transform']['filter'] != '': 194 | if 'transform' not in self.VegaLiteSpec[VegaZero['mark']]: 195 | self.VegaLiteSpec[VegaZero['mark']]['transform'] = [{ 196 | "filter": VegaZero['transform']['filter'] 197 | }] 198 | elif 'filter' not in self.VegaLiteSpec[VegaZero['mark']]['transform']: 199 | self.VegaLiteSpec[VegaZero['mark']]['transform'].append({ 200 | "filter": VegaZero['transform']['filter'] 201 | }) 202 | else: 203 | self.VegaLiteSpec[VegaZero['mark']]['transform']['filter'] += ' & ' + VegaZero['transform']['filter'] 204 | 205 | if VegaZero['transform']['topk'] != '': 206 | if VegaZero['transform']['sort']['axis'] == 'x': 207 | sort_field = VegaZero['encoding']['x'] 208 | elif VegaZero['transform']['sort']['axis'] == 'y': 209 | sort_field = VegaZero['encoding']['y']['y'] 210 | else: 211 | print('Unknown sorting field: ', VegaZero['transform']['sort']['axis']) 212 | sort_field = VegaZero['transform']['sort']['axis'] 213 | if VegaZero['transform']['sort']['type'] == 'desc': 214 | sort_order = 'descending' 215 | else: 216 | sort_order = 'ascending' 217 | if 'transform' in self.VegaLiteSpec[VegaZero['mark']]: 218 | current_filter = self.VegaLiteSpec[VegaZero['mark']]['transform'][0]['filter'] 219 | self.VegaLiteSpec[VegaZero['mark']]['transform'][0][ 220 | 'filter'] = current_filter + ' & ' + "datum.rank <= " + str(VegaZero['transform']['topk']) 221 | self.VegaLiteSpec[VegaZero['mark']]['transform'].insert(0, { 222 | "window": [{ 223 | "field": sort_field, 224 | "op": "dense_rank", 225 | "as": "rank" 226 | }], 227 | "sort": [{"field": sort_field, "order": sort_order}] 228 | }) 229 | else: 230 | self.VegaLiteSpec[VegaZero['mark']]['transform'] = [ 231 | { 232 | "window": [{ 233 | "field": sort_field, 234 | "op": "dense_rank", 235 | "as": "rank" 236 | }], 237 | "sort": [{"field": sort_field, "order": sort_order}] 238 | }, 239 | { 240 | "filter": "datum.rank <= " + str(VegaZero['transform']['topk']) 241 | } 242 | ] 243 | 244 | if VegaZero['transform']['sort']['axis'] != '': 245 | if VegaZero['transform']['sort']['axis'] == 'x': 246 | if VegaZero['transform']['sort']['type'] == 'desc': 247 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['y']['sort'] = '-x' 248 | else: 249 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['y']['sort'] = 'x' 250 | else: 251 | if VegaZero['transform']['sort']['type'] == 'desc': 252 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['x']['sort'] = '-y' 253 | else: 254 | self.VegaLiteSpec[VegaZero['mark']]['encoding']['x']['sort'] = 'y' 255 | 256 | return self.VegaLiteSpec[VegaZero['mark']] 257 | 258 | --------------------------------------------------------------------------------