├── .gitignore ├── README.md ├── appendix.pdf ├── bounds.py ├── checkPrecisionRecall.py ├── cmd.sh ├── discovery.py ├── extractVectors.py ├── hnsw_cmd.sh ├── hnsw_search.py ├── lsh.py ├── lsh_cmd.sh ├── lsh_search.py ├── naive_search.py ├── notebook └── offline.ipynb ├── plotMetrics.py ├── requirements.txt ├── run_all.py ├── run_pretrain.py ├── run_pretrain_all.py ├── run_tus_all.py ├── sdd ├── __init__.py ├── augment.py ├── baselines.py ├── dataset.py ├── model.py ├── preprocessor.py ├── pretrain.py └── utils.py ├── starmie_overall.jpg ├── test_bounds.py ├── test_hnsw_search.py ├── test_lsh.py └── test_naive_search.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantics-aware Dataset Discovery from Data Lakes with Contextualized Column-based Representation Learning 2 | 3 | ![The overall architecture of Starmie](starmie_overall.jpg) 4 | 5 | ### Requirements 6 | 7 | * Python 3.7.10 8 | * PyTorch 1.9.0+cu111 9 | * Transformers 4.9.2 10 | * NVIDIA Apex (fp16 training) 11 | 12 | Install requirements: 13 | ``` 14 | pip install -r requirements 15 | ``` 16 | 17 | ### Datasets 18 | 19 | Datasets for table union search: 20 | * Santos: https://zenodo.org/record/7758091 21 | * TUS: https://github.com/RJMillerLab/table-union-search-benchmark 22 | 23 | WDC web tables: 24 | * See download instructions [here](https://webdatacommons.org/webtables/) for the 50M relational English tables. 25 | 26 | Viznet: https://github.com/megagonlabs/sato/tree/master/table_data 27 | 28 | ### Running the offline pre-training pipeline: 29 | 30 | The main entry point is `run_pretrain.py`. Example command: 31 | 32 | ``` 33 | CUDA_VISIBLE_DEVICES=0 python run_pretrain.py \ 34 | --task viznet \ 35 | --batch_size 64 \ 36 | --lr 5e-5 \ 37 | --lm roberta \ 38 | --n_epochs 3 \ 39 | --max_len 128 \ 40 | --size 10000 \ 41 | --projector 768 \ 42 | --save_model \ 43 | --augment_op drop_col \ 44 | --fp16 \ 45 | --sample_meth head \ 46 | --table_order column \ 47 | --run_id 0 48 | ``` 49 | 50 | Hyperparameters: 51 | 52 | * `--task`: the tasks that we current support include "santos", "santosLarge", "tus", "tusLarge", "large", "small", "viznet". The tasks "large" and "small" are for column matching and "viznet" is for column clustering. 53 | * `--batch_size`, `--lr`, `--n_epochs`, `--max_len`: standard batch size, learning rate, number of training epochs, max sequence length 54 | * `--lm`: the language model (we use roberta for all the experiments) 55 | * `--size`: the maximum number of tables/columns used during pre-training 56 | * `--projector`: the dimension of projector (768 by default, same in all the experiments) 57 | * `--save_model`: if this flag is on, the model checkpoint will be saved to the directory specified in the `--logdir` flag, such as `"results/viznet/model_drop_col_head_column_0.pt"` 58 | * `--augment_op`: augmentation operator for contrastive learning. It includes `["drop_col", "sample_row", "sample_row_ordered", "shuffle_col", "drop_cell", "sample_cells", "replace_cells", "drop_head_cells", "drop_num_cells", "swap_cells", "drop_num_col", "drop_nan_col", "shuffle_row"]` 59 | 1. Column-level: `drop_col` (drops a random column), `shuffle_col` (shuffles columns), `drop_num_col` (drops random numeric columns), `drop_nan_col` (drops columns with mostly NaNs) 60 | 2. Row-level: `sample_row` (sample rows), `sample_row_ordered` (sample rows but preserve order), `shuffle_row` (shuffles the order of rows) 61 | 3. Cell-level: `drop_cell` (drops a random cell), `sample_cells` (sample cells), `replace_cells` (sample random cells and replace with first ordered cells), `drop_head_cells` (drop first quarter cells), `drop_num_cells` (drop a sample of numeric cells), `swap_cells` (swap two cells) 62 | * `--sample_meth`: table pre-processing operator that preserves order and de-duplicates. It includes `["head", "alphaHead", "random", "constant", "frequent", "tfidf_token", "tfidf_entity", "tfidf_row", "pmi"]` 63 | 1. Row-level: `tfidf_row` (takes the rows with highest average tfidf scores), `pmi` (get highest pmi of pairs of column with topic column) 64 | 2. Entity-level: `tfidf_entity` (takes entities in a column with highest after tfidf scores over its tokens) 65 | 3. Token-level: `head` (take first N tokens), `alphaHead` (take first N sorted tokens), `random` (randomly sample tokens), `constant` (take every Nth token), `frequent` (take most frequently-occurring tokens), `tfidf_token` (take tokens with highest tfidf scores) 66 | * `--fp16`: half-precision training (always turn this on) 67 | * `--table_order`: row or column order for pre-processing, "row" or "column" 68 | * `--single_column`: if this flag is on, then it will run the single-column variant ignoring all the 69 | table context 70 | * `--mlflow_tag`: use this flag to assign any additional tags for mlflow logging 71 | 72 | ### Model Inference: 73 | Run `extractVectors.py`. Example command: 74 | 75 | ``` 76 | python extractVectors.py \ 77 | --benchmark santos \ 78 | --table_order column \ 79 | --run_id 0 80 | ``` 81 | 82 | Hyperparameters 83 | * `--benchmark`: the current benchmark for the experiment. Examples include `santos`, `santosLarge`, `tus`, `tusLarge`, `wdc` 84 | * `--single_column`: if this flag is on, then it will retrieve the single-column variant 85 | * `--run_id`: the run_id of the job (I use 0 for experiments) 86 | * `--table_order`: column-ordered or row-ordered (always use `column`) 87 | * `--save_model`: whether to save the vectors in a pickle file, which is then used in the online processing 88 | 89 | 90 | ### Online processing 91 | 92 | 1. Linear & Bounds: Run `test_naive_search.py`. Some scripts are in `tus_cmd.sh` and `run_tus_all.py` (for slurm scheduling). Example command: 93 | 94 | ``` 95 | python test_naive_search.py \ 96 | --encoder cl \ 97 | --benchmark santos \ 98 | --augment_op drop_col \ 99 | --sample_meth tfidf_entity \ 100 | --matching linear \ 101 | --table_order column \ 102 | --run_id 0 \ 103 | --K 10 \ 104 | --threshold 0.7 105 | ``` 106 | 107 | Hyperparameters 108 | * `--encoder`: choice of encoder. Options include "cl" (this is for both full Starmie and 109 | singleCol baseline), "sato", "sherlock" 110 | * `--benchmark`: choice of benchmark for data lake. Options include "santos", "santosLarge", 111 | "tus", "tusLarge", "wdc" 112 | * `--augment_op`: choice of augmentation operator 113 | * `--sample_meth`: choice of sampling method 114 | * `--matching`: "linear" matching (full) or "bounds". If you would like to run "greedy", add the 115 | function call to the code 116 | * `--table_order`: "column" or "row" (just use column) 117 | * `--run_id`: always 0 118 | * `--single_column`: when set to True, run the single column baseline 119 | * `--K`: what you would like to set K to in top-K results 120 | * `--threshold`: the similarity threshold 121 | 122 | FOR ERROR ANALYSIS: bucket (bucket number between 0 and 5), analysis (either "col" for number of columns, "row" for number of rows,numeric" for percentage of numerical columns 123 | 124 | FOR SCALABILITY EXPERIMENTS: scal (what fraction of data lake do we want to get the metrics scores for – 0.2,0.4,0.6,0.8,1.0) 125 | 126 | 127 | 128 | 2. LSH: Run test_lsh.py (example script: lsh_cmd.sh). Example command: 129 | 130 | ``` 131 | python test_lsh.py \ 132 | --encoder cl \ 133 | --benchmark santosLarge \ 134 | --run_id 0 \ 135 | --num_func 8 \ 136 | --num_table 100 \ 137 | --K 60 \ 138 | --scal 1.0 139 | ``` 140 | 141 | Hyperparameters: 142 | * `--encoder`: choice of encoder. Options include "cl" (this is for both full Starmie and 143 | singleCol baseline), "sato", "sherlock" 144 | * `--benchmark`: choice of benchmark for data lake. Options include "santos", "santosLarge", 145 | "tus", "tusLarge", "wdc" 146 | * `--run_id`: always 0 147 | * `--single_column`: when set to True, run the single column baseline 148 | * `--num_func`: number of hash functions (always use 8 for ‘cl’ encoder) 149 | * `--num_table`: number of tables (always use 100 for ‘cl’ encoder) 150 | * `--K`: what you would like to set K to in top-K results 151 | 152 | FOR SCALABILITY EXPERIMENTS: scal (what fraction of data lake do we want to get 153 | the metrics scores for – 0.2,0.4,0.6,0.8,1.0) 154 | 155 | 3. HNSW: Run test_hnsw_search.py (example script: hnsw_cmd.sh). 156 | Example command: 157 | ``` 158 | python test_hnsw_search.py \ 159 | --encoder cl \ 160 | --benchmark santosLarge \ 161 | --run_id 0 \ 162 | --K 60 \ 163 | --scal 1.0 164 | ``` 165 | 166 | Hyperparameters: 167 | 168 | * `--encoder`: choice of encoder. Options include "cl" (this is for both full Starmie and 169 | singleCol baseline), "sato", "sherlock" 170 | * `--benchmark`: choice of benchmark for data lake. Options include "santos", "santosLarge", 171 | "tus", "tusLarge", "wdc" 172 | * `--run_id`: always 0 173 | * `--single_column`: when set to True, run the single column baseline 174 | * `--K`: what you would like to set K to in top-K results 175 | 176 | FOR SCALABILITY EXPERIMENTS: scal (what fraction of data lake do we want to get 177 | the metrics scores for – 0.2,0.4,0.6,0.8,1.0) 178 | 179 | 180 | 181 | ## Data discovery for ML tasks: 182 | 183 | Run `discovery.py`. We assume: 184 | 1. A model checkpoint in `results/viznet/model_drop_col_head_column_0.pt` 185 | 2. The viznet dataset in `data/viznet/` 186 | 187 | Run the script by 188 | ``` 189 | python discovery.py 190 | ``` 191 | The code will print out the MSE for NoJoin, contrastiving learning (CL), Jaccard, and Overlap. The joined tables will be output to pickled files named `none_joined_tables.pkl`, `cl_joined_tables.pkl`, `jaccard_joined_tables.pkl`, and `overlap_joined_tables.pkl`. 192 | 193 | ### Column clustering: 194 | 195 | See Line 273 and Line 128 of the file `sdd/pretrain.py`. 196 | To run column clustering, you can run a sequence of commands (remember to check the file paths): 197 | 198 | ``` 199 | CUDA_VISIBLE_DEVICES=7 python run_pretrain.py \ 200 | --task viznet \ 201 | --batch_size 64 \ 202 | --lr 5e-5 \ 203 | --lm roberta \ 204 | --n_epochs 3 \ 205 | --max_len 128 \ 206 | --size 10000 \ 207 | --projector 768 \ 208 | --save_model \ 209 | --augment_op drop_col \ 210 | --fp16 \ 211 | --sample_meth head \ 212 | --table_order column \ 213 | --run_id 0 214 | ``` 215 | 216 | Copy the clustering results: 217 | ``` 218 | cp *.pkl data/viznet/multi_column 219 | ``` 220 | 221 | Each run will pre-train the models on 10k viznet tables and cluster all the columns. The clustering results will be stored at `data/viznet/multi_column/clusters.pkl` and `data/viznet/single_column/`. 222 | 223 | To view the clusters, you can use the jupyter notebook in `notebook/offline.ipynb`. Running the last cell should print out some clusters like 224 | 225 | ``` 226 | artist ---- 1. I Don't Give A ...; 2. I'm The Kinda; 3. I U She; 4. Kick It [featuring Iggy Pop]; 5. 227 | Operate 228 | artist ---- 1. Spoken Intro; 2. The Court; 3. Maze; 4. Girl Talk; 5. A La Mode 229 | artist ---- 1. Street Fighting Man; 2. Gimme Shelter; 3. (I Can't Get No) Satisfaction; 4. The 230 | Last Time; 5. Jumpin' Jack Flash 231 | … 232 | --------------------------------- 233 | type ---- Emerson Elementary School; Banneker Elementary School; Silver City Elementary 234 | School; New Stanley Elementary School; Frances Willard Elementary School 235 | type ---- Choctawhatchee Senior High School; Fort Walton Beach High School; Ami Kids 236 | Emerald Coast; Gulf Coast Christian School; Adolescent Substance Abuse 237 | city ---- Chilton; Stoughton 238 | … 239 | --------------------------------- 240 | description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 241 | 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST) 242 | day ---- Sept. 1; Sept. 7; Sept. 22; Sept. 29; Oct. 5 243 | description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 244 | 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST) 245 | ... 246 | address ---- 1721 Papillon St, North Port FL; 4113 Wabasso Ave, North Port FL; 3681 247 | Wayward Ave, North Port FL; 1118 N Salford Blvd, North Port FL; 2057 Bendix Ter, North 248 | Port FL 249 | address ---- 5 Brand Rd, Toms River NJ; 40 12th St, Toms River NJ; 75 Sea Breeze Rd, 250 | Toms River NJ; 98 Oak Tree Ln, Toms River NJ; 67 16th St, Toms River NJ 251 | address ---- 652 Martha St, Montgomery AL; 3184 Lexington Rd, Montgomery AL; 120 S 252 | Lewis St, Montgomery AL; 1812 W 2nd St #OP, Montgomery AL; 3582 Southview Ave, 253 | Montgomery AL 254 | --------------------------------- 255 | ``` 256 | ### Citation 257 | If you are using the code in this repo, please cite the following in your work: 258 | ``` 259 | @article{DBLP:journals/pvldb/FanWLZM23, 260 | author = {Grace Fan and 261 | Jin Wang and 262 | Yuliang Li and 263 | Dan Zhang and 264 | Ren{\'{e}}e J. Miller}, 265 | title = {Semantics-aware Dataset Discovery from Data Lakes with Contextualized 266 | Column-based Representation Learning}, 267 | journal = {Proc. {VLDB} Endow.}, 268 | volume = {16}, 269 | number = {7}, 270 | pages = {1726--1739}, 271 | year = {2023} 272 | } 273 | ``` 274 | 275 | ## Disclosure 276 | 277 | Embedded in, or bundled with, this product are open source software (OSS) components, datasets and other third party components identified below. The license terms respectively governing the datasets and third-party components continue to govern those portions, and you agree to those license terms, which, when applicable, specifically limit any distribution. You may receive a copy of, distribute and/or modify any open source code for the OSS component under the terms of their respective licenses. In the event of conflicts between Megagon Labs, Inc. Recruit Co., Ltd., license conditions and the Open Source Software license conditions, the Open Source Software conditions shall prevail with respect to the Open Source Software portions of the software. 278 | You agree not to, and are not permitted to, distribute actual datasets used with the OSS components listed below. You agree and are limited to distribute only links to datasets from known sources by listing them in the datasets overview table below. You are permitted to distribute derived datasets of data sets from known sources by including links to original dataset source in the datasets overview table below. You agree that any right to modify datasets originating from parties other than Megagon Labs, Inc. are governed by the respective third party’s license conditions. 279 | All OSS components and datasets are distributed WITHOUT ANY WARRANTY, without even implied warranty such as for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE, and without any liability to or claim against any Megagon Labs, Inc. entity other than as explicitly documented in this README document. You agree to cease using any part of the provided materials if you do not agree with the terms or the lack of any warranty herein. 280 | While Megagon Labs, Inc., makes commercially reasonable efforts to ensure that citations in this document are complete and accurate, errors may occur. If you see any error or omission, please help us improve this document by sending information to contact_oss@megagon.ai. 281 | 282 | ## Contact 283 | 284 | If you have any questions regarding the code and the paper, please directly contact Grace Fan (fan.gr@northeastern.edu). 285 | 286 | -------------------------------------------------------------------------------- /appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megagonlabs/starmie/5eb90fe27fb1162d2a62b555ac54908ee8e4c474/appendix.pdf -------------------------------------------------------------------------------- /bounds.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import random 4 | import os 5 | 6 | from munkres import Munkres, make_cost_matrix, DISALLOWED 7 | from numpy.linalg import norm 8 | 9 | 10 | def cosine_sim(vec1, vec2): 11 | assert vec1.ndim == vec2.ndim 12 | return np.dot(vec1, vec2) / (norm(vec1)*norm(vec2)) 13 | 14 | 15 | def verify(table1, table2, threshold=0.6): 16 | score = 0.0 17 | nrow = len(table1) 18 | ncol = len(table2) 19 | graph = np.zeros(shape=(nrow,ncol),dtype=float) 20 | for i in range(nrow): 21 | for j in range(ncol): 22 | sim = cosine_sim(table1[i],table2[j]) 23 | if sim > threshold: 24 | graph[i,j] = sim 25 | max_graph = make_cost_matrix(graph, lambda cost: (graph.max() - cost) if (cost != DISALLOWED) else DISALLOWED) 26 | m = Munkres() 27 | indexes = m.compute(max_graph) 28 | for row,col in indexes: 29 | score += graph[row,col] 30 | return score 31 | 32 | def upper_bound_bm(edges, nodes1, nodes2): 33 | ''' 34 | Calculate the upper bound of the bipartite matching 35 | Input: 36 | table1/table2: two tables each of which is with a set of column vectors 37 | threshold: the minimum cosine similarity to include an edge 38 | Output: 39 | The upper bound of the bipartite matching score (no smaller than true score) 40 | ''' 41 | score = 0.0 42 | for e in edges: 43 | score += e[0] 44 | nodes1.discard(e[1]) 45 | nodes2.discard(e[2]) 46 | if len(nodes1) == 0 or len(nodes2) == 0: 47 | return score 48 | return score 49 | 50 | def lower_bound_bm(edges, nodes1, nodes2): 51 | ''' 52 | Output the lower bound of the bipartite matching score (no larger than true score) 53 | ''' 54 | score = 0.0 55 | for e in edges: 56 | if e[1] in nodes1 and e[2] in nodes2: 57 | score += e[0] 58 | nodes1.discard(e[1]) 59 | nodes2.discard(e[2]) 60 | if len(nodes1) == 0 or len(nodes2) == 0: 61 | return score 62 | return score 63 | 64 | 65 | def get_edges(table1, table2, threshold): 66 | ''' 67 | Generate the similarity graph used by lower bounds and upper bounds 68 | Args: 69 | table1 (numpy array): the vectors of the query (# rows: # columns in a table, #cols: dimension of embedding) 70 | table2 (numpy array): similar to table1, set of column vectors of the data lake table 71 | threshold (float): minimum cosine similarity to include an edge 72 | Return: 73 | list of edges and sets of nodes used in lower and upper bounds calculations 74 | ''' 75 | nrow = len(table1) 76 | ncol = len(table2) 77 | edges = [] 78 | nodes1 = set() 79 | nodes2 = set() 80 | for i in range(nrow): 81 | for j in range(ncol): 82 | sim = cosine_sim(table1[i],table2[j]) 83 | if sim > threshold: 84 | edges.append((sim,i,j)) 85 | nodes1.add(i) 86 | nodes2.add(j) 87 | edges.sort(reverse=True) 88 | return edges, nodes1, nodes2 -------------------------------------------------------------------------------- /checkPrecisionRecall.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pickle5 as p 3 | import pandas as pd 4 | from matplotlib import * 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | import mlflow 8 | 9 | def loadDictionaryFromPickleFile(dictionaryPath): 10 | ''' Load the pickle file as a dictionary 11 | Args: 12 | dictionaryPath: path to the pickle file 13 | Return: dictionary from the pickle file 14 | ''' 15 | filePointer=open(dictionaryPath, 'rb') 16 | dictionary = p.load(filePointer) 17 | filePointer.close() 18 | return dictionary 19 | 20 | def saveDictionaryAsPickleFile(dictionary, dictionaryPath): 21 | ''' Save dictionary as a pickle file 22 | Args: 23 | dictionary to be saved 24 | dictionaryPath: filepath to which the dictionary will be saved 25 | ''' 26 | filePointer=open(dictionaryPath, 'wb') 27 | pickle.dump(dictionary,filePointer, protocol=pickle.HIGHEST_PROTOCOL) 28 | filePointer.close() 29 | 30 | 31 | def calcMetrics(max_k, k_range, resultFile, gtPath=None, resPath=None, record=True): 32 | ''' Calculate and log the performance metrics: MAP, Precision@k, Recall@k 33 | Args: 34 | max_k: the maximum K value (e.g. for SANTOS benchmark, max_k = 10. For TUS benchmark, max_k = 60) 35 | k_range: step size for the K's up to max_k 36 | gtPath: file path to the groundtruth 37 | resPath: file path to the raw results from the model 38 | record (boolean): to log in MLFlow or not 39 | Return: MAP, P@K, R@K 40 | ''' 41 | groundtruth = loadDictionaryFromPickleFile(gtPath) 42 | # resultFile = loadDictionaryFromPickleFile(resPath) 43 | 44 | # ============================================================================= 45 | # Precision and recall 46 | # ============================================================================= 47 | precision_array = [] 48 | recall_array = [] 49 | for k in range(1, max_k+1): 50 | true_positive = 0 51 | false_positive = 0 52 | false_negative = 0 53 | rec = 0 54 | ideal_recall = [] 55 | for table in resultFile: 56 | # t28 tables have less than 60 results. So, skipping them in the analysis. 57 | if table.split("____",1)[0] != "t_28dc8f7610402ea7": 58 | if table in groundtruth: 59 | groundtruth_set = set(groundtruth[table]) 60 | groundtruth_set = {x.split(".")[0] for x in groundtruth_set} 61 | result_set = resultFile[table][:k] 62 | result_set = [x.split(".")[0] for x in result_set] 63 | # find_intersection = true positives 64 | find_intersection = set(result_set).intersection(groundtruth_set) 65 | tp = len(find_intersection) 66 | fp = k - tp 67 | fn = len(groundtruth_set) - tp 68 | if len(groundtruth_set)>=k: 69 | true_positive += tp 70 | false_positive += fp 71 | false_negative += fn 72 | rec += tp / (tp+fn) 73 | ideal_recall.append(k/len(groundtruth[table])) 74 | precision = true_positive / (true_positive + false_positive) 75 | recall = rec/len(resultFile) 76 | precision_array.append(precision) 77 | recall_array.append(recall) 78 | if k % 10 == 0: 79 | print(k, "IDEAL RECALL:", sum(ideal_recall)/len(ideal_recall)) 80 | used_k = [k_range] 81 | if max_k >k_range: 82 | for i in range(k_range * 2, max_k+1, k_range): 83 | used_k.append(i) 84 | print("--------------------------") 85 | for k in used_k: 86 | print("Precision at k = ",k,"=", precision_array[k-1]) 87 | print("Recall at k = ",k,"=", recall_array[k-1]) 88 | print("--------------------------") 89 | 90 | map_sum = 0 91 | for k in range(0, max_k): 92 | map_sum += precision_array[k] 93 | mean_avg_pr = map_sum/max_k 94 | print("The mean average precision is:", mean_avg_pr) 95 | 96 | # logging to mlflow 97 | if record: # if the user would like to log to MLFlow 98 | mlflow.log_metric("mean_avg_precision", mean_avg_pr) 99 | mlflow.log_metric("prec_k", precision_array[max_k-1]) 100 | mlflow.log_metric("recall_k", recall_array[max_k-1]) 101 | 102 | return mean_avg_pr, precision_array[max_k-1], recall_array[max_k-1] -------------------------------------------------------------------------------- /cmd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python run_pretrain.py \ 2 | --task small \ 3 | --batch_size 64 \ 4 | --lr 5e-5 \ 5 | --lm roberta \ 6 | --n_epochs 3 \ 7 | --max_len 128 \ 8 | --size 10000 \ 9 | --projector 768 \ 10 | --save_model \ 11 | --augment_op drop_col \ 12 | --fp16 \ 13 | --sample_meth head \ 14 | --table_order column \ 15 | --run_id 0 16 | -------------------------------------------------------------------------------- /discovery.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pickle 4 | import torch 5 | import os 6 | 7 | from sentence_transformers import SentenceTransformer 8 | from xgboost import XGBRegressor 9 | from sklearn.metrics import mean_squared_error 10 | from tqdm import tqdm 11 | 12 | def clean_table(table, target='Rating'): 13 | """Clean an input table. 14 | """ 15 | if target not in table: 16 | return table 17 | 18 | new_vals = [] 19 | for val in table[target]: 20 | try: 21 | if isinstance(val, str): 22 | val = val.replace(',', '').replace('%', '') 23 | new_vals.append(float(val)) 24 | except: 25 | new_vals.append(float('nan')) 26 | 27 | table[target] = new_vals 28 | return table.dropna(subset=[target]) 29 | 30 | 31 | lm = SentenceTransformer('paraphrase-MiniLM-L6-v2').to('cuda') 32 | lm.eval() 33 | 34 | def featurize(table, target='Rating'): 35 | """Featurize a query table. 36 | """ 37 | all_vectors = [] 38 | for column in table: 39 | if column == target: 40 | continue 41 | if table[column].dtype.kind in 'if': 42 | all_vectors.append(np.expand_dims(table[column], axis=1)) 43 | else: 44 | with torch.no_grad(): 45 | vectors = lm.encode(list(table[column].astype('str'))) 46 | all_vectors.append(vectors) 47 | 48 | return np.concatenate(all_vectors, axis=1), np.array(table[target]) 49 | 50 | def process_query_tables(query_tables): 51 | """Run ML on a dictionary of query tables. 52 | """ 53 | for table in query_tables.values(): 54 | N = len(table) 55 | table['Rating'] = (table['Rating'] - table['Rating'].min()) / (table['Rating'].max() - table['Rating'].min() + 1e-6) 56 | # table['Rating'] = table['Rating'] / (table['Rating'].max() + 1e-6) 57 | table = table.sample(frac=1.0, random_state=42) 58 | train = table[:N//5*4] 59 | test = table[N//5*4:] 60 | 61 | x, y = featurize(train) 62 | model = XGBRegressor() 63 | model.fit(x, y) 64 | 65 | x, y = featurize(test) 66 | y_pred = model.predict(x) 67 | # print(len(table), mean_squared_error(y, y_pred)) 68 | print(mean_squared_error(y, y_pred)) 69 | 70 | 71 | def check_table_pair(table_a, vectors_a, table_b, vectors_b, method='naive', target='Rating'): 72 | """Check if two tables are joinable. Return the join result and the similarity score 73 | """ 74 | best_pair = None 75 | max_score = -1 76 | target_sim = 0.0 77 | 78 | for col_a, vec_a in zip(table_a, vectors_a): 79 | norm_vec_a = np.linalg.norm(vec_a) 80 | if col_a == target: 81 | if method == 'cl': 82 | for col_b, vec_b in zip(table_b, vectors_b): 83 | if table_a[col_a].dtype != table_b[col_b].dtype: 84 | continue 85 | sim = np.dot(vec_a, vec_b) / norm_vec_a / np.linalg.norm(vec_b) 86 | # if sim > 0: 87 | target_sim += sim 88 | else: 89 | continue 90 | seta = set(table_a[col_a].unique()) 91 | 92 | for col_b, vec_b in zip(table_b, vectors_b): 93 | if table_a[col_a].dtype != table_b[col_b].dtype: 94 | continue 95 | setb = set(table_b[col_b].unique()) 96 | if method == 'jaccard': 97 | score = len(seta.intersection(setb)) / len(seta.union(setb)) 98 | elif method == 'cl': 99 | overlap = len(seta.intersection(setb)) # / len(seta.union(setb)) 100 | score = float(overlap) * (1.0 + np.dot(vec_a, vec_b) / norm_vec_a / np.linalg.norm(vec_b)) 101 | elif method == 'overlap': 102 | score = len(seta.intersection(setb)) / len(seta) 103 | else: 104 | score = 0.0 105 | 106 | if score > max_score: 107 | max_score = score 108 | best_pair = col_a, col_b 109 | 110 | if target_sim > 0: 111 | max_score *= target_sim 112 | 113 | return best_pair, max_score 114 | 115 | 116 | if __name__ == '__main__': 117 | # step 1: load columns and vectors 118 | viznet_columns = pd.read_csv('data/viznet/test.csv.full') 119 | 120 | # step 2: select data lake tables 121 | if os.path.exists('datalake_tables.pkl'): 122 | tables, table_vectors = pickle.load(open('datalake_tables.pkl', 'rb')) 123 | else: 124 | tables = {} 125 | for table_id in tqdm(viznet_columns['table_id'], total=len(viznet_columns)): 126 | # get table length 127 | table = pd.read_csv('data/viznet/tables/table_%d.csv' % table_id) 128 | if len(table) >= 50: 129 | tables[table_id] = table 130 | 131 | from sdd.pretrain import load_checkpoint, inference_on_tables 132 | table_vectors = {} 133 | ckpt_path = "results/viznet/model_drop_col_head_column_0.pt" 134 | # ckpt_path = "results/viznet/model_drop_col_head_128_0.pt" 135 | ckpt = torch.load(ckpt_path) 136 | table_model, table_dataset = load_checkpoint(ckpt) 137 | all_tables = list(tables.values()) 138 | vectors = inference_on_tables(all_tables, table_model, table_dataset) 139 | for tid, v in zip(tables, vectors): 140 | table_vectors[tid] = v 141 | 142 | pickle.dump((tables, table_vectors), open('datalake_tables.pkl', 'wb')) 143 | 144 | # step 3: select query tables 145 | query_tables = {} 146 | total_rows = 0 147 | for tid, table in tables.items(): 148 | if 'Rating' in table: 149 | table = clean_table(table) 150 | if len(table) >= 200: 151 | query_tables[tid] = table 152 | total_rows += len(table) 153 | 154 | # step 4: run each data discovery method 155 | for method in ['cl']: 156 | # for method in ['none', 'cl', 'jaccard', 'overlap']: 157 | result_tables = {} 158 | 159 | for tid_a in tqdm(query_tables): 160 | best_table = query_tables[tid_a] 161 | if method == 'none': 162 | result_tables[tid_a] = best_table 163 | continue 164 | 165 | best_similarity = -1.0 166 | best_pair = None 167 | table_a = tables[tid_a] 168 | vectors_a = table_vectors[tid_a] 169 | 170 | for tid_b in tables: 171 | if tid_b in query_tables: 172 | continue 173 | if tid_a != tid_b: 174 | table_b = tables[tid_b] 175 | vectors_b = table_vectors[tid_b] 176 | 177 | res, similarity = check_table_pair(table_a, vectors_a, 178 | table_b, vectors_b, method=method) 179 | if res is not None and similarity > best_similarity: 180 | best_similarity = similarity 181 | best_table = table_b 182 | best_pair = res 183 | 184 | if best_similarity >= 0: 185 | table_b_tmp = best_table.drop_duplicates(subset=[best_pair[1]]).set_index(best_pair[1]) 186 | best_table = table_a.join(table_b_tmp, on=best_pair[0], rsuffix='_r') 187 | else: 188 | best_table = table_a 189 | result_tables[tid_a] = best_table 190 | # result_tables.append(best_table) 191 | pickle.dump(result_tables, open('%s_joined_tables.pkl' % method, 'wb')) 192 | process_query_tables(result_tables) 193 | -------------------------------------------------------------------------------- /extractVectors.py: -------------------------------------------------------------------------------- 1 | from sdd.pretrain import load_checkpoint, inference_on_tables 2 | import torch 3 | import pandas as pd 4 | import numpy as np 5 | import glob 6 | import pickle 7 | import time 8 | import sys 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | def extractVectors(dfs, dataFolder, augment, sample, table_order, run_id, singleCol=False): 13 | ''' Get model inference on tables 14 | Args: 15 | dfs (list of DataFrames): tables to get model inference on 16 | dataFolder (str): benchmark folder name 17 | augment (str): augmentation operator used in vector file path (e.g. 'drop_cell') 18 | sample (str): sampling method used in vector file path (e.g. 'head') 19 | table_order (str): 'column' or 'row' ordered 20 | run_id (int): used in file path 21 | singleCol (boolean): is this for single column baseline 22 | Return: 23 | list of features for the dataframe 24 | ''' 25 | if singleCol: 26 | model_path = "results/%s/model_%s_%s_%s_%dsingleCol.pt" % (dataFolder, augment, sample, table_order,run_id) 27 | else: 28 | model_path = "results/%s/model_%s_%s_%s_%d.pt" % (dataFolder, augment, sample, table_order,run_id) 29 | ckpt = torch.load(model_path, map_location=torch.device('cuda')) 30 | # load_checkpoint from sdd/pretain 31 | model, trainset = load_checkpoint(ckpt) 32 | return inference_on_tables(dfs, model, trainset, batch_size=1024) 33 | 34 | def get_df(dataFolder): 35 | ''' Get the DataFrames of each table in a folder 36 | Args: 37 | dataFolder: filepath to the folder with all tables 38 | Return: 39 | dataDfs (dict): key is the filename, value is the dataframe of that table 40 | ''' 41 | dataFiles = glob.glob(dataFolder+"/*.csv") 42 | dataDFs = {} 43 | for file in dataFiles: 44 | df = pd.read_csv(file,lineterminator='\n') 45 | if len(df) > 1000: 46 | # get first 1000 rows 47 | df = df.head(1000) 48 | filename = file.split("/")[-1] 49 | dataDFs[filename] = df 50 | return dataDFs 51 | 52 | 53 | if __name__ == '__main__': 54 | ''' Get the model features by calling model inference from sdd/pretrain 55 | ''' 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--benchmark", type=str, default="santos") # can be 'santos', 'santosLarge', 'tus', 'tusLarge', 'wdc' 58 | # single-column mode without table context 59 | parser.add_argument("--single_column", dest="single_column", action="store_true") 60 | parser.add_argument("--run_id", type=int, default=0) 61 | parser.add_argument("--table_order", type=str, default='column') 62 | parser.add_argument("--save_model", dest="save_model", action="store_true") 63 | 64 | hp = parser.parse_args() 65 | 66 | # START PARAMETER: defining the benchmark (dataFolder), if it is a single column baseline, 67 | # run_id, table_order, and augmentation operators and sampling method if they are different from default 68 | dataFolder = hp.benchmark 69 | isSingleCol = hp.single_column 70 | if 'santos' in dataFolder or dataFolder == 'wdc': 71 | ao = 'drop_col' 72 | sm = 'tfidf_entity' 73 | if isSingleCol: 74 | ao = 'drop_cell' 75 | elif dataFolder == 'tus': 76 | ao = 'drop_cell' 77 | sm = 'alphaHead' 78 | else: # dataFolder = tusLarge 79 | ao = 'drop_cell' 80 | sm = 'tfidf_entity' 81 | 82 | run_id = hp.run_id 83 | table_order = hp.table_order 84 | # END PARAMETER 85 | 86 | # Change the data paths to where the benchmarks are stored 87 | if dataFolder == 'santos': 88 | DATAPATH = "data/santos/" 89 | dataDir = ['query', 'datalake'] 90 | elif dataFolder == 'santosLarge': 91 | DATAPATH = 'data/santos-benchmark/real-benchmark/' 92 | dataDir = ['query', 'datalake'] 93 | elif dataFolder == 'tus': 94 | DATAPATH = 'data/table-union-search-benchmark/small/' 95 | dataDir = ['santos-query', 'benchmark'] 96 | elif dataFolder == 'tusLarge': 97 | DATAPATH = 'data/table-union-search-benchmark/large/' 98 | dataDir = ['query', 'benchmark'] 99 | elif dataFolder == 'wdc': 100 | DATAPATH = {'query': 'data/wdc/query', 'benchmark': 'data/wdc/0/'} 101 | dataDir = ['query', 'benchmark'] 102 | 103 | inference_times = 0 104 | # dataDir is the query and data lake 105 | for dir in dataDir: 106 | print("//==== ", dir) 107 | if dataFolder == 'wdc': 108 | DATAFOLDER = DATAPATH[dir] 109 | else: 110 | DATAFOLDER = DATAPATH+dir 111 | dfs = get_df(DATAFOLDER) 112 | print("num dfs:",len(dfs)) 113 | 114 | dataEmbeds = [] 115 | dfs_totalCount = len(dfs) 116 | dfs_count = 0 117 | 118 | # Extract model vectors, and measure model inference time 119 | start_time = time.time() 120 | cl_features = extractVectors(list(dfs.values()), dataFolder, ao, sm, table_order, run_id, singleCol=isSingleCol) 121 | inference_times += time.time() - start_time 122 | print("%s %s inference time: %d seconds" %(dataFolder, dir, time.time() - start_time)) 123 | for i, file in enumerate(dfs): 124 | dfs_count += 1 125 | # get features for this file / dataset 126 | cl_features_file = np.array(cl_features[i]) 127 | dataEmbeds.append((file, cl_features_file)) 128 | if dir == 'santos-query': 129 | saveDir = 'query' 130 | elif dir == 'benchmark': 131 | saveDir = 'datalake' 132 | else: saveDir = dir 133 | 134 | if isSingleCol: 135 | output_path = "data/%s/vectors/cl_%s_%s_%s_%s_%d_singleCol.pkl" % (dataFolder, saveDir, ao, sm, table_order, run_id) 136 | else: 137 | output_path = "data/%s/vectors/cl_%s_%s_%s_%s_%d.pkl" % (dataFolder, saveDir, ao, sm, table_order, run_id) 138 | if hp.save_model: 139 | pickle.dump(dataEmbeds, open(output_path, "wb")) 140 | print("Benchmark: ", dataFolder) 141 | print("--- Total Inference Time: %s seconds ---" % (inference_times)) 142 | -------------------------------------------------------------------------------- /hnsw_cmd.sh: -------------------------------------------------------------------------------- 1 | time python test_hnsw_search.py \ 2 | --encoder cl \ 3 | --benchmark santosLarge \ 4 | --run_id 0 \ 5 | --K 60 \ 6 | --scal 1.0 -------------------------------------------------------------------------------- /hnsw_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import pickle 4 | import time 5 | import hnswlib 6 | 7 | from munkres import Munkres, make_cost_matrix, DISALLOWED 8 | from numpy.linalg import norm 9 | 10 | 11 | class HNSWSearcher(object): 12 | def __init__(self, 13 | table_path, 14 | index_path, 15 | scale 16 | ): 17 | tfile = open(table_path,"rb") 18 | tables = pickle.load(tfile) 19 | # For scalability experiments: load a percentage of tables 20 | self.tables = random.sample(tables, int(scale*len(tables))) 21 | print("From %d total data-lake tables, scale down to %d tables" % (len(tables), len(self.tables))) 22 | tfile.close() 23 | self.vec_dim = len(self.tables[1][1][0]) 24 | 25 | index_start_time = time.time() 26 | self.index = hnswlib.Index(space='cosine', dim=self.vec_dim) 27 | self.all_columns, self.col_table_ids = self._preprocess_table_hnsw() 28 | # if not os.path.exists(index_path): 29 | # build index from scratch 30 | # self.index.init_index(max_elements=len(self.all_columns), ef_construction=100, M=16) 31 | self.index.init_index(max_elements=len(self.all_columns), ef_construction=100, M=32) 32 | 33 | self.index.set_ef(10) 34 | self.index.add_items(self.all_columns) 35 | self.index.save_index(index_path) 36 | print("--- Indexing Time: %s seconds ---" % (time.time() - index_start_time)) 37 | # else: 38 | # # load index 39 | # self.index.load_index(index_path, max_elements = len(self.all_columns)) 40 | 41 | def topk(self, enc, query, K, N=5, threshold=0.6): 42 | # Note: N is the number of columns retrieved from the index 43 | query_cols = [] 44 | for col in query[1]: 45 | query_cols.append(col) 46 | candidates = self._find_candidates(query_cols, N) 47 | if enc == 'sato': 48 | scores = [] 49 | querySherlock = query[1][:, :1187] 50 | querySato = query[1][0, 1187:] 51 | for table in candidates: 52 | sherlock = table[1][:, :1187] 53 | sato = table[1][0, 1187:] 54 | sScore = self._verify(querySherlock, sherlock, threshold) 55 | sherlockScore = (1/min(len(querySherlock), len(sherlock))) * sScore 56 | satoScore = self._cosine_sim(querySato, sato) 57 | score = sherlockScore + satoScore 58 | scores.append((score, table[0])) 59 | else: # encoder is sherlock 60 | scores = [(self._verify(query[1], table[1], threshold), table[0]) for table in candidates] 61 | scores.sort(reverse=True) 62 | scoreLength = len(scores) 63 | return scores[:K], scoreLength 64 | 65 | def _preprocess_table_hnsw(self): 66 | all_columns = [] 67 | col_table_ids = [] 68 | for idx,table in enumerate(self.tables): 69 | for col in table[1]: 70 | all_columns.append(col) 71 | col_table_ids.append(idx) 72 | return all_columns, col_table_ids 73 | 74 | def _find_candidates(self,query_cols, N): 75 | table_subs = set() 76 | labels, _ = self.index.knn_query(query_cols, k=N) 77 | for result in labels: 78 | # result: list of subscriptions of column vector 79 | for idx in result: 80 | table_subs.add(self.col_table_ids[idx]) 81 | candidates = [] 82 | for tid in table_subs: 83 | candidates.append(self.tables[tid]) 84 | return candidates 85 | 86 | def _cosine_sim(self, vec1, vec2): 87 | assert vec1.ndim == vec2.ndim 88 | return np.dot(vec1, vec2) / (norm(vec1)*norm(vec2)) 89 | 90 | def _verify(self, table1, table2, threshold): 91 | score = 0.0 92 | nrow = len(table1) 93 | ncol = len(table2) 94 | graph = np.zeros(shape=(nrow,ncol),dtype=float) 95 | for i in range(nrow): 96 | for j in range(ncol): 97 | sim = self._cosine_sim(table1[i],table2[j]) 98 | if sim > threshold: 99 | graph[i,j] = sim 100 | 101 | max_graph = make_cost_matrix(graph, lambda cost: (graph.max() - cost) if (cost != DISALLOWED) else DISALLOWED) 102 | m = Munkres() 103 | indexes = m.compute(max_graph) 104 | for row,col in indexes: 105 | score += graph[row,col] 106 | return score -------------------------------------------------------------------------------- /lsh.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import sys 4 | from tqdm import tqdm 5 | from iteration_utilities import duplicates, flatten, unique_everseen 6 | 7 | class CosineLSH(object): 8 | 9 | def __init__(self, num_funcs, dim, num_tables=100): 10 | self.num_funcs = num_funcs 11 | self.base_vectors = [np.random.randn( 12 | num_funcs, dim) for i in range(num_tables)] 13 | self.base_vector = np.vstack(self.base_vectors) 14 | self.num_tables = num_tables 15 | self.hash_table = np.empty((2**num_funcs * num_tables,), object) 16 | self.hash_table[...] = [[] for _ in range(2**num_funcs * num_tables)] 17 | self.dim = dim 18 | self.vectors = None 19 | self.current_idx = 0 20 | self.names = [] 21 | 22 | def index_one(self, vector, name): 23 | for hash_table_idx, base_vector in enumerate(self.base_vectors): 24 | index = vector.dot(base_vector.T) > 0 25 | index = (2**np.array(range(self.num_funcs)) * index).sum() 26 | relative_index = hash_table_idx * 2 ** self.num_funcs + index 27 | self.hash_table[relative_index].append(self.current_idx) 28 | self.names.append(name) 29 | if type(self.vectors) == type(None): 30 | self.vectors = vector 31 | else: 32 | self.vectors = np.vstack([self.vectors, vector]) 33 | 34 | def index_batch(self, vectors, names): 35 | idxs = range(self.current_idx, self.current_idx+ vectors.shape[0]) 36 | for hash_table_idx, base_vector in tqdm(enumerate(self.base_vectors), total = self.num_tables): 37 | indices = vectors.dot(base_vector.T) > 0 38 | indices = indices.dot(2 ** np.array(range(self.num_funcs))) 39 | for index, idx in zip(indices, idxs): 40 | relative_index = hash_table_idx * 2 ** self.num_funcs + index 41 | self.hash_table[relative_index].append(idx) 42 | self.current_idx += vectors.shape[0] 43 | self.names += names 44 | if type(self.vectors) == type(None): 45 | self.vectors = vectors 46 | else: 47 | self.vectors = np.vstack([self.vectors, vectors]) 48 | 49 | def get_size(self): 50 | # Get the memory size of the vectors 51 | vector_size = sys.getsizeof(self.vectors) 52 | return (vector_size)/1000000 53 | 54 | 55 | def query(self, vector, N=10, radius=1): 56 | res_indices = [] 57 | indices = vector.dot(self.base_vector.T).reshape(self.num_tables,-1) > 0 58 | if radius == 0: 59 | res_indices = indices.dot(2**np.arange(self.num_funcs)) + np.arange(self.num_tables) * 2**self.num_funcs 60 | elif radius == 1: 61 | clone_indices = indices.repeat(axis=0,repeats= self.num_funcs) 62 | rel_indices = (np.arange(self.num_tables) * 2**self.num_funcs).repeat(axis=0,repeats=self.num_funcs) 63 | translate = np.tile(np.eye(self.num_funcs), (self.num_tables,1)) 64 | res_indices = (np.abs(clone_indices-translate).dot(2**np.arange(self.num_funcs)) + rel_indices).astype(int) 65 | res_indices = np.concatenate([res_indices, indices.dot(2**np.arange(self.num_funcs)) + np.arange(self.num_tables) * 2**self.num_funcs]) 66 | 67 | lst = self.hash_table[res_indices].tolist() 68 | 69 | res = list(unique_everseen(duplicates(flatten(lst)))) 70 | sim_scores = vector.dot(self.vectors[res].T) 71 | 72 | max_sim_indices = sim_scores.argsort()[-N:][::-1] 73 | max_sim_scores = sim_scores[max_sim_indices] 74 | 75 | return [self.names[res[i]] for i in max_sim_indices], [x for x in max_sim_scores] -------------------------------------------------------------------------------- /lsh_cmd.sh: -------------------------------------------------------------------------------- 1 | time python test_lsh.py \ 2 | --encoder cl \ 3 | --benchmark santosLarge \ 4 | --run_id 0 \ 5 | --num_func 8 \ 6 | --num_table 100 \ 7 | --K 60 \ 8 | --scal 1.0 -------------------------------------------------------------------------------- /lsh_search.py: -------------------------------------------------------------------------------- 1 | 2 | from locale import currency 3 | import numpy as np 4 | import random 5 | import pickle 6 | import os 7 | import time 8 | import sys 9 | 10 | from munkres import Munkres, make_cost_matrix, DISALLOWED 11 | from numpy.linalg import norm 12 | from lsh import CosineLSH 13 | 14 | 15 | class LSHSearcher(object): 16 | def __init__(self, 17 | table_path, 18 | hash_func_num, 19 | hash_table_num, 20 | scale 21 | ): 22 | tfile = open(table_path,"rb") 23 | tables = pickle.load(tfile) # load a percentage of tables 24 | self.tables = random.sample(tables, int(scale*len(tables))) 25 | print("From %d total data-lake tables, scale down to %d tables" % (len(tables), len(self.tables))) 26 | tfile.close() 27 | print("hash_func_num: ", hash_func_num, "hash_table_num: ", hash_table_num) 28 | index_start_time = time.time() 29 | self.vec_dim = len(self.tables[1][1][0]) 30 | self.all_columns, self.col_table_ids = self._preprocess_table_lsh() 31 | self.lsh = CosineLSH(hash_func_num, self.vec_dim, hash_table_num) 32 | self.lsh.index_batch(self.all_columns, range(self.all_columns.shape[0])) 33 | print("--- Indexing Time: %s seconds ---" % (time.time() - index_start_time)) 34 | print("--- Size of LSH index %s MB ---" % (self.lsh.get_size())) 35 | # print("--- Size of LSH index %s MB (numpy nbytes) ---" % (self.lsh.nbytes)*1000000) 36 | 37 | def topk(self, enc, query, K, N=5, threshold=0.6): 38 | # Note: N is the number of columns retrieved from the index 39 | query_cols = [] 40 | for col in query[1]: 41 | query_cols.append(col) 42 | candidates = self._find_candidates(query_cols, N) 43 | if enc == 'sato': 44 | scores = [] 45 | querySherlock = query[1][:, :1187] 46 | querySato = query[1][0, 1187:] 47 | for table in candidates: 48 | sherlock = table[1][:, :1187] 49 | sato = table[1][0, 1187:] 50 | sScore = self._verify(querySherlock, sherlock, threshold) 51 | sherlockScore = (1/min(len(querySherlock), len(sherlock))) * sScore 52 | satoScore = self._cosine_sim(querySato, sato) 53 | score = sherlockScore + satoScore 54 | scores.append((score, table[0])) 55 | else: # encoder is sherlock 56 | scores = [(self._verify(query[1], table[1], threshold), table[0]) for table in candidates] 57 | scores.sort(reverse=True) 58 | scoreLength = len(scores) 59 | return scores[:K], scoreLength 60 | 61 | def _preprocess_table_lsh(self): 62 | all_columns = [] 63 | col_table_ids = [] 64 | for idx,table in enumerate(self.tables): 65 | for col in table[1]: 66 | all_columns.append(col) 67 | col_table_ids.append(idx) 68 | all_columns = np.asarray(all_columns) 69 | return all_columns, col_table_ids 70 | 71 | def _find_candidates(self,query_cols, N): 72 | table_subs = set() 73 | for col in query_cols: 74 | result, _ = self.lsh.query(col, N) 75 | for idx in result: 76 | table_subs.add(self.col_table_ids[idx]) 77 | candidates = [] 78 | for tid in table_subs: 79 | candidates.append(self.tables[tid]) 80 | return candidates 81 | 82 | def _cosine_sim(self, vec1, vec2): 83 | assert vec1.ndim == vec2.ndim 84 | return np.dot(vec1, vec2) / (norm(vec1)*norm(vec2)) 85 | 86 | def _verify(self, table1, table2, threshold): 87 | score = 0.0 88 | nrow = len(table1) 89 | ncol = len(table2) 90 | graph = np.zeros(shape=(nrow,ncol),dtype=float) 91 | for i in range(nrow): 92 | for j in range(ncol): 93 | sim = self._cosine_sim(table1[i],table2[j]) 94 | if sim > threshold: 95 | graph[i,j] = sim 96 | 97 | max_graph = make_cost_matrix(graph, lambda cost: (graph.max() - cost) if (cost != DISALLOWED) else DISALLOWED) 98 | m = Munkres() 99 | indexes = m.compute(max_graph) 100 | for row,col in indexes: 101 | score += graph[row,col] 102 | return score -------------------------------------------------------------------------------- /naive_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import random 4 | import heapq 5 | from munkres import Munkres, make_cost_matrix, DISALLOWED 6 | from numpy.linalg import norm 7 | from bounds import verify, upper_bound_bm, lower_bound_bm, get_edges 8 | 9 | class NaiveSearcher(object): 10 | def __init__(self, 11 | table_path, 12 | scale, 13 | index_path=None 14 | ): 15 | if index_path != None: 16 | self.index_path = index_path 17 | 18 | # load tables to be queried 19 | tfile = open(table_path,"rb") 20 | tables = pickle.load(tfile) 21 | # For scalability experiments: load a percentage of tables 22 | self.tables = random.sample(tables, int(scale*len(tables))) 23 | print("From %d total data-lake tables, scale down to %d tables" % (len(tables), len(self.tables))) 24 | tfile.close() 25 | 26 | def topk(self, enc, query, K, threshold=0.6): 27 | ''' Exact top-k cosine similarity with full bipartite matching 28 | Args: 29 | enc (str): choice of encoder (e.g. 'sato', 'cl', 'sherlock') -- mainly to check if the encoder is 'sato' 30 | query: the query, where query[0] is the query filename, and query[1] is the set of column vectors 31 | K (int): choice of K 32 | threshold (float): similarity threshold. For small SANTOS benchmark, we use threshold=0.7. For the larger benchmarks, threshold=0.1 33 | Return: 34 | Tables with top-K scores 35 | ''' 36 | if enc == 'sato': 37 | # For SATO encoder, the first 1187 items in the vector are from Sherlock. The rest are from topic modeling 38 | scores = [] 39 | querySherlock = query[1][:, :1187] 40 | querySato = query[1][0, 1187:] 41 | for table in self.tables: 42 | sherlock = table[1][:, :1187] 43 | sato = table[1][0, 1187:] 44 | sScore = self._verify(querySherlock, sherlock, threshold) 45 | sherlockScore = (1/min(len(querySherlock), len(sherlock))) * sScore 46 | satoScore = self._cosine_sim(querySato, sato) 47 | score = sherlockScore + satoScore 48 | scores.append((score, table[0])) 49 | else: 50 | scores = [(self._verify(query[1], table[1], threshold), table[0]) for table in self.tables] 51 | scores.sort(reverse=True) 52 | return scores[:K] 53 | 54 | def topk_bounds(self, enc, query, K, threshold=0.6): 55 | ''' Algorithm: Pruning with Bounds 56 | Bounds Techique: reduce # of verification calls 57 | Args: 58 | enc (str): choice of encoder (e.g. 'sato', 'cl', 'sherlock') -- mainly to check if the encoder is 'sato' 59 | query: the query, where query[0] is the query filename, and query[1] is the set of column vectors 60 | K (int): choice of K 61 | threshold (float): similarity threshold. For small SANTOS benchmark, we use threshold=0.7. For the larger benchmarks, threshold=0.1 62 | Return: 63 | Tables with top-K scores 64 | ''' 65 | H = [] 66 | heapq.heapify(H) 67 | if enc == 'sato': 68 | querySherlock = query[1][:, :1187] 69 | querySato = query[1][0, 1187:] 70 | satoScore = 0.0 71 | for table in self.tables: 72 | # get sherlock and sato components if the encoder is 'sato 73 | if enc == 'sato': 74 | tScore = table[1][:, :1187] 75 | qScore = querySherlock 76 | sato = table[1][0, 1187:] 77 | satoScore = self._cosine_sim(querySato, sato) 78 | else: 79 | tScore = table[1] 80 | qScore = query[1] 81 | 82 | # add to heap to get len(H) = K 83 | if len(H) < K: # len(H) = number of elements in H 84 | score = verify(qScore, tScore, threshold) 85 | if enc == 'sato': score = self._combine_sherlock_sato(score, qScore, tScore, satoScore) 86 | heapq.heappush(H, (score, table[0])) 87 | else: 88 | topScore = H[0] 89 | # Helper method from bounds.py for to reduce the cost of the graph 90 | edges, nodes1, nodes2 = get_edges(qScore, tScore, threshold) 91 | lb = lower_bound_bm(edges, nodes1, nodes2) 92 | ub = upper_bound_bm(edges, nodes1, nodes2) 93 | if enc == 'sato': 94 | lb = self._combine_sherlock_sato(lb, qScore, tScore, satoScore) 95 | ub = self._combine_sherlock_sato(ub, qScore, tScore, satoScore) 96 | 97 | if lb > topScore[0]: 98 | heapq.heappop(H) 99 | score = verify(qScore, tScore, threshold) 100 | if enc == 'sato': score = self._combine_sherlock_sato(score, qScore, tScore, satoScore) 101 | heapq.heappush(H, (score, table[0])) 102 | elif ub >= topScore[0]: 103 | score = verify(qScore, tScore, threshold) 104 | if enc == 'sato': score = self._combine_sherlock_sato(score, qScore, tScore, satoScore) 105 | if score > topScore[0]: 106 | heapq.heappop(H) 107 | heapq.heappush(H, (score, table[0])) 108 | scores = [] 109 | while len(H) > 0: 110 | scores.append(heapq.heappop(H)) 111 | scores.sort(reverse=True) 112 | return scores 113 | 114 | 115 | def _combine_sherlock_sato(self, score, qScore, tScore, satoScore): 116 | ''' Helper method for topk_bounds() to calculate sherlock and sato scores, if the encoder is SATO 117 | ''' 118 | sherlockScore = (1/min(len(qScore), len(tScore))) * score 119 | full_satoScore = sherlockScore + satoScore 120 | return full_satoScore 121 | 122 | def topk_greedy(self, enc, query, K, threshold=0.6): 123 | ''' Greedy algorithm for matching 124 | Args: 125 | enc (str): choice of encoder (e.g. 'sato', 'cl', 'sherlock') -- mainly to check if the encoder is 'sato' 126 | query: the query, where query[0] is the query filename, and query[1] is the set of column vectors 127 | K (int): choice of K 128 | threshold (float): similarity threshold. For small SANTOS benchmark, we use threshold=0.7. For the larger benchmarks, threshold=0.1 129 | Return: 130 | Tables with top-K scores 131 | ''' 132 | if enc == 'sato': 133 | scores = [] 134 | querySherlock = query[1][:, :1187] 135 | querySato = query[1][0, 1187:] 136 | for table in self.tables: 137 | sherlock = table[1][:, :1187] 138 | sato = table[1][0, 1187:] 139 | sScore = self._verify_greedy(querySherlock, sherlock, threshold) 140 | sherlockScore = (1/min(len(querySherlock), len(sherlock))) * sScore 141 | satoScore = self._cosine_sim(querySato, sato) 142 | score = sherlockScore + satoScore 143 | scores.append((score, table[0])) 144 | else: # encoder is sherlock 145 | scores = [(self._verify_greedy(query[1], table[1], threshold), table[0]) for table in self.tables] 146 | scores.sort(reverse=True) 147 | return scores[:K] 148 | 149 | def _cosine_sim(self, vec1, vec2): 150 | ''' Get the cosine similarity of two input vectors: vec1 and vec2 151 | ''' 152 | assert vec1.ndim == vec2.ndim 153 | return np.dot(vec1, vec2) / (norm(vec1)*norm(vec2)) 154 | 155 | def _verify(self, table1, table2, threshold): 156 | score = 0.0 157 | nrow = len(table1) 158 | ncol = len(table2) 159 | graph = np.zeros(shape=(nrow,ncol),dtype=float) 160 | for i in range(nrow): 161 | for j in range(ncol): 162 | sim = self._cosine_sim(table1[i],table2[j]) 163 | if sim > threshold: 164 | graph[i,j] = sim 165 | 166 | max_graph = make_cost_matrix(graph, lambda cost: (graph.max() - cost) if (cost != DISALLOWED) else DISALLOWED) 167 | m = Munkres() 168 | indexes = m.compute(max_graph) 169 | for row,col in indexes: 170 | score += graph[row,col] 171 | return score 172 | 173 | def _verify_greedy(self, table1, table2, threshold): 174 | nodes1 = set() 175 | nodes2 = set() 176 | score = 0.0 177 | nrow = len(table1) 178 | ncol = len(table2) 179 | edges = [] 180 | for i in range(nrow): 181 | for j in range(ncol): 182 | sim = self._cosine_sim(table1[i],table2[j]) 183 | if sim > threshold: 184 | edges.append((sim,i,j)) 185 | nodes1.add(i) 186 | nodes2.add(j) 187 | edges.sort(reverse=True) 188 | for e in edges: 189 | score += e[0] 190 | nodes1.discard(e[1]) 191 | nodes2.discard(e[2]) 192 | if len(nodes1) == 0 or len(nodes2) == 0: 193 | return score 194 | return score -------------------------------------------------------------------------------- /notebook/offline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "index_len = 80345\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "# load tables\n", 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import pickle\n", 21 | "import os\n", 22 | "\n", 23 | "from tqdm import tqdm\n", 24 | "\n", 25 | "\n", 26 | "# re-build indices\n", 27 | "sato_path = \"/nfs/users/yuliang/ssl-em/column_type_detection/data\"\n", 28 | "index = {}\n", 29 | "max_len = 64\n", 30 | "\n", 31 | "for sid in range(5):\n", 32 | " path = os.path.join(sato_path, \"sato_cv_%d.csv\" % sid)\n", 33 | " df = pd.read_csv(path)\n", 34 | "\n", 35 | " for table_id, col_idx, data, cls in zip(df['table_id'], df['col_idx'], df['data'], df['class']):\n", 36 | " tokens = data.split(' ')\n", 37 | " data = ' '.join(tokens[:max_len])\n", 38 | " index[data] = (table_id, col_idx, cls)\n", 39 | "\n", 40 | "print('index_len =', len(index))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 7, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "100%|██████████| 5000/5000 [00:13<00:00, 383.83it/s]\n", 53 | "100%|██████████| 2500/2500 [00:05<00:00, 417.63it/s]\n", 54 | "100%|██████████| 2500/2500 [00:04<00:00, 533.09it/s]\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "dataset_path = \"/nfs/users/yuliang/ssl-em/column_type_detection\"\n", 60 | "\n", 61 | "datasets = []\n", 62 | "all_tables = {}\n", 63 | "table_lists = []\n", 64 | "\n", 65 | "for fn in ['train.txt', 'valid.txt', 'test.txt']:\n", 66 | " path = os.path.join(dataset_path, fn)\n", 67 | " rows = []\n", 68 | " labels = []\n", 69 | " output_df = {'l_table_id': [],\n", 70 | " 'r_table_id': [],\n", 71 | " 'l_column_id': [],\n", 72 | " 'r_column_id': [],\n", 73 | " 'l_ori_table_id': [],\n", 74 | " 'r_ori_table_id': [],\n", 75 | " 'l_column_type': [],\n", 76 | " 'r_column_type': [],\n", 77 | " 'match': []}\n", 78 | "\n", 79 | " for line in tqdm(open(path).readlines()):\n", 80 | " left, right, label = line.strip().split('\\t')\n", 81 | " labels.append(label)\n", 82 | "\n", 83 | " features = []\n", 84 | " for text, prefix in zip([left, right], [\"l_\", \"r_\"]):\n", 85 | " ori_table_id = index[text][0]\n", 86 | " col_idx = index[text][1]\n", 87 | " cls = index[text][2]\n", 88 | " if ori_table_id not in all_tables:\n", 89 | " table_path = os.path.join(\"/nfs/users/yuliang/table_data/viznet_tables/\", index[text][0])\n", 90 | " df = pd.read_csv(table_path, index_col=[0])\n", 91 | " # new table id and the DataFrame\n", 92 | " all_tables[ori_table_id] = (len(all_tables), df)\n", 93 | " table_lists.append(df)\n", 94 | " \n", 95 | " table_id, df = all_tables[ori_table_id]\n", 96 | " output_df[prefix + \"table_id\"].append(table_id)\n", 97 | " output_df[prefix + \"column_id\"].append(col_idx)\n", 98 | " output_df[prefix + \"ori_table_id\"].append(ori_table_id)\n", 99 | " output_df[prefix + \"column_type\"].append(cls)\n", 100 | " \n", 101 | " output_df['match'].append(1 if output_df['l_column_type'][-1] == output_df['r_column_type'][-1] else 0)\n", 102 | " \n", 103 | " # output\n", 104 | " output_path = os.path.join(dataset_path, fn.replace('.txt', '.csv'))\n", 105 | " output_df = pd.DataFrame(output_df)\n", 106 | " output_df.to_csv(output_path, index=False)\n", 107 | "\n", 108 | "# output all_tables\n", 109 | "for idx, table in enumerate(table_lists):\n", 110 | " output_path = os.path.join(dataset_path, 'table_%d.csv' % idx)\n", 111 | " table.to_csv(output_path, index=False)\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Preprocess the viznet dataset" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 39, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stderr", 128 | "output_type": "stream", 129 | "text": [ 130 | "100%|██████████| 23820/23820 [00:20<00:00, 1151.86it/s]\n", 131 | "100%|██████████| 23877/23877 [00:17<00:00, 1346.04it/s]\n", 132 | "100%|██████████| 23893/23893 [00:17<00:00, 1355.23it/s]\n", 133 | "100%|██████████| 23783/23783 [00:15<00:00, 1508.37it/s]\n", 134 | "100%|██████████| 23987/23987 [00:14<00:00, 1677.27it/s]\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "dataset_path = \"/nfs/users/yuliang/SDD/data/viznet\"\n", 140 | "sato_path = \"/nfs/users/yuliang/ssl-em/column_type_detection/data\"\n", 141 | "\n", 142 | "datasets = []\n", 143 | "all_tables = {}\n", 144 | "table_lists = []\n", 145 | "all_columns = {'table_id': [],\n", 146 | " 'ori_table_id': [],\n", 147 | " 'column_id': [],\n", 148 | " 'class': []}\n", 149 | "\n", 150 | "# re-build indices\n", 151 | "index = {}\n", 152 | "max_len = 64\n", 153 | "\n", 154 | "for sid in range(5):\n", 155 | " path = os.path.join(sato_path, \"sato_cv_%d.csv\" % sid)\n", 156 | " df = pd.read_csv(path)\n", 157 | "\n", 158 | " for table_id, col_idx, data, cls in zip(df['table_id'], df['col_idx'], df['data'], df['class']):\n", 159 | " tokens = data.split(' ')\n", 160 | " data = ' '.join(tokens[:max_len])\n", 161 | " index[data] = (table_id, col_idx, cls)\n", 162 | "\n", 163 | "\n", 164 | "for sid in range(5):\n", 165 | " path = os.path.join(sato_path, \"sato_cv_%d.csv\" % sid)\n", 166 | " df = pd.read_csv(path)\n", 167 | "\n", 168 | " for data, cls in tqdm(zip(df['data'], df['class']), total=len(df)):\n", 169 | " tokens = data.split(' ')\n", 170 | " data = ' '.join(tokens[:max_len])\n", 171 | " \n", 172 | " ori_table_id = index[data][0]\n", 173 | " col_idx = index[data][1]\n", 174 | "\n", 175 | " if ori_table_id not in all_tables:\n", 176 | " table_path = os.path.join(\"/nfs/users/yuliang/table_data/viznet_tables/\", ori_table_id)\n", 177 | " df = pd.read_csv(table_path, index_col=[0])\n", 178 | " # new table id and the DataFrame\n", 179 | " table_id = len(all_tables)\n", 180 | " all_tables[ori_table_id] = (len(all_tables), df)\n", 181 | " table_lists.append(df)\n", 182 | " else:\n", 183 | " table_id = all_tables[ori_table_id][0]\n", 184 | " \n", 185 | " all_columns['table_id'].append(table_id)\n", 186 | " all_columns['ori_table_id'].append(ori_table_id)\n", 187 | " all_columns['column_id'].append(col_idx)\n", 188 | " all_columns['class'].append(cls)\n", 189 | "\n", 190 | "all_columns = pd.DataFrame(all_columns) \n", 191 | "all_columns.to_csv(os.path.join(dataset_path, 'test.csv'), index=False)\n", 192 | "\n", 193 | "# output all_tables\n", 194 | "for idx, table in enumerate(table_lists):\n", 195 | " output_path = os.path.join(dataset_path, 'tables', 'table_%d.csv' % idx)\n", 196 | " table.to_csv(output_path, index=False)\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## Column clustering" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 2, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "import os\n", 213 | "import pickle\n", 214 | "\n", 215 | "path = '../python/'\n", 216 | "\n", 217 | "# load data\n", 218 | "column_vectors, labels = pickle.load(open(\"../data/viznet/multi_column/column_vectors.pkl\", \"rb\"))\n", 219 | "# pairs = pickle.load(open(\"../python/column_pairs.pkl\", \"rb\"))" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 88, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stderr", 229 | "output_type": "stream", 230 | "text": [ 231 | " 17%|█▋ | 20035/119360 [00:00<00:01, 96277.52it/s] " 232 | ] 233 | }, 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "error: 3 2\n" 239 | ] 240 | }, 241 | { 242 | "name": "stderr", 243 | "output_type": "stream", 244 | "text": [ 245 | " 41%|████ | 48508/119360 [00:00<00:00, 92311.08it/s]" 246 | ] 247 | }, 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "error: 3 2\n", 253 | "error: 3 2\n", 254 | "error: 3 3\n" 255 | ] 256 | }, 257 | { 258 | "name": "stderr", 259 | "output_type": "stream", 260 | "text": [ 261 | " 65%|██████▍ | 77231/119360 [00:00<00:00, 94977.64it/s]" 262 | ] 263 | }, 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "error: 3 2\n", 269 | "error: 3 3\n" 270 | ] 271 | }, 272 | { 273 | "name": "stderr", 274 | "output_type": "stream", 275 | "text": [ 276 | " 81%|████████▏ | 97221/119360 [00:01<00:00, 97580.35it/s]" 277 | ] 278 | }, 279 | { 280 | "name": "stdout", 281 | "output_type": "stream", 282 | "text": [ 283 | "error: 3 3\n", 284 | "error: 3 2\n", 285 | "error: 3 2\n" 286 | ] 287 | }, 288 | { 289 | "name": "stderr", 290 | "output_type": "stream", 291 | "text": [ 292 | "100%|██████████| 119360/119360 [00:01<00:00, 95355.72it/s]\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "# sherlock and sato\n", 298 | "import pandas as pd\n", 299 | "import numpy as np\n", 300 | "from tqdm import tqdm\n", 301 | "\n", 302 | "testset = pd.read_csv(\"../data/viznet/test.csv.full\")\n", 303 | "sherlock_sato_features = pickle.load(open('sato/sato_features.pkl', 'rb'))\n", 304 | "\n", 305 | "sherlock_features = []\n", 306 | "sato_features = []\n", 307 | "idx = 0\n", 308 | "\n", 309 | "# for table_id, column_id in tqdm(zip(testset['table_id'], testset['column_id']), total=len(testset)):\n", 310 | "# num_column = len(sherlock_sato_features[idx][0])\n", 311 | "# real_num_column = len(pd.read_csv(\"../data/viznet/tables/table_%d.csv\" % table_id).columns)\n", 312 | "# print(idx, num_column, real_num_column)\n", 313 | "# idx += 1\n", 314 | "\n", 315 | "\n", 316 | "for table_id, column_id in tqdm(zip(testset['table_id'], testset['column_id']), total=len(testset)):\n", 317 | " if column_id >= len(sherlock_sato_features[idx][0]):\n", 318 | " print(\"error: \", column_id, len(sherlock_sato_features[idx][0]))\n", 319 | " column_id = len(sherlock_sato_features[idx][0]) - 1\n", 320 | "\n", 321 | " sherlock_feature = sherlock_sato_features[idx][0][column_id]\n", 322 | " sato_feature = np.concatenate([sherlock_feature, sherlock_sato_features[idx][1]])\n", 323 | " sherlock_features.append(sherlock_feature)\n", 324 | " sato_features.append(sato_feature)\n", 325 | " idx += 1\n", 326 | "\n", 327 | "pickle.dump(sherlock_features, open(\"../data/viznet/sherlock/column_vectors.pkl\", \"wb\"))\n", 328 | "pickle.dump(sato_features, open(\"../data/viznet/sato/column_vectors.pkl\", \"wb\"))" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 16, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "import numpy as np\n", 338 | "import pickle\n", 339 | "\n", 340 | "from tqdm import tqdm\n", 341 | "from collections import deque, Counter\n", 342 | "\n", 343 | "\n", 344 | "def blocked_matmul(mata, matb,\n", 345 | " threshold=None,\n", 346 | " k=None,\n", 347 | " batch_size=512):\n", 348 | " \"\"\"Find the most similar pairs of vectors from two matrices (top-k or threshold)\n", 349 | "\n", 350 | " Args:\n", 351 | " mata (np.ndarray): the first matrix\n", 352 | " matb (np.ndarray): the second matrix\n", 353 | " threshold (float, optional): if set, return all pairs of cosine\n", 354 | " similarity above the threshold\n", 355 | " k (int, optional): if set, return for each row in matb the top-k\n", 356 | " most similar vectors in mata\n", 357 | " batch_size (int, optional): the batch size of each block\n", 358 | " \n", 359 | " Returns:\n", 360 | " list of tuples: the pairs of similar vectors' indices and the similarity\n", 361 | " \"\"\"\n", 362 | " mata = np.array(mata)\n", 363 | " matb = np.array(matb)\n", 364 | " results = []\n", 365 | " for start in tqdm(range(0, len(matb), batch_size)):\n", 366 | " block = matb[start:start+batch_size]\n", 367 | " sim_mat = np.matmul(mata, block.transpose())\n", 368 | " if k is not None:\n", 369 | " indices = np.argpartition(-sim_mat, k, axis=0)\n", 370 | " for row in indices[:k]:\n", 371 | " for idx_b, idx_a in enumerate(row):\n", 372 | " idx_b += start\n", 373 | " results.append((idx_a, idx_b, sim_mat[idx_a][idx_b-start]))\n", 374 | " elif threshold is not None:\n", 375 | " indices = np.argwhere(sim_mat >= threshold)\n", 376 | " for idx_a, idx_b in indices:\n", 377 | " idx_b += start\n", 378 | " results.append((idx_a, idx_b, sim_mat[idx_a][idx_b-start]))\n", 379 | " return results\n", 380 | "\n", 381 | "\n", 382 | "def connected_components(pairs, cluster_size=50):\n", 383 | " \"\"\"Helper function for computing the connected components\n", 384 | " \"\"\"\n", 385 | " edges = {}\n", 386 | " pairs.sort(key=lambda x: x[2], reverse=True)\n", 387 | " for left, right, _ in pairs:\n", 388 | " if left not in edges:\n", 389 | " edges[left] = []\n", 390 | " if right not in edges:\n", 391 | " edges[right] = []\n", 392 | " \n", 393 | " edges[left].append(right)\n", 394 | " edges[right].append(left)\n", 395 | " \n", 396 | " # print('num nodes =', len(edges))\n", 397 | " all_ccs = []\n", 398 | " used = set([])\n", 399 | " for start in edges:\n", 400 | " if start in used:\n", 401 | " continue\n", 402 | " used.add(start)\n", 403 | " cc = [start]\n", 404 | " \n", 405 | " queue = deque([start])\n", 406 | " while len(queue) > 0:\n", 407 | " u = queue.popleft()\n", 408 | " for v in edges[u]:\n", 409 | " if v not in used:\n", 410 | " cc.append(v)\n", 411 | " used.add(v)\n", 412 | " queue.append(v)\n", 413 | " if len(cc) >= cluster_size:\n", 414 | " break\n", 415 | " \n", 416 | " if len(cc) >= cluster_size:\n", 417 | " break\n", 418 | " \n", 419 | " all_ccs.append(cc)\n", 420 | " # print(cc)\n", 421 | " return all_ccs\n", 422 | "\n", 423 | "\n", 424 | "def evaluate_clustering(vectors, labels):\n", 425 | " \"\"\"Evaluate column clustering on input column vectors.\n", 426 | " \"\"\"\n", 427 | " # normalize the vectors\n", 428 | " vectors = np.array(vectors)\n", 429 | " vectors /= np.linalg.norm(vectors, axis=-1)[:, np.newaxis]\n", 430 | "\n", 431 | " # top k matching columns\n", 432 | " pairs = blocked_matmul(vectors, vectors,\n", 433 | " k=20,\n", 434 | " batch_size=4096)\n", 435 | "\n", 436 | " # dump the clustering results\n", 437 | " pickle.dump(pairs, open('column_pairs.pkl', 'wb'))\n", 438 | "\n", 439 | " # run column clustering algorithm\n", 440 | " ccs = connected_components(pairs)\n", 441 | "\n", 442 | " # dump the clustering results\n", 443 | " pickle.dump(ccs, open('clusters.pkl', 'wb'))\n", 444 | "\n", 445 | " # compute purity\n", 446 | " purity = []\n", 447 | " total = 0\n", 448 | " for cc in ccs:\n", 449 | " cnt = Counter()\n", 450 | " for column_id in cc:\n", 451 | " label = labels[column_id]\n", 452 | " cnt[label] += 1\n", 453 | " purity.append(cnt.most_common(1)[0][1])\n", 454 | " total += len(cc)\n", 455 | " purity = np.sum(purity) / total\n", 456 | "\n", 457 | " return {\"num_clusters\": len(ccs), \n", 458 | " \"avg_cluster_size\": np.mean([len(cc) for cc in ccs]),\n", 459 | " \"purity\": purity}\n", 460 | "\n", 461 | "# for method in ['sherlock', 'sato', 'single_column', 'multi_column']:\n", 462 | "# column_vectors = pickle.load(open('../data/viznet/%s/column_vectors.pkl' % method, \"rb\"))\n", 463 | "# res = evaluate_clustering(column_vectors, labels)\n", 464 | "# print(res)\n", 465 | "# os.system('mv *.pkl ../data/viznet/%s/' % method)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 21, 471 | "metadata": {}, 472 | "outputs": [ 473 | { 474 | "name": "stdout", 475 | "output_type": "stream", 476 | "text": [ 477 | "{'method': 'sato', 'num_clusters': 2456, 'avg_cluster_size': 48.59934853420195, 'purity': 0.37356735924932977}\n", 478 | "{'method': 'sherlock', 'num_clusters': 2395, 'avg_cluster_size': 49.83716075156576, 'purity': 0.3050351876675603}\n", 479 | "{'method': 'multi_column', 'num_clusters': 2297, 'avg_cluster_size': 51.96343056160209, 'purity': 0.5118800268096515}\n", 480 | "{'method': 'single_column', 'num_clusters': 9252, 'avg_cluster_size': 12.900994379593602, 'purity': 0.20379524128686327}\n" 481 | ] 482 | } 483 | ], 484 | "source": [ 485 | "def compute_purity(ccs):\n", 486 | " purity = []\n", 487 | " total = 0\n", 488 | " for cc in ccs:\n", 489 | " cnt = Counter()\n", 490 | " for column_id in cc:\n", 491 | " label = labels[column_id]\n", 492 | " cnt[label] += 1\n", 493 | " purity.append(cnt.most_common(1)[0][1])\n", 494 | " total += len(cc)\n", 495 | " purity = np.sum(purity) / total\n", 496 | " return purity\n", 497 | "\n", 498 | "def tune_cluster_size(pairs, target=50):\n", 499 | " left = 0\n", 500 | " right = 5000\n", 501 | " min_diff = 1e6\n", 502 | " res_ccs = []\n", 503 | "\n", 504 | " while right - left > 10:\n", 505 | " mid = (left + right) // 2\n", 506 | " ccs = connected_components(pairs, cluster_size=mid)\n", 507 | " avg_size = np.mean([len(cc) for cc in ccs])\n", 508 | " if abs(avg_size - target) < min_diff:\n", 509 | " min_diff = abs(avg_size - target)\n", 510 | " res_ccs = ccs\n", 511 | "\n", 512 | " # print(mid, avg_size)\n", 513 | " if avg_size > target:\n", 514 | " right = mid\n", 515 | " else:\n", 516 | " left = mid\n", 517 | " # purity = compute_purity(ccs)\n", 518 | " \n", 519 | " purity = compute_purity(res_ccs)\n", 520 | " return res_ccs, purity\n", 521 | "\n", 522 | "\n", 523 | "for model in ['sato', 'sherlock', 'multi_column', 'single_column']:\n", 524 | " pairs = pickle.load(open(\"../data/viznet/%s/column_pairs.pkl\" % model, \"rb\"))\n", 525 | " ccs, purity = tune_cluster_size(pairs)\n", 526 | " res = {\"method\": model,\n", 527 | " \"num_clusters\": len(ccs), \n", 528 | " \"avg_cluster_size\": np.mean([len(cc) for cc in ccs]),\n", 529 | " \"purity\": purity}\n", 530 | " print(res)\n", 531 | "\n", 532 | " # for cluster_size in [25, 50, 75, 100, 150, 200]:\n", 533 | " # ccs = connected_components(pairs, cluster_size=cluster_size)\n", 534 | "\n", 535 | " # # compute purity\n", 536 | " # purity = compute_purity(ccs)\n", 537 | " # res = {\"method\": model,\n", 538 | " # \"num_clusters\": len(ccs), \n", 539 | " # \"avg_cluster_size\": np.mean([len(cc) for cc in ccs]),\n", 540 | " # \"purity\": purity}\n", 541 | " # print(res)\n" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 2, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "artist ---- 1. I Don't Give A ...; 2. I'm The Kinda; 3. I U She; 4. Kick It [featuring Iggy Pop]; 5. Operate\n", 554 | "artist ---- 1. Spoken Intro; 2. The Court; 3. Maze; 4. Girl Talk; 5. A La Mode\n", 555 | "artist ---- 1. Street Fighting Man; 2. Gimme Shelter; 3. (I Can't Get No) Satisfaction; 4. The Last Time; 5. Jumpin' Jack Flash\n", 556 | "artist ---- 1. Angel of the Morning; 2. Shot Full of Love; 3. Ride 'Em Cowboys; 4. Queen of Hearts; 5. River of Love\n", 557 | "artist ---- 1. New Wave; 2. Up The Cuts; 3. Thrash Unreal; 4. White People For Peace; 5. Stop!\n", 558 | "artist ---- 1. Trigger Happy; 2. Sentimental Fool; 3. I Didn't Know That You Cared; 4. Love Ruins Everything; 5. Baby\n", 559 | "artist ---- 1. You; 2. Creep; 3. How Do You?; 4. Stop Whispering; 5. Thinking About You\n", 560 | "artist ---- 1. Buena; 2. Honey White; 3. You Speak My Language; 4. Cure for Pain; 5. Candy\n", 561 | "artist ---- 1. Mr. Grieves; 2. Crackity Jones; 3. La La Love You; 4. No. 13 Baby; 5. There Goes My Gun\n", 562 | "artist ---- 1. Street Fighting Man; 2. Gimme Shelter; 3. (I Can't Get No) Satisfaction; 4. The Last Time; 5. Jumpin' Jack Flash\n", 563 | "---------------------------------\n", 564 | "type ---- Emerson Elementary School; Banneker Elementary School; Silver City Elementary School; New Stanley Elementary School; Frances Willard Elementary School\n", 565 | "type ---- Choctawhatchee Senior High School; Fort Walton Beach High School; Ami Kids Emerald Coast; Gulf Coast Christian School; Adolescent Substance Abuse\n", 566 | "city ---- Chilton; Stoughton\n", 567 | "name ---- Crasnier-Mednansky, Martine; Park, Maxwell; Studley, William; Saier, Milton\n", 568 | "type ---- Roosevelt High School; Karen Wagner High School; Thompson Center\n", 569 | "type ---- Oak Park and River Forest High School; Harbor Academy Reg Safe Sch Prg; Fenwick High School; Trinity High School\n", 570 | "type ---- Emerson Elementary School; Banneker Elementary School; Silver City Elementary School; New Stanley Elementary School; Frances Willard Elementary School\n", 571 | "type ---- Roosevelt School; Gwendolyn Brooks Middle School; Percy Julian Middle School; St. Luke Catholic School; Learning Network\n", 572 | "type ---- Sumner Academy Of Arts and Science; Wyandotte High School; J C Harmon High School; School For Blind High; Bishop Ward High School\n", 573 | "type ---- Emerson Elementary School; Banneker Elementary School; Silver City Elementary School; New Stanley Elementary School; Frances Willard Elementary School\n", 574 | "---------------------------------\n", 575 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 576 | "day ---- Sept. 1; Sept. 7; Sept. 22; Sept. 29; Oct. 5\n", 577 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 578 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 579 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 580 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 581 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 582 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 583 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 584 | "description ---- Fri Sep 11,2015 3:30 PM (CST); Fri Sep 11,2015 6:00 PM (CST); Sat Sep 12,2015 10:00 AM (CST); Sat Sep 12,2015 12:00 PM (CST); Sat Sep 12,2015 5:30 PM (CST)\n", 585 | "---------------------------------\n", 586 | "name ---- People's Grocery Co-op Exchange; Prairieland Market; The Merc (Community Mercantile); Topeka Natural Food Co-op\n", 587 | "name ---- People's Grocery Co-op Exchange; Prairieland Market; The Merc (Community Mercantile); Topeka Natural Food Co-op\n", 588 | "name ---- People's Grocery Co-op Exchange; Prairieland Market; The Merc (Community Mercantile); Topeka Natural Food Co-op\n", 589 | "name ---- People's Grocery Co-op Exchange; Prairieland Market; The Merc (Community Mercantile); Topeka Natural Food Co-op\n", 590 | "name ---- Amazing Grains; BisMan Community Food Cooperative; Bowdon Locker & Grocery; Prairie Roots Food Co-op\n", 591 | "name ---- Apples Street Market; Bexley Natural Market; Clintonville Community Market; Kent Natural Foods Co-op; MOON Co-op Natural Foods Market\n", 592 | "name ---- Fiddleheads Food Co-op; Northwest Corner Co-op; The Local Beet Co-op; Willimantic Food Co-op\n", 593 | "name ---- Fiddleheads Food Co-op; Northwest Corner Co-op; The Local Beet Co-op; Willimantic Food Co-op\n", 594 | "name ---- Bread & Roses Food Cooperative; Citizens Co-op; Community Harvest Market; Ever'man Cooperative Grocery & Cafe; New Leaf Market\n", 595 | "name ---- Bread & Roses Food Cooperative; Citizens Co-op; Community Harvest Market; Ever'man Cooperative Grocery & Cafe; New Leaf Market\n", 596 | "---------------------------------\n", 597 | "address ---- 1930 Lagen St, Dubuque IA; 3860 Short St, Dubuque IA; 3962 Aurora St, Dubuque IA; 2222 Saint Celia St, Dubuque IA; 1676 Amy Ct, Dubuque IA\n", 598 | "address ---- 43 Whitney Rd, Mystic CT; 87 Quaker Farm Rd, Mystic CT; 171 Lambtown Rd, Mystic CT; 1827 Gold Star Hwy, Groton CT; 25 Nantucket Dr, Mystic CT\n", 599 | "address ---- 3418 Magnolia Way, Broadview Heights OH; 3397 Magnolia Way, Broadview Heights OH; 3080 Osage Way, Broadview Heights OH; 13741 Monica Dr, North Royalton OH; 13021 Eagle Chase, North Royalton OH\n", 600 | "address ---- 2762 Riverwood Ln, Jacksonville FL; 1804 Lorimier Rd, Jacksonville FL; 1639 Lorimier Rd, Jacksonville FL; 1705 Lorimier Rd, Jacksonville FL; 2757 White Oak Ln, Jacksonville FL\n", 601 | "address ---- 7232 E Mckinley St, Scottsdale AZ; 519 N 73rd Pl, Scottsdale AZ; 7308 E Polk St, Scottsdale AZ; 713 N 74th St, Scottsdale AZ; 24 E Mckinley Cir, Tempe AZ\n", 602 | "address ---- 5725 N Depauw St, Portland OR; 6126 N Superior St, Portland OR; 7129 N Buchanan Ave, Portland OR; 9006 N Ida Ave, Portland OR; 4925 N Princeton St, Portland OR\n", 603 | "address ---- 31 Lawndale Ave, Lebanon OH; 18 Lawndale Ave, Lebanon OH; 908 Hartz Dr, Lebanon OH; 917 Stanwood Dr, Lebanon OH; 923 Birchwood Dr, Lebanon OH\n", 604 | "address ---- 1721 Papillon St, North Port FL; 4113 Wabasso Ave, North Port FL; 3681 Wayward Ave, North Port FL; 1118 N Salford Blvd, North Port FL; 2057 Bendix Ter, North Port FL\n", 605 | "address ---- 5 Brand Rd, Toms River NJ; 40 12th St, Toms River NJ; 75 Sea Breeze Rd, Toms River NJ; 98 Oak Tree Ln, Toms River NJ; 67 16th St, Toms River NJ\n", 606 | "address ---- 652 Martha St, Montgomery AL; 3184 Lexington Rd, Montgomery AL; 120 S Lewis St, Montgomery AL; 1812 W 2nd St #OP, Montgomery AL; 3582 Southview Ave, Montgomery AL\n", 607 | "---------------------------------\n" 608 | ] 609 | } 610 | ], 611 | "source": [ 612 | "# visualize each cluster\n", 613 | "\n", 614 | "import pandas as pd\n", 615 | "import pickle\n", 616 | "\n", 617 | "dataset_path = '/nfs/users/yuliang/SDD/data/viznet/test.csv.full'\n", 618 | "ccs = pickle.load(open(\"../data/viznet/multi_column/clusters.pkl\", \"rb\"))\n", 619 | "testset = pd.read_csv(dataset_path)\n", 620 | "\n", 621 | "def show_cc(ccs, idx):\n", 622 | " for cid in ccs[idx][:10]:\n", 623 | " table_id = testset['table_id'][cid]\n", 624 | " column_id = testset['column_id'][cid]\n", 625 | " label = testset['class'][cid]\n", 626 | "\n", 627 | " table = pd.read_csv('/nfs/users/yuliang/SDD/data/viznet/tables/table_%d.csv' % table_id)\n", 628 | " value = '; '.join(table[table.columns[column_id]][:5].astype(str))\n", 629 | "\n", 630 | " # print(label, cid, table_id, column_id, '----', value)\n", 631 | " print(label, '----', value)\n", 632 | " print(\"---------------------------------\")\n", 633 | "\n", 634 | "show_cc(ccs, 10)\n", 635 | "show_cc(ccs, 35)\n", 636 | "show_cc(ccs, 57)\n", 637 | "show_cc(ccs, 69)\n", 638 | "show_cc(ccs, 78)\n" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": null, 644 | "metadata": {}, 645 | "outputs": [], 646 | "source": [] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": null, 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [] 654 | } 655 | ], 656 | "metadata": { 657 | "interpreter": { 658 | "hash": "a4defd7b6bdd6eb378ea4d7dbbc33dd84cc0def60f9189b1f6f216e34c2c378c" 659 | }, 660 | "kernelspec": { 661 | "display_name": "Python 3.7.10 64-bit ('sdd')", 662 | "language": "python", 663 | "name": "python3" 664 | }, 665 | "language_info": { 666 | "codemirror_mode": { 667 | "name": "ipython", 668 | "version": 3 669 | }, 670 | "file_extension": ".py", 671 | "mimetype": "text/x-python", 672 | "name": "python", 673 | "nbconvert_exporter": "python", 674 | "pygments_lexer": "ipython3", 675 | "version": "3.7.10" 676 | }, 677 | "orig_nbformat": 4 678 | }, 679 | "nbformat": 4, 680 | "nbformat_minor": 2 681 | } 682 | -------------------------------------------------------------------------------- /plotMetrics.py: -------------------------------------------------------------------------------- 1 | from matplotlib import * 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | 5 | 6 | fsize = 20 7 | tsize = 14 8 | # ============================================================================= 9 | # Plotting of Figures for the paper 10 | # ============================================================================= 11 | def plotMapFig(benchmark, map_dict): 12 | ''' Plot the MAP scores as bar chart (NOT USED) 13 | Args: 14 | benchmark (str) 15 | map_dict (dict): stores each method with their associated map scores 16 | ''' 17 | labels = {"d3l":r"$D^{3}L$", 18 | "SANTOS": r"SANTOS", 19 | "Starmie": r"Starmie", 20 | "SingleCol": r"SingleCol", 21 | "SATO":r"SATO", 22 | "Sherlock": r"Sherlock" 23 | } 24 | 25 | # ========== MAP ========== 26 | x = [] 27 | for method in map_dict.keys(): 28 | x.append(labels[method]) 29 | y = list(map_dict.values()) 30 | 31 | x_ticks = np.arange(len(x)) 32 | width = 0.7 33 | fig, ax = plt.subplots() 34 | ax.set_ylabel('MAP@k', fontsize=tsize) 35 | ax.set_xlabel('Method', fontsize=tsize) 36 | ax.set_xticks(x_ticks) 37 | ax.set_xticklabels(x) 38 | if benchmark in ['santos', 'tus_large']: 39 | plt.ylim(0.4, 1.03) 40 | elif benchmark == 'tus_small': 41 | plt.ylim(0.75, 1.01) 42 | # Annotate each bar with their MAP score 43 | pps = ax.bar(x_ticks, y, width, label='MAP@k') 44 | for p in pps: 45 | height = p.get_height() 46 | ax.annotate('{}'.format(height), 47 | xy=(p.get_x() + p.get_width() / 2, height), 48 | xytext=(0, 3), 49 | textcoords="offset points", 50 | ha='center', va='bottom') 51 | # save the figure to a local path 52 | fig.savefig('../../Starmie/%s_map.pdf' % (benchmark)) 53 | plt.show() 54 | 55 | 56 | 57 | def plotJointFig(k, benchmark, precision_list_dict, recall_list_dict, ideal_list): 58 | ''' Plot P@K and R@K figures for a specified benchmark 59 | Saves plot to a local filepath, and shows the path (the legend is hidden) 60 | Args: 61 | k (list): list of k values associated with each score 62 | benchmark (str): e.g. 'santos', 'tus_small' 63 | precision_list_dict (dict): With each method as key, its value is the list of precision scores for each k 64 | recall_list_dict (dict): With each method as key, its value is the list of recall scores for each k 65 | ideal_list: list of IDEAL recall scores for each k 66 | ''' 67 | # number of methods to compare 68 | col_number = 6 69 | # Formatting / Styling choices 70 | colors = {"d3l":"#e52638", 71 | "SANTOS":"#777777", 72 | "Starmie": "royalblue", 73 | "SingleCol": "#68affc", 74 | "SATO": "#699f3c", 75 | "Sherlock": "darkgoldenrod" 76 | } 77 | linestyles = {"d3l":"dashed", 78 | "SANTOS":"dotted", 79 | "Starmie": "solid", 80 | "SingleCol": "dashdot", 81 | "SATO": (0, (3, 1, 1, 1)), 82 | "Sherlock": (0, (3, 1, 1, 1, 1, 1)) 83 | } 84 | 85 | labels = {"d3l":r"$D^{3}L$", 86 | "SANTOS": r"SANTOS", 87 | "Starmie": r"Starmie", 88 | "SingleCol": r"SingleCol", 89 | "SATO":r"SATO", 90 | "Sherlock": r"Sherlock" 91 | } 92 | 93 | markers = {"d3l":"^", 94 | "SANTOS":"o", 95 | "Starmie": "s", 96 | "SingleCol": "*", 97 | "SATO": "p", 98 | "Sherlock": "+" 99 | } 100 | # ========== PRECISION/RECALL with LEGEND ========== 101 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,5)) 102 | for ax in axes: 103 | # formatting for both P@K and R@K graphs: x-axis is labeled with "k", add grids, set font sizes 104 | ax.set_xlabel("k", fontsize=fsize) 105 | ax.grid(linestyle = '--', linewidth = 0.5) 106 | for item in (ax.get_xticklabels() + ax.get_yticklabels()): 107 | item.set_fontsize(tsize) 108 | plt.rc('legend', fontsize=tsize) 109 | 110 | # Plot the Precision@K graph 111 | axes[0].set_ylabel("P@k", fontsize=fsize) 112 | for which_method in precision_list_dict: 113 | axes[0].plot(k, precision_list_dict[which_method], color = colors[which_method], linestyle = linestyles[which_method], linewidth = 2, label = labels[which_method], marker=markers[which_method], markersize = 10) 114 | 115 | # Plot the Recall@K graph, along with IDEAL recall 116 | axes[1].set_ylabel("R@k", fontsize=fsize) 117 | for which_method in recall_list_dict: 118 | axes[1].plot(k, recall_list_dict[which_method], color = colors[which_method], linestyle = linestyles[which_method], linewidth = 2, label = labels[which_method], marker=markers[which_method], markersize = 10) 119 | axes[1].plot(k, ideal_list, color = "black", label = "IDEAL", linewidth = 3) 120 | 121 | # Add Legend 122 | handles, labels = axes[1].get_legend_handles_labels() 123 | lgd = fig.legend(handles, labels, bbox_to_anchor=(0.5, 1.1), ncol=col_number+1, loc='upper center') 124 | fig.tight_layout() 125 | 126 | # Save figure to a local path 127 | fig.savefig('../../Starmie/%s_P_R.pdf' % (benchmark), bbox_extra_artists=(lgd,), bbox_inches='tight') 128 | plt.show() 129 | 130 | 131 | def plotScalFig(k, dl_sizes, benchmark, scal_k, scal_size): 132 | ''' Plot the scalability figures for a specified benchmark 133 | Saves plot to a local filepath, and shows the path (the legend is hidden) 134 | Args: 135 | k (list): list of k values associated with each score 136 | dl_sizes (list): list of data lake sizes for x_axis of scalability graph for varying DL size 137 | benchmark (str): e.g. 'real', 'wdc' 138 | scal_k (dict): With each technique as key, its value is the query times (in ms) 139 | scal_size (dict): With each technique as key, its value is the query times (in ms) 140 | ''' 141 | # number of methods to compare 142 | col_number = 4 143 | # Formatting / Styling choices 144 | colors = {"Linear":"royalblue", 145 | "Bounds":"green", 146 | "LSH": "red", 147 | "HNSW": "darkgoldenrod" 148 | } 149 | linestyles = {"LSH":"dashed", 150 | "Bounds":"dotted", 151 | "Linear": "solid", 152 | "HNSW": "dashdot" 153 | } 154 | 155 | labels = {"Linear":"Linear", 156 | "Bounds": "Bounds", 157 | "LSH": "LSH Index", 158 | "HNSW": "HNSW Index" 159 | } 160 | 161 | markers = {"LSH":"^", 162 | "Bounds":"o", 163 | "Linear": "s", 164 | "HNSW": "*" 165 | } 166 | 167 | x_axis_labels = ['K', 'Data Lake Size (# tables / # attributes)'] 168 | # ========== SCALABILITY with LEGEND ========== 169 | fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,5)) 170 | for ind, ax in enumerate(axes): 171 | # formatting for graphs: x-axis is labeled with "k", add grids, set font sizes 172 | ax.set_xlabel(x_axis_labels[ind], fontsize=fsize) 173 | ax.set_ylabel("Average Query Time (sec)", fontsize=fsize) 174 | ax.grid(linestyle = '--', linewidth = 0.5) 175 | for item in (ax.get_xticklabels() + ax.get_yticklabels()): 176 | item.set_fontsize(tsize) 177 | plt.rc('legend', fontsize=tsize) 178 | 179 | # Plot the graph for varying k's as x-axis 180 | for which_method in scal_k: 181 | # if we want to plot in seconds. Otherwise, scal_k[which_method] is in ms 182 | scal_method = [int(qt)/1000 for qt in scal_k[which_method]] 183 | axes[0].plot(k, scal_method, color = colors[which_method], linestyle = linestyles[which_method], linewidth = 2, label = labels[which_method], marker=markers[which_method], markersize = 10) 184 | 185 | # Plot the graph for scalability, which growing data lake 186 | for which_method in scal_size: 187 | # if we want to plot in seconds. Otherwise, scal_k[which_method] is in ms 188 | scal_method = [int(qt)/1000 for qt in scal_size[which_method]] 189 | axes[1].plot(dl_sizes, scal_method, color = colors[which_method], linestyle = linestyles[which_method], linewidth = 2, label = labels[which_method], marker=markers[which_method], markersize = 10) 190 | 191 | # Add legend 192 | handles, labels = axes[1].get_legend_handles_labels() 193 | lgd = fig.legend(handles, labels, bbox_to_anchor=(0.5, 1.1), ncol=col_number+1, loc='upper center') 194 | fig.tight_layout() 195 | 196 | # Save plot to local file path 197 | fig.savefig('../../Starmie/%s_scal_sec.pdf' % (benchmark), bbox_extra_artists=(lgd,), bbox_inches='tight') 198 | plt.show() 199 | 200 | 201 | if __name__ == '__main__': 202 | ''' 203 | Plot the experimental results figures, shown in the paper 204 | ''' 205 | # ========== Metrics Dictionaries ========== 206 | # ========================================== 207 | ''' Plot the Performance metrics for each benchmark: SANTOS, TUS Small, TUS Large ''' 208 | # ---------- 1. SANTOS Benchmark ----------- 209 | precision_dict_santos = {'Starmie': [1, 1, 0.992, 0.991, 0.984], 'SingleCol': [1, 0.927, 0.896, 0.869, 0.798], 'SATO': [1, 0.913, 0.872, 0.846, 0.806], 'Sherlock': [1, 0.833, 0.772, 0.726, 0.672], 'SANTOS': [0.98, 0.947, 0.936, 0.926, 0.908], 'd3l': [0.5, 0.467, 0.512, 0.546, 0.576]} 210 | map_dict_santos = {'Starmie': 0.993, 'SingleCol': 0.891, 'SATO': 0.878, 'Sherlock': 0.782, 'SANTOS': 0.93, 'd3l': 0.523} 211 | recall_dict_santos = {'Starmie': [0.075,0.225,0.372,0.52,0.737], 'SingleCol': [0.075,0.208,0.333,0.451,0.588], 'SATO': [0.08,0.203,0.322,0.436,0.594], 'Sherlock': [0.08,0.185,0.284,0.373,0.493], 'SANTOS': [0.074,0.215,0.353,0.49,0.69], 'd3l': [0.037,0.099,0.185,0.278,0.422]} 212 | ideal_santos = [0.08,0.23,0.38,0.53,0.75] 213 | k_santos = [1,3,5,7,10] 214 | 215 | # plotJointFig(k_santos, 'santos', precision_dict_santos, recall_dict_santos, ideal_santos) 216 | # plotMapFig('santos', map_dict_santos) 217 | 218 | # ---------- 2. TUS Small Benchmark ----------- 219 | precision_dict_tus_small = {'Starmie': [0.998,0.995,0.993,0.989,0.984,0.977], 'SingleCol': [0.977,0.97,0.956,0.944,0.927,0.907], 'SATO': [0.972,0.962,0.962,0.961,0.96,0.956], 'Sherlock': [0.998,0.995,0.993,0.985,0.967,0.933], 'SANTOS': [0.934,0.903,0.886,0.873,0.845,0.814], 'd3l': [0.807,0.804,0.8,0.792,0.777,0.765]} 220 | map_dict_tus_small = {'Starmie': 0.991, 'SingleCol': 0.954, 'SATO': 0.966, 'Sherlock': 0.984, 'SANTOS': 0.885, 'd3l': 0.794} 221 | recall_dict_tus_small = {'Starmie': [0.047,0.094,0.14,0.187,0.232,0.277], 'SingleCol': [0.046,0.091,0.135,0.177,0.217,0.255], 'SATO': [0.046,0.091,0.136,0.181,0.227,0.271], 'Sherlock': [0.047,0.094,0.141,0.186,0.229,0.265], 'SANTOS': [0.044,0.084,0.125,0.164,0.199,0.23], 'd3l': [0.038,0.076,0.113,0.149,0.182,0.215]} 222 | ideal_tus_small = [0.057,0.114,0.17,0.227,0.284,0.341] 223 | k_tus = [10,20,30,40,50,60] 224 | 225 | # plotJointFig(k_tus, 'tus_small', precision_dict_tus_small, recall_dict_tus_small, ideal_tus_small) 226 | # plotMapFig('tus_small', map_dict_tus_small) 227 | 228 | 229 | 230 | # ---------- 3. TUS Large Benchmark ----------- 231 | precision_dict_tus_large = {'Starmie': [0.997,0.988,0.967,0.948,0.932,0.915], 'SingleCol': [0.951,0.925,0.903,0.876,0.85,0.824], 'SATO': [0.978,0.96,0.931,0.907,0.886,0.866], 'Sherlock': [0.929,0.83,0.734,0.654,0.581,0.525], 'd3l': [0.495,0.469,0.464,0.464,0.473,0.468]} 232 | map_dict_tus_large = {'Starmie': 0.965, 'SingleCol': 0.902, 'SATO': 0.93, 'Sherlock': 0.744, 'd3l': 0.484} 233 | recall_dict_tus_large = {'Starmie': [0.045,0.088,0.129,0.167,0.204,0.238], 'SingleCol': [0.043,0.082,0.119,0.153,0.183,0.208], 'SATO': [0.044,0.086,0.125,0.161,0.193,0.223], 'Sherlock': [0.041,0.071,0.092,0.105,0.114,0.119], 'd3l': [0.019,0.039,0.06,0.082,0.105,0.124]} 234 | ideal_tus_large = [0.046,0.092,0.138,0.185,0.231,0.277] 235 | 236 | # plotJointFig(k_tus, 'tus_large', precision_dict_tus_large, recall_dict_tus_large, ideal_tus_large) 237 | # plotMapFig('tus_large', map_dict_tus_large) 238 | 239 | 240 | 241 | ''' Plot scalability graphs for SANTOS REAL, WDC benchmarks. In the paper: include tables for indexing time and storage overhead ''' 242 | # ====== Scalability Dictionaries ========== 243 | # ========================================== 244 | # ---------- 1. SANTOS Real Benchmark ----------- 245 | scal_real_k = {'Linear': [71880,70620,70460,70540,70680,70580], 'Bounds': [30350,33050,34450,35520,36690,37230], 'LSH': [3470,3510,3460,3470,3420,3560], 'HNSW': [330,330,320,320,330,320]} 246 | scal_real_size = {'Linear': [13630,28120,42220,56920,70580], 'Bounds': [9540,16840,23930,30890,37230], 'LSH': [960,1590,2100,2890,3560], 'HNSW': [500,460,340,320,320]} 247 | k_scal = [10,20,30,40,50,60] 248 | dl_real_sizes = ['2.2K / 24K','4.4K / 48K','6.6K / 72K','8.8K / 96K','11K / 120K'] 249 | plotScalFig(k_scal, dl_real_sizes, 'real', scal_real_k, scal_real_size) 250 | 251 | # ---------- 2. WDC Benchmark ----------- 252 | scal_wdc_k = {'Linear': [865960,847880,818070,874010,874810,819170], 'Bounds': [341650,335850,356130,341370,339990,338580], 'LSH': [94370,101840,104660,94650,97420,106410], 'HNSW': [240,220,230,240,280,300]} 253 | scal_wdc_size = {'Linear': [161910,324000,488310,668493,819170], 'Bounds': [69530,155140,204780,274310,338580], 'LSH': [16350,34450,53850,72670,106410], 'HNSW': [230,290,200,340,300]} 254 | dl_wdc_sizes = ['200K / 1M ','400K / 2M','600K / 3M','800K / 4M','1M / 5M'] 255 | # plotScalFig(k_scal, dl_wdc_sizes, 'wdc', scal_wdc_k, scal_wdc_size) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | regex==2019.12.20 3 | scipy==1.3.3 4 | sentencepiece==0.1.85 5 | sklearn==0.0 6 | spacy==2.2.3 7 | tensorboardX==2.0 8 | jsonlines==1.2.0 9 | nltk==3.4.5 10 | torch==1.9.0+cu111 11 | tqdm==4.41.0 12 | transformers==4.9.2 -------------------------------------------------------------------------------- /run_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | datasets = ['small', 'large'] 4 | lms = ['distilbert', 'roberta'] 5 | batch_sizes = [32, 64] 6 | 7 | for ds in datasets: 8 | for lm in lms: 9 | for batch_size in batch_sizes: 10 | for run_id in range(5): 11 | cmd = """python train.py \ 12 | --task %s \ 13 | --batch_size %s \ 14 | --lr 5e-5 \ 15 | --lm %s \ 16 | --n_epochs 20 \ 17 | --max_len 128 \ 18 | --fp16 \ 19 | --run_id %d""" % (ds, batch_size, lm, run_id) 20 | print(cmd) 21 | os.system('sbatch -c 1 -G 1 -J my-exp --tasks-per-node=1 --wrap="%s"' % cmd) 22 | -------------------------------------------------------------------------------- /run_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import random 4 | import torch 5 | import mlflow 6 | 7 | from sdd.dataset import PretrainTableDataset 8 | from sdd.pretrain import train 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--task", type=str, default="small") 13 | parser.add_argument("--logdir", type=str, default="results/") 14 | parser.add_argument("--run_id", type=int, default=0) 15 | parser.add_argument("--batch_size", type=int, default=32) 16 | parser.add_argument("--max_len", type=int, default=128) 17 | parser.add_argument("--size", type=int, default=10000) 18 | parser.add_argument("--lr", type=float, default=5e-5) 19 | parser.add_argument("--n_epochs", type=int, default=20) 20 | parser.add_argument("--lm", type=str, default='roberta') 21 | parser.add_argument("--projector", type=int, default=768) 22 | parser.add_argument("--augment_op", type=str, default='drop_col,sample_row') 23 | parser.add_argument("--save_model", dest="save_model", action="store_true") 24 | parser.add_argument("--fp16", dest="fp16", action="store_true") 25 | # single-column mode without table context 26 | parser.add_argument("--single_column", dest="single_column", action="store_true") 27 | # row / column-ordered for preprocessing 28 | parser.add_argument("--table_order", type=str, default='column') 29 | # for sampling 30 | parser.add_argument("--sample_meth", type=str, default='head') 31 | # mlflow tag 32 | parser.add_argument("--mlflow_tag", type=str, default=None) 33 | 34 | hp = parser.parse_args() 35 | 36 | # mlflow logging 37 | for variable in ["task", "batch_size", "lr", "n_epochs", "augment_op", "sample_meth", "table_order"]: 38 | mlflow.log_param(variable, getattr(hp, variable)) 39 | 40 | if hp.mlflow_tag: 41 | mlflow.set_tag("tag", hp.mlflow_tag) 42 | 43 | # set seed 44 | seed = hp.run_id 45 | random.seed(seed) 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | if torch.cuda.is_available(): 49 | torch.cuda.manual_seed_all(seed) 50 | 51 | # Change the data paths to where the benchmarks are stored 52 | if "santos" in hp.task: 53 | path = 'data/%s/datalake' % hp.task 54 | if hp.task == "santosLarge": 55 | path = 'data/santos-benchmark/real-benchmark/datalake' 56 | elif "tus" in hp.task: 57 | path = 'data/table-union-search-benchmark/small/benchmark' 58 | if hp.task == "tusLarge": 59 | path = 'data/table-union-search-benchmark/large/benchmark' 60 | 61 | else: 62 | path = 'data/%s/tables' % hp.task 63 | # trainset = PretrainTableDataset(path, 64 | # augment_op=hp.augment_op, 65 | # lm=hp.lm, 66 | # max_len=hp.max_len, 67 | # size=hp.size, 68 | # single_column=hp.single_column, 69 | # sample_meth=hp.sample_meth) 70 | trainset = PretrainTableDataset.from_hp(path, hp) 71 | 72 | train(trainset, hp) 73 | -------------------------------------------------------------------------------- /run_pretrain_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | max_lens = [128, 256] 4 | augment_ops = ['drop_col', 'sample_row', 'sample_row_ordered', 'shuffle_col', 'drop_cell', 'drop_num_col', 'drop_nan_col', 'shuffle_row'] 5 | sampling_methods = ['head', 'random', 'constant', 'frequent', 'tfidf_token', 'tfidf_entity'] 6 | 7 | for ml in max_lens: 8 | for ao in [augment_ops[4]]: 9 | for sm in sampling_methods: 10 | for run_id in range(5): 11 | # add --single_column for baseline 12 | cmd = """python run_pretrain.py \ 13 | --task %s \ 14 | --batch_size 64 \ 15 | --lr 5e-5 \ 16 | --lm roberta \ 17 | --n_epochs 3 \ 18 | --max_len %d \ 19 | --size 10000 \ 20 | --save_model \ 21 | --single_column \ 22 | --augment_op %s \ 23 | --fp16 \ 24 | --sample_meth %s \ 25 | --run_id %d""" % ("small", ml, ao, sm, run_id) 26 | print(cmd) 27 | os.system('sbatch -c 1 -G 1 -J my-exp --tasks-per-node=1 --output=slurm.out --wrap="%s"' % cmd) -------------------------------------------------------------------------------- /run_tus_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | benchmark = 'wdc' 5 | matching = 'exact' 6 | table_order = 'column' 7 | augment_ops = ['drop_col', 'sample_row', 'sample_row_ordered', 'shuffle_col', 'drop_cell', 'drop_num_col', 'drop_nan_col', 'shuffle_row'] 8 | sampling_methods = ['head', 'random', 'constant', 'frequent', 'tfidf_token', 'tfidf_entity', 'tfidf_row'] 9 | 10 | k = 60 11 | threshold = 0.1 12 | enc = 'cl' 13 | ao = 'drop_col' 14 | sm = 'tfidf_entity' 15 | run_id = 0 16 | cmd = """python test_naive_search.py \ 17 | --encoder %s \ 18 | --benchmark %s \ 19 | --augment_op %s \ 20 | --sample_meth %s \ 21 | --matching %s \ 22 | --table_order %s \ 23 | --run_id %d \ 24 | --K %d \ 25 | --threshold %f \ 26 | --scal %f""" % (enc, benchmark, ao, sm, matching, table_order, run_id, k, threshold, scale) 27 | 28 | print(cmd) 29 | os.system('sbatch -c 1 -G 1 -J my-exp --tasks-per-node=1 --output=slurm.out --wrap="%s"' % cmd) 30 | -------------------------------------------------------------------------------- /sdd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megagonlabs/starmie/5eb90fe27fb1162d2a62b555ac54908ee8e4c474/sdd/__init__.py -------------------------------------------------------------------------------- /sdd/augment.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import random 3 | 4 | def augment(table: pd.DataFrame, op: str): 5 | """Apply an augmentation operator on a table. 6 | 7 | Args: 8 | table (DataFrame): the input table 9 | op (str): operator name 10 | 11 | Return: 12 | DataFrame: the augmented table 13 | """ 14 | if op == 'drop_col': 15 | # set values of a random column to 0 16 | col = random.choice(table.columns) 17 | table = table.copy() 18 | table[col] = "" 19 | elif op == 'sample_row': 20 | # sample 50% of rows 21 | if len(table) > 0: 22 | table = table.sample(frac=0.5) 23 | elif op == 'sample_row_ordered': 24 | # sample 50% of rows 25 | if len(table) > 0: 26 | table = table.sample(frac=0.5).sort_index() 27 | elif op == 'shuffle_col': 28 | # shuffle the column orders 29 | new_columns = list(table.columns) 30 | random.shuffle(new_columns) 31 | table = table[new_columns] 32 | elif op == 'drop_cell': 33 | # drop a random cell 34 | table = table.copy() 35 | row_idx = random.randint(0, len(table) - 1) 36 | col_idx = random.randint(0, len(table.columns) - 1) 37 | table.iloc[row_idx, col_idx] = "" 38 | elif op == 'sample_cells': 39 | # sample half of the cells randomly 40 | table = table.copy() 41 | col_idx = random.randint(0, len(table.columns) - 1) 42 | sampleRowIdx = [] 43 | for _ in range(len(table) // 2 - 1): 44 | sampleRowIdx.append(random.randint(0, len(table) - 1)) 45 | for ind in sampleRowIdx: 46 | table.iloc[ind, col_idx] = "" 47 | elif op == 'replace_cells': 48 | # replace half of the cells randomly with the first values after sorting 49 | table = table.copy() 50 | col_idx = random.randint(0, len(table.columns) - 1) 51 | sortedCol = table[table.columns[col_idx]].sort_values().tolist() 52 | sampleRowIdx = [] 53 | for _ in range(len(table) // 2 - 1): 54 | sampleRowIdx.append(random.randint(0, len(table) - 1)) 55 | for ind in sampleRowIdx: 56 | table.iloc[ind, col_idx] = sortedCol[ind] 57 | elif op == 'drop_head_cells': 58 | # drop the first quarter of cells 59 | table = table.copy() 60 | col_idx = random.randint(0, len(table.columns) - 1) 61 | sortedCol = table[table.columns[col_idx]].sort_values().tolist() 62 | sortedHead = sortedCol[:len(table)//4] 63 | for ind in range(0,len(table)): 64 | if table.iloc[ind, col_idx] in sortedHead: 65 | table.iloc[ind, col_idx] = "" 66 | elif op == 'drop_num_cells': 67 | # drop numeric cells 68 | table = table.copy() 69 | tableCols = list(table.columns) 70 | numTable = table.select_dtypes(include=['number']) 71 | numCols = numTable.columns.tolist() 72 | if numCols == []: 73 | col_idx = random.randint(0, len(table.columns) - 1) 74 | else: 75 | col = random.choice(numCols) 76 | col_idx = tableCols.index(col) 77 | sampleRowIdx = [] 78 | for _ in range(len(table) // 2 - 1): 79 | sampleRowIdx.append(random.randint(0, len(table) - 1)) 80 | for ind in sampleRowIdx: 81 | table.iloc[ind, col_idx] = "" 82 | elif op == 'swap_cells': 83 | # randomly swap two cells 84 | table = table.copy() 85 | row_idx = random.randint(0, len(table) - 1) 86 | row2_idx = random.randint(0, len(table) - 1) 87 | while row2_idx == row_idx: 88 | row2_idx = random.randint(0, len(table) - 1) 89 | col_idx = random.randint(0, len(table.columns) - 1) 90 | cell1 = table.iloc[row_idx, col_idx] 91 | cell2 = table.iloc[row2_idx, col_idx] 92 | table.iloc[row_idx, col_idx] = cell2 93 | table.iloc[row2_idx, col_idx] = cell1 94 | elif op == 'drop_num_col': # number of columns is not preserved 95 | # remove numeric columns 96 | numTable = table.select_dtypes(include=['number']) 97 | numCols = numTable.columns.tolist() 98 | textTable = table.select_dtypes(exclude=['number']) 99 | textCols = textTable.columns.tolist() 100 | addedCols = 0 101 | while addedCols <= len(numCols) // 2 and len(numCols) > 0: 102 | numRandCol = numCols.pop(random.randrange(len(numCols))) 103 | textCols.append(numRandCol) 104 | addedCols += 1 105 | textCols = sorted(textCols,key=list(table.columns).index) 106 | table = table[textCols] 107 | elif op == 'drop_nan_col': # number of columns is not preserved 108 | # remove a half of the number of columns that contain nan values 109 | newCols, nanSums = [], {} 110 | for column in table.columns: 111 | if table[column].isna().sum() != 0: 112 | nanSums[column] = table[column].isna().sum() 113 | else: 114 | newCols.append(column) 115 | nanSums = {k: v for k, v in sorted(nanSums.items(), key=lambda item: item[1], reverse=True)} 116 | nanCols = list(nanSums.keys()) 117 | newCols += random.sample(nanCols, len(nanCols) // 2) 118 | table = table[newCols] 119 | elif op == 'shuffle_row': 120 | # shuffle the rows 121 | table = table.sample(frac=1) 122 | 123 | return table -------------------------------------------------------------------------------- /sdd/baselines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | import sklearn.metrics as metrics 7 | import mlflow 8 | 9 | from .utils import evaluate 10 | from .model import TableModel 11 | from .dataset import TableDataset 12 | 13 | from torch.utils import data 14 | from transformers import AdamW, get_linear_schedule_with_warmup 15 | 16 | 17 | def train_step(train_iter, model, optimizer, scheduler, scaler, hp): 18 | """Perform a single training step 19 | 20 | Args: 21 | train_iter (Iterator): the train data loader 22 | model (DMModel): the model 23 | optimizer (Optimizer): the optimizer (Adam or AdamW) 24 | scheduler (LRScheduler): learning rate scheduler 25 | scaler (GradScaler): gradient scaler for fp16 training 26 | hp (Namespace): other hyper-parameters (e.g., fp16) 27 | 28 | Returns: 29 | None 30 | """ 31 | criterion = nn.CrossEntropyLoss() 32 | # criterion = nn.MSELoss() 33 | for i, batch in enumerate(train_iter): 34 | # x1, x2, x12, y = batch 35 | x, y = batch 36 | optimizer.zero_grad() 37 | if hp.fp16: 38 | with torch.cuda.amp.autocast(): 39 | prediction = model(x) 40 | loss = criterion(prediction, y.to(model.device)) 41 | scaler.scale(loss).backward() 42 | scaler.step(optimizer) 43 | scaler.update() 44 | else: 45 | prediction = model(x) 46 | loss = criterion(prediction, y.to(model.device)) 47 | loss.backward() 48 | optimizer.step() 49 | 50 | scheduler.step() 51 | if i % 10 == 0: # monitoring 52 | print(f"step: {i}, loss: {loss.item()}") 53 | del loss 54 | 55 | 56 | def train(trainset, validset, testset, hp): 57 | """Train and evaluate the model 58 | 59 | Args: 60 | trainset (TableDataset): the training set 61 | validset (TableDataset): the validation set 62 | testset (TableDataset): the test set 63 | hp (Namespace): Hyper-parameters (e.g., batch_size, 64 | learning rate, fp16) 65 | Returns: 66 | None 67 | """ 68 | padder = trainset.pad 69 | # create the DataLoaders 70 | train_iter = data.DataLoader(dataset=trainset, 71 | batch_size=hp.batch_size, 72 | shuffle=True, 73 | num_workers=0, 74 | collate_fn=padder) 75 | valid_iter = data.DataLoader(dataset=validset, 76 | batch_size=hp.batch_size, 77 | shuffle=False, 78 | num_workers=0, 79 | collate_fn=padder) 80 | test_iter = data.DataLoader(dataset=testset, 81 | batch_size=hp.batch_size, 82 | shuffle=False, 83 | num_workers=0, 84 | collate_fn=padder) 85 | 86 | # initialize model, optimizer, and LR scheduler 87 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 88 | model = TableModel(device=device, lm=hp.lm) 89 | model = model.cuda() 90 | optimizer = AdamW(model.parameters(), lr=hp.lr) 91 | if hp.fp16: 92 | scaler = torch.cuda.amp.GradScaler() 93 | else: 94 | scaler = None 95 | 96 | num_steps = (len(trainset) // hp.batch_size) * hp.n_epochs 97 | scheduler = get_linear_schedule_with_warmup(optimizer, 98 | num_warmup_steps=0, 99 | num_training_steps=num_steps) 100 | 101 | 102 | 103 | best_dev_f1 = best_test_f1 = 0.0 104 | for epoch in range(1, hp.n_epochs+1): 105 | # train 106 | model.train() 107 | train_step(train_iter, model, optimizer, scheduler, scaler, hp) 108 | 109 | # eval 110 | model.eval() 111 | dev_f1, th = evaluate(model, valid_iter) 112 | test_f1 = evaluate(model, test_iter, threshold=th) 113 | 114 | if dev_f1 > best_dev_f1: 115 | best_dev_f1 = dev_f1 116 | best_test_f1 = test_f1 117 | print(f"epoch {epoch}: dev_f1={dev_f1}, f1={test_f1}, best_f1={best_test_f1}") 118 | 119 | # logging to mlflow 120 | for variable in ["dev_f1", "test_f1", "best_test_f1"]: 121 | mlflow.log_metric(variable, eval(variable)) -------------------------------------------------------------------------------- /sdd/dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch 3 | import random 4 | import pandas as pd 5 | import os 6 | 7 | from torch.utils import data 8 | from transformers import AutoTokenizer 9 | from .augment import augment 10 | from typing import List 11 | from .preprocessor import computeTfIdf, tfidfRowSample, preprocess 12 | 13 | # map lm name to huggingface's pre-trained model names 14 | lm_mp = {'roberta': 'roberta-base', 15 | 'bert': 'bert-base-uncased', 16 | 'distilbert': 'distilbert-base-uncased'} 17 | 18 | 19 | class TableDataset(data.Dataset): 20 | """Table dataset""" 21 | 22 | def __init__(self, 23 | path, 24 | max_len=256, 25 | lm='roberta'): 26 | self.tokenizer = AutoTokenizer.from_pretrained(lm_mp[lm]) 27 | self.pairs = [] 28 | self.max_len = max_len 29 | self.samples = pd.read_csv(path) 30 | self.labels = self.samples['match'] 31 | self.table_path = os.path.join(os.path.split(path)[0], "tables") 32 | self.table_cache = {} 33 | 34 | def _read_table(self, table_id): 35 | """Read a table""" 36 | if table_id in self.table_cache: 37 | table = self.table_cache[table_id] 38 | else: 39 | table = pd.read_csv(os.path.join(self.table_path, 40 | "table_%d.csv" % table_id)) 41 | self.table_cache[table_id] = table 42 | 43 | return table 44 | 45 | 46 | def __len__(self): 47 | """Return the size of the dataset.""" 48 | return len(self.samples) 49 | 50 | def __getitem__(self, idx): 51 | """Return a tokenized item of the dataset. 52 | 53 | Args: 54 | idx (int): the index of the item 55 | 56 | Returns: 57 | List of int: token ID's of the two entities combined 58 | int: the label of the pair (0: unmatch, 1: match) 59 | """ 60 | # idx = random.randint(0, len(self.pairs)-1) 61 | l_table_id = self.samples['l_table_id'][idx] 62 | r_table_id = self.samples['r_table_id'][idx] 63 | l_column_id = self.samples['l_column_id'][idx] 64 | r_column_id = self.samples['r_column_id'][idx] 65 | 66 | l_table = self._read_table(l_table_id) 67 | r_table = self._read_table(r_table_id) 68 | 69 | l_column = l_table[l_table.columns[l_column_id]].astype(str) 70 | r_column = r_table[r_table.columns[r_column_id]].astype(str) 71 | 72 | # baseline: simple concatenation 73 | left = ' '.join(l_column) 74 | right = ' '.join(r_column) 75 | 76 | x = self.tokenizer.encode(text=left, 77 | text_pair=right, 78 | max_length=self.max_len, 79 | truncation=True) 80 | return x, self.labels[idx] 81 | 82 | 83 | def pad(self, batch): 84 | """Merge a list of dataset items into a train/test batch 85 | 86 | Args: 87 | batch (list of tuple): a list of dataset items 88 | 89 | Returns: 90 | LongTensor: x1 of shape (batch_size, seq_len) 91 | LongTensor: x2 of shape (batch_size, seq_len). 92 | Elements of x1 and x2 are padded to the same length 93 | LongTensor: x12 of shape (batch_size, seq_len'). 94 | Elements of x12 are padded to the same length 95 | LongTensor: a batch of labels, (batch_size,) 96 | """ 97 | if len(batch[0]) == 4: 98 | # em 99 | x1, x2, x12, y = zip(*batch) 100 | 101 | maxlen = max([len(x) for x in x1+x2]) 102 | 103 | x1 = [xi + [self.tokenizer.pad_token_id]*(maxlen - len(xi)) for xi in x1] 104 | x2 = [xi + [self.tokenizer.pad_token_id]*(maxlen - len(xi)) for xi in x2] 105 | 106 | maxlen = max([len(x) for x in x12]) 107 | x12 = [xi + [self.tokenizer.pad_token_id]*(maxlen - len(xi)) for xi in x12] 108 | 109 | return torch.LongTensor(x1), \ 110 | torch.LongTensor(x2), \ 111 | torch.LongTensor(x12), \ 112 | torch.LongTensor(y) 113 | else: 114 | # cleaning 115 | x1, y = zip(*batch) 116 | maxlen = max([len(x) for x in x1]) 117 | x1 = [xi + [self.tokenizer.pad_token_id]*(maxlen - len(xi)) for xi in x1] 118 | return torch.LongTensor(x1), torch.LongTensor(y) 119 | 120 | 121 | class PretrainTableDataset(data.Dataset): 122 | """Table dataset for pre-training""" 123 | 124 | def __init__(self, 125 | path, 126 | augment_op, 127 | max_len=256, 128 | size=None, 129 | lm='roberta', 130 | single_column=False, 131 | sample_meth='wordProb', 132 | table_order='column'): 133 | self.tokenizer = AutoTokenizer.from_pretrained(lm_mp[lm]) 134 | self.max_len = max_len 135 | self.path = path 136 | 137 | # assuming tables are in csv format 138 | self.tables = [fn for fn in os.listdir(path) if '.csv' in fn] 139 | 140 | # only keep the first n tables 141 | if size is not None: 142 | self.tables = self.tables[:size] 143 | 144 | self.table_cache = {} 145 | 146 | # augmentation operators 147 | self.augment_op = augment_op 148 | 149 | # logging counter 150 | self.log_cnt = 0 151 | 152 | # sampling method 153 | self.sample_meth = sample_meth 154 | 155 | # single-column mode 156 | self.single_column = single_column 157 | 158 | # row or column order for preprocessing 159 | self.table_order = table_order 160 | 161 | # tokenizer cache 162 | self.tokenizer_cache = {} 163 | 164 | @staticmethod 165 | def from_hp(path: str, hp: Namespace): 166 | """Construct a PretrainTableDataset from hyperparameters 167 | 168 | Args: 169 | path (str): the path to the table directory 170 | hp (Namespace): the hyperparameters 171 | 172 | Returns: 173 | PretrainTableDataset: the constructed dataset 174 | """ 175 | return PretrainTableDataset(path, 176 | augment_op=hp.augment_op, 177 | lm=hp.lm, 178 | max_len=hp.max_len, 179 | size=hp.size, 180 | single_column=hp.single_column, 181 | sample_meth=hp.sample_meth, 182 | table_order=hp.table_order) 183 | 184 | 185 | def _read_table(self, table_id): 186 | """Read a table""" 187 | if table_id in self.table_cache: 188 | table = self.table_cache[table_id] 189 | else: 190 | fn = os.path.join(self.path, self.tables[table_id]) 191 | table = pd.read_csv(fn, lineterminator='\n') 192 | self.table_cache[table_id] = table 193 | 194 | return table 195 | 196 | 197 | def _tokenize(self, table: pd.DataFrame) -> List[int]: 198 | """Tokenize a DataFrame table 199 | 200 | Args: 201 | table (DataFrame): the input table 202 | 203 | Returns: 204 | List of int: list of token ID's with special tokens inserted 205 | Dictionary: a map from column names to special tokens 206 | """ 207 | res = [] 208 | max_tokens = self.max_len * 2 // len(table.columns) 209 | budget = max(1, self.max_len // len(table.columns) - 1) 210 | tfidfDict = computeTfIdf(table) if "tfidf" in self.sample_meth else None # from preprocessor.py 211 | 212 | # a map from column names to special token indices 213 | column_mp = {} 214 | 215 | # column-ordered preprocessing 216 | if self.table_order == 'column': 217 | if 'row' in self.sample_meth: 218 | table = tfidfRowSample(table, tfidfDict, max_tokens) 219 | for column in table.columns: 220 | tokens = preprocess(table[column], tfidfDict, max_tokens, self.sample_meth) # from preprocessor.py 221 | col_text = self.tokenizer.cls_token + " " + \ 222 | ' '.join(tokens[:max_tokens]) + " " 223 | 224 | column_mp[column] = len(res) 225 | res += self.tokenizer.encode(text=col_text, 226 | max_length=budget, 227 | add_special_tokens=False, 228 | truncation=True) 229 | else: 230 | # row-ordered preprocessing 231 | reached_max_len = False 232 | for rid in range(len(table)): 233 | row = table.iloc[rid:rid+1] 234 | for column in table.columns: 235 | tokens = preprocess(row[column], tfidfDict, max_tokens, self.sample_meth) # from preprocessor.py 236 | if rid == 0: 237 | column_mp[column] = len(res) 238 | col_text = self.tokenizer.cls_token + " " + \ 239 | ' '.join(tokens[:max_tokens]) + " " 240 | else: 241 | col_text = self.tokenizer.pad_token + " " + \ 242 | ' '.join(tokens[:max_tokens]) + " " 243 | 244 | tokenized = self.tokenizer.encode(text=col_text, 245 | max_length=budget, 246 | add_special_tokens=False, 247 | truncation=True) 248 | 249 | if len(tokenized) + len(res) <= self.max_len: 250 | res += tokenized 251 | else: 252 | reached_max_len = True 253 | break 254 | 255 | if reached_max_len: 256 | break 257 | 258 | self.log_cnt += 1 259 | if self.log_cnt % 5000 == 0: 260 | print(self.tokenizer.decode(res)) 261 | 262 | return res, column_mp 263 | 264 | 265 | def __len__(self): 266 | """Return the size of the dataset.""" 267 | return len(self.tables) 268 | 269 | def __getitem__(self, idx): 270 | """Return a tokenized item of the dataset. 271 | 272 | Args: 273 | idx (int): the index of the item 274 | 275 | Returns: 276 | List of int: token ID's of the first view 277 | List of int: token ID's of the second view 278 | """ 279 | table_ori = self._read_table(idx) 280 | 281 | # single-column mode: only keep one random column 282 | if self.single_column: 283 | col = random.choice(table_ori.columns) 284 | table_ori = table_ori[[col]] 285 | 286 | # apply the augmentation operator 287 | if ',' in self.augment_op: 288 | op1, op2 = self.augment_op.split(',') 289 | table_tmp = table_ori 290 | table_ori = augment(table_tmp, op1) 291 | table_aug = augment(table_tmp, op2) 292 | else: 293 | table_aug = augment(table_ori, self.augment_op) 294 | 295 | # convert table into string 296 | x_ori, mp_ori = self._tokenize(table_ori) 297 | x_aug, mp_aug = self._tokenize(table_aug) 298 | 299 | # make sure that x_ori and x_aug has the same number of cls tokens 300 | # x_ori_cnt = sum([int(x == self.tokenizer.cls_token_id) for x in x_ori]) 301 | # x_aug_cnt = sum([int(x == self.tokenizer.cls_token_id) for x in x_aug]) 302 | # assert x_ori_cnt == x_aug_cnt 303 | 304 | # insertsect the two mappings 305 | cls_indices = [] 306 | for col in mp_ori: 307 | if col in mp_aug: 308 | cls_indices.append((mp_ori[col], mp_aug[col])) 309 | 310 | return x_ori, x_aug, cls_indices 311 | 312 | 313 | def pad(self, batch): 314 | """Merge a list of dataset items into a training batch 315 | 316 | Args: 317 | batch (list of tuple): a list of sequences 318 | 319 | Returns: 320 | LongTensor: x_ori of shape (batch_size, seq_len) 321 | LongTensor: x_aug of shape (batch_size, seq_len) 322 | tuple of List: the cls indices 323 | """ 324 | x_ori, x_aug, cls_indices = zip(*batch) 325 | max_len_ori = max([len(x) for x in x_ori]) 326 | max_len_aug = max([len(x) for x in x_aug]) 327 | maxlen = max(max_len_ori, max_len_aug) 328 | x_ori_new = [xi + [self.tokenizer.pad_token_id]*(maxlen - len(xi)) for xi in x_ori] 329 | x_aug_new = [xi + [self.tokenizer.pad_token_id]*(maxlen - len(xi)) for xi in x_aug] 330 | 331 | # decompose the column alignment 332 | cls_ori = [] 333 | cls_aug = [] 334 | for item in cls_indices: 335 | cls_ori.append([]) 336 | cls_aug.append([]) 337 | 338 | for idx1, idx2 in item: 339 | cls_ori[-1].append(idx1) 340 | cls_aug[-1].append(idx2) 341 | 342 | return torch.LongTensor(x_ori_new), torch.LongTensor(x_aug_new), (cls_ori, cls_aug) 343 | -------------------------------------------------------------------------------- /sdd/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers import AutoModel, AutoTokenizer 6 | 7 | lm_mp = {'roberta': 'roberta-base', 8 | 'distilbert': 'distilbert-base-uncased'} 9 | 10 | class TableModel(nn.Module): 11 | """A baseline model for Table/Column matching""" 12 | 13 | def __init__(self, device='cuda', lm='roberta'): 14 | super().__init__() 15 | self.bert = AutoModel.from_pretrained(lm_mp[lm]) 16 | self.device = device 17 | hidden_size = 768 18 | self.fc = torch.nn.Linear(hidden_size, 2) 19 | # self.fc = torch.nn.Linear(hidden_size, 1) 20 | # self.cosine = nn.CosineSimilarity() 21 | # self.distance = nn.PairwiseDistance() 22 | 23 | def forward(self, x): 24 | """Encode the left, right, and the concatenation of left+right. 25 | 26 | Args: 27 | x (LongTensor): a batch of ID's of the left+right 28 | 29 | Returns: 30 | Tensor: binary prediction 31 | """ 32 | x = x.to(self.device) # (batch_size, seq_len) 33 | 34 | # left+right 35 | enc_pair = self.bert(x)[0][:, 0, :] # (batch_size, emb_size) 36 | 37 | batch_size = len(x) 38 | # left and right 39 | enc = self.bert(x)[0][:, 0, :] 40 | 41 | # enc = self.bert(torch.cat((x1, x2)))[0][:, 0, :] 42 | # enc1 = enc[:batch_size] # (batch_size, emb_size) 43 | # enc2 = enc[batch_size:] # (batch_size, emb_size) 44 | 45 | # fully connected 46 | return self.fc(enc) 47 | 48 | 49 | 50 | def off_diagonal(x): 51 | """Return a flattened view of the off-diagonal elements of a square matrix. 52 | """ 53 | n, m = x.shape 54 | assert n == m 55 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 56 | 57 | 58 | class BarlowTwinsSimCLR(nn.Module): 59 | """Barlow Twins or SimCLR encoder for contrastive learning. 60 | """ 61 | def __init__(self, hp, device='cuda', lm='roberta'): 62 | super().__init__() 63 | self.hp = hp 64 | self.bert = AutoModel.from_pretrained(lm_mp[lm]) 65 | self.device = device 66 | hidden_size = 768 67 | 68 | # projector 69 | self.projector = nn.Linear(hidden_size, hp.projector) 70 | 71 | # normalization layer for the representations z1 and z2 72 | self.bn = nn.BatchNorm1d(hidden_size, affine=False) 73 | 74 | # a fully connected layer for fine tuning 75 | self.fc = nn.Linear(hidden_size * 2, 2) 76 | 77 | # contrastive 78 | self.criterion = nn.CrossEntropyLoss().to(device) 79 | 80 | # cls token id 81 | self.cls_token_id = AutoTokenizer.from_pretrained(lm_mp[lm]).cls_token_id 82 | 83 | 84 | def info_nce_loss(self, features, 85 | batch_size, 86 | n_views, 87 | temperature=0.07): 88 | """Copied from https://github.com/sthalles/SimCLR/blob/master/simclr.py 89 | """ 90 | labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0) 91 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() 92 | labels = labels.to(self.device) 93 | 94 | features = F.normalize(features, dim=1) 95 | 96 | similarity_matrix = torch.matmul(features, features.T) 97 | 98 | # discard the main diagonal from both: labels and similarities matrix 99 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device) 100 | labels = labels[~mask].view(labels.shape[0], -1) 101 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 102 | # assert similarity_matrix.shape == labels.shape 103 | 104 | # select and combine multiple positives 105 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) 106 | 107 | # select only the negatives the negatives 108 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 109 | 110 | logits = torch.cat([positives, negatives], dim=1) 111 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device) 112 | 113 | logits = logits / temperature 114 | return logits, labels 115 | 116 | def _extract_columns(self, x, z, cls_indices=None): 117 | """Helper function for extracting column vectors from LM outputs. 118 | """ 119 | x_flat = x.view(-1) 120 | column_vectors = z.view((x_flat.shape[0], -1)) 121 | 122 | if cls_indices is None: 123 | indices = [idx for idx, token_id in enumerate(x_flat) \ 124 | if token_id == self.cls_token_id] 125 | else: 126 | indices = [] 127 | seq_len = x.shape[-1] 128 | for rid in range(len(cls_indices)): 129 | indices += [idx + rid * seq_len for idx in cls_indices[rid]] 130 | 131 | return column_vectors[indices] 132 | 133 | 134 | def inference(self, x): 135 | """Apply the model on a serialized table. 136 | 137 | Args: 138 | x (LongTensor): a batch of serialized tables 139 | 140 | Returns: 141 | Tensor: the column vectors for all tables 142 | """ 143 | x = x.to(self.device) 144 | z = self.bert(x)[0] 145 | z = self.projector(z) # optional 146 | return self._extract_columns(x, z) 147 | 148 | 149 | def forward(self, x_ori, x_aug, cls_indices, mode="simclr"): 150 | """Apply the model for contrastive learning. 151 | 152 | Args: 153 | x_ori (LongTensor): the first views of a batch of tables 154 | x_aug (LongTensor): the second views of a batch of tables 155 | cls_indices (tuple of List): the cls_token alignment 156 | mode (str, optional): the name of the contrastive learning algorithm 157 | 158 | Returns: 159 | Tensor: the loss 160 | """ 161 | if mode in ["simclr", "barlow_twins"]: 162 | # pre-training 163 | # encode 164 | batch_size = len(x_ori) 165 | x_ori = x_ori.to(self.device) # original, (batch_size, seq_len) 166 | x_aug = x_aug.to(self.device) # augment, (batch_size, seq_len) 167 | 168 | # encode each table (all columns) 169 | x = torch.cat((x_ori, x_aug)) # (2*batch_size, seq_len) 170 | z = self.bert(x)[0] # (2*batch_size, seq_len, hidden_size) 171 | 172 | # assert that x_ori and x_aug have the same number of columns 173 | z_ori = z[:batch_size] # (batch_size, seq_len, hidden_size) 174 | z_aug = z[batch_size:] # (batch_size, seq_len, hidden_size) 175 | 176 | cls_ori, cls_aug = cls_indices 177 | 178 | z_ori = self._extract_columns(x_ori, z_ori, cls_ori) # (total_num_columns, hidden_size) 179 | z_aug = self._extract_columns(x_aug, z_aug, cls_aug) # (total_num_columns, hidden_size) 180 | assert z_ori.shape == z_aug.shape 181 | 182 | z = torch.cat((z_ori, z_aug)) 183 | z = self.projector(z) # (2*total_num_columns, projector_size) 184 | 185 | if mode == "simclr": 186 | # simclr 187 | logits, labels = self.info_nce_loss(z, len(z) // 2, 2) 188 | loss = self.criterion(logits, labels) 189 | return loss 190 | elif mode == "barlow_twins": 191 | # barlow twins 192 | z1 = z[:len(z) // 2] 193 | z2 = z[len(z) // 2:] 194 | 195 | # empirical cross-correlation matrix 196 | c = (self.bn(z1).T @ self.bn(z2)) / (len(z1)) 197 | 198 | # use --scale-loss to multiply the loss by a constant factor 199 | on_diag = ((torch.diagonal(c) - 1) ** 2).sum() * self.hp.scale_loss 200 | off_diag = (off_diagonal(c) ** 2).sum() * self.hp.scale_loss 201 | loss = on_diag + self.hp.lambd * off_diag 202 | return loss 203 | elif mode == "finetune": 204 | pass 205 | # TODO 206 | # x1 = x1.to(self.device) # (batch_size, seq_len) 207 | # x2 = x2.to(self.device) # (batch_size, seq_len) 208 | # x12 = x12.to(self.device) # (batch_size, seq_len) 209 | # # left+right 210 | # enc_pair = self.projector(self.bert(x12)[0][:, 0, :]) # (batch_size, emb_size) 211 | # batch_size = len(x1) 212 | 213 | # # left and right 214 | # enc = self.projector(self.bert(torch.cat((x1, x2)))[0][:, 0, :]) 215 | # #enc = self.bert(torch.cat((x1, x2)))[0][:, 0, :] 216 | # enc1 = enc[:batch_size] # (batch_size, emb_size) 217 | # enc2 = enc[batch_size:] # (batch_size, emb_size) 218 | 219 | # return self.fc(torch.cat((enc_pair, (enc1 - enc2).abs()), dim=1)) 220 | -------------------------------------------------------------------------------- /sdd/preprocessor.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import math 3 | import collections 4 | import string 5 | from pandas.api.types import infer_dtype 6 | 7 | def computeTfIdf(tableDf): 8 | """ Compute tfIdf of each column independently 9 | Called by _tokenize() method in dataset.py 10 | Args: 11 | table (DataFrame): input table 12 | Return: tfIdf dict containing tfIdf scores for all columns 13 | """ 14 | # tfidf that considers each column (document) independently 15 | def computeTf(wordDict, doc): 16 | # input doc is a list 17 | tfDict = {} 18 | docCount = len(doc) 19 | for word, count in wordDict.items(): 20 | tfDict[word] = count / float(docCount) 21 | return tfDict 22 | def computeIdf(docList): 23 | idfDict = {} 24 | N = len(docList) 25 | idfDict = dict.fromkeys(docList.keys(), 0) 26 | for word, val in docList.items(): 27 | if val > 0: 28 | idfDict[word] += 1 29 | for word, val in idfDict.items(): 30 | idfDict[word] = math.log10(N / float(val)) 31 | return idfDict 32 | idf = {} 33 | for column in tableDf.columns: 34 | colVals = [val for entity in tableDf[column] for val in str(entity).split(' ')] 35 | wordSet = set(colVals) 36 | wordDict = dict.fromkeys(wordSet, 0) 37 | for val in colVals: 38 | wordDict[str(val)] += 1 39 | idf.update(computeIdf(wordDict)) 40 | return idf 41 | 42 | 43 | def pmiSample(val_counts, table, colIdxs, currIdx, max_tokens): 44 | """ Compute PMI of pairs of columns (one of which is the topic column) 45 | Used in pmi sampling 46 | Args: 47 | val_counts (dict): stores the count of each (topic value, property value), topic value, and property value 48 | table (DataFrame): input table 49 | colIdxs (list): list of column indexes using column headers 50 | currIdx: current column index 51 | max_tokens: maximum tokens from pretrain arguments 52 | Return: list of sampled tokens for this column 53 | """ 54 | tokens = [] 55 | valPairs = [] 56 | topicCol = table[colIdxs[0]] 57 | PMIs = {} 58 | for i in range(topicCol.shape[0]): 59 | topicVal = topicCol[i] 60 | propVal = table.at[i, currIdx] 61 | if (topicVal, propVal) in val_counts and topicVal in val_counts and propVal in val_counts: 62 | pair_pmi = val_counts[(topicVal, propVal)] / (val_counts[topicVal] * val_counts[propVal]) 63 | PMIs[(topicVal, propVal)] = pair_pmi 64 | PMIs = {k: v for k, v in sorted(PMIs.items(), key=lambda item: item[1], reverse=True)} 65 | if colIdxs.index(currIdx) == 0: 66 | valPairs = [k[0] for k in PMIs.keys()] 67 | else: 68 | valPairs = [k[1] for k in PMIs.keys()] 69 | for val in valPairs: 70 | for v in str(val).split(' '): 71 | if v not in tokens: 72 | tokens.append(v) 73 | if len(tokens) >= max_tokens: 74 | break 75 | return tokens 76 | 77 | 78 | def constantSample(colVals, max_tokens): 79 | '''Helper for preprocess() for constant sampling: take nth elements of column 80 | For sampling method 'constant' 81 | Args: 82 | colVals: list of tokens in each entity in the column 83 | max_tokens: maximum tokens specified in pretrain argument 84 | Return: 85 | list of tokens, such that list is of length max_tokens 86 | ''' 87 | step = math.ceil(len(colVals) / max_tokens) 88 | tokens = colVals[::step] 89 | while len(tokens) > max_tokens: 90 | step += 1 91 | tokens = colVals[::step] 92 | return tokens 93 | 94 | def frequentSample(colVals, max_tokens): 95 | '''Frequent sampling: Take most frequently occuring tokens 96 | For sampling method 'frequent' 97 | Args: 98 | colVals: list of tokens in each entity in the column 99 | max_tokens: maximum tokens specified in pretrain argument 100 | Return list of tokens 101 | ''' 102 | tokens, tokenFreq = [], {} 103 | tokenFreq = collections.Counter(colVals) 104 | tokenFreq = {k: v for k, v in sorted(tokenFreq.items(), key=lambda item: item[1], reverse=True)[:max_tokens]} 105 | for t in colVals: 106 | if t in tokenFreq and t not in tokens: 107 | tokens.append(t) 108 | return tokens 109 | 110 | def tfidfSample(column, tfidfDict, method, max_tokens): 111 | '''TFIDF sampling: Take tokens with highest idf scores 112 | For sampling methods 'tfidf_token', 'tfidf_entity' 113 | Args: 114 | column (pandas Series): current column from input table DataFrame 115 | tfidfDict (dict): dict with tfidf scores for each column, created in _tokenize() 116 | method (str): sampling method ('tfidf_token', 'tfidf_entity') 117 | max_tokens: maximum tokens specified in pretrain argument 118 | Return list of tokens 119 | ''' 120 | tokens, tokenList, tokenFreq = [], [], {} 121 | if method == "tfidf_token": 122 | # token level 123 | for colVal in column.unique(): 124 | for val in str(colVal).split(' '): 125 | idf = tfidfDict[val] 126 | tokenFreq[val] = idf 127 | tokenList.append(val) 128 | tokenFreq = {k: v for k, v in sorted(tokenFreq.items(), key=lambda item: item[1], reverse=True)[:max_tokens]} 129 | for t in tokenList: 130 | if t in tokenFreq and t not in tokens: 131 | tokens.append(t) 132 | 133 | elif method == "tfidf_entity": 134 | # entity level 135 | for colVal in column.unique(): 136 | valIdfs = [] 137 | for val in str(colVal).split(' '): 138 | valIdfs.append(tfidfDict[val]) 139 | idf = sum(valIdfs)/len(valIdfs) 140 | tokenFreq[colVal] = idf 141 | tokenList.append(colVal) 142 | tokenFreq = {k: v for k, v in sorted(tokenFreq.items(), key=lambda item: item[1], reverse=True)} 143 | valCount, N = 0, 0 144 | for entity in tokenFreq: 145 | valCount += len(str(entity).split(' ')) 146 | if valCount < max_tokens: N += 1 147 | tokenFreq = {k: tokenFreq[k] for k in list(tokenFreq)[:N]} 148 | for t in tokenList: 149 | if t in tokenFreq and t not in tokens: 150 | tokens += str(t).split(' ') 151 | return tokens 152 | 153 | 154 | def tfidfRowSample(table, tfidfDict, max_tokens): 155 | '''TFIDF sampling: Take rows with tokens that have highest idf scores 156 | For sampling method 'tfidf_row' 157 | Called in _tokenize() method in dataset.py 158 | Args: 159 | table (DataFrame): input table 160 | tfidfDict (dict): dict with tfidf scores for each column, created in _tokenize() 161 | max_tokens: maximum tokens specified in pretrain argument 162 | Return table with sampled rows using tfidf 163 | ''' 164 | tokenFreq = {} 165 | sortedRowInds = [] 166 | for row in table.itertuples(): 167 | index = row.Index 168 | valIdfs = [] 169 | rowVals = [val for entity in list(row[1:]) for val in str(entity).split(' ')] 170 | for val in rowVals: 171 | valIdfs.append(tfidfDict[val]) 172 | idf = sum(valIdfs)/len(valIdfs) 173 | tokenFreq[index] = idf 174 | tokenFreq = {k: v for k, v in sorted(tokenFreq.items(), key=lambda item: item[1], reverse=True)} 175 | sortedRowInds = list(tokenFreq.keys())[:max_tokens] 176 | table = table.reindex(sortedRowInds) 177 | return table 178 | 179 | def preprocess(column: pd.Series, tfidfDict: dict, max_tokens: int, method: str): 180 | '''Preprocess a column into list of max_tokens number of tokens 181 | Possible methods = "head", "alphaHead", "random", "constant", "frequent", "tfidf_token", "tfidf_entity", "tfidf_row" 182 | Args: 183 | column (pandas Series): current column from input table DataFrame 184 | tfidfDict (dict): dict with tfidf scores for each column, created in _tokenize() 185 | max_tokens: maximum tokens specified in pretrain argument 186 | method (str): sampling method from list of possible methods 187 | Returns list of sampled tokens 188 | ''' 189 | tokens = [] 190 | colVals = [val for entity in column for val in str(entity).split(' ')] 191 | if method == "head" or method == "tfidf_row": 192 | for val in colVals: 193 | if val not in tokens: 194 | tokens.append(val) 195 | if len(tokens) >= max_tokens: 196 | break 197 | elif method == "alphaHead": 198 | if 'mixed' in infer_dtype(column): 199 | column = column.astype(str) 200 | sortedCol = column.sort_values() 201 | sortedColVals = [str(val).lower() for entity in sortedCol for val in str(entity).split(' ')] 202 | for val in sortedColVals: 203 | if val not in tokens: 204 | tokens.append(val) 205 | if len(tokens) >= max_tokens: 206 | break 207 | elif method == "random": 208 | tokens = pd.Series(colVals).sample(min(len(colVals),max_tokens)).sort_index().tolist() 209 | elif method == "constant": 210 | tokens = constantSample(colVals, max_tokens) 211 | elif method == "frequent": 212 | tokens = frequentSample(colVals, max_tokens) 213 | elif "tfidf" in method and method != "tfidf_row": 214 | tokens = tfidfSample(column, tfidfDict, method, max_tokens) 215 | return tokens -------------------------------------------------------------------------------- /sdd/pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | import sklearn.metrics as metrics 7 | import mlflow 8 | import pandas as pd 9 | import os 10 | 11 | from .utils import evaluate_column_matching, evaluate_clustering 12 | from .model import BarlowTwinsSimCLR 13 | from .dataset import PretrainTableDataset 14 | 15 | from tqdm import tqdm 16 | from torch.utils import data 17 | from transformers import AdamW, get_linear_schedule_with_warmup 18 | from typing import List 19 | 20 | 21 | def train_step(train_iter, model, optimizer, scheduler, scaler, hp): 22 | """Perform a single training step 23 | 24 | Args: 25 | train_iter (Iterator): the train data loader 26 | model (BarlowTwinsSimCLR): the model 27 | optimizer (Optimizer): the optimizer (Adam or AdamW) 28 | scheduler (LRScheduler): learning rate scheduler 29 | scaler (GradScaler): gradient scaler for fp16 training 30 | hp (Namespace): other hyper-parameters (e.g., fp16) 31 | 32 | Returns: 33 | None 34 | """ 35 | for i, batch in enumerate(train_iter): 36 | x_ori, x_aug, cls_indices = batch 37 | optimizer.zero_grad() 38 | 39 | if hp.fp16: 40 | with torch.cuda.amp.autocast(): 41 | loss = model(x_ori, x_aug, cls_indices, mode='simclr') 42 | scaler.scale(loss).backward() 43 | scaler.step(optimizer) 44 | scaler.update() 45 | else: 46 | loss = model(x_ori, x_aug, cls_indices, mode='simclr') 47 | loss.backward() 48 | optimizer.step() 49 | 50 | scheduler.step() 51 | if i % 10 == 0: # monitoring 52 | print(f"step: {i}, loss: {loss.item()}") 53 | del loss 54 | 55 | 56 | def train(trainset, hp): 57 | """Train and evaluate the model 58 | 59 | Args: 60 | trainset (PretrainTableDataset): the training set 61 | hp (Namespace): Hyper-parameters (e.g., batch_size, 62 | learning rate, fp16) 63 | Returns: 64 | The pre-trained table model 65 | """ 66 | padder = trainset.pad 67 | # create the DataLoaders 68 | train_iter = data.DataLoader(dataset=trainset, 69 | batch_size=hp.batch_size, 70 | shuffle=True, 71 | num_workers=0, 72 | collate_fn=padder) 73 | 74 | # initialize model, optimizer, and LR scheduler 75 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 76 | model = BarlowTwinsSimCLR(hp, device=device, lm=hp.lm) 77 | model = model.cuda() 78 | optimizer = AdamW(model.parameters(), lr=hp.lr) 79 | if hp.fp16: 80 | scaler = torch.cuda.amp.GradScaler() 81 | else: 82 | scaler = None 83 | 84 | num_steps = (len(trainset) // hp.batch_size) * hp.n_epochs 85 | scheduler = get_linear_schedule_with_warmup(optimizer, 86 | num_warmup_steps=0, 87 | num_training_steps=num_steps) 88 | 89 | for epoch in range(1, hp.n_epochs+1): 90 | # train 91 | model.train() 92 | train_step(train_iter, model, optimizer, scheduler, scaler, hp) 93 | 94 | # save the last checkpoint 95 | if hp.save_model and epoch == hp.n_epochs: 96 | directory = os.path.join(hp.logdir, hp.task) 97 | if not os.path.exists(directory): 98 | os.makedirs(directory) 99 | 100 | # save the checkpoints for each component 101 | if hp.single_column: 102 | ckpt_path = os.path.join(hp.logdir, hp.task, 'model_'+str(hp.augment_op)+'_'+str(hp.sample_meth)+'_'+str(hp.table_order)+'_'+str(hp.run_id)+'singleCol.pt') 103 | else: 104 | ckpt_path = os.path.join(hp.logdir, hp.task, 'model_'+str(hp.augment_op)+'_'+str(hp.sample_meth)+'_'+str(hp.table_order)+'_'+str(hp.run_id)+'.pt') 105 | 106 | ckpt = {'model': model.state_dict(), 107 | 'hp': hp} 108 | torch.save(ckpt, ckpt_path) 109 | 110 | # test loading checkpoints 111 | # load_checkpoint(ckpt_path) 112 | # intrinsic evaluation with column matching 113 | if hp.task in ['small', 'large']: 114 | # Train column matching models using the learned representations 115 | metrics_dict = evaluate_pretrain(model, trainset) 116 | 117 | # log metrics 118 | mlflow.log_metrics(metrics_dict) 119 | 120 | print("epoch %d: " % epoch + ", ".join(["%s=%f" % (k, v) \ 121 | for k, v in metrics_dict.items()])) 122 | 123 | # evaluate on column clustering 124 | if hp.task in ['viznet']: 125 | # Train column matching models using the learned representations 126 | metrics_dict = evaluate_column_clustering(model, trainset) 127 | 128 | # log metrics 129 | mlflow.log_metrics(metrics_dict) 130 | print("epoch %d: " % epoch + ", ".join(["%s=%f" % (k, v) \ 131 | for k, v in metrics_dict.items()])) 132 | 133 | 134 | 135 | def inference_on_tables(tables: List[pd.DataFrame], 136 | model: BarlowTwinsSimCLR, 137 | unlabeled: PretrainTableDataset, 138 | batch_size=128, 139 | total=None): 140 | """Extract column vectors from a table. 141 | 142 | Args: 143 | tables (List of DataFrame): the list of tables 144 | model (BarlowTwinsSimCLR): the model to be evaluated 145 | unlabeled (PretrainTableDataset): the unlabeled dataset 146 | batch_size (optional): batch size for model inference 147 | 148 | Returns: 149 | List of np.array: the column vectors 150 | """ 151 | total=total if total is not None else len(tables) 152 | batch = [] 153 | results = [] 154 | for tid, table in tqdm(enumerate(tables), total=total): 155 | x, _ = unlabeled._tokenize(table) 156 | 157 | batch.append((x, x, [])) 158 | if tid == total - 1 or len(batch) == batch_size: 159 | # model inference 160 | with torch.no_grad(): 161 | x, _, _ = unlabeled.pad(batch) 162 | # all column vectors in the batch 163 | column_vectors = model.inference(x) 164 | ptr = 0 165 | for xi in x: 166 | current = [] 167 | for token_id in xi: 168 | if token_id == unlabeled.tokenizer.cls_token_id: 169 | current.append(column_vectors[ptr].cpu().numpy()) 170 | ptr += 1 171 | results.append(current) 172 | 173 | batch.clear() 174 | 175 | return results 176 | 177 | 178 | def load_checkpoint(ckpt): 179 | """Load a model from a checkpoint. 180 | ** If you would like to run your own benchmark, update the ds_path here 181 | Args: 182 | ckpt (str): the model checkpoint. 183 | 184 | Returns: 185 | BarlowTwinsSimCLR: the pre-trained model 186 | PretrainDataset: the dataset for pre-training the model 187 | """ 188 | hp = ckpt['hp'] 189 | 190 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 191 | print(device) 192 | model = BarlowTwinsSimCLR(hp, device=device, lm=hp.lm) 193 | model = model.to(device) 194 | model.load_state_dict(ckpt['model']) 195 | 196 | # dataset paths, depending on benchmark for the current task 197 | ds_path = 'data/santos/datalake' 198 | if hp.task == "santosLarge": 199 | # Change the data paths to where the benchmarks are stored 200 | ds_path = 'data/santos-benchmark/real-benchmark/datalake' 201 | elif hp.task == "tus": 202 | ds_path = 'data/table-union-search-benchmark/small/benchmark' 203 | elif hp.task == "tusLarge": 204 | ds_path = 'data/table-union-search-benchmark/large/benchmark' 205 | elif hp.task == "wdc": 206 | ds_path = 'data/wdc/0' 207 | dataset = PretrainTableDataset.from_hp(ds_path, hp) 208 | 209 | return model, dataset 210 | 211 | 212 | def evaluate_pretrain(model: BarlowTwinsSimCLR, 213 | unlabeled: PretrainTableDataset): 214 | """Evaluate pre-trained model. 215 | 216 | Args: 217 | model (BarlowTwinsSimCLR): the model to be evaluated 218 | unlabeled (PretrainTableDataset): the unlabeled dataset 219 | 220 | Returns: 221 | Dict: the dictionary of metrics (e.g., valid_f1) 222 | """ 223 | table_path = 'data/%s/tables' % model.hp.task 224 | 225 | # encode each dataset 226 | featurized_datasets = [] 227 | for dataset in ["train", "valid", "test"]: 228 | ds_path = 'data/%s/%s.csv' % (model.hp.task, dataset) 229 | ds = pd.read_csv(ds_path) 230 | 231 | def encode_tables(table_ids, column_ids): 232 | tables = [] 233 | for table_id, col_id in zip(table_ids, column_ids): 234 | table = pd.read_csv(os.path.join(table_path, \ 235 | "table_%d.csv" % table_id)) 236 | if model.hp.single_column: 237 | table = table[[table.columns[col_id]]] 238 | tables.append(table) 239 | vectors = inference_on_tables(tables, model, unlabeled, 240 | batch_size=128) 241 | 242 | # assert all columns exist 243 | for vec, table in zip(vectors, tables): 244 | assert len(vec) == len(table.columns) 245 | 246 | res = [] 247 | for vec, cid in zip(vectors, column_ids): 248 | if cid < len(vec): 249 | res.append(vec[cid]) 250 | else: 251 | # single column 252 | res.append(vec[-1]) 253 | return res 254 | 255 | # left tables 256 | l_features = encode_tables(ds['l_table_id'], ds['l_column_id']) 257 | 258 | # right tables 259 | r_features = encode_tables(ds['r_table_id'], ds['r_column_id']) 260 | 261 | features = [] 262 | Y = ds['match'] 263 | for l, r in zip(l_features, r_features): 264 | feat = np.concatenate((l, r, np.abs(l - r))) 265 | features.append(feat) 266 | 267 | featurized_datasets.append((features, Y)) 268 | 269 | train, valid, test = featurized_datasets 270 | return evaluate_column_matching(train, valid, test) 271 | 272 | 273 | def evaluate_column_clustering(model: BarlowTwinsSimCLR, 274 | unlabeled: PretrainTableDataset): 275 | """Evaluate pre-trained model on a column clustering dataset. 276 | 277 | Args: 278 | model (BarlowTwinsSimCLR): the model to be evaluated 279 | unlabeled (PretrainTableDataset): the unlabeled dataset 280 | 281 | Returns: 282 | Dict: the dictionary of metrics (e.g., purity, number of clusters) 283 | """ 284 | table_path = 'data/%s/tables' % model.hp.task 285 | 286 | # encode each dataset 287 | featurized_datasets = [] 288 | ds_path = 'data/%s/test.csv' % model.hp.task 289 | ds = pd.read_csv(ds_path) 290 | table_ids, column_ids = ds['table_id'], ds['column_id'] 291 | 292 | # encode all tables 293 | def table_iter(): 294 | for table_id, col_id in zip(table_ids, column_ids): 295 | table = pd.read_csv(os.path.join(table_path, \ 296 | "table_%d.csv" % table_id)) 297 | if model.hp.single_column: 298 | table = table[[table.columns[col_id]]] 299 | yield table 300 | 301 | vectors = inference_on_tables(table_iter(), model, unlabeled, 302 | batch_size=128, total=len(table_ids)) 303 | 304 | # # assert all columns exist 305 | # for vec, table in zip(vectors, tables): 306 | # assert len(vec) == len(table.columns) 307 | 308 | column_vectors = [] 309 | for vec, cid in zip(vectors, column_ids): 310 | if cid < len(vec): 311 | column_vectors.append(vec[cid]) 312 | else: 313 | # single column 314 | column_vectors.append(vec[-1]) 315 | 316 | return evaluate_clustering(column_vectors, ds['class']) 317 | -------------------------------------------------------------------------------- /sdd/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import sklearn.metrics as metrics 4 | 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.svm import LinearSVC 7 | from sklearn.ensemble import GradientBoostingClassifier 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.pipeline import make_pipeline 10 | from sklearn.preprocessing import StandardScaler 11 | from xgboost import XGBClassifier 12 | from tqdm import tqdm 13 | from collections import deque, Counter 14 | 15 | 16 | def evaluate(model, iterator, threshold=None): 17 | """Evaluate a model on a validation/test dataset 18 | 19 | Args: 20 | model (TableModel): the EM model 21 | iterator (Iterator): the valid/test dataset iterator 22 | threshold (float, optional): the threshold on the 0-class 23 | 24 | Returns: 25 | float: the F1 score 26 | float (optional): if threshold is not provided, the threshold 27 | value that gives the optimal F1 28 | """ 29 | all_p = [] 30 | all_y = [] 31 | all_probs = [] 32 | with torch.no_grad(): 33 | for batch in iterator: 34 | if len(batch) == 4: 35 | x1, x2, x12, y = batch 36 | logits = model(x1, x2, x12) 37 | else: 38 | x, y = batch 39 | logits = model(x) 40 | 41 | # print(probs) 42 | probs = logits.softmax(dim=1)[:, 1] 43 | 44 | # print(logits) 45 | # pred = logits.argmax(dim=1) 46 | all_probs += probs.cpu().numpy().tolist() 47 | # all_p += pred.cpu().numpy().tolist() 48 | all_y += y.cpu().numpy().tolist() 49 | 50 | if threshold is not None: 51 | pred = [1 if p > threshold else 0 for p in all_probs] 52 | f1 = metrics.f1_score(all_y, pred) 53 | return f1 54 | else: 55 | best_th = 0.5 56 | f1 = 0.0 # metrics.f1_score(all_y, all_p) 57 | 58 | for th in np.arange(0.0, 1.0, 0.05): 59 | pred = [1 if p > th else 0 for p in all_probs] 60 | new_f1 = metrics.f1_score(all_y, pred) 61 | if new_f1 > f1: 62 | f1 = new_f1 63 | best_th = th 64 | 65 | return f1, best_th 66 | 67 | 68 | 69 | def evaluate_column_matching(train, valid, test): 70 | """Run classification algorithms on feature vectors. 71 | """ 72 | # datasets = pickle.load(open(feature_path, "rb")) 73 | # train, valid, test = datasets 74 | 75 | ml_models = { 76 | "LR": LogisticRegression, 77 | "SVM": LinearSVC, 78 | "GB": XGBClassifier, # GradientBoostingClassifier, 79 | "RF": RandomForestClassifier 80 | } 81 | 82 | mname = "GB" 83 | Model = ml_models[mname] 84 | 85 | # standardization 86 | pipe = make_pipeline(StandardScaler(), Model()) 87 | 88 | # training 89 | pipe.fit(np.nan_to_num(train[0]), train[1]) 90 | 91 | # eval 92 | results = {} 93 | for ds, ds_name in zip([valid, test], ['valid', 'test']): 94 | X, y = ds 95 | y_pred = pipe.predict(np.nan_to_num(X)) 96 | f1 = metrics.f1_score(y, y_pred) 97 | p = metrics.precision_score(y, y_pred) 98 | r = metrics.recall_score(y, y_pred) 99 | 100 | for var in ["f1", "p", "r"]: 101 | results[ds_name + "_" + var] = eval(var) 102 | 103 | return results 104 | 105 | 106 | def blocked_matmul(mata, matb, 107 | threshold=None, 108 | k=None, 109 | batch_size=512): 110 | """Find the most similar pairs of vectors from two matrices (top-k or threshold) 111 | 112 | Args: 113 | mata (np.ndarray): the first matrix 114 | matb (np.ndarray): the second matrix 115 | threshold (float, optional): if set, return all pairs of cosine 116 | similarity above the threshold 117 | k (int, optional): if set, return for each row in matb the top-k 118 | most similar vectors in mata 119 | batch_size (int, optional): the batch size of each block 120 | 121 | Returns: 122 | list of tuples: the pairs of similar vectors' indices and the similarity 123 | """ 124 | mata = np.array(mata) 125 | matb = np.array(matb) 126 | results = [] 127 | for start in tqdm(range(0, len(matb), batch_size)): 128 | block = matb[start:start+batch_size] 129 | sim_mat = np.matmul(mata, block.transpose()) 130 | if k is not None: 131 | indices = np.argpartition(-sim_mat, k, axis=0) 132 | for row in indices[:k]: 133 | for idx_b, idx_a in enumerate(row): 134 | idx_b += start 135 | results.append((idx_a, idx_b, sim_mat[idx_a][idx_b-start])) 136 | elif threshold is not None: 137 | indices = np.argwhere(sim_mat >= threshold) 138 | for idx_a, idx_b in indices: 139 | idx_b += start 140 | results.append((idx_a, idx_b, sim_mat[idx_a][idx_b-start])) 141 | return results 142 | 143 | 144 | def connected_components(pairs, cluster_size=50): 145 | """Helper function for computing the connected components 146 | """ 147 | edges = {} 148 | for left, right, _ in pairs: 149 | if left not in edges: 150 | edges[left] = [] 151 | if right not in edges: 152 | edges[right] = [] 153 | 154 | edges[left].append(right) 155 | edges[right].append(left) 156 | 157 | print('num nodes =', len(edges)) 158 | all_ccs = [] 159 | used = set([]) 160 | for start in edges: 161 | if start in used: 162 | continue 163 | used.add(start) 164 | cc = [start] 165 | 166 | queue = deque([start]) 167 | while len(queue) > 0: 168 | u = queue.popleft() 169 | for v in edges[u]: 170 | if v not in used: 171 | cc.append(v) 172 | used.add(v) 173 | queue.append(v) 174 | 175 | if len(cc) >= cluster_size: 176 | break 177 | 178 | all_ccs.append(cc) 179 | # print(cc) 180 | return all_ccs 181 | 182 | 183 | def evaluate_clustering(vectors, labels): 184 | """Evaluate column clustering on input column vectors. 185 | """ 186 | # top 20 matching columns 187 | pairs = blocked_matmul(vectors, vectors, 188 | k=20, 189 | batch_size=4096) 190 | 191 | # run column clustering algorithm 192 | ccs = connected_components(pairs) 193 | 194 | # compute purity 195 | purity = [] 196 | for cc in ccs: 197 | cnt = Counter() 198 | for column_id in cc: 199 | label = labels[column_id] 200 | cnt[label] += 1 201 | purity.append(cnt.most_common(1)[0][1] / len(cc)) 202 | purity = np.mean(purity) 203 | 204 | return {"num_clusters": len(ccs), 205 | "avg_cluster_size": np.mean([len(cc) for cc in ccs]), 206 | "purity": purity} 207 | -------------------------------------------------------------------------------- /starmie_overall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megagonlabs/starmie/5eb90fe27fb1162d2a62b555ac54908ee8e4c474/starmie_overall.jpg -------------------------------------------------------------------------------- /test_bounds.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from bounds import verify, upper_bound_bm, lower_bound_bm 4 | 5 | 6 | tables = pickle.load(open("sherlock_datalake.pkl","rb")) 7 | queries = pickle.load(open("sherlock_toyQuery.pkl","rb")) 8 | 9 | query = queries[4] 10 | threshold = 0.6 11 | for table in tables: 12 | lb = lower_bound_bm(table[1], query[1], threshold) 13 | ub = upper_bound_bm(table[1], query[1], threshold) 14 | true = verify(table[1], query[1], threshold) 15 | print("lower bound: ", lb,"upper bound: ", ub, "true value: ", true) -------------------------------------------------------------------------------- /test_hnsw_search.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import mlflow 4 | import argparse 5 | import time 6 | import numpy as np 7 | from hnsw_search import HNSWSearcher 8 | from checkPrecisionRecall import saveDictionaryAsPickleFile, calcMetrics 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--encoder", type=str, default="sato", choices=['sherlock', 'sato', 'cl', 'tapex']) 13 | parser.add_argument("--benchmark", type=str, default='santos') 14 | parser.add_argument("--run_id", type=int, default=0) 15 | parser.add_argument("--single_column", dest="single_column", action="store_true") 16 | parser.add_argument("--K", type=int, default=10) 17 | parser.add_argument("--scal", type=float, default=1.00) 18 | # parser.add_argument("--N", type=int, default=10) 19 | # parser.add_argument("--threshold", type=float, default=0.7) 20 | # mlflow tag 21 | parser.add_argument("--mlflow_tag", type=str, default=None) 22 | 23 | hp = parser.parse_args() 24 | 25 | # mlflow logging 26 | for variable in ["encoder", "benchmark", "single_column", "run_id", "K", "scal"]: 27 | mlflow.log_param(variable, getattr(hp, variable)) 28 | 29 | if hp.mlflow_tag: 30 | mlflow.set_tag("tag", hp.mlflow_tag) 31 | 32 | encoder = hp.encoder 33 | singleCol = hp.single_column 34 | 35 | dataFolder = hp.benchmark 36 | # Set augmentation operators, sampling methods, K, and threshold values according to the benchmark 37 | if 'santos' in dataFolder or dataFolder == 'wdc': 38 | sampAug = "drop_col_tfidf_entity" 39 | K = 10 40 | threshold = 0.7 41 | if dataFolder == 'santosLarge' or dataFolder == 'wdc': 42 | K, threshold = hp.K, 0.1 43 | elif "tus" in dataFolder: 44 | sampAug = "drop_cell_alphaHead" 45 | K = 60 46 | threshold = 0.1 47 | singSampAug = "drop_cell_tfidf_entity" 48 | 49 | # If we need to change the value of N, or change the filepath to the pkl files (including indexing), change here: 50 | # N: number of returned elements for each query column 51 | if encoder in ['sherlock', 'sato']: 52 | N = 50 53 | query_path = "data/"+dataFolder+"/"+encoder+"_query.pkl" 54 | table_path = "data/"+dataFolder+"/"+encoder+"_datalake.pkl" 55 | index_path = "data/"+dataFolder+"/indexes/hnsw_"+encoder+".bin" 56 | else: 57 | N = 4 58 | table_id = hp.run_id 59 | table_path = "data/"+dataFolder+"/vectors/cl_datalake_"+sampAug+"_column_"+str(table_id)+".pkl" 60 | query_path = "data/"+dataFolder+"/vectors/cl_query_"+sampAug+"_column_"+str(table_id)+".pkl" 61 | index_path = "data/"+dataFolder+"/indexes/hnsw_open_data_"+str(table_id)+"_"+str(hp.scal)+".bin" 62 | if singleCol: 63 | N = 50 64 | table_path = "data/"+dataFolder+"/vectors/cl_datalake_"+singSampAug+"_column_"+str(table_id)+"_singleCol.pkl" 65 | query_path = "data/"+dataFolder+"/vectors/cl_query_"+singSampAug+"_column_"+str(table_id)+"_singleCol.pkl" 66 | index_path = "data/"+dataFolder+"/indexes/hnsw_open_data_"+str(table_id)+"_singleCol.bin" 67 | 68 | # Call HNSWSearcher from hnsw_search.py 69 | searcher = HNSWSearcher(table_path, index_path, hp.scal) 70 | queries = pickle.load(open(query_path,"rb")) 71 | 72 | start_time = time.time() 73 | returnedResults = {} 74 | avgNumResults = [] 75 | query_times = [] 76 | 77 | for q in queries: 78 | query_start_time = time.time() 79 | res, scoreLength = searcher.topk(encoder,q,K, N=N,threshold=threshold) #N=10, 80 | returnedResults[q[0]] = [r[1] for r in res] 81 | avgNumResults.append(scoreLength) 82 | query_times.append(time.time() - query_start_time) 83 | 84 | print("Average number of Results: ", sum(avgNumResults)/len(avgNumResults)) 85 | print("Average QUERY TIME: %s seconds " % (sum(query_times)/len(query_times))) 86 | print("10th percentile: ", np.percentile(query_times, 10), " 90th percentile: ", np.percentile(query_times, 90)) 87 | print("--- Total Query Time: %s seconds ---" % (time.time() - start_time)) 88 | 89 | # santosLarge and WDC benchmarks are used for efficiency 90 | if hp.benchmark == 'santosLarge' or hp.benchmark == 'wdc': 91 | print("No groundtruth for %s benchmark" % (hp.benchmark)) 92 | else: 93 | # Calculating effectiveness scores (Change the paths to where the ground truths are stored) 94 | if 'santos' in hp.benchmark: 95 | k_range = 1 96 | groundTruth = "data/santos/santosUnionBenchmark.pickle" 97 | else: 98 | k_range = 10 99 | if hp.benchmark == 'tus': 100 | groundTruth = 'data/table-union-search-benchmark/small/tus-groundtruth/tusLabeledtusUnionBenchmark' 101 | elif hp.benchmark == 'tusLarge': 102 | groundTruth = 'data/table-union-search-benchmark/large/tus-groundtruth/tusLabeledtusLargeUnionBenchmark' 103 | 104 | calcMetrics(K, k_range, returnedResults, gtPath=groundTruth) 105 | -------------------------------------------------------------------------------- /test_lsh.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import time 4 | import argparse 5 | import mlflow 6 | import numpy as np 7 | from lsh_search import LSHSearcher 8 | from checkPrecisionRecall import saveDictionaryAsPickleFile, calcMetrics 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--encoder", type=str, default="sato", choices=['sherlock', 'sato', 'cl', 'tapex']) 13 | parser.add_argument("--benchmark", type=str, default='santos') 14 | parser.add_argument("--run_id", type=int, default=0) 15 | parser.add_argument("--single_column", dest="single_column", action="store_true") 16 | parser.add_argument("--num_func", type=int, default=16) 17 | parser.add_argument("--num_table", type=int, default=100) 18 | parser.add_argument("--K", type=int, default=10) 19 | parser.add_argument("--scal", type=float, default=1.00) 20 | # parser.add_argument("--N", type=int, default=10) 21 | # parser.add_argument("--threshold", type=float, default=0.7) 22 | # mlflow tag 23 | parser.add_argument("--mlflow_tag", type=str, default=None) 24 | 25 | hp = parser.parse_args() 26 | 27 | 28 | # mlflow logging 29 | for variable in ["encoder", "num_func", "num_table", "benchmark", "K", "run_id", "scal"]: 30 | mlflow.log_param(variable, getattr(hp, variable)) 31 | 32 | if hp.mlflow_tag: 33 | mlflow.set_tag("tag", hp.mlflow_tag) 34 | 35 | 36 | encoder = hp.encoder 37 | singleCol = hp.single_column 38 | 39 | dataFolder = hp.benchmark 40 | # Set augmentation operators, sampling methods, K, and threshold values according to the benchmark 41 | if 'santos' in dataFolder or dataFolder == 'wdc': 42 | sampAug = "drop_col_tfidf_entity" 43 | K = 10 44 | threshold = 0.7 45 | if dataFolder == 'santosLarge' or dataFolder == 'wdc': 46 | K, threshold = hp.K, 0.1 47 | elif 'tus' in dataFolder: 48 | sampAug = "drop_cell_alphaHead" 49 | K = 60 50 | threshold = 0.1 51 | singSampAug = "drop_cell_tfidf_entity" 52 | 53 | # If we need to change the value of N, or change the filepath to the pkl files (including indexing), change here: 54 | # N: number of returned elements for each query column 55 | if encoder in ['sherlock', 'sato']: 56 | N = 50 57 | query_path = "data/"+dataFolder+"/"+encoder+"_query.pkl" 58 | table_path = "data/"+dataFolder+"/"+encoder+"_datalake.pkl" 59 | index_path = "data/"+dataFolder+"/indexes/lsh_"+encoder+".bin" 60 | else: 61 | N = 4 62 | table_id = hp.run_id 63 | table_path = "data/"+dataFolder+"/vectors/cl_datalake_"+sampAug+"_column_"+str(table_id)+".pkl" 64 | query_path = "data/"+dataFolder+"/vectors/cl_query_"+sampAug+"_column_"+str(table_id)+".pkl" 65 | index_path = "data/"+dataFolder+"/indexes/lsh_open_data_"+str(table_id)+"_"+str(hp.scal)+".bin" 66 | if singleCol: 67 | N = 50 68 | table_path = "data/"+dataFolder+"/vectors/cl_datalake_"+singSampAug+"_column_"+str(table_id)+"_singleCol.pkl" 69 | query_path = "data/"+dataFolder+"/vectors/cl_query_"+singSampAug+"_column_"+str(table_id)+"_singleCol.pkl" 70 | index_path = "data/"+dataFolder+"/indexes/lsh_open_data_"+str(table_id)+"_singleCol.bin" 71 | 72 | num_hash_func = hp.num_func 73 | num_hash_table = hp.num_table 74 | # Call LSHSearcher from lsh_search.py 75 | searcher = LSHSearcher(table_path, num_hash_func, num_hash_table, hp.scal) 76 | # Load the query from the pickle file 77 | queries = pickle.load(open(query_path,"rb")) 78 | start_time = time.time() 79 | returnedResults = {} 80 | avgNumResults = [] 81 | query_times = [] 82 | 83 | for q in queries: 84 | query_start_time = time.time() 85 | res, numCalls = searcher.topk(encoder,q,K,N=N,threshold=threshold) 86 | returnedResults[q[0]] = [r[1] for r in res] 87 | avgNumResults.append(numCalls) 88 | query_times.append(time.time() - query_start_time) 89 | 90 | print("Average number of verification calls: ", sum(avgNumResults)/len(avgNumResults)) 91 | print("Average QUERY TIME: %s seconds " % (sum(query_times)/len(query_times))) 92 | print("10th percentile: ", np.percentile(query_times, 10), " 90th percentile: ", np.percentile(query_times, 90)) 93 | print("--- Total Query Time: %s seconds ---" % (time.time() - start_time)) 94 | 95 | # santosLarge and WDC benchmarks are used for efficiency 96 | if hp.benchmark == 'santosLarge' or hp.benchmark == 'wdc': 97 | print("No groundtruth for %s benchmark" % (hp.benchmark)) 98 | else: 99 | # Calculating effectiveness scores (Change the paths to where the ground truths are stored) 100 | if 'santos' in hp.benchmark: 101 | k_range = 1 102 | groundTruth = "data/santos/santosUnionBenchmark.pickle" 103 | else: 104 | k_range = 10 105 | if hp.benchmark == 'tus': 106 | groundTruth = 'data/table-union-search-benchmark/small/tus-groundtruth/tusLabeledtusUnionBenchmark' 107 | elif hp.benchmark == 'tusLarge': 108 | groundTruth = 'data/table-union-search-benchmark/large/tus-groundtruth/tusLabeledtusLargeUnionBenchmark' 109 | 110 | calcMetrics(K, k_range, returnedResults, gtPath=groundTruth) 111 | -------------------------------------------------------------------------------- /test_naive_search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import pickle 4 | import argparse 5 | import mlflow 6 | from naive_search import NaiveSearcher 7 | from checkPrecisionRecall import saveDictionaryAsPickleFile, calcMetrics 8 | import time 9 | 10 | def generate_random_table(nrow, ncol): 11 | return np.random.rand(nrow, ncol) 12 | 13 | def verify(table1, table2,threshold=0.6): 14 | score = 0.0 15 | nrow = len(table1) 16 | ncol = len(table2) 17 | graph = np.zeros(shape=(nrow,ncol),dtype=float) 18 | for i in range(nrow): 19 | for j in range(ncol): 20 | sim = cosine_sim(table1[i],table2[j]) 21 | if sim > threshold: 22 | graph[i,j] = sim 23 | max_graph = make_cost_matrix(graph, lambda cost: (graph.max() - cost) if (cost != DISALLOWED) else DISALLOWED) 24 | m = Munkres() 25 | indexes = m.compute(max_graph) 26 | for row,col in indexes: 27 | score += graph[row,col] 28 | return score,indexes 29 | 30 | def generate_test_data(num, ndim): 31 | # for test only: randomly generate tables and 2 queries 32 | # num: the number of tables in the dataset; ndim: dimension of column vectors 33 | tables = [] 34 | queries = [] 35 | for i in range(num): 36 | ncol = random.randint(2,9) 37 | tbl = generate_random_table(ncol, ndim) 38 | tables.append((i,tbl)) 39 | for j in range(2): 40 | ncol = random.randint(2,9) 41 | tbl = generate_random_table(ncol, ndim) 42 | queries.append((j+num,tbl)) 43 | return tables, queries 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--encoder", type=str, default="sato", choices=['sherlock', 'sato', 'cl', 'tapex']) 49 | parser.add_argument("--benchmark", type=str, default='santos') 50 | parser.add_argument("--augment_op", type=str, default="drop_col") 51 | parser.add_argument("--sample_meth", type=str, default="tfidf_entity") 52 | # matching is the type of matching 53 | parser.add_argument("--matching", type=str, default='exact') #exact or bounds (or greedy) 54 | parser.add_argument("--table_order", type=str, default="column") 55 | parser.add_argument("--run_id", type=int, default=0) 56 | parser.add_argument("--single_column", dest="single_column", action="store_true") 57 | # For error analysis 58 | parser.add_argument("--bucket", type=int, default=0) # the error analysis has 5 equally-sized buckets 59 | parser.add_argument("--analysis", type=str, default='col') # 'col', 'row', 'numeric' 60 | parser.add_argument("--K", type=int, default=10) 61 | parser.add_argument("--threshold", type=float, default=0.6) 62 | # For Scalability experiments 63 | parser.add_argument("--scal", type=float, default=1.00) 64 | # mlflow tag 65 | parser.add_argument("--mlflow_tag", type=str, default=None) 66 | 67 | hp = parser.parse_args() 68 | 69 | # mlflow logging 70 | for variable in ["encoder", "benchmark", "augment_op", "sample_meth", "matching", "table_order", "run_id", "single_column", "K", "threshold", "scal"]: 71 | mlflow.log_param(variable, getattr(hp, variable)) 72 | 73 | if hp.mlflow_tag: 74 | mlflow.set_tag("tag", hp.mlflow_tag) 75 | 76 | dataFolder = hp.benchmark 77 | 78 | # If the filepath to the pkl files are different, change here: 79 | if hp.encoder == 'cl': 80 | query_path = "data/"+dataFolder+"/vectors/"+hp.encoder+"_query_"+hp.augment_op+"_"+hp.sample_meth+"_"+hp.table_order+"_"+str(hp.run_id) 81 | table_path = "data/"+dataFolder+"/vectors/"+hp.encoder+"_datalake_"+hp.augment_op+"_"+hp.sample_meth+"_"+hp.table_order+"_"+str(hp.run_id) 82 | 83 | if hp.single_column: 84 | query_path += "_singleCol" 85 | table_path += "_singleCol" 86 | query_path += ".pkl" 87 | table_path += ".pkl" 88 | else: 89 | query_path = "data/"+dataFolder+"/"+hp.encoder+"_query.pkl" 90 | table_path = "data/"+dataFolder+"/"+hp.encoder+"_datalake.pkl" 91 | 92 | # Load the query file 93 | qfile = open(query_path,"rb") 94 | queries = pickle.load(qfile) 95 | print("Number of queries: %d" % (len(queries))) 96 | qfile.close() 97 | # Call NaiveSearcher, which has linear search and bounds search, from naive_search.py 98 | searcher = NaiveSearcher(table_path, hp.scal) 99 | returnedResults = {} 100 | start_time = time.time() 101 | # For error analysis of tables 102 | analysis = hp.analysis 103 | # bucketFile = open("data/"+dataFolder+"/buckets/query_"+analysis+"Bucket_"+str(hp.bucket)+".txt", "r") 104 | # bucket = bucketFile.read() 105 | queries.sort(key = lambda x: x[0]) 106 | query_times = [] 107 | qCount = 0 108 | 109 | for query in queries: 110 | qCount += 1 111 | if qCount % 10 == 0: 112 | print("Processing query ",qCount, " of ", len(queries), " total queries.") 113 | # if query[0] in bucket: 114 | query_start_time = time.time() 115 | if hp.matching == 'exact': 116 | qres = searcher.topk(hp.encoder, query, hp.K, threshold=hp.threshold) 117 | else: # Bounds matching 118 | qres = searcher.topk_bounds(hp.encoder, query, hp.K, threshold=hp.threshold) 119 | res = [] 120 | for tpl in qres: 121 | tmp = (tpl[0],tpl[1]) 122 | res.append(tmp) 123 | returnedResults[query[0]] = [r[1] for r in res] 124 | query_times.append(time.time() - query_start_time) 125 | 126 | print("Average QUERY TIME: %s seconds " % (sum(query_times)/len(query_times))) 127 | print("10th percentile: ", np.percentile(query_times, 10), " 90th percentile: ", np.percentile(query_times, 90)) 128 | print("--- Total Query Time: %s seconds ---" % (time.time() - start_time)) 129 | 130 | # santosLarge and WDC benchmarks are used for efficiency 131 | if hp.benchmark == 'santosLarge' or hp.benchmark == 'wdc': 132 | print("No groundtruth for %s benchmark" % (hp.benchmark)) 133 | else: 134 | # Calculating effectiveness scores (Change the paths to where the ground truths are stored) 135 | if 'santos' in hp.benchmark: 136 | k_range = 1 137 | groundTruth = "data/santos/santosUnionBenchmark.pickle" 138 | else: 139 | k_range = 10 140 | if hp.benchmark == 'tus': 141 | groundTruth = 'data/table-union-search-benchmark/small/tus-groundtruth/tusLabeledtusUnionBenchmark' 142 | elif hp.benchmark == 'tusLarge': 143 | groundTruth = 'data/table-union-search-benchmark/large/tus-groundtruth/tusLabeledtusLargeUnionBenchmark' 144 | 145 | calcMetrics(hp.K, k_range, returnedResults, gtPath=groundTruth) 146 | --------------------------------------------------------------------------------