├── .gitignore ├── get_data_parameters.py ├── readme.md ├── representation_model.py ├── run_meta_learner.py ├── run_multi_task.py ├── tools ├── pytool.py └── utils.py ├── torch_test.py ├── unified-model-v1.yml ├── unified-model.yml └── vector_loader.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /get_data_parameters.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle 3 | import os 4 | import json 5 | 6 | 7 | class DataParameters: 8 | def __init__(self, condition_max_num, indexes_id, tables_id, columns_id, physic_ops_id, column_total_num, 9 | table_total_num, index_total_num, physic_op_total_num, condition_op_dim, compare_ops_id, bool_ops_id, 10 | bool_ops_total_num, compare_ops_total_num, data, min_max_column, word_vectors, cost_label_min, 11 | cost_label_max, card_label_min, card_label_max): 12 | self.condition_max_num = condition_max_num 13 | self.indexes_id = indexes_id 14 | self.tables_id = tables_id 15 | self.columns_id = columns_id 16 | self.physic_ops_id = physic_ops_id 17 | self.column_total_num = column_total_num 18 | self.table_total_num = table_total_num 19 | self.index_total_num = index_total_num 20 | self.physic_op_total_num = physic_op_total_num 21 | self.condition_op_dim = condition_op_dim 22 | self.compare_ops_id = compare_ops_id 23 | self.bool_ops_id = bool_ops_id 24 | self.bool_ops_total_num = bool_ops_total_num 25 | self.compare_ops_total_num = compare_ops_total_num 26 | self.data = data 27 | self.min_max_column = min_max_column 28 | self.word_vectors = word_vectors 29 | self.cost_label_min = cost_label_min 30 | self.cost_label_max = cost_label_max 31 | self.card_label_min = card_label_min 32 | self.card_label_max = card_label_max 33 | 34 | 35 | class ModelParameters: 36 | def __init__(self, input_size, trans_input_size, cond_hidden_size, head_num, coder_layer_num, plan_pos_size, 37 | pos_flag, attn_flag, max_leaves_num, trans_dim_feedforward, beam_width, table_num): 38 | self.input_size = input_size 39 | self.trans_input_size = trans_input_size 40 | self.cond_hidden_size = cond_hidden_size 41 | self.head_num = head_num 42 | self.coder_layer_num = coder_layer_num 43 | self.plan_pos_size = plan_pos_size 44 | self.pos_flag = pos_flag 45 | self.attn_flag = attn_flag 46 | self.max_leaves_num = max_leaves_num 47 | self.trans_dim_feedforward = trans_dim_feedforward 48 | self.beam_width = beam_width 49 | self.table_num = table_num 50 | 51 | 52 | def load_numeric_min_max(path): 53 | with open(path, 'r') as f: 54 | min_max_column = json.loads(f.read()) 55 | return min_max_column 56 | 57 | 58 | def get_data_parameters(mode): 59 | print("Preparing data...") 60 | dataset = load_dataset('table_data/') 61 | 62 | if os.path.exists('table_data/table_info.txt'): 63 | column2pos, indexes_id, tables_id, columns_id, physic_ops_id, compare_ops_id, bool_ops_id, table_names = \ 64 | pickle.load(open('table_data/table_info.txt', 'rb')) 65 | else: 66 | column2pos, indexes_id, tables_id, columns_id, physic_ops_id, compare_ops_id, bool_ops_id, table_names = \ 67 | prepare_dataset(dataset) 68 | pickle.dump(prepare_dataset(dataset), open('table_data/table_info.txt', 'wb')) 69 | print('Data prepared!') 70 | print("Loading min and max data...") 71 | min_max_column = load_numeric_min_max('table_data/min_max_vals.json') 72 | print('Min max loaded!') 73 | 74 | index_total_num = len(indexes_id) 75 | table_total_num = len(tables_id) 76 | column_total_num = len(columns_id) 77 | physic_op_total_num = len(physic_ops_id) 78 | compare_ops_total_num = len(compare_ops_id) 79 | bool_ops_total_num = len(bool_ops_id) 80 | if mode == "numeric": 81 | condition_op_dim = bool_ops_total_num + compare_ops_total_num + column_total_num + 1 82 | else: 83 | condition_op_dim = bool_ops_total_num + compare_ops_total_num + column_total_num + 600 84 | 85 | # plan_node_max_num, condition_max_num, cost_label_min, cost_label_max, card_label_min, card_label_max \ 86 | # = obtain_upper_bound_query_size('../train_data/', True, mode) 87 | plan_node_max_num, condition_max_num, cost_label_min, cost_label_max, card_label_min, card_label_max = 72, 13, \ 88 | -9.210340371976182, 14.411289201204804, -2.3025850929940455, 19.536825097210908 89 | cost_label_min_test, cost_label_max_test, card_label_min_test, card_label_max_test \ 90 | = float("inf"), -float("inf"), float("inf"), -float("inf") 91 | # plan_node_max_num, condition_max_num, cost_label_min, cost_label_max, card_label_min, card_label_max = obtain_upper_bound_query_size('../test_files_open_source/plans_seq_sample.json') 92 | # plan_node_max_num_test, condition_max_num_test, cost_label_min_test, cost_label_max_test, card_label_min_test, card_label_max_test = obtain_upper_bound_query_size('../test_files_open_source/plans_seq_sample.json') 93 | cost_label_min = min(cost_label_min, cost_label_min_test) 94 | cost_label_max = max(cost_label_max, cost_label_max_test) 95 | card_label_min = min(card_label_min, card_label_min_test) 96 | card_label_max = max(card_label_max, card_label_max_test) 97 | print('query upper size prepared') 98 | 99 | data_parameters = DataParameters(condition_max_num, indexes_id, tables_id, columns_id, physic_ops_id, 100 | column_total_num, table_total_num, index_total_num, physic_op_total_num, 101 | condition_op_dim, compare_ops_id, bool_ops_id, bool_ops_total_num, 102 | compare_ops_total_num, dataset, min_max_column, [], cost_label_min, 103 | cost_label_max, card_label_min, card_label_max) 104 | 105 | return data_parameters 106 | 107 | 108 | def load_dataset(dir_path): 109 | data = dict() 110 | data["aka_name"] = pd.read_csv(dir_path + '/aka_name.csv', header=None, low_memory=False) 111 | data["aka_title"] = pd.read_csv(dir_path + '/aka_title.csv', header=None, low_memory=False) 112 | data["cast_info"] = pd.read_csv(dir_path + '/cast_info.csv', header=None, low_memory=False) 113 | data["char_name"] = pd.read_csv(dir_path + '/char_name.csv', header=None, low_memory=False) 114 | data["company_name"] = pd.read_csv(dir_path + '/company_name.csv', header=None, low_memory=False) 115 | data["company_type"] = pd.read_csv(dir_path + '/company_type.csv', header=None, low_memory=False) 116 | data["comp_cast_type"] = pd.read_csv(dir_path + '/comp_cast_type.csv', header=None, low_memory=False) 117 | data["complete_cast"] = pd.read_csv(dir_path + '/complete_cast.csv', header=None, low_memory=False) 118 | data["info_type"] = pd.read_csv(dir_path + '/info_type.csv', header=None, low_memory=False) 119 | data["keyword"] = pd.read_csv(dir_path + '/keyword.csv', header=None, low_memory=False) 120 | data["kind_type"] = pd.read_csv(dir_path + '/kind_type.csv', header=None, low_memory=False) 121 | data["link_type"] = pd.read_csv(dir_path + '/link_type.csv', header=None, low_memory=False) 122 | data["movie_companies"] = pd.read_csv(dir_path + '/movie_companies.csv', header=None, low_memory=False) 123 | data["movie_info"] = pd.read_csv(dir_path + '/movie_info.csv', header=None, low_memory=False) 124 | data["movie_info_idx"] = pd.read_csv(dir_path + '/movie_info_idx.csv', header=None, low_memory=False) 125 | data["movie_keyword"] = pd.read_csv(dir_path + '/movie_keyword.csv', header=None, low_memory=False) 126 | data["movie_link"] = pd.read_csv(dir_path + '/movie_link.csv', header=None, low_memory=False) 127 | data["name"] = pd.read_csv(dir_path + '/name.csv', header=None, low_memory=False) 128 | data["person_info"] = pd.read_csv(dir_path + '/person_info.csv', header=None, low_memory=False) 129 | data["role_type"] = pd.read_csv(dir_path + '/role_type.csv', header=None, low_memory=False) 130 | data["title"] = pd.read_csv(dir_path + '/title.csv', header=None, low_memory=False) 131 | 132 | aka_name_column = { 133 | 'id': 0, 134 | 'person_id': 1, 135 | 'name': 2, 136 | 'imdb_index': 3, 137 | 'name_pcode_cf': 4, 138 | 'name_pcode_nf': 5, 139 | 'surname_pcode': 6, 140 | 'md5sum': 7 141 | } 142 | 143 | aka_title_column = { 144 | 'id': 0, 145 | 'movie_id': 1, 146 | 'title': 2, 147 | 'imdb_index': 3, 148 | 'kind_id': 4, 149 | 'production_year': 5, 150 | 'phonetic_code': 6, 151 | 'episode_of_id': 7, 152 | 'season_nr': 8, 153 | 'episode_nr': 9, 154 | 'note': 10, 155 | 'md5sum': 11 156 | } 157 | 158 | cast_info_column = { 159 | 'id': 0, 160 | 'person_id': 1, 161 | 'movie_id': 2, 162 | 'person_role_id': 3, 163 | 'note': 4, 164 | 'nr_order': 5, 165 | 'role_id': 6 166 | } 167 | 168 | char_name_column = { 169 | 'id': 0, 170 | 'name': 1, 171 | 'imdb_index': 2, 172 | 'imdb_id': 3, 173 | 'name_pcode_nf': 4, 174 | 'surname_pcode': 5, 175 | 'md5sum': 6 176 | } 177 | 178 | comp_cast_type_column = { 179 | 'id': 0, 180 | 'kind': 1 181 | } 182 | 183 | company_name_column = { 184 | 'id': 0, 185 | 'name': 1, 186 | 'country_code': 2, 187 | 'imdb_id': 3, 188 | 'name_pcode_nf': 4, 189 | 'name_pcode_sf': 5, 190 | 'md5sum': 6 191 | } 192 | 193 | company_type_column = { 194 | 'id': 0, 195 | 'kind': 1 196 | } 197 | 198 | complete_cast_column = { 199 | 'id': 0, 200 | 'movie_id': 1, 201 | 'subject_id': 2, 202 | 'status_id': 3 203 | } 204 | 205 | info_type_column = { 206 | 'id': 0, 207 | 'info': 1 208 | } 209 | 210 | keyword_column = { 211 | 'id': 0, 212 | 'keyword': 1, 213 | 'phonetic_code': 2 214 | } 215 | 216 | kind_type_column = { 217 | 'id': 0, 218 | 'kind': 1 219 | } 220 | 221 | link_type_column = { 222 | 'id': 0, 223 | 'link': 1 224 | } 225 | 226 | movie_companies_column = { 227 | 'id': 0, 228 | 'movie_id': 1, 229 | 'company_id': 2, 230 | 'company_type_id': 3, 231 | 'note': 4 232 | } 233 | 234 | movie_info_idx_column = { 235 | 'id': 0, 236 | 'movie_id': 1, 237 | 'info_type_id': 2, 238 | 'info': 3, 239 | 'note': 4 240 | } 241 | 242 | movie_keyword_column = { 243 | 'id': 0, 244 | 'movie_id': 1, 245 | 'keyword_id': 2 246 | } 247 | 248 | movie_link_column = { 249 | 'id': 0, 250 | 'movie_id': 1, 251 | 'linked_movie_id': 2, 252 | 'link_type_id': 3 253 | } 254 | 255 | name_column = { 256 | 'id': 0, 257 | 'name': 1, 258 | 'imdb_index': 2, 259 | 'imdb_id': 3, 260 | 'gender': 4, 261 | 'name_pcode_cf': 5, 262 | 'name_pcode_nf': 6, 263 | 'surname_pcode': 7, 264 | 'md5sum': 8 265 | } 266 | 267 | role_type_column = { 268 | 'id': 0, 269 | 'role': 1 270 | } 271 | 272 | title_column = { 273 | 'id': 0, 274 | 'title': 1, 275 | 'imdb_index': 2, 276 | 'kind_id': 3, 277 | 'production_year': 4, 278 | 'imdb_id': 5, 279 | 'phonetic_code': 6, 280 | 'episode_of_id': 7, 281 | 'season_nr': 8, 282 | 'episode_nr': 9, 283 | 'series_years': 10, 284 | 'md5sum': 11 285 | } 286 | 287 | movie_info_column = { 288 | 'id': 0, 289 | 'movie_id': 1, 290 | 'info_type_id': 2, 291 | 'info': 3, 292 | 'note': 4 293 | } 294 | 295 | person_info_column = { 296 | 'id': 0, 297 | 'person_id': 1, 298 | 'info_type_id': 2, 299 | 'info': 3, 300 | 'note': 4 301 | } 302 | data["aka_name"].columns = aka_name_column 303 | data["aka_title"].columns = aka_title_column 304 | data["cast_info"].columns = cast_info_column 305 | data["char_name"].columns = char_name_column 306 | data["company_name"].columns = company_name_column 307 | data["company_type"].columns = company_type_column 308 | data["comp_cast_type"].columns = comp_cast_type_column 309 | data["complete_cast"].columns = complete_cast_column 310 | data["info_type"].columns = info_type_column 311 | data["keyword"].columns = keyword_column 312 | data["kind_type"].columns = kind_type_column 313 | data["link_type"].columns = link_type_column 314 | data["movie_companies"].columns = movie_companies_column 315 | data["movie_info"].columns = movie_info_column 316 | data["movie_info_idx"].columns = movie_info_idx_column 317 | data["movie_keyword"].columns = movie_keyword_column 318 | data["movie_link"].columns = movie_link_column 319 | data["name"].columns = name_column 320 | data["person_info"].columns = person_info_column 321 | data["role_type"].columns = role_type_column 322 | data["title"].columns = title_column 323 | return data 324 | 325 | 326 | def prepare_dataset(database): 327 | 328 | column2pos = dict() 329 | 330 | tables = ['aka_name', 'aka_title', 'cast_info', 'char_name', 'company_name', 'company_type', 'comp_cast_type', 'complete_cast', 'info_type', 'keyword', 'kind_type', 'link_type', 'movie_companies', 'movie_info', 'movie_info_idx', 331 | 'movie_keyword', 'movie_link', 'name', 'person_info', 'role_type', 'title'] 332 | 333 | for table_name in tables: 334 | column2pos[table_name] = database[table_name].columns 335 | 336 | indexes = ['aka_name_pkey', 'aka_title_pkey', 'cast_info_pkey', 'char_name_pkey', 337 | 'comp_cast_type_pkey', 'company_name_pkey', 'company_type_pkey', 'complete_cast_pkey', 338 | 'info_type_pkey', 'keyword_pkey', 'kind_type_pkey', 'link_type_pkey', 'movie_companies_pkey', 339 | 'movie_info_idx_pkey', 'movie_keyword_pkey', 'movie_link_pkey', 'name_pkey', 'role_type_pkey', 340 | 'title_pkey', 'movie_info_pkey', 'person_info_pkey', 'company_id_movie_companies', 341 | 'company_type_id_movie_companies', 'info_type_id_movie_info_idx', 'info_type_id_movie_info', 342 | 'info_type_id_person_info', 'keyword_id_movie_keyword', 'kind_id_aka_title', 'kind_id_title', 343 | 'linked_movie_id_movie_link', 'link_type_id_movie_link', 'movie_id_aka_title', 'movie_id_cast_info', 344 | 'movie_id_complete_cast', 'movie_id_movie_ companies', 'movie_id_movie_info_idx', 345 | 'movie_id_movie_keyword', 'movie_id_movie_link', 'movie_id_movie_info', 'person_id_aka_name', 346 | 'person_id_cast_info', 'person_id_person_info', 'person_role_id_cast_info', 'role_id_cast_info'] 347 | indexes_id = dict() 348 | for idx, index in enumerate(indexes): 349 | indexes_id[index] = idx + 1 350 | physic_ops_id = {'Materialize':1, 'Sort':2, 'Hash':3, 'Merge Join':4, 'Bitmap Index Scan':5, 351 | 'Index Only Scan':6, 'BitmapAnd':7, 'Nested Loop':8, 'Aggregate':9, 'Result':10, 352 | 'Hash Join':11, 'Seq Scan':12, 'Bitmap Heap Scan':13, 'Index Scan':14, 'BitmapOr':15} 353 | strategy_id = {'Plain':1} 354 | compare_ops_id = {'=':1, '>':2, '<':3, '!=':4, '~~':5, '!~~':6, '!Null': 7, '>=':8, '<=':9} 355 | bool_ops_id = {'AND':1,'OR':2} 356 | tables_id = {} 357 | columns_id = {} 358 | table_id = 1 359 | column_id = 1 360 | for table_name in tables: 361 | tables_id[table_name] = table_id 362 | table_id += 1 363 | for column in column2pos[table_name]: 364 | columns_id[table_name+'.'+column] = column_id 365 | column_id += 1 366 | return column2pos, indexes_id, tables_id, columns_id, physic_ops_id, compare_ops_id, bool_ops_id, tables 367 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## A Unified Transferable Model in PyTorch 2 | 3 | Pytorch implementation of a unified model based transformer to explore transferability across databases and across tasks in ML-enhanced DBMS. 4 | 5 | ### Requirment 6 | 7 | PyTorch 1.8.1 8 | 9 | Python 3.8.5 10 | 11 | ### Environment configuration 12 | 13 | If your cuda version is **10.2**, you can run the following code to set up a conda environment: 14 | 15 | ``` 16 | conda env create -f join-opt-env.yml 17 | source activate join-opt-env 18 | ``` 19 | 20 | ### Running experiments 21 | 22 | There are two basic tasks, the first is multi-task learning on query optimization to explore transferability across tasks, and the second is meta learning on join order selection to explore transferability across databases. 23 | 24 | #### Multi-task learning on query optimization 25 | 26 | Run `python run_multi_task.py --help` to see list of tunable knobs. 27 | 28 | Setting `--train_data_path` and `--test_data_path` is neccessary, It represents the path of training set and test set respectively. 29 | 30 | `--mode` knob has two options. `--mode="multi-task"` to run our Transformer model, and `--cost`,`--card` and `--join` to run corresponding task.`--mode="tree-lstm"` to run Tree-LSTM model. 31 | 32 | ```python 33 | # Train and test three tasks at the same time 34 | python run_multi_task.py --mode="mulit-task" --card=1 --cost==1 --join=1 --train_data_path="/train_data" --test_data_path="/test_data" 35 | 36 | # Train and test tree-lstm model 37 | python run_multi_task.py --mode="tree-lstm" --train_data_path="/train_data" --test_data_path="/test_data" 38 | ``` 39 | 40 | #### Meta learning on join order selection 41 | 42 | Run `python run_meta_learning.py --help` to see list of tunable knobs. 43 | 44 | Setting `--train_db` , `--test_db` and `--db_path` is neccessary, It represents id of database for training, id of database for testing and the directory of databases. 45 | 46 | ```python 47 | # Train on db[0,1,2,3], test on db[4] 48 | python run_meta_learning.py --train_db 0 1 2 3 --test_db 4 --db_path="/data/meta_learner" 49 | ``` 50 | 51 | ### Contact 52 | If you have any questions about the code, please email yupei.yu@alibaba-inc.com, yangpeilun.ypl@alibaba-inc.com 53 | 54 | ### Reference 55 | 56 | If you find this repository useful in your work, please cite [our paper](https://arxiv.org/pdf/2105.02418.pdf). 57 | 58 | ``` 59 | @article{wu2021unified, 60 | title={A Unified Transferable Model for ML-Enhanced DBMS}, 61 | author={Wu, Ziniu and Yang, Peilun and Yu, Pei and Zhu, Rong and Han, Yuxing and Li, Yaliang and Lian, Defu and Zeng, Kai and Zhou, Jingren}, 62 | journal={Conference on Innovative Data Systems Research}, 63 | year={2022} 64 | } 65 | ``` 66 | 67 | -------------------------------------------------------------------------------- /representation_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from get_data_parameters import * 5 | import math 6 | from copy import deepcopy 7 | 8 | 9 | class TreeNode(object): 10 | def __init__(self, current_vec, parent): 11 | self.item = current_vec 12 | self.parent = parent 13 | self.children = [] 14 | 15 | def get_parent(self): 16 | return self.parent 17 | 18 | def get_item(self): 19 | return self.item 20 | 21 | def get_children(self): 22 | return self.children 23 | 24 | def add_child(self, child): 25 | self.children.append(child) 26 | 27 | 28 | class PositionalEncoding(nn.Module): 29 | def __init__(self, d_model, dropout=0.1, max_len=5000): 30 | super(PositionalEncoding, self).__init__() 31 | self.dropout = nn.Dropout(p=dropout) 32 | pe = torch.zeros(max_len, d_model) 33 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 34 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 35 | pe[:, 0::2] = torch.sin(position * div_term) 36 | pe[:, 1::2] = torch.cos(position * div_term) 37 | pe = pe.unsqueeze(0).transpose(0, 1) 38 | self.register_buffer('pe', pe) 39 | 40 | def forward(self, x): 41 | x = x + self.pe[:x.size(0), :] 42 | return self.dropout(x) 43 | 44 | 45 | class MultiHeadAttention(nn.Module): 46 | def __init__(self, input_dim, head_num, head_size): 47 | super(MultiHeadAttention, self).__init__() 48 | self.head_num = head_num 49 | self.head_size = head_size 50 | 51 | self.input_dim = input_dim 52 | self.out_dim = head_num * head_size 53 | self.WQ = nn.Linear(self.input_dim, self.out_dim) 54 | self.WK = nn.Linear(self.input_dim, self.out_dim) 55 | self.WV = nn.Linear(self.input_dim, self.out_dim) 56 | 57 | def forward(self, seq_input): 58 | batch_size, seq_len, embedding_size = seq_input.size() 59 | Q = self.WQ(seq_input).view(batch_size, seq_len, self.head_num, self.head_size) 60 | Q = Q.permute(0, 2, 1, 3) # bs, hn, sl, hs 61 | K = self.WK(seq_input).reshape(batch_size, seq_len, self.head_num, self.head_size) 62 | K = K.permute(0, 2, 1, 3) 63 | V = self.WV(seq_input).reshape(batch_size, seq_len, self.head_num, self.head_size) 64 | V = V.permute(0, 2, 1, 3) 65 | sim = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.head_size ** 0.5 # bs, hn, sl, sl 66 | attn = torch.matmul(F.softmax(sim, dim=3), V) # bs, hn, sl, hs 67 | attn = attn.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.out_dim) 68 | return attn 69 | 70 | 71 | class TreeLSTM(nn.Module): 72 | def __init__(self, input_dim, hidden_dim, hid_dim, middle_result_dim, task_num): 73 | # parameters.condition_op_dim, 128, 256 74 | super(TreeLSTM, self).__init__() 75 | self.hidden_dim = hidden_dim 76 | 77 | self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True) 78 | 79 | self.batch_norm1 = nn.BatchNorm1d(hid_dim) 80 | # The linear layer that maps from hidden state space to tag space 81 | 82 | self.sample_mlp = nn.Linear(1000, hid_dim) 83 | self.condition_mlp = nn.Linear(hidden_dim, hid_dim) 84 | # self.out_mlp1 = nn.Linear(hidden_dim, middle_result_dim) 85 | # self.hid_mlp1 = nn.Linear(15+108+2*hid_dim, hid_dim) 86 | # self.out_mlp1 = nn.Linear(hid_dim, middle_result_dim) 87 | 88 | self.lstm2 = nn.LSTM(15 + 108 + 2 * hid_dim, hidden_dim, batch_first=True) 89 | 90 | # self.lstm2_binary = nn.LSTM(15+108+2*hid_dim, hidden_dim, batch_first=True) 91 | # self.lstm2_binary = nn.LSTM(15+108+2*hid_dim, hidden_dim, batch_first=True) 92 | self.batch_norm2 = nn.BatchNorm1d(hidden_dim) 93 | # The linear layer that maps from hidden state space to tag space 94 | self.hid_mlp2_task1 = nn.Linear(hidden_dim, hid_dim) 95 | self.hid_mlp2_task2 = nn.Linear(hidden_dim, hid_dim) 96 | self.batch_norm3 = nn.BatchNorm1d(hid_dim) 97 | self.hid_mlp3_task1 = nn.Linear(hid_dim, hid_dim) 98 | self.hid_mlp3_task2 = nn.Linear(hid_dim, hid_dim) 99 | self.out_mlp2_task1 = nn.Linear(hid_dim, 1) 100 | self.out_mlp2_task2 = nn.Linear(hid_dim, 1) 101 | 102 | # self.hidden2values2 = nn.Linear(hidden_dim, action_num) 103 | 104 | def init_hidden(self, hidden_dim, batch_size=1): 105 | # Before we've done anything, we dont have any hidden state. 106 | # Refer to the Pytorch documentation to see exactly 107 | # why they have this dimensionality. 108 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 109 | return (torch.zeros(1, batch_size, hidden_dim).cuda(), 110 | torch.zeros(1, batch_size, hidden_dim).cuda()) 111 | 112 | def predict(self, hidden_vec): 113 | out = self.batch_norm2(hidden_vec) 114 | 115 | out_task1 = F.relu(self.hid_mlp2_task1(out)) 116 | out_task1 = self.batch_norm3(out_task1) 117 | out_task1 = F.relu(self.hid_mlp3_task1(out_task1)) 118 | out_task1 = self.out_mlp2_task1(out_task1) 119 | out_task1 = torch.sigmoid(out_task1) 120 | 121 | out_task2 = F.relu(self.hid_mlp2_task2(out)) 122 | out_task2 = self.batch_norm3(out_task2) 123 | out_task2 = F.relu(self.hid_mlp3_task2(out_task2)) 124 | out_task2 = self.out_mlp2_task2(out_task2) 125 | out_task2 = torch.sigmoid(out_task2) 126 | return out_task1, out_task2 127 | 128 | def forward(self, operators, extra_infos, condition1s, condition2s, samples, condition_masks, mapping): 129 | num_level = condition1s.size()[0] # get number of layer 130 | num_node_per_level = condition1s.size()[1] # get maximum width 131 | num_condition_per_node = condition1s.size()[2] 132 | condition_op_length = condition1s.size()[3] # get the dimension of condition_op 133 | 134 | inputs = condition1s.view(num_level * num_node_per_level, num_condition_per_node, condition_op_length) 135 | hidden = self.init_hidden(self.hidden_dim, num_level * num_node_per_level) 136 | 137 | out, hid = self.lstm1(inputs, hidden) 138 | last_output1 = hid[0].view(num_level * num_node_per_level, -1) # (14*133, 128) 139 | 140 | # condition2 141 | num_level = condition2s.size()[0] 142 | num_node_per_level = condition2s.size()[1] 143 | num_condition_per_node = condition2s.size()[2] 144 | condition_op_length = condition2s.size()[3] 145 | 146 | inputs = condition2s.view(num_level * num_node_per_level, num_condition_per_node, condition_op_length) 147 | hidden = self.init_hidden(self.hidden_dim, num_level * num_node_per_level) 148 | 149 | out, hid = self.lstm1(inputs, hidden) 150 | last_output2 = hid[0].view(num_level * num_node_per_level, -1) # (14*133, 128) 151 | 152 | last_output1 = F.relu(self.condition_mlp(last_output1)) 153 | last_output2 = F.relu(self.condition_mlp(last_output2)) 154 | last_output = (last_output1 + last_output2) / 2 155 | last_output = self.batch_norm1(last_output).view(num_level, num_node_per_level, -1) 156 | 157 | # print (last_output.size()) 158 | # torch.Size([14, 133, 256]) 159 | 160 | sample_output = F.relu(self.sample_mlp(samples)) 161 | sample_output = sample_output * condition_masks 162 | 163 | out = torch.cat((operators, extra_infos, last_output, sample_output), 2) 164 | # print (out.size()) 165 | # torch.Size([14, 133, 635]) 166 | # out = out * node_masks 167 | # start = time.time() 168 | hidden = self.init_hidden(self.hidden_dim, num_node_per_level) 169 | last_level = out[num_level - 1].view(num_node_per_level, 1, -1) 170 | # torch.Size([133, 1, 635]) 171 | _, (hid, cid) = self.lstm2(last_level, hidden) 172 | mapping = mapping.long() 173 | cost_pred = torch.zeros(size=(num_level, num_node_per_level, 1)).cuda() 174 | card_pred = torch.zeros(size=(num_level, num_node_per_level, 1)).cuda() 175 | for idx in reversed(range(0, num_level - 1)): 176 | cur_cost, cur_card = self.predict(hid[0]) 177 | cost_pred[idx+1] = cur_cost 178 | card_pred[idx+1] = cur_card 179 | mapp_left = mapping[idx][:, 0] 180 | mapp_right = mapping[idx][:, 1] 181 | pad = torch.zeros_like(hid)[:, 0].unsqueeze(1).cuda() 182 | next_hid = torch.cat((pad, hid), 1) 183 | pad = torch.zeros_like(cid)[:, 0].unsqueeze(1).cuda() 184 | next_cid = torch.cat((pad, cid), 1) 185 | hid_left = torch.index_select(next_hid, 1, mapp_left) # (133, 1, 128) 186 | cid_left = torch.index_select(next_cid, 1, mapp_left) 187 | hid_right = torch.index_select(next_hid, 1, mapp_right) # (133, 1, 128) 188 | cid_right = torch.index_select(next_cid, 1, mapp_right) 189 | hid = (hid_left + hid_right) / 2 190 | cid = (cid_left + cid_right) / 2 191 | last_level = out[idx].view(num_node_per_level, 1, -1) 192 | _, (hid, cid) = self.lstm2(last_level, (hid, cid)) 193 | cur_cost, cur_card = self.predict(hid[0]) 194 | cost_pred[0] = cur_cost 195 | card_pred[0] = cur_card 196 | return cost_pred, card_pred 197 | 198 | 199 | class UnifiedModel(nn.Module): 200 | # train: beam search with constraint + sequential loss + cost + cardinality 201 | def __init__(self, model_parameters: ModelParameters): 202 | # parameters.condition_op_dim, 128, 256 203 | super(UnifiedModel, self).__init__() 204 | p = model_parameters 205 | self.pos_flag = p.pos_flag 206 | self.attn_flag = p.attn_flag 207 | self.hidden_size = p.trans_input_size 208 | self.trans_dim_feedforward = p.trans_dim_feedforward 209 | self.max_leaves_num = p.max_leaves_num 210 | self.beam_width = p.beam_width 211 | 212 | self.trans_input1 = nn.Linear(p.input_size, p.trans_input_size) 213 | self.trans_encoder1 = nn.TransformerEncoderLayer(d_model=p.trans_input_size, 214 | nhead=p.head_num, 215 | dim_feedforward=self.trans_dim_feedforward) 216 | self.trans_encoder1s = nn.TransformerEncoder(self.trans_encoder1, p.coder_layer_num) 217 | 218 | self.batch_norm1 = nn.BatchNorm1d(p.cond_hidden_size) 219 | 220 | self.start_vec = nn.Parameter(torch.FloatTensor(size=(1, self.hidden_size))) 221 | self.start_vec = nn.init.kaiming_normal_(self.start_vec, a=0, mode='fan_in') 222 | self.sample_mlp = nn.Linear(1000, p.cond_hidden_size) 223 | self.condition_mlp = nn.Linear(p.trans_input_size, p.cond_hidden_size) 224 | 225 | self.trans_input = nn.Linear(p.cond_hidden_size + p.cond_hidden_size + 123, p.trans_input_size) 226 | self.plan_tree_pos_embedding = nn.Linear(p.plan_pos_size, p.trans_input_size) 227 | 228 | self.plan_trad_pos_encoding = PositionalEncoding(d_model=p.trans_input_size) 229 | self.trans_encoder2 = nn.TransformerEncoderLayer(d_model=p.trans_input_size, 230 | nhead=p.head_num, 231 | dim_feedforward=self.trans_dim_feedforward) 232 | self.trans_encoder2s = nn.TransformerEncoder(self.trans_encoder2, p.coder_layer_num) 233 | 234 | self.batch_norm2 = nn.BatchNorm1d(p.trans_input_size) 235 | 236 | # attention 237 | self.plan_W0 = nn.Linear(p.trans_input_size, 1) 238 | self.plan_W1 = nn.Linear(p.trans_input_size, p.trans_input_size) 239 | self.plan_W2 = nn.Linear(p.trans_input_size, p.trans_input_size) 240 | self.hid_mlp2_task1 = nn.Linear(p.trans_input_size, p.cond_hidden_size) 241 | self.hid_mlp2_task2 = nn.Linear(p.trans_input_size, p.cond_hidden_size) 242 | self.batch_norm3 = nn.BatchNorm1d(p.cond_hidden_size) 243 | self.hid_mlp3_task1 = nn.Linear(p.cond_hidden_size, p.cond_hidden_size) 244 | self.hid_mlp3_task2 = nn.Linear(p.cond_hidden_size, p.cond_hidden_size) 245 | self.out_mlp2_task1 = nn.Linear(p.cond_hidden_size, 1) 246 | self.out_mlp2_task2 = nn.Linear(p.cond_hidden_size, 1) 247 | 248 | self.trans_decoder = nn.TransformerDecoderLayer(d_model=self.hidden_size, nhead=p.head_num, 249 | dim_feedforward=self.trans_dim_feedforward) 250 | self.trans_decoders = nn.TransformerDecoder(self.trans_decoder, p.coder_layer_num) 251 | self.output = nn.Linear(p.trans_input_size, self.max_leaves_num) 252 | self.embedding = nn.Embedding(self.max_leaves_num, self.hidden_size) 253 | 254 | 255 | def encode(self, operators, extra_infos, condition1s, condition2s, samples, condition_masks, plan_pos_encoding, 256 | leaf_node_marker): 257 | batch_size, nodes_num, num_condition_per_node, condition_op_length = condition1s.size() 258 | batch_num = batch_size * nodes_num 259 | 260 | inputs = condition1s.view(batch_num, num_condition_per_node, condition_op_length) 261 | inputs = self.trans_input1(inputs) 262 | 263 | hid = self.trans_encoder1s(inputs.permute(1, 0, 2)) # (10, 64*25, 128) 264 | last_output1 = torch.mean(hid, dim=0).view(batch_num, -1) # (64*25, 128) 265 | 266 | inputs = condition2s.view(batch_num, num_condition_per_node, condition_op_length) 267 | inputs = self.trans_input1(inputs) 268 | hid = self.trans_encoder1s(inputs.permute(1, 0, 2)) 269 | last_output2 = torch.mean(hid, dim=0).view(batch_num, -1) # (64*25, 128) 270 | 271 | last_output1 = F.relu(self.condition_mlp(last_output1)) 272 | last_output2 = F.relu(self.condition_mlp(last_output2)) 273 | last_output = (last_output1 + last_output2) / 2 # ? 274 | last_output = self.batch_norm1(last_output).view(batch_size, nodes_num, -1) 275 | 276 | # print (last_output.size()) 277 | # torch.Size([64, 25, 256]) 278 | 279 | sample_output = F.relu(self.sample_mlp(samples)) 280 | sample_output = sample_output * condition_masks 281 | 282 | out = torch.cat((operators, extra_infos, last_output, sample_output), 2) 283 | # print (out.size()) 284 | # torch.Size([64, 25, 635]) 285 | # out = out * node_masks 286 | out1 = self.trans_input(out) # (64, 25, 128) 287 | 288 | out_jo = torch.cat((torch.zeros(size=(batch_size, nodes_num, 15)).cuda(), out[:, :, 15:]), dim=2) 289 | out2 = self.trans_input(out_jo) # (64, 25, 128) 290 | 291 | none_padding = torch.zeros(size=(batch_size, 1, self.hidden_size)).cuda() 292 | plan_rep_cat = torch.cat((none_padding, out2), dim=1).view(batch_size * (nodes_num + 1), -1) # (64, 26, 128) 293 | leaf_node_rep = torch.index_select(plan_rep_cat, 0, leaf_node_marker).view(batch_size, self.max_leaves_num, -1) 294 | 295 | if self.pos_flag == 1: 296 | plan_pos_encoding = self.plan_tree_pos_embedding(plan_pos_encoding) # (64, 25, 128) 297 | out1 = out1 + plan_pos_encoding 298 | elif self.pos_flag == 2: 299 | out1 = self.plan_trad_pos_encoding(out1) 300 | 301 | cost_card_rep = self.trans_encoder2(out1.permute(1, 0, 2)).permute(1, 0, 2) 302 | join_order_rep = self.trans_encoder2(leaf_node_rep.permute(1, 0, 2)).permute(1, 0, 2) # (64, 7, 128) 303 | return cost_card_rep, join_order_rep 304 | 305 | def decode(self, join_order_rep, cost_card_rep, test, teacher_forcing, trans_target, res_mask, adj_matrix): 306 | batch_size, nodes_num, _ = cost_card_rep.shape 307 | trans_target = trans_target.permute(1, 0, 2) # (7, 64, 128) 308 | ####################### join order ############################ 309 | if not test: 310 | jo_output = torch.ones(size=(batch_size, 1, self.max_leaves_num)).cuda() 311 | tgt_pre = self.start_vec.expand(batch_size, self.hidden_size).unsqueeze(0) 312 | trans_target = torch.cat((tgt_pre, trans_target), dim=0) # (8, 64, 128) 拼了初始向量,一点问题都没有 313 | for seq_id in range(join_order_rep.size()[1]): 314 | if torch.rand(1).cuda()[0] < teacher_forcing.tf: 315 | tgt_pre = self.trans_decoders(tgt=trans_target[seq_id].unsqueeze(0), 316 | memory=join_order_rep.permute(1, 0, 2)) 317 | else: 318 | tgt_pre = self.trans_decoders(tgt=tgt_pre, memory=join_order_rep.permute(1, 0, 2)) 319 | res = F.relu(self.output(tgt_pre.permute(1, 0, 2))) # (64, 1, 7) 320 | max_index = torch.argmax(res, dim=2) # (64, 1) 321 | one_hot = torch.zeros(batch_size, self.max_leaves_num).cuda().scatter_(1, max_index, 1) # (64, 7) 322 | tgt_pre = torch.cat((one_hot, trans_target[seq_id + 1, :, 7:]), dim=1).unsqueeze(0) # (1, 64, 128) 323 | # tgt_pre = torch.cat((one_hot, torch.zeros(batch_size, 121).cuda()), dim=1).unsqueeze(0) 324 | jo_output = torch.cat((jo_output, res), dim=1) 325 | join_order_output = jo_output[:, 1:, :] # (64, 7, 7) 326 | else: 327 | tgt_pre = self.start_vec.expand(batch_size, self.hidden_size).unsqueeze(0) 328 | jo_output = torch.ones(size=(batch_size, 1, self.max_leaves_num)).cuda() 329 | for seq_id in range(join_order_rep.size()[1]): 330 | tgt_pre = self.trans_decoders(tgt=tgt_pre, memory=join_order_rep.permute(1, 0, 2)) 331 | res = F.relu(self.output(tgt_pre.permute(1, 0, 2))) 332 | res += res_mask.unsqueeze(1) 333 | max_index = torch.argmax(res, dim=2) # (64, 1) 334 | for bi in range(batch_size): 335 | res_mask[bi, max_index[bi]] = -float("inf") 336 | one_hot = torch.zeros(batch_size, self.max_leaves_num).cuda().scatter_(1, max_index, 1) # (64, 7) 337 | tgt_pre = torch.cat((one_hot, trans_target[seq_id, :, 7:]), dim=1).unsqueeze(0) # (1, 64, 128) 338 | # tgt_pre = torch.cat((one_hot, torch.zeros(batch_size, 121).cuda()), dim=1).unsqueeze(0) 339 | jo_output = torch.cat((jo_output, res), dim=1) 340 | join_order_output = torch.argmax(jo_output[:, 1:, :], dim=2) # (64, 7, 7) 341 | 342 | # # check_legality 343 | # for batch_idx in range(batch_size): 344 | # if not self.check_legality(join_order_output[batch_idx], adj_matrix[batch_idx]): 345 | 346 | #################### cardinality estimation cost estiamtion ##################### 347 | plan_rep = cost_card_rep.reshape(batch_size * nodes_num, -1) 348 | plan_rep = self.batch_norm2(plan_rep) 349 | 350 | out_task1 = F.relu(self.hid_mlp2_task1(plan_rep)) 351 | out_task1 = self.batch_norm3(out_task1) 352 | out_task1 = F.relu(self.hid_mlp3_task1(out_task1)) 353 | out_task1 = self.out_mlp2_task1(out_task1) 354 | cost_output = torch.sigmoid(out_task1) 355 | cost_output = cost_output.reshape(batch_size, nodes_num, -1) 356 | 357 | out_task2 = F.relu(self.hid_mlp2_task2(plan_rep)) 358 | out_task2 = self.batch_norm3(out_task2) 359 | out_task2 = F.relu(self.hid_mlp3_task2(out_task2)) 360 | out_task2 = self.out_mlp2_task2(out_task2) 361 | card_output = torch.sigmoid(out_task2) 362 | card_output = card_output.reshape(batch_size, nodes_num, -1) 363 | 364 | return cost_output, card_output, join_order_output 365 | 366 | def forward(self, operators, extra_infos, condition1s, condition2s, samples, condition_masks, plan_pos_encoding, 367 | leaf_node_marker, test, teacher_forcing, trans_target, res_mask, adj_matrix): 368 | cost_card_rep, join_order_rep = self.encode(operators, extra_infos, condition1s, condition2s, samples, 369 | condition_masks, plan_pos_encoding, leaf_node_marker) 370 | cost_output, card_output, join_order_output = self.decode(join_order_rep, cost_card_rep, test, teacher_forcing, 371 | trans_target, res_mask, adj_matrix) 372 | batch_size = cost_output.shape[0] 373 | res_mask = res_mask.unsqueeze(2).expand(batch_size, self.max_leaves_num, self.max_leaves_num) 374 | if not test: 375 | join_order_output += res_mask 376 | return cost_output, card_output, join_order_output 377 | 378 | def step(self, tgt_input, join_order_rep, res_mask): 379 | tgt_output = self.trans_decoders(tgt=tgt_input, memory=join_order_rep.permute(1, 0, 2)) 380 | tgt_output = F.leaky_relu(self.output(tgt_output.permute(1, 0, 2))) 381 | tgt_output = tgt_output.squeeze(1) 382 | tgt_output += res_mask 383 | return tgt_output 384 | 385 | def generate_mask(self, res_mask, top_k): 386 | # (64, 7) 387 | cur_num, k = top_k.shape 388 | res_mask = res_mask.unsqueeze(1).repeat(1, k, 1) 389 | for i in range(cur_num): 390 | for j in range(k): 391 | res_mask[i, j, top_k[i, j]] = -float("inf") 392 | return res_mask.reshape(-1, self.max_leaves_num) 393 | 394 | def beam_search_test(self, operators, extra_infos, condition1s, condition2s, samples, condition_masks, 395 | plan_pos_encoding, leaf_node_marker, trans_target, res_mask_truth, adj_matrix, ground_truth): 396 | # adj_matrix (64, 7, 7) 397 | res_mask = deepcopy(res_mask_truth) 398 | batch_size = res_mask.shape[0] 399 | _, join_order_rep = self.encode(operators, extra_infos, condition1s, condition2s, samples, condition_masks, 400 | plan_pos_encoding, leaf_node_marker) 401 | 402 | first_input = self.start_vec 403 | # join_order_output = torch.ones((1, self.max_leaves_num, self.max_leaves_num)).cuda() 404 | results = [] 405 | joeus = 0 406 | for batch_id in range(batch_size): 407 | res_index = self.generate_results_test(first_input, join_order_rep[batch_id], trans_target[batch_id], 408 | res_mask[batch_id], adj_matrix[batch_id]) 409 | res_index = res_index.long() 410 | # print(res_index) 411 | # if not self.check_legality(res_index, adj_matrix[batch_id]): 412 | # print("???????????????") 413 | table_num = res_index.shape[0] 414 | joeu = self.calculate_joeu(res_index.unsqueeze(0), ground_truth[batch_id]) 415 | joeus += joeu 416 | res_index = torch.cat((res_index, -1 * torch.ones(self.max_leaves_num - table_num).cuda())).unsqueeze(0) 417 | # print(res_index) 418 | results.append(res_index) 419 | results = torch.cat(results) 420 | return results, joeus / batch_size 421 | 422 | def test_mask(self, seen, valid_tables): 423 | # (seen, ) (1, ) (7, 7) 424 | res_mask = torch.zeros(self.max_leaves_num).cuda() 425 | for idx in range(seen.shape[0]): 426 | res_mask[seen[idx]] = -float("inf") 427 | for idx in range(self.max_leaves_num): 428 | if res_mask[idx] == -float("inf"): 429 | continue 430 | if valid_tables[idx] == 0: 431 | res_mask[idx] = -float("inf") 432 | return res_mask 433 | 434 | def generate_results_test(self, input_vec, join_order_rep, trans_target, res_mask, adj_matrix): 435 | # (1, 128) (7, 128) (7, 128) 436 | res_mask = res_mask.unsqueeze(0) # (1, 7) 437 | 438 | cur_output = self.step(input_vec.unsqueeze(0), join_order_rep.unsqueeze(1), res_mask) # (7, ) 439 | final_output = torch.ones(1, 1).cuda() 440 | cur_prob = torch.ones(1).cuda() 441 | table_num = res_mask[0].shape[0] - torch.isinf(res_mask[0]).sum() 442 | record_table_num = table_num 443 | valid_tables = torch.zeros(1, 7).bool().cuda() 444 | # print(table_num, res_mask) 445 | for i in range(table_num): 446 | # time step 447 | tmp_dis = cur_output.view(-1, self.max_leaves_num) # (1, 7) 448 | tmp_dis = F.softmax(tmp_dis, dim=1) 449 | cur_num = tmp_dis.shape[0] 450 | # print(top_k_value, top_k) 451 | next_final_output = [] 452 | next_cur_output = [] 453 | next_prob = [] 454 | next_res_mask = [] 455 | next_valid_tables = [] 456 | # print(f"time step {i}") 457 | for j in range(cur_num): 458 | # the possible number of last time step 459 | table_rest_num = res_mask[j].shape[0] - torch.isinf(res_mask[j]).sum() 460 | can_expand_num = min(self.beam_width, table_rest_num) 461 | # print(can_expand_num) 462 | if can_expand_num == 0: 463 | continue 464 | top_k_value, top_k = tmp_dis[j].topk(can_expand_num) # (1, 3) 465 | tmp_final_output = [] 466 | for k in range(can_expand_num): 467 | # current expand number 468 | next_prob.append(cur_prob[j]*top_k_value[k]) 469 | tmp_valid_tables = valid_tables[j] | adj_matrix[top_k[k]] 470 | seen = torch.cat((final_output[j], torch.Tensor([top_k[k]]).cuda())) 471 | tmp_final_output.append(seen.unsqueeze(0)) 472 | 473 | tmp_res_mask = self.test_mask(seen[1:].long(), tmp_valid_tables) 474 | # print(tmp_res_mask) 475 | next_res_mask.append(tmp_res_mask.unsqueeze(0)) 476 | next_valid_tables.append(tmp_valid_tables.unsqueeze(0)) 477 | 478 | one_hot = F.one_hot(top_k[k], num_classes=self.max_leaves_num).float() # (1, 7) 479 | cur_tgt = torch.cat((one_hot, trans_target[i, self.max_leaves_num:])).unsqueeze(0) # (1, 128) 480 | tmp_cur_output = self.step(cur_tgt.unsqueeze(0), join_order_rep.unsqueeze(0), tmp_res_mask) # (1, 7) 481 | next_cur_output.append(tmp_cur_output) 482 | next_final_output.append(torch.cat(tmp_final_output)) 483 | 484 | res_mask = torch.cat(next_res_mask) 485 | # print(res_mask.shape) 486 | valid_tables = torch.cat(next_valid_tables) 487 | # print(valid_tables.shape) 488 | final_output = torch.cat(next_final_output) 489 | # print(final_output.shape) 490 | cur_prob = torch.Tensor(next_prob) 491 | # print(cur_prob.shape) 492 | cur_output = torch.cat(next_cur_output) 493 | # print(cur_output.shape) 494 | 495 | final_output = final_output[:, 1:] 496 | assert final_output.shape[1] == record_table_num 497 | final_index = torch.argmax(cur_prob) # (prob_num) 498 | 499 | return final_output[final_index] 500 | 501 | def generate_results(self, input_vec, join_order_rep, trans_target, res_mask, adj_matrix, ground_truth): 502 | # (1, 128) (7, 128) (7, 128) 503 | # beam search (train or test) 504 | final_prob = torch.ones(1).cuda() 505 | final_output = torch.ones(size=(1, 1, self.max_leaves_num)).cuda() 506 | final_index = torch.ones((1, 1)).cuda() 507 | res_mask = res_mask.unsqueeze(0) # (1, 7) 508 | 509 | cur_output = self.step(input_vec.unsqueeze(0), join_order_rep.unsqueeze(1), res_mask) # (7, ) 510 | final_output = torch.cat((final_output, cur_output.view(1, 1, self.max_leaves_num)), dim=1) 511 | table_num = res_mask[0].shape[0] - torch.isinf(res_mask[0]).sum() 512 | 513 | for i in range(table_num): 514 | # print(res_mask) 515 | table_rest_num = res_mask[0].shape[0] - torch.isinf(res_mask[0]).sum() 516 | 517 | tmp_dis = cur_output.view(-1, self.max_leaves_num) # (1, 7) 518 | tmp_dis = F.softmax(tmp_dis, dim=1) 519 | cur_num = tmp_dis.shape[0] 520 | 521 | can_expand_num = min(self.beam_width, table_rest_num) 522 | top_k_value, top_k = tmp_dis.topk(can_expand_num, dim=1) # (1, 3) 523 | # print(top_k_value, top_k) 524 | 525 | final_prob = final_prob.unsqueeze(1).repeat(1, can_expand_num).reshape(-1) 526 | # print(final_prob, top_k_value.reshape(-1)) 527 | final_index = final_index.unsqueeze(1).repeat(1, can_expand_num, 1).reshape(cur_num * can_expand_num, -1) 528 | final_index = torch.cat((final_index, top_k.view(cur_num * can_expand_num, 1)), dim=1) 529 | final_prob = final_prob * top_k_value.reshape(-1) 530 | 531 | res_mask = self.generate_mask(res_mask, top_k) # (3, 7) 532 | one_hot = F.one_hot(top_k, num_classes=self.max_leaves_num).float() # (1, en, 7) 533 | cur_tgt = trans_target[i, :].unsqueeze(1).repeat(1, cur_num * can_expand_num, 1) 534 | cur_tgt = cur_tgt.reshape(cur_num, can_expand_num, self.hidden_size) 535 | cur_tgt = torch.cat((one_hot, cur_tgt[:, :, self.max_leaves_num:]), dim=2) # (1, en, 128) 536 | cur_tgt = cur_tgt.view(-1, self.hidden_size).unsqueeze(0) 537 | 538 | join_order_rep_tmp = join_order_rep.unsqueeze(1).repeat(1, cur_num * can_expand_num, 1) # (7, ., 128) 539 | cur_output = self.step(cur_tgt, join_order_rep_tmp, res_mask) # (3, 7) 540 | 541 | final_output = final_output.unsqueeze(1).repeat(1, can_expand_num, 1, 1).reshape(cur_num * can_expand_num, 542 | -1, self.max_leaves_num) 543 | cur_output = cur_output.view(-1, 1, self.max_leaves_num) 544 | final_output = torch.cat((final_output, cur_output), dim=1) 545 | final_output = final_output[:, 1:, :] # (prob_num, table_num, 7) 546 | # print(final_output.shape, final_prob.shape) 547 | # print(f"final_prob {final_prob}") 548 | 549 | sort_prob, sort_prob_index = torch.sort(final_prob, descending=True) # (prob_num) 550 | all_prob_index = final_index[:, 1:] # (prob_num, table_num) 551 | illegal_index = torch.zeros((1, table_num)).cuda() 552 | illegal_prob = torch.zeros(1).cuda() 553 | 554 | legal_index = torch.zeros((1, table_num)).cuda() 555 | legal_prob = torch.zeros(1).cuda() 556 | 557 | for prob_idx in range(sort_prob_index.shape[0]): 558 | cur_index = all_prob_index[sort_prob_index[prob_idx]].long() 559 | cur_prob = sort_prob[prob_idx] 560 | if self.check_legality(cur_index, adj_matrix): 561 | legal_index = torch.cat((legal_index, cur_index.unsqueeze(0)), dim=0) 562 | legal_prob = torch.cat((legal_prob, cur_prob.unsqueeze(0)), dim=0) 563 | else: 564 | illegal_index = torch.cat((illegal_index, cur_index.unsqueeze(0)), dim=0) 565 | illegal_prob = torch.cat((illegal_prob, cur_prob.unsqueeze(0)), dim=0) 566 | 567 | illegal_index = illegal_index[1:] 568 | illegal_prob = illegal_prob[1:] 569 | legal_index = legal_index[1:] 570 | legal_prob = legal_prob[1:] 571 | 572 | if min(legal_prob.shape) == 0: 573 | optimal_prob = 0 574 | optimal_index = torch.zeros(1).long().cuda() 575 | else: 576 | joeus = self.calculate_joeu(legal_index, ground_truth) 577 | optimal_index = torch.argmax(joeus) 578 | optimal_prob = legal_prob[optimal_index] 579 | 580 | legal_index = self.del_tensor_ele(legal_index, optimal_index.item()) 581 | legal_prob = self.del_tensor_ele(legal_prob, optimal_index.item()) 582 | return illegal_index, illegal_prob, legal_index, legal_prob, optimal_index, optimal_prob 583 | 584 | def del_tensor_ele(self, tensor, index): 585 | return torch.cat((tensor[:index], tensor[index+1:]), dim=0) 586 | 587 | def calculate_sequential_loss(self, operators, extra_infos, condition1s, condition2s, samples, condition_masks, 588 | plan_pos_encoding, leaf_node_marker, trans_target, res_mask_truth, adj_matrix, ground_truth): 589 | # adj_matrix (64, 7, 7) 590 | # beam search train seq loss 591 | # print("Calculating sequential loss...") 592 | res_mask = deepcopy(res_mask_truth) 593 | batch_size = res_mask.shape[0] 594 | _, join_order_rep = self.encode(operators, extra_infos, condition1s, condition2s, samples, condition_masks, 595 | plan_pos_encoding, leaf_node_marker) 596 | 597 | first_input = self.start_vec 598 | seq_loss = 0 599 | for batch_id in range(batch_size): 600 | illegal_index, illegal_prob, legal_index, legal_prob, optimal_index, optimal_prob \ 601 | = self.generate_results(first_input, join_order_rep[batch_id], trans_target[batch_id], 602 | res_mask[batch_id], adj_matrix[batch_id], ground_truth[batch_id]) 603 | table_num = illegal_index.shape[0] 604 | if min(illegal_prob.shape) == 0: # it is reasonable that loss is 0 if there are no illeagal outputs 605 | illegal_loss = 0 606 | else: 607 | illegal_loss = torch.log(torch.sum(torch.exp(1/table_num * torch.log(illegal_prob)))) 608 | if optimal_prob == 0: 609 | optimal_loss = 0 610 | else: 611 | optimal_loss = -torch.log(optimal_prob) 612 | if min(legal_prob.shape) == 0: 613 | legal_risk = 0 614 | else: 615 | joeus = self.calculate_joeu(legal_index, ground_truth[batch_id]) 616 | legal_risk = self.calculate_risk(legal_prob, joeus) 617 | optimal_loss = -torch.log(optimal_prob) 618 | seq_loss += (optimal_loss + illegal_loss + legal_risk) 619 | return seq_loss/batch_size 620 | 621 | def calculate_joeu(self, legal_index, gt): 622 | # (prob, table_num) (prob) (7) 623 | if min(legal_index.shape) == 0: 624 | return 0 625 | prob_num, table_num = legal_index.shape 626 | joeus = torch.ones(1).cuda() 627 | for prob_id in range(prob_num): 628 | cnt = 0 629 | cur_index = legal_index[prob_id] 630 | if cur_index[0] == gt[0]: 631 | for t in range(1, table_num): 632 | if cur_index[t] == gt[t]: 633 | cnt += 1 634 | else: 635 | break 636 | elif cur_index[0] == gt[1] and cur_index[1] == gt[0]: 637 | cnt += 1 638 | for t in range(2, table_num): 639 | if cur_index[t] == gt[t]: 640 | cnt += 1 641 | else: 642 | break 643 | joeus = torch.cat((joeus, torch.Tensor([cnt/(table_num-1)]).cuda()), dim=0) 644 | return joeus[1:] 645 | 646 | def calculate_risk(self, legal_prob, joeus): 647 | return torch.sum((1 - joeus) * legal_prob) 648 | 649 | def check_legality(self, join_order_index, adj_matrix): 650 | # (table_num) (7, 7) 651 | cur_tables_num = join_order_index.shape[0] 652 | # check 653 | seen = deepcopy(adj_matrix[join_order_index[0]]).bool() 654 | for idx in range(1, cur_tables_num): 655 | if seen[join_order_index[idx]] == 0: 656 | return False 657 | seen = seen | adj_matrix[join_order_index[idx]] 658 | return True 659 | 660 | 661 | class MetaLearner(nn.Module): 662 | # transferability across databases 663 | # meta feature 664 | def __init__(self, trans_input_size, head_num, trans_dim_feedforward, coder_layer_num, max_leaves_num, drop_out): 665 | # parameters.condition_op_dim, 128, 256 666 | super(MetaLearner, self).__init__() 667 | self.max_leaves_num = max_leaves_num 668 | self.head_num = head_num 669 | self.hidden_size = trans_input_size 670 | self.trans_encoder1 = nn.TransformerEncoderLayer(d_model=trans_input_size, 671 | nhead=head_num, 672 | dim_feedforward=trans_dim_feedforward, 673 | dropout=drop_out) 674 | self.trans_encoder1s = nn.TransformerEncoder(self.trans_encoder1, coder_layer_num) 675 | 676 | self.start_vec = nn.Parameter(torch.FloatTensor(size=(1, self.hidden_size))) 677 | self.start_vec = nn.init.kaiming_normal_(self.start_vec, a=0, mode='fan_in') 678 | 679 | self.trans_encoder2 = nn.TransformerEncoderLayer(d_model=trans_input_size, 680 | nhead=head_num, 681 | dim_feedforward=trans_dim_feedforward) 682 | self.trans_encoder2s = nn.TransformerEncoder(self.trans_encoder2, coder_layer_num) 683 | 684 | self.trans_decoder = nn.TransformerDecoderLayer(d_model=self.hidden_size, nhead=head_num, 685 | dim_feedforward=trans_dim_feedforward, dropout=drop_out) 686 | self.trans_decoders = nn.TransformerDecoder(self.trans_decoder, coder_layer_num) 687 | self.output = nn.Linear(self.hidden_size, self.max_leaves_num) 688 | 689 | def encode(self, feature_encoding, encoder_mask, agg_matrix): 690 | # (64, 30, 128) (64, 30, 30) (64, 10, 30) 691 | # print(feature_encoding.shape, encoder_mask.shape, agg_matrix.shape) 692 | bsz, tgt_len, tgt_len = encoder_mask.shape 693 | encoder_mask = encoder_mask.unsqueeze(1).repeat(1, self.head_num, 1, 1).view(-1, tgt_len, tgt_len) 694 | assert self.head_num*bsz == encoder_mask.shape[0] 695 | join_order_rep = self.trans_encoder1s(feature_encoding.permute(1, 0, 2), mask=encoder_mask) # (20, 64, 128) 696 | join_order_rep = torch.matmul(agg_matrix, join_order_rep.permute(1, 0, 2)) 697 | return join_order_rep 698 | 699 | def decode(self, join_order_rep, test, teacher_forcing, trans_target, res_mask, adj_matrix): 700 | batch_size, nodes_num, _ = join_order_rep.shape 701 | # join_order_rep = torch.zeros_like(join_order_rep) 702 | trans_target = trans_target.squeeze() 703 | trans_target = trans_target.permute(1, 0, 2) # (10, 64, 128) 704 | org_res_mask = deepcopy(res_mask) 705 | ####################### join order ############################ 706 | if not test: 707 | jo_output = torch.ones(size=(batch_size, 1, self.max_leaves_num)).cuda() 708 | tgt_pre = self.start_vec.expand(batch_size, self.hidden_size).unsqueeze(0) 709 | trans_target = torch.cat((tgt_pre, trans_target), dim=0) # (8, 64, 128) 710 | for seq_id in range(join_order_rep.size()[1]): 711 | if torch.rand(1).cuda()[0] < teacher_forcing.tf: 712 | tgt_pre = self.trans_decoders(tgt=trans_target[seq_id].unsqueeze(0), 713 | memory=join_order_rep.permute(1, 0, 2)) 714 | else: 715 | tgt_pre = self.trans_decoders(tgt=tgt_pre, memory=join_order_rep.permute(1, 0, 2)) 716 | res = torch.sigmoid(self.output(tgt_pre.permute(1, 0, 2))) # (64, 1, 7) 717 | max_index = torch.argmax(res, dim=2) # (64, 1) 718 | one_hot = torch.zeros(batch_size, self.max_leaves_num).cuda().scatter_(1, max_index, 1) # (64, 7) 719 | tgt_pre = torch.cat((one_hot, trans_target[seq_id + 1, :, self.max_leaves_num:]), dim=1).unsqueeze(0) # (1, 64, 128) 720 | jo_output = torch.cat((jo_output, res), dim=1) 721 | join_order_output = jo_output[:, 1:, :] # (64, 7, 7) 722 | else: 723 | seen = torch.zeros((batch_size, 1)).cuda() 724 | valid_tables = torch.zeros((batch_size, self.max_leaves_num)).bool().cuda() 725 | tgt_pre = self.start_vec.expand(batch_size, self.hidden_size).unsqueeze(0) 726 | jo_output = torch.ones(size=(batch_size, 1, self.max_leaves_num)).cuda() 727 | for seq_id in range(join_order_rep.size()[1]): 728 | tgt_pre = self.trans_decoders(tgt=tgt_pre, memory=join_order_rep.permute(1, 0, 2)) 729 | res = torch.sigmoid(self.output(tgt_pre.permute(1, 0, 2))) 730 | res += res_mask.unsqueeze(1) 731 | max_index = torch.argmax(res, dim=2) # (64, 1) 732 | seen = torch.cat((seen, max_index), dim=1) 733 | 734 | for bi in range(batch_size): 735 | valid_tables[bi] = valid_tables[bi] | adj_matrix[bi, max_index[bi]] 736 | res_mask[bi] = org_res_mask[bi] + self.test_mask(seen[bi][1:].long(), valid_tables[bi]) 737 | if seq_id == 0 and res_mask[bi].min() == -float("inf") and res_mask[bi].max() == -float("inf"): 738 | print(max_index[bi], adj_matrix[bi], valid_tables[bi]) 739 | 740 | # tgt_pre = self.embedding(max_index.squeeze()).unsqueeze(0) 741 | one_hot = torch.zeros(batch_size, self.max_leaves_num).cuda().scatter_(1, max_index, 1) # (64, 7) 742 | tgt_pre = torch.cat((one_hot, trans_target[seq_id, :, self.max_leaves_num:]), dim=1).unsqueeze( 743 | 0) # (1, 64, 128) 744 | jo_output = torch.cat((jo_output, res), dim=1) 745 | join_order_output = jo_output[:, 1:, :] 746 | 747 | return join_order_output # (64, 10, 10) 748 | 749 | def forward(self, feature_encoding, agg_matrix, encoder_mask, test, teacher_forcing, trans_target, res_mask, adj_matrix): 750 | join_order_rep = self.encode(feature_encoding, encoder_mask, agg_matrix) 751 | join_order_output = self.decode(join_order_rep, test, teacher_forcing, trans_target, res_mask, adj_matrix) 752 | batch_size = join_order_output.shape[0] 753 | res_mask = res_mask.squeeze().unsqueeze(1).expand(batch_size, self.max_leaves_num, self.max_leaves_num) 754 | if not test: 755 | join_order_output += res_mask 756 | return join_order_output 757 | 758 | def test_mask(self, seen, valid_tables): 759 | res_mask = torch.zeros(self.max_leaves_num).cuda() 760 | for idx in range(seen.shape[0]): 761 | res_mask[seen[idx]] = -float("inf") 762 | for idx in range(self.max_leaves_num): 763 | if res_mask[idx] == -float("inf"): 764 | continue 765 | if valid_tables[idx] == 0: 766 | res_mask[idx] = -float("inf") 767 | return res_mask 768 | -------------------------------------------------------------------------------- /run_meta_learner.py: -------------------------------------------------------------------------------- 1 | from representation_model import * 2 | from vector_loader import * 3 | from tools.pytool import * 4 | from tools.utils import * 5 | import torch 6 | import time 7 | import pickle 8 | import tqdm 9 | import os 10 | import argparse 11 | 12 | 13 | def zero_shot_learning(train_start, train_end, validate_start, validate_end, num_epochs, patience, lr, trans_input_size, 14 | head_num, trans_dim_feedforward, coder_layer_num, max_leaves_num, drop_out, train_db, val_db, 15 | save_f, db_path, save_model): 16 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | decay_rate = 0.9 18 | decay_point = 10 19 | batch_size = 64 20 | seed = 666 21 | model = MetaLearner(trans_input_size, head_num, trans_dim_feedforward, coder_layer_num, max_leaves_num, drop_out).to(DEVICE) 22 | teacher_forcing = TeacherForcing(start_tf=1, decay_rate=decay_rate, decay_point=decay_point, cur_epoch=0, end_epoch=50, 23 | verbose=True) 24 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 25 | early_stopping = EarlyStopping(patience=patience, verbose=True, path=f'model/{save_model}.pt') 26 | criterion = nn.CrossEntropyLoss() 27 | model.train() 28 | start = time.time() 29 | max_corr_num = 0 30 | 31 | print(f"train_db: {train_db}") 32 | print(f"val_db: {val_db}") 33 | train_batches = 100 34 | val_batches = 20 35 | 36 | for epoch in range(num_epochs): 37 | # db_list, shuffle, seed, suffix, batch_num, test, 38 | train_iter = get_batch_meta_learner_iterator(train_db, 1, seed, 3, train_batches, 0, db_path) 39 | val_iter = get_batch_meta_learner_iterator(val_db, 0, seed, 3, val_batches, 0, db_path) 40 | train_loss = [] 41 | print("=====================================") 42 | save_f.write("=====================================\n") 43 | train_start_time = time.time() 44 | for train_data in tqdm.tqdm(train_iter): 45 | test = 0 46 | ground_truth_batch, agg_matrix_batch, attn_mask_batch, trans_target_batch, feature_encoding_batch, \ 47 | res_mask_batch, adj_matrix_batch = train_data 48 | 49 | ground_truth_batch = torch.LongTensor(ground_truth_batch).to(DEVICE) 50 | agg_matrix_batch = torch.Tensor(agg_matrix_batch).to(DEVICE) 51 | trans_target_batch = torch.FloatTensor(trans_target_batch).to(DEVICE) 52 | attn_mask_batch = torch.Tensor(attn_mask_batch).to(DEVICE) 53 | feature_encoding_batch = torch.FloatTensor(feature_encoding_batch).to(DEVICE) 54 | res_mask_batch = torch.FloatTensor(res_mask_batch).to(DEVICE) 55 | adj_matrix_batch = torch.BoolTensor(adj_matrix_batch).to(DEVICE) 56 | 57 | optimizer.zero_grad() 58 | join_order_output = model(feature_encoding_batch, agg_matrix_batch, attn_mask_batch, test, teacher_forcing, 59 | trans_target_batch, res_mask_batch, adj_matrix_batch) 60 | loss = join_order_loss(join_order_output, ground_truth_batch, criterion) 61 | train_loss.append(loss) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | batch_num = train_end - train_start 66 | train_end_time = time.time() 67 | teacher_forcing.check() 68 | print('Training batch time: ', train_end_time - train_start_time) 69 | print("Epoch {}, training all tasks loss: {}".format(epoch, sum(train_loss) / batch_num)) 70 | save_f.write("Epoch {}, training all tasks loss: {}\n".format(epoch, sum(train_loss) / batch_num)) 71 | 72 | cpl_correct_cnt, icpl_correct_cnt, pos_correct_cnt, total_pos_cnt = 0, 0, 0, 0 73 | cpl_correct_cnt_rp, pos_correct_cnt_rp, icpl_correct_cnt_rp = 0, 0, 0 74 | model.eval() 75 | results = [] 76 | random_res = [] 77 | 78 | for val_data in tqdm.tqdm(val_iter): 79 | test = 1 80 | ground_truth_batch, agg_matrix_batch, attn_mask_batch, trans_target_batch, feature_encoding_batch, \ 81 | res_mask_batch, adj_matrix_batch = val_data 82 | ground_truth_batch = torch.LongTensor(ground_truth_batch).to(DEVICE) 83 | agg_matrix_batch = torch.Tensor(agg_matrix_batch).to(DEVICE) 84 | trans_target_batch = torch.FloatTensor(trans_target_batch).to(DEVICE) 85 | attn_mask_batch = torch.Tensor(attn_mask_batch).to(DEVICE) 86 | feature_encoding_batch = torch.FloatTensor(feature_encoding_batch).to(DEVICE) 87 | res_mask_batch = torch.FloatTensor(res_mask_batch).to(DEVICE) 88 | adj_matrix_batch = torch.BoolTensor(adj_matrix_batch).to(DEVICE) 89 | 90 | join_order_output = model(feature_encoding_batch, agg_matrix_batch, attn_mask_batch, test, teacher_forcing, 91 | trans_target_batch, res_mask_batch, adj_matrix_batch) 92 | results += output_file(join_order_output, ground_truth_batch) 93 | ccc, pcc, tpn, icc = prediction_compare(join_order_output, ground_truth_batch) 94 | ccc_rp, pcc_rp, _, icc_rp, rd_list = random_prediction(ground_truth_batch, adj_matrix_batch) 95 | random_res += rd_list 96 | cpl_correct_cnt += ccc 97 | pos_correct_cnt += pcc 98 | icpl_correct_cnt += icc 99 | total_pos_cnt += tpn 100 | 101 | cpl_correct_cnt_rp += ccc_rp 102 | icpl_correct_cnt_rp += icc_rp 103 | pos_correct_cnt_rp += pcc_rp 104 | 105 | batch_num = validate_end - validate_start 106 | cur_corr = cpl_correct_cnt+icpl_correct_cnt 107 | pickle.dump(results, open("/mnt/train_data/meta_learner/log/prediction_rd.pkl", "wb")) 108 | if cur_corr > max_corr_num: 109 | max_corr_num = cur_corr 110 | pickle.dump(results, open("/mnt/train_data/meta_learner/log/prediction.pkl", "wb")) 111 | early_stopping(-cpl_correct_cnt-icpl_correct_cnt, model) 112 | print(f"Epoch {epoch}: ") 113 | print(f" Complete correct count: {cpl_correct_cnt}/{len(val_db)*val_batches * batch_size}") 114 | save_f.write(f" Position correct count: {pos_correct_cnt}/{total_pos_cnt}\n") 115 | print(f" Position correct count: {pos_correct_cnt}/{total_pos_cnt}") 116 | print(f" Incomplete correct count: {icpl_correct_cnt}/{len(val_db)*val_batches * batch_size}") 117 | print(f" Current the maximum of ccc+ic: {max_corr_num}") 118 | save_f.write(f" Current the maximum of ccc+ic: {max_corr_num}\n") 119 | print(f"Random Prediction:") 120 | print(f" Complete correct count: {cpl_correct_cnt_rp}/{len(val_db)*val_batches * batch_size}") 121 | print(f" Incomplete correct count: {icpl_correct_cnt_rp}/{len(val_db)*val_batches * batch_size}") 122 | 123 | if early_stopping.early_stop: 124 | print('Early stopping') 125 | break 126 | end = time.time() 127 | print("======================================") 128 | print(f'total time cost:{end - start}') 129 | return max_corr_num 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--mode', type=str, default="multi-task", help="multi_task: our model, tree-lstm: previous SOTA") 135 | parser.add_argument('--epochs', type=int, default=200, help="the training epochs") 136 | parser.add_argument('--patience', type=int, default=10, help="the patience of early stopping") 137 | parser.add_argument('--lr', type=float, default=1e-6, help="the learning rate") 138 | parser.add_argument('--bw', type=int, default=3, help="the beam width") 139 | parser.add_argument('--phs', type=int, default=256, help="the hidden size of predicate encoding") 140 | parser.add_argument('--hn', type=int, default=4, help="the head number of transformer") 141 | parser.add_argument('--cln', type=int, default=3, help="the layer number of (en/de)coder") 142 | parser.add_argument('--tdf', type=int, default=64, help="the size of transformer feedforward network") 143 | parser.add_argument('--dp', type=int, default=0.2, help="the drop out of transformer") 144 | parser.add_argument('--train_db', type=int, nargs="+", help="the list of training databases") 145 | parser.add_argument('--test_db', type=int, nargs="+", help="the path of test data") 146 | parser.add_argument('--db_path', type=str, default="/mnt/train_data/meta_learner", 147 | help="the directory of databases") 148 | args = parser.parse_args() 149 | 150 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 151 | SEED = 666 152 | torch.manual_seed(SEED) 153 | 154 | # train setting: 155 | epochs = args.epochs 156 | patience = args.patience 157 | lr = args.lr 158 | 159 | # model set: 160 | trans_input_size = 128 161 | head_num = args.hn 162 | coder_layer_num = args.cln 163 | trans_dim_feedforward = args.tdf 164 | drop_out = args.dp 165 | 166 | teacher_forcing_set = 1 167 | print("******************************************************") 168 | print(f'learning rate: {lr}') 169 | print(f'trans_input_size: {trans_input_size}') 170 | print(f'head_num: {head_num}') 171 | print(f'coder_layer_num: {coder_layer_num}') 172 | print(f'trans_dim_feedforward: {trans_dim_feedforward}') 173 | print(f'drop_out: {drop_out}') 174 | print("******************************************************") 175 | 176 | total_files_num = 120 177 | train_start = 0 178 | train_end = int(total_files_num*0.9) 179 | val_start = int(total_files_num*0.9) 180 | val_end = total_files_num 181 | 182 | db_path = args.db_path 183 | train_db = args.train_db 184 | val_db = args.test_db 185 | max_leaves_num = 10 186 | save_model = "meta_learner" 187 | with open(f"meta_learning_log.txt", "w") as save_f: 188 | save_f.write(f"train_db: {train_db}, val_db: {val_db}\n\n") 189 | 190 | t = zero_shot_learning(train_start, train_end, val_start, val_end, epochs, patience, lr, trans_input_size, 191 | head_num, trans_dim_feedforward, coder_layer_num, max_leaves_num, drop_out, 192 | train_db, val_db, save_f, db_path, save_model) 193 | 194 | -------------------------------------------------------------------------------- /run_multi_task.py: -------------------------------------------------------------------------------- 1 | from tools.pytool import * 2 | from tools.utils import * 3 | from get_data_parameters import * 4 | from representation_model import * 5 | import torch 6 | import time 7 | import os 8 | import argparse 9 | import tqdm 10 | from vector_loader import * 11 | 12 | 13 | def train_lstm(train_start, train_end, validate_start, validate_end, num_epochs, patience, lr, data_parameters, 14 | model_parameters, directory, cuda, each_node, test, save_model): 15 | """ 16 | Training tree-lstm on both cradinality estimation and cost estimation. 17 | """ 18 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | input_dim = data_parameters.condition_op_dim 20 | hidden_dim = model_parameters.trans_input_size 21 | hid_dim = model_parameters.cond_hidden_size 22 | middle_result_dim = 128 23 | task_num = 2 24 | if cuda: 25 | model = TreeLSTM(input_dim, hidden_dim, hid_dim, middle_result_dim, task_num).cuda() 26 | else: 27 | model = TreeLSTM(input_dim, hidden_dim, hid_dim, middle_result_dim, task_num) 28 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 29 | early_stopping = EarlyStopping(patience=patience, verbose=True, path='model/' + save_model + ".pt") 30 | 31 | model.train() 32 | start = time.time() 33 | vali_card_qerror_list = [] 34 | vali_cost_qerror_list = [] 35 | 36 | if test: 37 | num_epochs = 1 38 | for epoch in range(num_epochs): 39 | cost_loss_total = 0. 40 | card_loss_total = 0. 41 | model.train() 42 | print("=====================================") 43 | for batch_idx in tqdm.tqdm(range(train_start, train_end)): 44 | if test: 45 | break 46 | target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss, sampless, \ 47 | condition_maskss, mapping = get_tree_lstm_batch_data(batch_idx, 0, directory=directory) 48 | 49 | target_cost = torch.FloatTensor(target_cost).to(DEVICE) 50 | target_cardinality = torch.FloatTensor(target_cardinality).to(DEVICE) 51 | operatorss = torch.FloatTensor(operatorss).to(DEVICE).squeeze(0) 52 | extra_infoss = torch.FloatTensor(extra_infoss).to(DEVICE).squeeze(0) 53 | condition1ss = torch.FloatTensor(condition1ss).to(DEVICE).squeeze(0) 54 | condition2ss = torch.FloatTensor(condition2ss).to(DEVICE).squeeze(0) 55 | sampless = torch.FloatTensor(sampless).to(DEVICE).squeeze(0) 56 | condition_maskss = torch.FloatTensor(condition_maskss).to(DEVICE).squeeze(0).unsqueeze(2) 57 | mapping = torch.FloatTensor(mapping).to(DEVICE).squeeze(0) 58 | 59 | optimizer.zero_grad() 60 | estimate_cost, estimate_cardinality = model(operatorss, extra_infoss, condition1ss, condition2ss, sampless, 61 | condition_maskss, mapping) 62 | cost_loss, cost_loss_median, cost_loss_max = qerror_loss_each_node(estimate_cost, target_cost, 63 | data_parameters.cost_label_min, 64 | data_parameters.cost_label_max, 65 | mapping, each_node) 66 | card_loss, card_loss_median, card_loss_max = qerror_loss_each_node(estimate_cardinality, 67 | target_cardinality, 68 | data_parameters.card_label_min, 69 | data_parameters.card_label_max, 70 | mapping, each_node) 71 | # print(card_loss.item(), card_loss_median.item(), card_loss_max.item(), card_max_idx.item()) 72 | loss = cost_loss + card_loss 73 | cost_loss_total += cost_loss.detach().cpu().item() 74 | card_loss_total += card_loss.detach().cpu().item() 75 | loss.backward() 76 | optimizer.step() 77 | # print('Training batch time: ', end - start) 78 | batch_num = train_end - train_start 79 | print("Epoch {}, training cost loss: {}, training card loss: {}".format(epoch, cost_loss_total / batch_num, 80 | card_loss_total / batch_num)) 81 | cost_loss_total_mean = 0. 82 | card_loss_total_mean = 0. 83 | 84 | cost_loss_total_median = 0. 85 | card_loss_total_median = 0. 86 | 87 | cost_loss_total_max = 0. 88 | card_loss_total_max = 0. 89 | model.eval() 90 | if test: 91 | validate_start, validate_end = 0, 1 92 | print("Test:") 93 | for batch_idx in range(validate_start, validate_end): 94 | target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss, sampless, \ 95 | condition_maskss, mapping = get_tree_lstm_batch_data(batch_idx, 0, directory=directory) 96 | 97 | target_cost = torch.FloatTensor(target_cost).to(DEVICE) 98 | target_cardinality = torch.FloatTensor(target_cardinality).to(DEVICE) 99 | operatorss = torch.FloatTensor(operatorss).to(DEVICE).squeeze(0) 100 | extra_infoss = torch.FloatTensor(extra_infoss).to(DEVICE).squeeze(0) 101 | condition1ss = torch.FloatTensor(condition1ss).to(DEVICE).squeeze(0) 102 | condition2ss = torch.FloatTensor(condition2ss).to(DEVICE).squeeze(0) 103 | sampless = torch.FloatTensor(sampless).to(DEVICE).squeeze(0) 104 | condition_maskss = torch.FloatTensor(condition_maskss).to(DEVICE).squeeze(0).unsqueeze(2) 105 | mapping = torch.FloatTensor(mapping).to(DEVICE).squeeze(0) 106 | 107 | estimate_cost, estimate_cardinality = model(operatorss, extra_infoss, condition1ss, condition2ss, sampless, 108 | condition_maskss, mapping) 109 | 110 | cost_loss, cost_loss_median, cost_loss_max = qerror_loss_each_node(estimate_cost, target_cost, 111 | data_parameters.cost_label_min, 112 | data_parameters.cost_label_max, 113 | mapping, each_node) 114 | card_loss, card_loss_median, card_loss_max = qerror_loss_each_node(estimate_cardinality, 115 | target_cardinality, 116 | data_parameters.card_label_min, 117 | data_parameters.card_label_max, 118 | mapping, each_node) 119 | # print(card_loss.item(), card_loss_median.item(), card_loss_max.item()) 120 | cost_loss_total_mean += cost_loss.detach().cpu().item() 121 | card_loss_total_mean += card_loss.detach().cpu().item() 122 | 123 | cost_loss_total_median += cost_loss_median.detach().cpu().item() 124 | card_loss_total_median += card_loss_median.detach().cpu().item() 125 | 126 | cost_loss_total_max += cost_loss_max.detach().cpu().item() 127 | card_loss_total_max += card_loss_max.detach().cpu().item() 128 | 129 | batch_num = validate_end - validate_start 130 | early_stopping((cost_loss_total_mean+card_loss_total_mean) / batch_num, model) 131 | if early_stopping.early_stop: 132 | print('Early stopping') 133 | break 134 | 135 | vali_cost_qerror_list.append(cost_loss_total_mean / batch_num) 136 | vali_card_qerror_list.append(card_loss_total_mean / batch_num) 137 | print("=> Epoch {}, Validating results:".format(epoch)) 138 | print("MEAN: cost q-error: {}, card q-error: {}".format(cost_loss_total_mean / batch_num, 139 | card_loss_total_mean / batch_num)) 140 | print("MEDIAN: cost q-error: {}, card q-error: {}".format(cost_loss_total_median / batch_num, 141 | card_loss_total_median / batch_num)) 142 | print("MAX: cost q-error: {}, card q-error: {}".format(cost_loss_total_max / batch_num, 143 | card_loss_total_max / batch_num)) 144 | end = time.time() 145 | print("===========================================================================") 146 | print(f'total time cost:{end - start}') 147 | plot(vali_card_qerror_list, vali_cost_qerror_list, num_epochs) 148 | return model 149 | 150 | 151 | def train_all_task(train_start, train_end, validate_start, validate_end, num_epochs, data_parameters, model_parameters, directory, patience, lr, card_task, cost_task, join_task, 152 | each_node, test, f, save_model): 153 | """ 154 | Multi-task training on cardinality estimation, cost estimation and join order selection. 155 | """ 156 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 157 | decay_rate = 0.9 158 | decay_point = 5 159 | model = UnifiedModel(model_parameters).to(DEVICE) 160 | if test: 161 | print("Loading model:") 162 | checkpoint_finetune = torch.load('model/jo_checkpoint_bswc_sl.pt') 163 | model_dict = model.state_dict() # state_dict() 164 | pretrained_dict = {k: v for k, v in checkpoint_finetune.items() if k in model_dict} 165 | model_dict.update(pretrained_dict) 166 | model.load_state_dict(model_dict) 167 | 168 | batch_size = 64 169 | teacher_forcing = TeacherForcing(start_tf=1, decay_rate=decay_rate, decay_point=decay_point, cur_epoch=0, end_epoch=50, 170 | verbose=True) 171 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 172 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=4, verbose=True) 173 | early_stopping = EarlyStopping(patience=patience, verbose=True, path=f'model/{save_model}.pt') 174 | criterion = nn.CrossEntropyLoss() 175 | model.train() 176 | start = time.time() 177 | max_corr_num = 0 178 | if test: 179 | num_epochs = 1 180 | for epoch in range(num_epochs): 181 | train_join_order_token_list = [] 182 | train_join_order_seq_list = [] 183 | train_cost_list = [] 184 | train_card_list = [] 185 | train_all_list = [] 186 | print("=====================================") 187 | train_start_time = time.time() 188 | for batch_idx in tqdm.tqdm(range(train_start, train_end)): 189 | if test: 190 | break 191 | join_order_truth, cost_truth, card_truth, operatorss, extra_infoss, condition1ss, condition2ss, \ 192 | sampless, condition_maskss, position_encoding, leaf_node_marker, trans_target, res_mask, adj_matrix = \ 193 | get_trans_batch_data(batch_idx, 1, directory=directory) 194 | 195 | join_order_truth, cost_truth, card_truth, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, \ 196 | position_encoding = torch.LongTensor(join_order_truth).to(DEVICE), torch.FloatTensor(cost_truth).to(DEVICE), \ 197 | torch.FloatTensor(card_truth).to(DEVICE), torch.FloatTensor(operatorss).to(DEVICE), \ 198 | torch.FloatTensor(extra_infoss).to(DEVICE), torch.FloatTensor(condition1ss).to(DEVICE), \ 199 | torch.FloatTensor(condition2ss).to(DEVICE), torch.FloatTensor(sampless).to(DEVICE), \ 200 | torch.FloatTensor(condition_maskss).to(DEVICE), torch.FloatTensor(position_encoding).to(DEVICE) 201 | trans_target = torch.FloatTensor(trans_target).to(DEVICE) 202 | leaf_node_marker = torch.LongTensor(leaf_node_marker).to(DEVICE) 203 | res_mask = torch.FloatTensor(res_mask).to(DEVICE) 204 | adj_matrix = torch.BoolTensor(adj_matrix).to(DEVICE) 205 | 206 | optimizer.zero_grad() 207 | cost_pred, card_pred, join_order_pred = model(operatorss, extra_infoss, condition1ss, condition2ss, 208 | sampless, condition_maskss, position_encoding, 209 | leaf_node_marker, 0, teacher_forcing, trans_target, res_mask, 210 | adj_matrix) 211 | 212 | train_join_order_loss_token = join_order_loss(join_order_pred, join_order_truth, criterion) 213 | train_join_order_loss_seq = model.calculate_sequential_loss(operatorss, extra_infoss, condition1ss, 214 | condition2ss, sampless, 215 | condition_maskss, position_encoding, 216 | leaf_node_marker, trans_target, 217 | res_mask, adj_matrix, join_order_truth) 218 | if cost_task: 219 | train_cost_loss, _, _ = qerror_loss_seq_each_node(cost_pred, cost_truth, data_parameters.cost_label_min, 220 | data_parameters.cost_label_max, each_node) 221 | else: 222 | train_cost_loss = 0. 223 | if card_task: 224 | train_card_loss, _, _ = qerror_loss_seq_each_node(card_pred, card_truth, data_parameters.card_label_min, 225 | data_parameters.card_label_max, each_node) 226 | else: 227 | train_card_loss = 0. 228 | 229 | if not join_task: 230 | train_join_order_loss_token, train_join_order_loss_seq = 0, 0 231 | # dynamic_w.update(train_cost_loss, train_card_loss) 232 | tasks_loss = train_join_order_loss_token + train_card_loss + train_cost_loss \ 233 | + train_join_order_loss_seq 234 | 235 | train_join_order_token_list.append(train_join_order_loss_token) 236 | train_join_order_seq_list.append(train_join_order_loss_seq) 237 | train_cost_list.append(train_cost_loss) 238 | train_card_list.append(train_card_loss) 239 | train_all_list.append(tasks_loss) 240 | tasks_loss.backward() 241 | optimizer.step() 242 | 243 | batch_num = train_end - train_start 244 | train_end_time = time.time() 245 | teacher_forcing.check() 246 | scheduler.step(sum(train_all_list) / batch_num) 247 | print('Training batch time: ', train_end_time - train_start_time) 248 | print("Epoch {}, training all tasks loss: {}".format(epoch, sum(train_all_list) / batch_num)) 249 | print(" training join order token loss: {}".format(sum(train_join_order_token_list) / batch_num)) 250 | print(" training join order seq loss: {}".format(sum(train_join_order_seq_list) / batch_num)) 251 | print(" training cost loss: {}".format(sum(train_cost_list) / batch_num)) 252 | print(" training card loss: {}".format(sum(train_card_list) / batch_num)) 253 | 254 | val_join_order_list = [] 255 | val_cost_list = [] 256 | val_cost_median_list = [] 257 | val_cost_max_list = [] 258 | 259 | val_card_list = [] 260 | val_card_median_list = [] 261 | val_card_max_list = [] 262 | 263 | val_total_list = [] 264 | joeu_list = [] 265 | if not test: 266 | print("Validation: ") 267 | else: 268 | print("Test: ") 269 | cpl_correct_cnt = 0 270 | icpl_correct_cnt = 0 271 | pos_correct_cnt = 0 272 | total_pos_cnt = 0 273 | ills = 0 274 | model.eval() 275 | if test: 276 | validate_start, validate_end = 0, 1 277 | for batch_idx in tqdm.tqdm(range(validate_start, validate_end)): 278 | # print('Test_batch_idx: ', batch_idx) 279 | join_order_truth, cost_truth, card_truth, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, \ 280 | position_encoding, leaf_node_marker, trans_target, res_mask, adj_matrix = get_trans_batch_data(batch_idx, 1, 281 | directory=directory) 282 | 283 | join_order_truth, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, \ 284 | position_encoding = torch.LongTensor(join_order_truth).to(DEVICE), torch.FloatTensor(operatorss).to(DEVICE), \ 285 | torch.FloatTensor(extra_infoss).to(DEVICE), torch.FloatTensor(condition1ss).to(DEVICE), \ 286 | torch.FloatTensor(condition2ss).to(DEVICE), torch.FloatTensor(sampless).to(DEVICE), \ 287 | torch.FloatTensor(condition_maskss).to(DEVICE), torch.FloatTensor(position_encoding).to(DEVICE) 288 | leaf_node_marker = torch.LongTensor(leaf_node_marker).to(DEVICE) 289 | trans_target = torch.FloatTensor(trans_target).to(DEVICE) 290 | res_mask = torch.FloatTensor(res_mask).to(DEVICE) 291 | cost_truth = torch.FloatTensor(cost_truth).to(DEVICE) 292 | card_truth = torch.FloatTensor(card_truth).to(DEVICE) 293 | adj_matrix = torch.BoolTensor(adj_matrix).to(DEVICE) 294 | 295 | # cost_pred, card_pred, join_order_pred = model(operatorss, extra_infoss, condition1ss, condition2ss, sampless, 296 | # condition_maskss, position_encoding, leaf_node_marker, 1, 297 | # teacher_forcing, trans_target, res_mask, adj_matrix) 298 | # c, p, t, ic = beam_search_prediction_compare(join_order_pred, join_order_truth) 299 | # joeu_list.append(0) 300 | 301 | cost_pred, card_pred, _ = model(operatorss, extra_infoss, condition1ss, condition2ss, sampless, 302 | condition_maskss, position_encoding, leaf_node_marker, 0, 303 | teacher_forcing, trans_target, res_mask, adj_matrix) 304 | 305 | if join_task: 306 | join_order_pred, joeu_mean = model.beam_search_test(operatorss, extra_infoss, condition1ss, 307 | condition2ss, sampless, condition_maskss, 308 | position_encoding, leaf_node_marker, 309 | trans_target, res_mask, adj_matrix, 310 | join_order_truth) 311 | joeu_mean = joeu_mean.item() 312 | c, p, t, ic = beam_search_prediction_compare(join_order_pred, join_order_truth) 313 | else: 314 | join_order_pred, joeu_mean = 0, 0 315 | c, p, t, ic = 0, 0, 0, 0 316 | joeu_list.append(joeu_mean) 317 | 318 | cpl_correct_cnt += c 319 | pos_correct_cnt += p 320 | total_pos_cnt += t 321 | icpl_correct_cnt += ic 322 | val_loss = 0 323 | # val_loss = join_order_loss(join_order_pred, join_order_truth, res_mask, criterion) 324 | if cost_task: 325 | val_cost_loss, val_cost_median, val_cost_max = \ 326 | qerror_loss_seq_each_node(cost_pred, cost_truth, data_parameters.cost_label_min, 327 | data_parameters.cost_label_max, each_node) 328 | val_cost_loss, val_cost_median, val_cost_max = val_cost_loss.item(), val_cost_median.item(),\ 329 | val_cost_max.item() 330 | else: 331 | val_cost_loss, val_cost_median, val_cost_max = 0, 0, 0 332 | if card_task: 333 | val_card_loss, val_card_median, val_card_max = \ 334 | qerror_loss_seq_each_node(card_pred, card_truth, data_parameters.card_label_min, 335 | data_parameters.card_label_max, each_node) 336 | val_card_loss, val_card_median, val_card_max = val_card_loss.item(), val_card_median.item(), \ 337 | val_card_max.item() 338 | else: 339 | val_card_loss, val_card_median, val_card_max = 0, 0, 0 340 | tasks_loss = val_card_loss + val_cost_loss 341 | val_join_order_list.append(val_loss) 342 | val_cost_list.append(val_cost_loss) 343 | val_cost_median_list.append(val_cost_median) 344 | val_cost_max_list.append(val_cost_max) 345 | 346 | val_card_list.append(val_card_loss) 347 | val_card_median_list.append(val_card_median) 348 | val_card_max_list.append(val_card_max) 349 | 350 | val_total_list.append(tasks_loss) 351 | 352 | batch_num = validate_end - validate_start 353 | avg_loss = round(sum(val_total_list) / batch_num, 3) 354 | cur_corr_num = cpl_correct_cnt + icpl_correct_cnt 355 | if not test and join_task: 356 | early_stopping(-sum(joeu_list) / batch_num, model) 357 | else: 358 | early_stopping(avg_loss, model) 359 | max_corr_num = max(max_corr_num, cur_corr_num) 360 | 361 | cost_mean = round(sum(val_cost_list) / batch_num, 3) 362 | cost_median = round(sum(val_cost_median_list)/batch_num, 3) 363 | cost_max = round(sum(val_cost_max_list)/batch_num, 3) 364 | 365 | card_mean = round(sum(val_card_list) / batch_num, 3) 366 | card_median = round(sum(val_card_median_list) / batch_num, 3) 367 | card_max = round(sum(val_card_max_list) / batch_num, 3) 368 | print("Epoch {}, validation all tasks loss: {}".format(epoch, avg_loss)) 369 | print(" validation join order loss: {}".format(sum(val_join_order_list) / batch_num)) 370 | if test: 371 | f.write(f"test cost loss: {cost_mean}, median: {cost_median}, max: {cost_max}\n") 372 | f.write(f"test card loss: {card_mean}, median: {card_median}, max: {card_max}\n") 373 | f.write(f"Current the maximum of ccc+ic: {max_corr_num}\n") 374 | f.write(f"JoEU: {sum(joeu_list) / batch_num}\n") 375 | f.write(f"Illegal num: {ills}\n") 376 | print("**************************************************************************") 377 | print(f" validation cost loss: {cost_mean}, median: {cost_median}, max: {cost_max}") 378 | print(f" validation card loss: {card_mean}, median: {card_median}, max: {card_max}") 379 | print(f" Complete correct count: {cpl_correct_cnt}/{batch_num * batch_size}") 380 | print(f" Position correct count: {pos_correct_cnt}/{total_pos_cnt}") 381 | print(f" Incomplete correct count: {icpl_correct_cnt}/{batch_num * batch_size}") 382 | print(f" JoEU: {sum(joeu_list) / batch_num}") 383 | print(f" Illegal num: {ills}") 384 | print(f" Current the maximum of ccc+ic: {max_corr_num}") 385 | 386 | if not test and early_stopping.early_stop: 387 | print('Early stopping') 388 | break 389 | end = time.time() 390 | print("======================================") 391 | print(f'total time cost:{end - start}') 392 | return model 393 | 394 | 395 | if __name__ == '__main__': 396 | parser = argparse.ArgumentParser() 397 | parser.add_argument('--mode', type=str, default="multi-task", help="multi_task: our model, tree-lstm: previous SOTA") 398 | parser.add_argument('--epochs', type=int, default=200, help="the training epochs") 399 | parser.add_argument('--patience', type=int, default=10, help="the patience of early stopping") 400 | parser.add_argument('--lr', type=float, default=1e-4, help="the learning rate") 401 | parser.add_argument('--bw', type=int, default=3, help="the beam width") 402 | parser.add_argument('--tis', type=int, default=128, help="the input size of transformer") 403 | parser.add_argument('--phs', type=int, default=256, help="the hidden size of predicate encoding") 404 | parser.add_argument('--hn', type=int, default=4, help="the head number of transformer") 405 | parser.add_argument('--cln', type=int, default=3, help="the layer number of (en/de)coder") 406 | parser.add_argument('--tdf', type=int, default=64, help="the size of transformer feedforward network") 407 | parser.add_argument('--card', type=int, default=1, help="the cardinality estimation button in multi-task learning") 408 | parser.add_argument('--cost', type=int, default=0, help="the cost estimation button in multi-task learning") 409 | parser.add_argument('--join', type=int, default=0, help="the join order selection button in multi-task learning") 410 | parser.add_argument('--en', type=int, default=0, help="whether considering the cardinality and cost of every node" 411 | " in planing tree") 412 | parser.add_argument('--train_data_path', type=str, default="/mnt/train_data/join_order_each_node", 413 | help="the path of train data") 414 | parser.add_argument('--test_data_path', type=str, default="/mnt/test_data/job", 415 | help="the path of test data") 416 | args = parser.parse_args() 417 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 418 | mode = args.mode 419 | SEED = 666 420 | torch.manual_seed(SEED) 421 | 422 | start_time = time.time() 423 | data_parameters = get_data_parameters(mode) 424 | end_time = time.time() 425 | print(f"Loading data cost time: {end_time-start_time}") 426 | 427 | # train setting: 428 | epochs = args.epochs 429 | patience = args.patience 430 | lr = args.lr 431 | teacher_forcing_set = 1 432 | beam_width = args.bw 433 | 434 | # model set: 435 | input_size = data_parameters.condition_op_dim 436 | trans_input_size = args.tis 437 | cond_hidden_size = args.phs 438 | head_num = args.hn 439 | coder_layer_num = args.cln 440 | plan_pos_size = 18 # tree2dfs 441 | trans_dim_feedforward = args.tdf 442 | max_leaves_num = 7 443 | table_num = 21 444 | pos_flag = 1 445 | attn_flag = 0 446 | 447 | model_parameters = ModelParameters(input_size, trans_input_size, cond_hidden_size, head_num, 448 | coder_layer_num, plan_pos_size, pos_flag, attn_flag, 449 | max_leaves_num, trans_dim_feedforward, beam_width, table_num) 450 | 451 | print("******************************************************") 452 | print(f'learning rate: {lr}') 453 | print(f'trans_input_size: {trans_input_size}') 454 | print(f'head_num: {head_num}') 455 | print(f'coder_layer_num: {coder_layer_num}') 456 | print(f'cond_hidden_size: {cond_hidden_size}') 457 | print(f'trans_dim_feedforward: {trans_dim_feedforward}') 458 | print(f'beam_width: {beam_width}') 459 | print("******************************************************") 460 | 461 | if mode == "multi-task": 462 | data_path = args.train_data_path 463 | total_files_num = 100 464 | train_start = 0 465 | train_end = int(total_files_num*0.9) 466 | val_start = int(total_files_num*0.9) 467 | val_end = total_files_num 468 | cost_flag = args.cost 469 | card_flag = args.card 470 | join_flag = args.join 471 | each_node = args.en 472 | print(f"cost: {cost_flag}, cardinality: {card_flag}, join order: {join_flag}") 473 | save_model = card_flag*"Card" + cost_flag*"Cost" + join_flag*"Join" + "Trans" 474 | with open("joint_reuslt_log.txt", "w") as f: 475 | for test in [0, 1]: 476 | if test: 477 | data_path = args.test_data_path 478 | train_all_task(train_start, train_end, val_start, val_end, epochs, data_parameters, model_parameters, 479 | data_path, patience, lr, card_flag, cost_flag, join_flag, each_node, test, f, save_model) 480 | elif mode == "tree-lstm": 481 | data_path = args.train_data_path 482 | total_files_num = 295 483 | train_start = 0 484 | train_end = int(total_files_num*0.9) 485 | val_start = int(total_files_num*0.9) 486 | val_end = total_files_num 487 | each_node = 0 488 | save_model = "TreeLSTM" 489 | for test in [0, 1]: 490 | if test: 491 | data_path = args.test_data_path 492 | train_lstm(train_start, train_end, val_start, val_end, epochs, patience, lr, data_parameters, 493 | model_parameters, data_path, mode, each_node, test, save_model) -------------------------------------------------------------------------------- /tools/pytool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | 5 | 6 | class EarlyStopping: 7 | '''Early stop the training if validation loss doesn't improve after 8 | a given patience''' 9 | 10 | def __init__(self, patience=7, verbose=False, delta=0, path='model/checkpoint.pt'): 11 | self.patience = patience 12 | self.verbose = verbose 13 | self.counter = 0 14 | self.best_score = None 15 | self.early_stop = False 16 | self.val_loss_min = np.Inf 17 | self.delta = delta 18 | self.path = path 19 | 20 | def __call__(self, val_loss, model): 21 | if val_loss > self.val_loss_min + self.delta: 22 | self.counter += 1 23 | print(f'EarlyStopping Counter: {self.counter} out of {self.patience}') 24 | if self.counter >= self.patience: 25 | self.early_stop = True 26 | else: 27 | self.save_checkpoint(val_loss, model) 28 | self.counter = 0 29 | 30 | def save_checkpoint(self, val_loss, model): 31 | '''save model when validation loss decreas''' 32 | if self.verbose: 33 | print(f'validation loss decrease ({self.val_loss_min:.6f}) ---> ({val_loss:.6f})') 34 | torch.save(model.state_dict(), self.path) 35 | self.val_loss_min = val_loss 36 | 37 | 38 | class TeacherForcing: 39 | def __init__(self, start_tf, decay_rate, decay_point, cur_epoch, end_epoch, verbose): 40 | self.decay_rate = decay_rate 41 | self.decay_point = decay_point 42 | self.cur_epoch = cur_epoch 43 | self.end_epoch = end_epoch 44 | self.tf = start_tf 45 | self.verbose = verbose 46 | 47 | def check(self): 48 | if self.cur_epoch < self.end_epoch: 49 | self.cur_epoch += 1 50 | if self.cur_epoch % self.decay_point == 0: 51 | if self.verbose: 52 | print(f"Teacher forcing decrease {self.tf:.6f} ----> {self.tf*self.decay_rate:.6f}") 53 | self.tf *= self.decay_rate 54 | else: 55 | self.tf = 0 56 | 57 | 58 | class DynamicWeight: 59 | def __init__(self, n_tasks, temper): 60 | self.n = n_tasks 61 | self.w = [1 for i in range(self.n)] 62 | self.temper = temper 63 | self.loss1 = [0 for i in range(self.n)] 64 | self.loss2 = [0 for i in range(self.n)] 65 | 66 | def update(self, *input_loss): 67 | if sum(self.loss1) == 0: 68 | self.loss1 = [loss for loss in input_loss] 69 | self.loss2 = [loss for loss in input_loss] 70 | else: 71 | self.loss1 = [loss for loss in self.loss2] 72 | self.loss2 = [loss for loss in input_loss] # current loss 73 | 74 | r = [self.loss2[i]/self.loss1[i] for i in range(self.n)] 75 | coe = self.n / sum([math.exp(x/self.temper) for x in r]) 76 | self.w = [coe * math.exp(i/self.temper) for i in r] 77 | 78 | 79 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def unnormalize(vecs, mini, maxi): 8 | return torch.exp(vecs * (maxi - mini) + mini) 9 | 10 | 11 | def plot(vali_card_qerror_list, vali_cost_qerror_list, num_epochs): 12 | 13 | fig = plt.figure(figsize=(20, 10)) 14 | ax1 = plt.subplot(211) 15 | plt.xlabel("epoch") 16 | plt.ylabel("card_qerror") 17 | plt.grid(axis='y') 18 | for a, b in zip([i for i in range(num_epochs)], vali_card_qerror_list): 19 | plt.text(a, b, '%.2f' % b, ha='center', va='bottom', fontsize=9) 20 | plt.plot(vali_card_qerror_list, color='blue', linestyle='-', marker='o') 21 | 22 | ax2 = plt.subplot(212) 23 | plt.xlabel("epoch") 24 | plt.ylabel("cost_qerror") 25 | plt.grid(axis='y') 26 | for a, b in zip([i for i in range(num_epochs)], vali_cost_qerror_list): 27 | plt.text(a, b, '%.2f' % b, ha='center', va='bottom', fontsize=9) 28 | plt.plot(vali_cost_qerror_list, color='blue', linestyle='-', marker='o') 29 | plt.savefig('./test_pic.jpg') 30 | plt.show() 31 | 32 | 33 | def qerror_loss(preds, targets, mini, maxi): 34 | qerror = [] 35 | preds = unnormalize(preds, mini, maxi) # 解归一化 36 | targets = unnormalize(targets, mini, maxi) 37 | for i in range(len(targets)): 38 | if preds[i] > targets[i]: 39 | qerror.append(preds[i] / targets[i]) 40 | else: 41 | qerror.append(targets[i] / preds[i]) 42 | return torch.mean(torch.cat(qerror)), torch.median(torch.cat(qerror)), torch.max(torch.cat(qerror)) 43 | 44 | 45 | def qerror_loss_each_node(preds, targets, mini, maxi, mapping, validation): 46 | # (14, 133), (14, 133), (14, 133, 2) 47 | qerror = [] 48 | preds = unnormalize(preds, mini, maxi) # 解归一化 49 | targets = unnormalize(targets, mini, maxi) 50 | if not validation: 51 | for level in range(len(mapping)-1): 52 | for w in range(len(mapping[0])): 53 | left, right = mapping[level][w] 54 | if left != 0: 55 | qerror.append(max(preds[level+1][left-1], targets[level+1][left-1])/min(preds[level+1][left-1], targets[level+1][left-1])) 56 | if right != 0: 57 | qerror.append( 58 | max(preds[level+1][right-1], targets[level+1][right-1]) / min(preds[level+1][right-1], targets[level+1][right-1])) 59 | for i in range(64): 60 | qerror.append(max(preds[0][i], targets[0][i]) / min(preds[0][i], targets[0][i])) 61 | return torch.mean(torch.cat(qerror)), torch.median(torch.cat(qerror)), torch.max(torch.cat(qerror)) 62 | 63 | 64 | def qerror_loss_seq_each_node(preds, targets, mini, maxi, each_node): 65 | # (64, 25, 1), (64, 25) 66 | qerror = [] 67 | preds = unnormalize(preds, mini, maxi) # 解归一化 68 | # print(preds[:, 0].cpu().detach().numpy().reshape(-1).tolist()) 69 | if len(targets.shape) == 1: 70 | targets = targets.unsqueeze(1) 71 | for i in range(len(targets)): 72 | for j in range(len(targets[0])): 73 | if targets[i][j] == -1: 74 | targets[i][j] = -float("inf") 75 | 76 | targets = unnormalize(targets, mini, maxi) 77 | if each_node: 78 | for batch_id in range(len(preds)): 79 | for i in range(len(preds[0])): 80 | if targets[batch_id][i] != 0: 81 | qerror.append(max(preds[batch_id][i], targets[batch_id][i]) / min(preds[batch_id][i], 82 | targets[batch_id][i])) 83 | else: 84 | for i in range(len(targets)): 85 | if targets[i][0] != -1: 86 | qerror.append(max(preds[i][0], targets[i][0]) / min(preds[i][0], targets[i][0])) 87 | return torch.mean(torch.cat(qerror)), torch.median(torch.cat(qerror)), torch.max(torch.cat(qerror)) 88 | 89 | 90 | def join_order_loss(pred, ground_truth, cel): 91 | loss_list = [] 92 | batch_size, cur_batch_nodes, position_size = pred.shape 93 | for i in range(batch_size): 94 | pd_idx = 0 95 | while pd_idx < len(ground_truth[i]) and ground_truth[i][pd_idx] != -1: 96 | pd_idx += 1 97 | loss_list.append(cel(pred[i][:pd_idx], ground_truth[i][:pd_idx])) 98 | return sum(loss_list) / len(loss_list) 99 | 100 | 101 | def beam_search_prediction_compare(pred, ground_truth): 102 | pred = pred.cpu().numpy() 103 | ground_truth = ground_truth.cpu().numpy() 104 | batch_size = pred.shape[0] 105 | complete_correct_cnt = 0 106 | position_correct_cnt = 0 107 | incomplete_correct_cnt = 0 108 | pd_cnt = 0 109 | for i in range(batch_size): 110 | pd_idx = 0 111 | while pd_idx < len(ground_truth[i]) and ground_truth[i][pd_idx] != -1: 112 | pd_idx += 1 113 | pred_idx_list = pred[i, :pd_idx] 114 | gt_idx_list = ground_truth[i, :pd_idx] 115 | assert len(set(pred_idx_list)) == pd_idx 116 | pd_cnt += pd_idx 117 | cur_corr = (pred_idx_list == gt_idx_list).sum() 118 | if pred_idx_list[0] == gt_idx_list[1] and cur_corr == pd_idx - 2: 119 | incomplete_correct_cnt += 1 120 | 121 | position_correct_cnt += cur_corr 122 | if cur_corr == pd_idx: 123 | complete_correct_cnt += 1 124 | return complete_correct_cnt, position_correct_cnt, pd_cnt, incomplete_correct_cnt 125 | 126 | 127 | def gen_random_seq(num_table, adj_matrix): 128 | # (10, 10) 129 | unseen = [_ for _ in range(num_table)] 130 | start = random.choice(unseen) 131 | unseen.remove(start) 132 | res = [start] 133 | cur = adj_matrix[start] 134 | 135 | for i in range(len(unseen)): 136 | rd = random.choice(unseen) 137 | while cur[rd] == 0: 138 | rd = random.choice(unseen) 139 | res.append(rd) 140 | unseen.remove(rd) 141 | cur = cur | adj_matrix[rd] 142 | return res 143 | 144 | 145 | def random_prediction(ground_truth, adj_matrix): 146 | num_tables = len(ground_truth[0]) 147 | ground_truth = ground_truth.detach().cpu().numpy() 148 | adj_matrix = adj_matrix.detach().cpu().numpy() 149 | batch_size = ground_truth.shape[0] 150 | complete_correct_cnt = 0 151 | position_correct_cnt = 0 152 | incomplete_correct_cnt = 0 153 | pd_cnt = 0 154 | res_pred = [] 155 | for i in range(batch_size): 156 | pd_idx = 0 157 | while pd_idx < len(ground_truth[i]) and ground_truth[i][pd_idx] != -1: 158 | pd_idx += 1 159 | pred_idx_list = gen_random_seq(pd_idx, adj_matrix[i]) 160 | gt_idx_list = ground_truth[i, :pd_idx] 161 | 162 | assert set(pred_idx_list) == set(gt_idx_list) 163 | pred_idx_record = pred_idx_list + [-1 for _ in range(num_tables - pd_idx)] 164 | res_pred.append(pred_idx_record) 165 | # print(pred_idx_list, gt_idx_list) 166 | pd_cnt += pd_idx 167 | cur_corr = (pred_idx_list == gt_idx_list).sum() 168 | if pred_idx_list[0] == gt_idx_list[1] and cur_corr == pd_idx - 2: 169 | incomplete_correct_cnt += 1 170 | 171 | position_correct_cnt += cur_corr 172 | if cur_corr == pd_idx: 173 | complete_correct_cnt += 1 174 | return complete_correct_cnt, position_correct_cnt, pd_cnt, incomplete_correct_cnt, res_pred 175 | 176 | 177 | def output_file(pred, ground_truth): 178 | # (64, 10, 10) (64, 10) 179 | pred = pred.detach().cpu().numpy() 180 | ground_truth = ground_truth.detach().cpu().numpy() 181 | 182 | out_list = [] 183 | num_tables = len(ground_truth[0]) 184 | pred = np.argmax(pred, axis=2) 185 | batch_size = pred.shape[0] 186 | for i in range(batch_size): 187 | pd_idx = 0 188 | while pd_idx < num_tables and ground_truth[i][pd_idx] != -1: 189 | pd_idx += 1 190 | pred_idx_list = pred[i, :pd_idx].tolist() + [-1 for _ in range(num_tables - pd_idx)] 191 | out_list.append(pred_idx_list) 192 | return out_list 193 | 194 | 195 | def prediction_compare(pred, ground_truth): 196 | # (64, 10, 10) (64, 10) 197 | pred = pred.detach().cpu().numpy() 198 | ground_truth = ground_truth.detach().cpu().numpy() 199 | 200 | pred = np.argmax(pred, axis=2) 201 | batch_size = pred.shape[0] 202 | complete_correct_cnt = 0 203 | position_correct_cnt = 0 204 | incomplete_correct_cnt = 0 205 | pd_cnt = 0 206 | for i in range(batch_size): 207 | pd_idx = 0 208 | while pd_idx < len(ground_truth[i]) and ground_truth[i][pd_idx] != -1: 209 | pd_idx += 1 210 | pred_idx_list = pred[i, :pd_idx] 211 | gt_idx_list = ground_truth[i, :pd_idx] 212 | # print(pred_idx_list, gt_idx_list) 213 | assert set(pred_idx_list) == set(gt_idx_list) 214 | pd_cnt += pd_idx 215 | cur_corr = (pred_idx_list == gt_idx_list).sum() 216 | if pred_idx_list[0] == gt_idx_list[1] and cur_corr == pd_idx - 2: 217 | incomplete_correct_cnt += 1 218 | 219 | position_correct_cnt += cur_corr 220 | if cur_corr == pd_idx: 221 | complete_correct_cnt += 1 222 | return complete_correct_cnt, position_correct_cnt, pd_cnt, incomplete_correct_cnt 223 | 224 | -------------------------------------------------------------------------------- /torch_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuziniu/unified_model/fe80b4001af9b71b2de2473bcc971028411bca72/torch_test.py -------------------------------------------------------------------------------- /unified-model-v1.yml: -------------------------------------------------------------------------------- 1 | name: unified-model 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 5 | - conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1 12 | - aiohttp=3.7.3 13 | - async-timeout=3.0.1 14 | - attrs=20.3.0 15 | - blas=1.0 16 | - boto=2.49.0 17 | - boto3=1.16.53 18 | - botocore=1.19.53 19 | - brotlipy=0.7.0 20 | - bz2file=0.98 21 | - bzip2=1.0.8 22 | - c-ares=1.17.1 23 | - ca-certificates=2021.1.19 24 | - cachetools=4.1.1 25 | - certifi=2020.12.5 26 | - cffi=1.14.4 27 | - chardet=3.0.4 28 | - cryptography=3.3.1 29 | - cudatoolkit=10.2.89 30 | - cycler=0.10.0 31 | - dbus=1.13.18 32 | - expat=2.2.10 33 | - ffmpeg=4.3 34 | - fontconfig=2.13.0 35 | - freetype=2.10.4 36 | - gensim=3.8.3 37 | - glib=2.66.1 38 | - gmp=6.2.1 39 | - gnutls=3.6.5 40 | - google-api-core=1.22.2 41 | - google-auth=1.24.0 42 | - google-cloud-core=1.5.0 43 | - google-cloud-storage=1.19.0 44 | - google-resumable-media=1.2.0 45 | - googleapis-common-protos=1.52.0 46 | - grpcio=1.33.2 47 | - gst-plugins-base=1.14.0 48 | - gstreamer=1.14.0 49 | - icu=58.2 50 | - idna=2.10 51 | - intel-openmp=2020.2 52 | - jmespath=0.10.0 53 | - jpeg=9b 54 | - kiwisolver=1.3.0 55 | - lame=3.100 56 | - lcms2=2.11 57 | - ld_impl_linux-64=2.33.1 58 | - libedit=3.1.20191231 59 | - libffi=3.3 60 | - libgcc-ng=9.1.0 61 | - libgfortran-ng=7.5.0 62 | - libgfortran4=7.5.0 63 | - libiconv=1.15 64 | - libpng=1.6.37 65 | - libprotobuf=3.13.0.1 66 | - libstdcxx-ng=9.1.0 67 | - libtiff=4.1.0 68 | - libuuid=1.0.3 69 | - libuv=1.40.0 70 | - libxcb=1.14 71 | - libxml2=2.9.10 72 | - lz4-c=1.9.2 73 | - matplotlib=3.3.2 74 | - matplotlib-base=3.3.2 75 | - mkl=2020.2 76 | - mkl-service=2.3.0 77 | - mkl_fft=1.2.0 78 | - mkl_random=1.1.1 79 | - multidict=5.1.0 80 | - ncurses=6.2 81 | - nettle=3.4.1 82 | - ninja=1.10.2 83 | - olefile=0.46 84 | - openh264=2.1.0 85 | - openssl=1.1.1k 86 | - pandas=1.2.0 87 | - pcre=8.44 88 | - pillow=8.1.0 89 | - pip=20.3.3 90 | - ply=3.11 91 | - protobuf=3.13.0.1 92 | - pyasn1=0.4.8 93 | - pyasn1-modules=0.2.7 94 | - pycparser=2.20 95 | - pyopenssl=20.0.1 96 | - pyparsing=2.4.7 97 | - pyqt=5.9.2 98 | - pysocks=1.7.1 99 | - python=3.8.5 100 | - python-dateutil=2.8.1 101 | - python_abi=3.8 102 | - pytorch=1.8.1 103 | - pytz=2020.5 104 | - qt=5.9.7 105 | - readline=8.0 106 | - requests=2.25.1 107 | - rsa=4.7 108 | - s3transfer=0.3.4 109 | - scipy=1.5.2 110 | - setuptools=51.1.2 111 | - sip=4.19.13 112 | - six=1.15.0 113 | - smart_open=4.1.0 114 | - sqlite=3.33.0 115 | - tk=8.6.10 116 | - torchaudio=0.8.1 117 | - torchvision=0.9.1 118 | - tornado=6.1 119 | - tqdm=4.55.1 120 | - typing-extensions=3.7.4.3 121 | - typing_extensions=3.7.4.3 122 | - urllib3=1.26.2 123 | - wheel=0.36.2 124 | - xz=5.2.5 125 | - yarl=1.6.3 126 | - zlib=1.2.11 127 | - zstd=1.4.5 128 | - pip: 129 | - argon2-cffi==20.1.0 130 | - async-generator==1.10 131 | - backcall==0.2.0 132 | - bleach==3.3.0 133 | - decorator==4.4.2 134 | - defusedxml==0.6.0 135 | - entrypoints==0.3 136 | - iniconfig==1.1.1 137 | - ipykernel==5.5.0 138 | - ipython==7.21.0 139 | - ipython-genutils==0.2.0 140 | - jedi==0.18.0 141 | - jinja2==2.11.3 142 | - joblib==1.0.1 143 | - jsonschema==3.2.0 144 | - jupyter-client==6.1.11 145 | - jupyter-core==4.7.1 146 | - jupyterlab-pygments==0.1.2 147 | - littleutils==0.2.2 148 | - llvmlite==0.36.0 149 | - markupsafe==1.1.1 150 | - mistune==0.8.4 151 | - nbclient==0.5.3 152 | - nbconvert==6.0.7 153 | - nbformat==5.1.2 154 | - nest-asyncio==1.5.1 155 | - networkx==2.5.1 156 | - notebook==6.2.0 157 | - numba==0.53.1 158 | - numexpr==2.7.3 159 | - numpy==1.20.3 160 | - ogb==1.3.1 161 | - outdated==0.2.1 162 | - packaging==20.9 163 | - pandocfilters==1.4.3 164 | - parso==0.8.1 165 | - patsy==0.5.1 166 | - pexpect==4.8.0 167 | - pgmpy==0.1.14 168 | - pickleshare==0.7.5 169 | - pluggy==0.13.1 170 | - pomegranate==0.14.5 171 | - prometheus-client==0.9.0 172 | - prompt-toolkit==3.0.16 173 | - psycopg2==2.8.6 174 | - ptyprocess==0.7.0 175 | - py==1.10.0 176 | - pygments==2.8.0 177 | - pyrsistent==0.17.3 178 | - pytest==6.2.4 179 | - pyyaml==5.4.1 180 | - pyzmq==22.0.3 181 | - scikit-learn==0.24.1 182 | - send2trash==1.5.0 183 | - sqlparse==0.4.1 184 | - statsmodels==0.12.2 185 | - tables==3.6.1 186 | - terminado==0.9.2 187 | - testpath==0.4.4 188 | - threadpoolctl==2.1.0 189 | - toml==0.10.2 190 | - traitlets==5.0.5 191 | - wcwidth==0.2.5 192 | - webencodings==0.5.1 193 | -------------------------------------------------------------------------------- /unified-model.yml: -------------------------------------------------------------------------------- 1 | name: unified-model 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 5 | - conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - aiohttp=3.7.3=py38h25fe258_0 13 | - async-timeout=3.0.1=py_1000 14 | - attrs=20.3.0=pyhd3deb0d_0 15 | - blas=1.0=mkl 16 | - boto=2.49.0=py_0 17 | - boto3=1.16.53=pyhd8ed1ab_0 18 | - botocore=1.19.53=pyhd8ed1ab_0 19 | - brotlipy=0.7.0=py38h8df0ef7_1001 20 | - bz2file=0.98=py_0 21 | - bzip2=1.0.8=h7b6447c_0 22 | - c-ares=1.17.1=h36c2ea0_0 23 | - ca-certificates=2021.1.19=h06a4308_1 24 | - cachetools=4.1.1=py_0 25 | - certifi=2020.12.5=py38h06a4308_0 26 | - cffi=1.14.4=py38h261ae71_0 27 | - chardet=3.0.4=py38h924ce5b_1008 28 | - cryptography=3.3.1=py38h3c74f83_0 29 | - cudatoolkit=10.2.89=hfd86e86_1 30 | - cycler=0.10.0=py38_0 31 | - dbus=1.13.18=hb2f20db_0 32 | - expat=2.2.10=he6710b0_2 33 | - ffmpeg=4.3=hf484d3e_0 34 | - fontconfig=2.13.0=h9420a91_0 35 | - freetype=2.10.4=h5ab3b9f_0 36 | - gensim=3.8.3=py38h950e882_2 37 | - glib=2.66.1=h92f7085_0 38 | - gmp=6.2.1=h2531618_2 39 | - gnutls=3.6.5=h71b1129_1002 40 | - google-api-core=1.22.2=py38h32f6830_0 41 | - google-auth=1.24.0=pyhd3deb0d_0 42 | - google-cloud-core=1.5.0=pyhd3deb0d_0 43 | - google-cloud-storage=1.19.0=py_0 44 | - google-resumable-media=1.2.0=pyhd3eb1b0_1 45 | - googleapis-common-protos=1.52.0=py38h578d9bd_1 46 | - grpcio=1.33.2=py38heead2fc_2 47 | - gst-plugins-base=1.14.0=h8213a91_2 48 | - gstreamer=1.14.0=h28cd5cc_2 49 | - icu=58.2=he6710b0_3 50 | - idna=2.10=pyh9f0ad1d_0 51 | - intel-openmp=2020.2=254 52 | - jmespath=0.10.0=pyh9f0ad1d_0 53 | - jpeg=9b=h024ee3a_2 54 | - kiwisolver=1.3.0=py38h2531618_0 55 | - lame=3.100=h7b6447c_0 56 | - lcms2=2.11=h396b838_0 57 | - ld_impl_linux-64=2.33.1=h53a641e_7 58 | - libedit=3.1.20191231=h14c3975_1 59 | - libffi=3.3=he6710b0_2 60 | - libgcc-ng=9.1.0=hdf63c60_0 61 | - libgfortran-ng=7.5.0=hae1eefd_17 62 | - libgfortran4=7.5.0=hae1eefd_17 63 | - libiconv=1.15=h63c8f33_5 64 | - libpng=1.6.37=hbc83047_0 65 | - libprotobuf=3.13.0.1=h8b12597_0 66 | - libstdcxx-ng=9.1.0=hdf63c60_0 67 | - libtiff=4.1.0=h2733197_1 68 | - libuuid=1.0.3=h1bed415_2 69 | - libuv=1.40.0=h7b6447c_0 70 | - libxcb=1.14=h7b6447c_0 71 | - libxml2=2.9.10=hb55368b_3 72 | - lz4-c=1.9.2=heb0550a_3 73 | - matplotlib=3.3.2=h06a4308_0 74 | - matplotlib-base=3.3.2=py38h817c723_0 75 | - mkl=2020.2=256 76 | - mkl-service=2.3.0=py38he904b0f_0 77 | - mkl_fft=1.2.0=py38h23d657b_0 78 | - mkl_random=1.1.1=py38h0573a6f_0 79 | - multidict=5.1.0=py38h27cfd23_2 80 | - ncurses=6.2=he6710b0_1 81 | - nettle=3.4.1=hbb512f6_0 82 | - ninja=1.10.2=py38hff7bd54_0 83 | - olefile=0.46=py_0 84 | - openh264=2.1.0=hd408876_0 85 | - openssl=1.1.1k=h27cfd23_0 86 | - pandas=1.2.0=py38ha9443f7_0 87 | - pcre=8.44=he6710b0_0 88 | - pillow=8.1.0=py38he98fc37_0 89 | - pip=20.3.3=py38h06a4308_0 90 | - ply=3.11=py38_0 91 | - protobuf=3.13.0.1=py38he6710b0_1 92 | - pyasn1=0.4.8=py_0 93 | - pyasn1-modules=0.2.7=py_0 94 | - pycparser=2.20=pyh9f0ad1d_2 95 | - pyopenssl=20.0.1=pyhd8ed1ab_0 96 | - pyparsing=2.4.7=py_0 97 | - pyqt=5.9.2=py38h05f1152_4 98 | - pysocks=1.7.1=py38h578d9bd_3 99 | - python=3.8.5=h7579374_1 100 | - python-dateutil=2.8.1=py_0 101 | - python_abi=3.8=1_cp38 102 | - pytorch=1.8.1=py3.8_cuda10.2_cudnn7.6.5_0 103 | - pytz=2020.5=pyhd3eb1b0_0 104 | - qt=5.9.7=h5867ecd_1 105 | - readline=8.0=h7b6447c_0 106 | - requests=2.25.1=pyhd3deb0d_0 107 | - rsa=4.7=pyhd3deb0d_0 108 | - s3transfer=0.3.4=pyhd8ed1ab_0 109 | - scipy=1.5.2=py38h0b6359f_0 110 | - setuptools=51.1.2=py38h06a4308_4 111 | - sip=4.19.13=py38he6710b0_0 112 | - six=1.15.0=py38h06a4308_0 113 | - smart_open=4.1.0=pyhd8ed1ab_0 114 | - sqlite=3.33.0=h62c20be_0 115 | - tk=8.6.10=hbc83047_0 116 | - torchaudio=0.8.1=py38 117 | - torchvision=0.9.1=py38_cu102 118 | - tornado=6.1=py38h27cfd23_0 119 | - tqdm=4.55.1=pyhd3eb1b0_0 120 | - typing-extensions=3.7.4.3=0 121 | - typing_extensions=3.7.4.3=py_0 122 | - urllib3=1.26.2=pyhd8ed1ab_0 123 | - wheel=0.36.2=pyhd3eb1b0_0 124 | - xz=5.2.5=h7b6447c_0 125 | - yarl=1.6.3=py38h25fe258_0 126 | - zlib=1.2.11=h7b6447c_3 127 | - zstd=1.4.5=h9ceee32_0 128 | - pip: 129 | - argon2-cffi==20.1.0 130 | - async-generator==1.10 131 | - backcall==0.2.0 132 | - bleach==3.3.0 133 | - decorator==4.4.2 134 | - defusedxml==0.6.0 135 | - entrypoints==0.3 136 | - iniconfig==1.1.1 137 | - ipykernel==5.5.0 138 | - ipython==7.21.0 139 | - ipython-genutils==0.2.0 140 | - jedi==0.18.0 141 | - jinja2==2.11.3 142 | - joblib==1.0.1 143 | - jsonschema==3.2.0 144 | - jupyter-client==6.1.11 145 | - jupyter-core==4.7.1 146 | - jupyterlab-pygments==0.1.2 147 | - littleutils==0.2.2 148 | - llvmlite==0.36.0 149 | - markupsafe==1.1.1 150 | - mistune==0.8.4 151 | - nbclient==0.5.3 152 | - nbconvert==6.0.7 153 | - nbformat==5.1.2 154 | - nest-asyncio==1.5.1 155 | - networkx==2.5.1 156 | - notebook==6.2.0 157 | - numba==0.53.1 158 | - numexpr==2.7.3 159 | - numpy==1.20.3 160 | - ogb==1.3.1 161 | - outdated==0.2.1 162 | - packaging==20.9 163 | - pandocfilters==1.4.3 164 | - parso==0.8.1 165 | - patsy==0.5.1 166 | - pexpect==4.8.0 167 | - pgmpy==0.1.14 168 | - pickleshare==0.7.5 169 | - pluggy==0.13.1 170 | - pomegranate==0.14.5 171 | - prometheus-client==0.9.0 172 | - prompt-toolkit==3.0.16 173 | - psycopg2==2.8.6 174 | - ptyprocess==0.7.0 175 | - py==1.10.0 176 | - pygments==2.8.0 177 | - pyrsistent==0.17.3 178 | - pytest==6.2.4 179 | - pyyaml==5.4.1 180 | - pyzmq==22.0.3 181 | - scikit-learn==0.24.1 182 | - send2trash==1.5.0 183 | - sqlparse==0.4.1 184 | - statsmodels==0.12.2 185 | - tables==3.6.1 186 | - terminado==0.9.2 187 | - testpath==0.4.4 188 | - threadpoolctl==2.1.0 189 | - toml==0.10.2 190 | - traitlets==5.0.5 191 | - wcwidth==0.2.5 192 | - webencodings==0.5.1 193 | -------------------------------------------------------------------------------- /vector_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | def get_tree_lstm_batch_data(batch_id, seq, directory): 6 | target_cost_batch = np.load(directory+'/target_cost_'+str(batch_id)+'.np.npy') 7 | target_cardinality_batch = np.load(directory+'/target_cardinality_'+str(batch_id)+'.np.npy') 8 | operators_batch = np.load(directory+'/operators_'+str(batch_id)+'.np.npy') 9 | extra_infos_batch = np.load(directory+'/extra_infos_'+str(batch_id)+'.np.npy') 10 | condition1s_batch = np.load(directory+'/condition1s_'+str(batch_id)+'.np.npy') 11 | condition2s_batch = np.load(directory+'/condition2s_'+str(batch_id)+'.np.npy') 12 | samples_batch = np.load(directory+'/samples_'+str(batch_id)+'.np.npy') 13 | condition_masks_batch = np.load(directory+'/condition_masks_'+str(batch_id)+'.np.npy') 14 | 15 | if seq: 16 | mapping_batch = np.load(directory+'/position_encoding_'+str(batch_id)+'.np.npy') 17 | else: 18 | mapping_batch = np.load(directory+'/mapping_'+str(batch_id)+'.np.npy') 19 | return target_cost_batch, target_cardinality_batch, operators_batch, extra_infos_batch, condition1s_batch,\ 20 | condition2s_batch, samples_batch, condition_masks_batch, mapping_batch 21 | 22 | 23 | def get_trans_batch_data(batch_id, seq, directory): 24 | target_cost_batch = np.load(directory + '/target_cost_'+ str(batch_id) + '.np.npy') 25 | target_cardinality_batch = np.load(directory + '/target_cardinality_' + str(batch_id) + '.np.npy') 26 | join_order_truth = np.load(directory+'/join_order_'+str(batch_id)+'.np.npy') 27 | trans_target = np.load(directory+'/trans_target_'+str(batch_id)+'.np.npy') 28 | operators_batch = np.load(directory+'/operators_'+str(batch_id)+'.np.npy') 29 | extra_infos_batch = np.load(directory+'/extra_infos_'+str(batch_id)+'.np.npy') 30 | condition1s_batch = np.load(directory+'/condition1s_'+str(batch_id)+'.np.npy') 31 | condition2s_batch = np.load(directory+'/condition2s_'+str(batch_id)+'.np.npy') 32 | samples_batch = np.load(directory+'/samples_'+str(batch_id)+'.np.npy') 33 | condition_masks_batch = np.load(directory+'/condition_masks_'+str(batch_id)+'.np.npy') 34 | leaf_node_marker = np.load(directory+'/leaf_node_marker_'+str(batch_id)+'.np.npy') 35 | res_mask = np.load(directory+'/res_mask_'+str(batch_id)+'.np.npy') 36 | adj_matrix = np.load(directory+'/join_order/adj_matrix_'+str(batch_id)+'.np.npy') 37 | 38 | if seq: 39 | mapping_batch = np.load(directory+'/position_encoding_'+str(batch_id)+'.np.npy') 40 | else: 41 | mapping_batch = np.load(directory+'/mapping_'+str(batch_id)+'.np.npy') 42 | return join_order_truth, target_cost_batch, target_cardinality_batch, operators_batch, extra_infos_batch, \ 43 | condition1s_batch, condition2s_batch, samples_batch, condition_masks_batch, mapping_batch, leaf_node_marker, \ 44 | trans_target, res_mask, adj_matrix 45 | 46 | 47 | def get_batch_meta_learner_iterator(db_list, shuffle, seed, suffix, batch_num, test, directory="/mnt/train_data/meta_learner"): 48 | tuples = [] 49 | random.seed(seed) 50 | 51 | for db_id in db_list: 52 | for batch_id in range(batch_num): 53 | tuples.append((db_id, batch_id)) 54 | 55 | if shuffle: 56 | random.shuffle(tuples) 57 | 58 | prefix = "test_data" if test else "train_data" 59 | 60 | for db_id, batch_id in tuples: 61 | ground_truth_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/ground_truth_{batch_id}.npy", allow_pickle=True) 62 | agg_matrix_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/agg_matrix_{batch_id}.npy", allow_pickle=True) 63 | attn_mask_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/attn_mask_{batch_id}.npy", allow_pickle=True) 64 | trans_target_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/trans_target_{batch_id}.npy", allow_pickle=True) 65 | feature_encoding_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/feature_encoding_{batch_id}.npy", allow_pickle=True) 66 | res_mask_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/res_mask_{batch_id}.npy", allow_pickle=True) 67 | adj_matrix_batch = np.load(f"{directory}/DB{db_id}/{prefix}{suffix}/adj_matrix_{batch_id}.npy", allow_pickle=True) 68 | yield (ground_truth_batch, agg_matrix_batch, attn_mask_batch, trans_target_batch, feature_encoding_batch, 69 | res_mask_batch, adj_matrix_batch) --------------------------------------------------------------------------------