├── README.md ├── requirements ├── requirements.txt └── requirements_cloudlab.txt └── src ├── classes ├── classes.py ├── paths.py ├── workload_runs.py └── workloads.py ├── conf ├── postgres │ ├── modified-postgresql10.conf │ ├── modified-postgresql11.conf │ ├── modified-postgresql12.conf │ ├── modified-postgresql13.conf │ ├── modified-postgresql14.conf │ ├── modified-postgresql16.conf │ └── pg_hba.conf └── zeroshot_hyperparameters │ ├── tune_best_config.json │ ├── tune_best_config_ablation_column_feats.json │ ├── tune_best_config_ablation_data_dist_feats.json │ ├── tune_best_config_ablation_operator_feats.json │ ├── tune_best_config_ablation_table_feats.json │ ├── tune_best_config_dec_ablation_all_feats.json │ ├── tune_best_config_dec_ablation_column_feats.json │ ├── tune_best_config_dec_ablation_operator_feats.json │ ├── tune_best_config_dec_ablation_table_feats.json │ ├── tune_deepdb_best_config.json │ └── tune_est_best_config.json ├── cross_db_benchmark ├── __init__.py ├── benchmark_tools │ ├── __init__.py │ ├── column_types.py │ ├── compare_runs.py │ ├── create_fk_indexes.py │ ├── database.py │ ├── drop_db.py │ ├── generate_column_stats.py │ ├── generate_string_statistics.py │ ├── generate_workload.py │ ├── get_table_lengths.py │ ├── join_conditions.py │ ├── load_database.py │ ├── parse_run.py │ ├── postgres │ │ ├── __init__.py │ │ ├── check_valid.py │ │ ├── combine_plans.py │ │ ├── compare_plan.py │ │ ├── database_connection.py │ │ ├── inflate_cardinality_errors.py │ │ ├── json_plan.py │ │ ├── parse_filter.py │ │ ├── parse_plan.py │ │ ├── plan_operator.py │ │ ├── run_workload.py │ │ └── utils.py │ ├── run_workload.py │ ├── tests │ │ └── test_hint_validation.py │ └── utils.py ├── datasets │ ├── accidents │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── airline │ │ ├── column_statistics.json │ │ ├── dataset_documentation │ │ │ ├── README.md │ │ │ └── script.py │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── baseball │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ ├── string_statistics.json │ │ └── table_lengths.json │ ├── basketball │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── carcinogenesis │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── consumer │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── credit │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── datasets.py │ ├── employee │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── fhnk │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── financial │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── geneea │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── genome │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── hepatitis │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── hockey │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── imdb │ │ ├── column_statistics.json │ │ ├── dataset_documentation │ │ │ ├── README.md │ │ │ └── script.py │ │ ├── schema.json │ │ ├── schema_sql │ │ │ └── postgres.sql │ │ ├── string_statistics.json │ │ └── table_lengths.json │ ├── imdb_full │ │ ├── column_statistics.json │ │ ├── dataset_documentation │ │ │ └── README.md │ │ ├── schema.json │ │ ├── schema_sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── movielens │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── seznam │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── ssb │ │ ├── column_statistics.json │ │ ├── dataset_documentation │ │ │ ├── README.md │ │ │ └── script.py │ │ ├── schema.json │ │ ├── schema_sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── tournament │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ │ ├── mysql.sql │ │ │ └── postgres.sql │ │ └── string_statistics.json │ ├── tpc_h │ │ ├── column_statistics.json │ │ ├── dataset_documentation │ │ │ ├── README.md │ │ │ └── script.py │ │ ├── schema.json │ │ ├── schema_sql │ │ │ └── postgres.sql │ │ ├── string_statistics.json │ │ └── table_lengths.json │ ├── tpc_h_pk │ │ ├── column_statistics.json │ │ ├── dataset_documentation │ │ │ ├── README.md │ │ │ └── script.py │ │ ├── schema.json │ │ ├── schema_sql │ │ │ └── postgres.sql │ │ ├── string_statistics.json │ │ └── table_lengths.json │ └── walmart │ │ ├── column_statistics.json │ │ ├── schema.json │ │ ├── schema_sql │ │ ├── mysql.sql │ │ └── postgres.sql │ │ └── string_statistics.json └── meta_tools │ ├── __init__.py │ ├── dataset_stats.py │ ├── derive.py │ ├── download_relational_fit.py │ ├── inflate_cardinality_errors.py │ ├── replace_aliases.py │ ├── scale_dataset.py │ └── slice_no_tables.py ├── deprecated ├── dataset_tools.py └── parse_all.py ├── evaluation ├── plots │ ├── 00_motivational_plot.ipynb │ ├── 01_join_order.ipynb │ ├── 02_join_order_pg_act_cards.ipynb │ ├── 03_access_path_selection.ipynb │ ├── 04_physical_plan_selection.ipynb │ ├── 05_analyze_joblight.ipynb │ ├── eval.py │ ├── evaluation_metrics.py │ └── utils.py └── workload_creation │ ├── create_evaluation_workloads.py │ ├── create_indexed_workloads.py │ ├── create_retraining_workloads.py │ └── test_workload_generation.py ├── experiments ├── data │ └── statistics │ │ ├── dataset_stats.csv │ │ ├── postgres_workload_driven_workload_stats.csv │ │ └── postgres_workload_stats.csv ├── evaluation │ └── utils.py ├── evaluation_workloads │ ├── generated │ │ └── workload_defs.py │ ├── imdb │ │ ├── job-light.sql │ │ ├── scale.sql │ │ ├── stripped_job-light.sql │ │ └── synthetic.sql │ ├── imdb_full │ │ └── job_full.sql │ ├── res │ │ ├── fkindexes.sql │ │ ├── job_full.sql │ │ ├── ssb_original.sql │ │ ├── tpc_h_original.sql │ │ └── tpc_h_subqueries_rewritten.sql │ ├── ssb │ │ └── benchmark.sql │ └── tpc_h │ │ └── benchmark.sql ├── postgres_workload_stats_with_sql.csv └── setup │ ├── postgres │ ├── run_workload_commands.py │ └── tune_hyperparameters.py │ └── utils.py ├── gather_feature_stats.py ├── main.py ├── models ├── dace │ ├── dace_dataset.py │ ├── dace_model.py │ └── dace_utils.py ├── qppnet │ ├── qppnet_dataloader.py │ └── qppnet_model.py ├── query_former │ ├── dataloader.py │ ├── model.py │ └── utils.py ├── tabular │ └── train_tabular_baseline.py ├── workload_driven │ ├── dataset │ │ ├── dataset_creation.py │ │ ├── mscn_batching.py │ │ └── plan_tree_batching.py │ ├── model │ │ ├── e2e_model.py │ │ ├── mscn_model.py │ │ └── tree_lstm.py │ ├── preprocessing │ │ ├── sample_vectors.py │ │ ├── sentence_creation.py │ │ └── word_embeddings.py │ └── tests │ │ ├── mscn │ │ └── test_message_passing.py │ │ └── plan_model │ │ └── test_message_passing.py └── zeroshot │ ├── message_aggregators │ ├── aggregator.py │ ├── gat.py │ ├── message_aggregators.py │ ├── mscn.py │ └── pooling.py │ ├── postgres_plan_batching.py │ ├── specific_models │ ├── model.py │ └── postgres_zero_shot.py │ ├── utils │ ├── activations.py │ ├── embeddings.py │ ├── fc_out_model.py │ └── node_type_encoder.py │ └── zero_shot_model.py ├── parse_all.py ├── predict.py ├── run_benchmark.py ├── scripts ├── exp_runner │ ├── exp_osf_upload.py │ ├── exp_predict_all.py │ ├── exp_predict_retrained.py │ ├── exp_remove_workload.py │ ├── exp_retrain_model.py │ ├── exp_run_evaluation_workloads.py │ ├── exp_run_training_workloads.py │ ├── exp_runner.py │ ├── exp_setup.py │ └── exp_train_model.py ├── misc │ ├── cloudlab.rspec │ └── parse_cloudlab_manifest.py └── postgres_installation │ ├── install_postgres.sh │ ├── install_postgres_10.sh │ ├── install_postgres_11.sh │ ├── install_postgres_12.sh │ ├── install_postgres_13.sh │ ├── install_postgres_16.sh │ ├── install_tools.sh │ ├── resize_partition.sh │ └── resize_partition_cont.sh ├── setup.py ├── tests ├── __init__.py ├── message_passing │ └── test_model_message_passing.py ├── utils.py └── workload_parsing │ ├── test_filter_parsing.py │ └── test_workload_parsing.py ├── train.py └── training ├── batch_to_funcs.py ├── dataset ├── dataset_creation.py └── plan_dataset.py ├── featurizations.py ├── losses.py ├── preprocessing └── feature_statistics.py └── training ├── checkpoint.py ├── metrics.py ├── train.py └── utils.py /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | psycopg2-binary 3 | tqdm 4 | scikit-learn 5 | numpy 6 | torch==2.0.1 7 | torchvision 8 | pydantic 9 | dgl 10 | osfclient 11 | paramiko 12 | scp 13 | dataclasses 14 | pytest 15 | optuna 16 | jupyter 17 | gensim 18 | matplotlib 19 | seaborn 20 | absl-py 21 | glog 22 | networkx 23 | tabulate 24 | scipy==1.10.1 25 | yapf 26 | mako 27 | pyarrow 28 | filelock 29 | lightgbm 30 | func-timeout 31 | joblib 32 | wandb 33 | lightning==2.0.9 34 | loralib==0.1.2 35 | numpy==1.24.3 36 | pytorch_lightning==2.0.9 37 | ray==2.7.0 38 | torch_tb_profiler 39 | tensorflow 40 | tensorrt 41 | -------------------------------------------------------------------------------- /requirements/requirements_cloudlab.txt: -------------------------------------------------------------------------------- 1 | dgl==1.1.2 2 | #mysql-connector-python 3 | networkx 4 | numpy 5 | osfclient 6 | pandas 7 | psycopg2-binary 8 | psutil 9 | sortedcontainers 10 | scikit-learn==1.4.0 11 | tabulate 12 | torch==2.1.2 13 | tqdm 14 | func_timeout 15 | duckdb 16 | gensim 17 | # for octopus 18 | paramiko 19 | scp 20 | scipy==1.11.0 21 | seaborn 22 | 23 | # for wandb 24 | wandb 25 | 26 | # for postgres db endpoint 27 | fastapi 28 | uvicorn 29 | 30 | # zs cost baseline 31 | optuna 32 | 33 | lightgbm 34 | attrs 35 | python-dotenv 36 | -------------------------------------------------------------------------------- /src/classes/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from dotenv import load_dotenv 5 | 6 | 7 | class Paths: 8 | root: Path 9 | data: Path 10 | code: Path 11 | runs: Path 12 | raw: Path 13 | json: Path 14 | parsed_plans: Path 15 | parsed_plans_baseline: Path 16 | augmented_plans_baseline: Path 17 | workloads: Path 18 | evaluation_workloads: Path 19 | training_workloads: Path 20 | 21 | evaluation: Path 22 | retraining_evaluation: Path 23 | 24 | models: Path 25 | retraining_models: Path 26 | sentences: Path 27 | known_hosts: Path 28 | 29 | def __init__(self, root_path: Path, data_path: Path = None): 30 | self.root = root_path 31 | if data_path: 32 | self.data = data_path 33 | else: 34 | self.data = root_path / 'data' 35 | self.code = root_path / 'src' 36 | self.runs = self.data / 'runs' 37 | self.raw = self.runs / 'raw' 38 | self.json = self.runs / 'json' 39 | self.parsed_plans = self.runs / 'parsed_plans' 40 | self.parsed_plans_baseline = self.runs / 'parsed_plans_baseline' 41 | self.augmented_plans_baseline = self.runs / 'augmented_plans_baseline' 42 | self.workloads = self.data / 'workloads' 43 | self.evaluation_workloads = self.workloads / 'evaluation' 44 | self.training_workloads = self.workloads / 'training' 45 | 46 | self.evaluation = self.data / 'evaluation' 47 | self.retraining_evaluation = self.data / 'retraining_evaluation' 48 | 49 | self.models = self.data / 'models' 50 | self.retraining_models = self.data / 'retraining_models' 51 | 52 | self.sentences = self.runs / 'sentences' 53 | 54 | 55 | class LocalPaths(Paths): 56 | def __init__(self): 57 | load_dotenv() 58 | super().__init__(root_path=Path(os.getenv('LOCAL_ROOT_PATH'))) 59 | self.node_list = self.code / 'scripts/misc/hostnames' 60 | self.requirements = self.root / 'requirements' 61 | self.plotting_path = self.data / 'plots' 62 | self.dataset_path = self.code / 'cross_db_benchmark' / 'datasets' 63 | self.known_hosts = Path(os.getenv('LOCAL_KNOWN_HOSTS_PATH')) 64 | 65 | 66 | class CloudlabPaths(Paths): 67 | 68 | def __init__(self): 69 | load_dotenv() 70 | super().__init__(root_path=Path(os.getenv('CLOUDLAB_ROOT_PATH'))) 71 | 72 | 73 | class ClusterPaths(Paths): 74 | 75 | def __init__(self): 76 | load_dotenv() 77 | super().__init__(root_path=Path(os.getenv('CLUSTER_ROOT_PATH')), 78 | data_path=Path(os.getenv('CLUSTER_STORAGE_PATH'))) 79 | -------------------------------------------------------------------------------- /src/classes/workload_runs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | import attrs as attrs 6 | from attr import field 7 | 8 | from classes.classes import ModelConfig, ModelType 9 | 10 | 11 | @attrs.define(frozen=True, slots=False) 12 | class WorkloadRuns: 13 | train_workload_runs: List[Path] = field(default=None) 14 | test_workload_runs: Optional[List[Path]] = field(default=None) 15 | target_test_csv_paths: List[Path] = [] 16 | 17 | def update_test_workloads(self, target_dir: Path, seed: int) -> None: 18 | if self.test_workload_runs: 19 | for test_path in self.test_workload_runs: 20 | test_workload = os.path.basename(test_path).replace('.json', '') 21 | self.target_test_csv_paths.append(Path(target_dir) / f'{test_workload}_{seed}') 22 | else: 23 | # When no test paths are given, this is a workload driven model, 24 | # and we use the training workload as test workload 25 | for test_path in self.train_workload_runs: 26 | test_workload = os.path.basename(test_path).replace('.json', '') 27 | self.target_test_csv_paths.append(Path(target_dir) / f'{test_workload}_{seed}') 28 | 29 | def check_if_done(self, model_name: str) -> bool: 30 | if all([os.path.exists(p) for p in self.target_test_csv_paths]) and self.target_test_csv_paths: 31 | short_paths = {} 32 | for path in self.target_test_csv_paths: 33 | key = (path.parts[-3], path.parts[-2]) 34 | value = path.parts[-1] 35 | short_paths.setdefault(key, []).append(value) 36 | print( 37 | f"{'Model '}{model_name:<16}{' was already trained and evaluated for '}{str(len(self.target_test_csv_paths))}{' queries: '}{str(short_paths)}") 38 | return True 39 | return False 40 | 41 | def check_model_compability(self, model_config: ModelConfig, mode: str): 42 | if mode == "train": 43 | if model_config.type == ModelType.WL_DRIVEN: 44 | assert len(self.train_workload_runs) == 1, "One training workload run is supported for workload driven models" 45 | assert len(self.test_workload_runs) == 0, "No test workload runs are allowed for workload driven models, as the test workload is the same as the training workload" 46 | 47 | if model_config.type == ModelType.WL_AGNOSTIC: 48 | assert len(self.train_workload_runs) >= 1, "Model need to be trained more than 1 database" 49 | assert len(self.test_workload_runs) == 1, "Only one test workload run is supported for workload agnostic models" 50 | return 51 | -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardDetail", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_ablation_column_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardAblateColumnFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_ablation_data_dist_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardAblateDataDistributionFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_ablation_operator_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardAblateOperatorFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_ablation_table_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardAblateTableFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_dec_ablation_all_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardDecAblationAllFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_dec_ablation_column_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardDecAblationColumnFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_dec_ablation_operator_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardDecAblationOperatorFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_best_config_dec_ablation_table_feats.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresTrueCardDecAblationTableFeats", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_deepdb_best_config.json: -------------------------------------------------------------------------------- 1 | {"batch_size": 256, "dropout": false, "final_layers": 4, "final_width_factor": 1.5, "hidden_dim": 128, "lr": 0.001, "max_emb_dim": 32, "message_passing_layers": 2, "node_layers": 4, "node_type_width_factor": 1.5, "p_dropout": 0.0, "plan_featurization_name": "PostgresDeepDBEstSystemCardDetail", "residual": false, "tree_layer_name": "MscnConv", "tree_layer_width_factor": 0.6} -------------------------------------------------------------------------------- /src/conf/zeroshot_hyperparameters/tune_est_best_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 256, 3 | "dropout": false, 4 | "final_layers": 4, 5 | "final_width_factor": 1.5, 6 | "hidden_dim": 128, 7 | "lr": 0.001, 8 | "max_emb_dim": 32, 9 | "message_passing_layers": 2, 10 | "node_layers": 4, 11 | "node_type_width_factor": 1.5, 12 | "p_dropout": 0.0, 13 | "plan_featurization_name": "PostgresEstSystemCardDetail", 14 | "residual": false, 15 | "tree_layer_name": "MscnConv", 16 | "tree_layer_width_factor": 0.6 17 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/lcm-eval/8ed11d4c47bae2cb7f0740f566170f3e736e8471/src/cross_db_benchmark/__init__.py -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/lcm-eval/8ed11d4c47bae2cb7f0740f566170f3e736e8471/src/cross_db_benchmark/benchmark_tools/__init__.py -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/column_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Datatype(Enum): 5 | INT = 'int' 6 | FLOAT = 'float' 7 | CATEGORICAL = 'categorical' 8 | STRING = 'string' 9 | MISC = 'misc' 10 | 11 | def __str__(self): 12 | return self.value 13 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/compare_runs.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 2 | from cross_db_benchmark.benchmark_tools.postgres.compare_plan import compare_plans 3 | from cross_db_benchmark.benchmark_tools.utils import load_json 4 | 5 | 6 | def compare_runs(source_path, alt_source_path, database, min_query_ms=100): 7 | if database == DatabaseSystem.POSTGRES: 8 | compare_func = compare_plans 9 | else: 10 | raise NotImplementedError(f"Database {database} not yet supported.") 11 | 12 | run_stats = load_json(source_path) 13 | alt_run_stats = load_json(alt_source_path) 14 | 15 | compare_func(run_stats, alt_run_stats, min_runtime=min_query_ms) 16 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/create_fk_indexes.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from cross_db_benchmark.benchmark_tools.load_database import create_db_conn 4 | from cross_db_benchmark.benchmark_tools.utils import load_schema_json 5 | 6 | 7 | def create_fk_indexes(dataset, database, db_name, database_conn_args, database_kwarg_dict): 8 | # check if tables are a connected acyclic graph 9 | schema = load_schema_json(dataset) 10 | db_conn = create_db_conn(database, db_name, database_conn_args, database_kwarg_dict) 11 | 12 | idx_sql = [] 13 | for r_id, r in enumerate(schema.relationships): 14 | table_left, col_left, table_right, col_right = r 15 | cname = col_left 16 | if isinstance(col_left, tuple) or isinstance(col_left, list): 17 | cname = "_".join(col_left) 18 | col_left = ", ".join([f'"{c}"' for c in col_left]) 19 | 20 | sql = f"create index {cname}_{table_left} on \"{table_left}\"({col_left});" 21 | idx_sql.append(sql) 22 | 23 | idx_sql.append("Vacuum Analyze;") 24 | for sql in tqdm(idx_sql): 25 | try: 26 | db_conn.submit_query(sql, db_created=True) 27 | print(sql) 28 | except: 29 | print(f"Skipping {sql}") 30 | continue 31 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/database.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class DatabaseSystem(Enum): 5 | POSTGRES = 'postgres' 6 | MYSQL = 'mysql' 7 | 8 | def __str__(self): 9 | return self.value 10 | 11 | 12 | class ExecutionMode: 13 | RAW_OUTPUT = 'raw' 14 | JSON_OUTPUT = 'json' 15 | 16 | 17 | class DatabaseConnection: 18 | def __init__(self, db_name=None, database_kwargs=None): 19 | self.db_name = db_name 20 | self.database_kwargs = database_kwargs 21 | 22 | def drop(self): 23 | raise NotImplementedError 24 | 25 | def load_database(self, data_dir, dataset, force=False): 26 | raise NotImplementedError 27 | 28 | def replicate_tuples(self, dataset, data_dir, no_prev_replications): 29 | raise NotImplementedError 30 | 31 | def set_statement_timeout(self, timeout_sec): 32 | raise NotImplementedError 33 | 34 | def run_query_collect_statistics(self, sql, repetitions, prefix, hint_validation, 35 | include_hint_notices, explain_only): 36 | raise NotImplementedError 37 | 38 | def collect_db_statistics(self): 39 | raise NotImplementedError 40 | 41 | def transform_dicts(self, column_stats_names, column_stats_rows): 42 | return [{k: v for k, v in zip(column_stats_names, row)} for row in column_stats_rows] 43 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/drop_db.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.load_database import create_db_conn 2 | 3 | 4 | def drop_db(database, db_name, database_conn_args, database_kwarg_dict): 5 | db_conn = create_db_conn(database, db_name, database_conn_args, database_kwarg_dict) 6 | db_conn.drop() 7 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/generate_column_stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from cross_db_benchmark.benchmark_tools.column_types import Datatype 8 | from cross_db_benchmark.benchmark_tools.utils import load_schema_json 9 | 10 | 11 | class CustomEncoder(json.JSONEncoder): 12 | def default(self, obj): 13 | if isinstance(obj, np.integer): 14 | return int(obj) 15 | elif isinstance(obj, np.floating): 16 | return float(obj) 17 | elif isinstance(obj, np.ndarray): 18 | return obj.tolist() 19 | elif isinstance(obj, Datatype): 20 | return str(obj) 21 | else: 22 | return super(CustomEncoder, self).default(obj) 23 | 24 | 25 | def column_stats(column, categorical_threshold=10000): 26 | """ 27 | Default method for encoding the datasets 28 | """ 29 | nan_ratio = sum(column.isna()) / len(column) 30 | stats = dict(nan_ratio=nan_ratio) 31 | if column.dtype == object: 32 | if len(column.unique()) > categorical_threshold: 33 | stats.update(dict(datatype=Datatype.MISC)) 34 | 35 | else: 36 | vals_sorted_by_occurence = list(column.value_counts().index) 37 | stats.update(dict( 38 | datatype=Datatype.CATEGORICAL, 39 | unique_vals=vals_sorted_by_occurence, 40 | num_unique=len(column.unique()) 41 | )) 42 | 43 | else: 44 | 45 | percentiles = list(column.quantile(q=[0.1 * i for i in range(11)])) 46 | 47 | stats.update(dict( 48 | max=column.max(), 49 | min=column.min(), 50 | mean=column.mean(), 51 | num_unique=len(column.unique()), 52 | percentiles=percentiles, 53 | )) 54 | 55 | if column.dtype == int: 56 | stats.update(dict(datatype=Datatype.INT)) 57 | 58 | else: 59 | stats.update(dict(datatype=Datatype.FLOAT)) 60 | 61 | return stats 62 | 63 | 64 | def generate_column_statistics(data_dir, dataset, force=True): 65 | # read the schema file 66 | column_stats_path = os.path.join('cross_db_benchmark/datasets/', dataset, 'column_statistics.json') 67 | if os.path.exists(column_stats_path) and not force: 68 | print("Column stats already created") 69 | return 70 | 71 | schema = load_schema_json(dataset) 72 | 73 | # read individual table csvs and derive statistics 74 | joint_column_stats = dict() 75 | for t in schema.tables: 76 | 77 | column_stats_table = dict() 78 | table_dir = os.path.join(data_dir, f'{t}.csv') 79 | assert os.path.exists(data_dir), f"Could not find table csv {table_dir}" 80 | print(f"Generating statistics for {t}") 81 | 82 | df_table = pd.read_csv(table_dir, **vars(schema.csv_kwargs)) 83 | 84 | for column in df_table.columns: 85 | column_stats_table[column] = column_stats(df_table[column]) 86 | 87 | joint_column_stats[t] = column_stats_table 88 | 89 | # save to json 90 | with open(column_stats_path, 'w') as outfile: 91 | # workaround for numpy and other custom datatypes 92 | json.dump(joint_column_stats, outfile, cls=CustomEncoder) 93 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/generate_string_statistics.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | 5 | import pandas as pd 6 | 7 | from cross_db_benchmark.benchmark_tools.column_types import Datatype 8 | from cross_db_benchmark.benchmark_tools.utils import load_schema_json, load_column_statistics 9 | 10 | 11 | def generate_string_stats(data_dir, dataset, force=True, max_sample_vals=100000, min_str_occ=0.01, 12 | verbose=False): 13 | # read the schema file 14 | string_stats_path = os.path.join('cross_db_benchmark/datasets/', dataset, 'string_statistics.json') 15 | if os.path.exists(string_stats_path) and not force: 16 | print("String stats already created") 17 | return 18 | 19 | schema = load_schema_json(dataset) 20 | column_stats = load_column_statistics(dataset) 21 | 22 | cols_with_freq_words = 0 23 | string_stats = dict() 24 | for table, cols in vars(column_stats).items(): 25 | 26 | string_stats[table] = dict() 27 | table_dir = os.path.join(data_dir, f'{table}.csv') 28 | assert os.path.exists(data_dir), f"Could not find table csv {table_dir}" 29 | if verbose: 30 | print(f"Generating string statistics for {table}") 31 | 32 | df_table = pd.read_csv(table_dir, nrows=max_sample_vals, **vars(schema.csv_kwargs)) 33 | 34 | for c, col_stats in vars(cols).items(): 35 | if col_stats.datatype in {str(Datatype.CATEGORICAL), str(Datatype.MISC)}: 36 | col_vals = df_table[c] 37 | # do not consider too many values 38 | col_vals = col_vals[:max_sample_vals] 39 | len_strs = len(col_vals) 40 | 41 | # check how often a word occurs 42 | word_vals = collections.defaultdict(int) 43 | try: 44 | split_col_vals = col_vals.str.split(' ') 45 | except: 46 | continue 47 | 48 | for scol_vals in split_col_vals: 49 | if not isinstance(scol_vals, list): 50 | continue 51 | for v in scol_vals: 52 | if not isinstance(v, str): 53 | continue 54 | word_vals[v] += 1 55 | 56 | # how often should a word appear 57 | min_expected_occ = max(int(len_strs * min_str_occ), 1) 58 | 59 | freq_str_words = list() 60 | for val, occ in word_vals.items(): 61 | if occ > min_expected_occ: 62 | freq_str_words.append(val) 63 | 64 | if len(freq_str_words) > 0: 65 | if verbose: 66 | print(f"Found {len(freq_str_words)} frequent words for {c} " 67 | f"(expected {min_expected_occ}/{len_strs})") 68 | 69 | cols_with_freq_words += 1 70 | string_stats[table][c] = dict(freq_str_words=freq_str_words) 71 | 72 | # save to json 73 | with open(string_stats_path, 'w') as outfile: 74 | print(f"Found {cols_with_freq_words} string-queryable columns for dataset {dataset}") 75 | # workaround for numpy and other custom datatypes 76 | json.dump(string_stats, outfile) 77 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/get_table_lengths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 5 | from cross_db_benchmark.benchmark_tools.load_database import create_db_conn 6 | 7 | def get_table_rows(db_name): 8 | db_conn = create_db_conn(database=DatabaseSystem.POSTGRES, 9 | db_name=db_name, 10 | database_conn_args=dict(user="postgres", password="bM2YGRAX*bG_QAilUid§2iD", host="localhost"), 11 | database_kwarg_dict=dict()) 12 | 13 | sql = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" 14 | result = db_conn.get_result(sql, db_created=True) 15 | 16 | table_rows = {} 17 | for table in result: 18 | table_name = table[0] 19 | row_count = db_conn.get_result(f"SELECT COUNT(*) FROM {table_name}", db_created=True) 20 | if row_count != 0: 21 | table_rows[table_name] = row_count[0][0] 22 | with open(f'{db_name}.json', "w") as outfile: 23 | json.dump(table_rows, outfile) 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--db_name", default=None) 28 | args = parser.parse_args() 29 | # Update these values with your PostgreSQL connection information 30 | get_table_rows(args.db_name) 31 | 32 | """ 33 | try: 34 | table_rows = (db_conn) 35 | for table, rows in table_rows.items(): 36 | print(f"Table: {table}, Rows: {rows}") 37 | except psycopg2.Error as e: 38 | print("Error connecting to PostgreSQL:", e) 39 | finally: 40 | if connection: 41 | connection.close() 42 | """ 43 | if __name__ == "__main__": 44 | main() -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/join_conditions.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.load_database import create_db_conn 2 | from cross_db_benchmark.benchmark_tools.utils import load_schema_json 3 | 4 | 5 | def check_schema_graph_recursively(table, visited_tables, visited_relationships, schema): 6 | if table in visited_tables: 7 | raise NotImplementedError("Schema is cyclic") 8 | visited_tables.add(table) 9 | 10 | for r_id, r in enumerate(schema.relationships): 11 | if r_id in visited_relationships: 12 | continue 13 | 14 | table_left, _, table_right, _ = r 15 | if table_left == table or table_right == table: 16 | visited_relationships.add(r_id) 17 | if table_left == table: 18 | check_schema_graph_recursively(table_right, visited_tables, visited_relationships, schema) 19 | elif table_right == table: 20 | check_schema_graph_recursively(table_left, visited_tables, visited_relationships, schema) 21 | 22 | 23 | def check_join_conditions(dataset, database, db_name, database_conn_args, database_kwarg_dict): 24 | db_conn = create_db_conn(database, db_name, database_conn_args, database_kwarg_dict) 25 | 26 | # check if tables are a connected acyclic graph 27 | schema = load_schema_json(dataset) 28 | visited_tables = set() 29 | visited_relationships = set() 30 | check_schema_graph_recursively(schema.tables[0], visited_tables, visited_relationships, schema) 31 | assert len(visited_tables) == len(schema.tables), "Schema graph is not connected" 32 | print("Schema graph is acyclic and connected") 33 | 34 | print("Checking join conditions...") 35 | db_conn.test_join_conditions(dataset) 36 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/load_database.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 2 | from cross_db_benchmark.benchmark_tools.postgres.database_connection import PostgresDatabaseConnection 3 | 4 | 5 | def create_db_conn(database, db_name, database_conn_args, database_kwarg_dict): 6 | if database == DatabaseSystem.POSTGRES: 7 | return PostgresDatabaseConnection(db_name=db_name, database_kwargs=database_conn_args, **database_kwarg_dict) 8 | else: 9 | raise NotImplementedError(f"Database {database} not yet supported.") 10 | 11 | 12 | def load_database(data_dir, dataset, database, db_name, database_conn_args, database_kwarg_dict, force=False): 13 | db_conn = create_db_conn(database, db_name, database_conn_args, database_kwarg_dict) 14 | db_conn.load_database(dataset, data_dir, force=force) 15 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/parse_run.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from cross_db_benchmark.benchmark_tools.postgres.combine_plans import combine_traces 5 | from cross_db_benchmark.benchmark_tools.postgres.parse_plan import parse_plans 6 | from cross_db_benchmark.benchmark_tools.utils import load_json 7 | 8 | 9 | def dumper(obj): 10 | try: 11 | return obj.toJSON() 12 | except: 13 | return obj.__dict__ 14 | 15 | 16 | def parse_run(source_paths, target_path, database, min_query_ms=100, max_query_ms=30000, 17 | parse_baseline=False, cap_queries=None, parse_join_conds=False, include_zero_card=False, 18 | explain_only=False): 19 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 20 | 21 | parse_func = parse_plans 22 | comb_func = combine_traces 23 | 24 | if not isinstance(source_paths, list): 25 | source_paths = [source_paths] 26 | 27 | assert all([os.path.exists(p) for p in source_paths]) 28 | run_stats = [load_json(p) for p in source_paths] 29 | run_stats = comb_func(run_stats) 30 | 31 | parsed_runs, stats = parse_func(run_stats, min_runtime=min_query_ms, max_runtime=max_query_ms, 32 | parse_baseline=parse_baseline, cap_queries=cap_queries, 33 | parse_join_conds=parse_join_conds, 34 | include_zero_card=include_zero_card, explain_only=explain_only) 35 | 36 | with open(target_path, 'w') as outfile: 37 | json.dump(parsed_runs, outfile, default=dumper) 38 | return len(parsed_runs['parsed_plans']), stats 39 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/postgres/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/lcm-eval/8ed11d4c47bae2cb7f0740f566170f3e736e8471/src/cross_db_benchmark/benchmark_tools/postgres/__init__.py -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/postgres/check_valid.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.database import ExecutionMode 2 | from cross_db_benchmark.benchmark_tools.postgres.json_plan import operator_tree_from_json 3 | from cross_db_benchmark.benchmark_tools.postgres.parse_plan import parse_raw_plan 4 | import copy 5 | import traceback 6 | 7 | 8 | def check_valid(mode: ExecutionMode, curr_statistics: dict, min_runtime: int = 100, verbose=True) -> bool: 9 | # Timeouts are also a valid signal in learning 10 | if 'timeout' in curr_statistics and curr_statistics['timeout']: 11 | if verbose: 12 | print("Invalid since it ran into a timeout") 13 | return False 14 | 15 | try: 16 | analyze_plans = curr_statistics['analyze_plans'] 17 | 18 | if analyze_plans is None or len(analyze_plans) == 0: 19 | if verbose: 20 | print("Invalid because no analyze plans are available") 21 | return False 22 | 23 | if mode == ExecutionMode.JSON_OUTPUT: 24 | analyze_plan = copy.deepcopy(analyze_plans[0]) 25 | analyze_plan = operator_tree_from_json(analyze_plan) 26 | runtime = analyze_plan.runtime * 1000 27 | cardinality = analyze_plan.min_cardinality() 28 | 29 | else: 30 | analyze_plan = analyze_plans[0] 31 | analyze_plan, runtime, _ = parse_raw_plan(analyze_plan, analyze=True, parse=True) 32 | analyze_plan.parse_lines_recursively() 33 | cardinality = analyze_plan.min_card() 34 | 35 | if cardinality == 0: 36 | if verbose: 37 | print("Invalid because of zero cardinality") 38 | return False 39 | 40 | if runtime < min_runtime: 41 | if verbose: 42 | print("Invalid because of too short runtime") 43 | return False 44 | 45 | return True 46 | except Exception as e: 47 | if verbose: 48 | print("Invalid due to error" + traceback.format_exc()) 49 | return False 50 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/postgres/combine_plans.py: -------------------------------------------------------------------------------- 1 | def combine_traces(runs): 2 | start_plan = runs[0] 3 | for p in runs[1:]: 4 | start_plan.query_list += p.query_list 5 | start_plan.total_time_secs += p.total_time_secs 6 | 7 | return start_plan 8 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/postgres/compare_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from cross_db_benchmark.benchmark_tools.postgres.parse_plan import parse_raw_plan 5 | 6 | 7 | def compare_plans(run_stats, alt_run_stats, min_runtime=100): 8 | # parse individual queries 9 | sql_q_id = dict() 10 | for i, q in enumerate(run_stats.query_list): 11 | sql_q_id[q.sql.strip()] = i 12 | 13 | q_errors = [] 14 | for q2 in tqdm(alt_run_stats.query_list): 15 | q_id = sql_q_id.get(q2.sql.strip()) 16 | if q_id is None: 17 | continue 18 | 19 | q = run_stats.query_list[q_id] 20 | 21 | if q.analyze_plans is None or q2.analyze_plans is None: 22 | continue 23 | 24 | if len(q.analyze_plans) == 0 or len(q2.analyze_plans) == 0: 25 | continue 26 | 27 | assert q.sql == q2.sql 28 | 29 | # parse the plan as a tree 30 | analyze_plan, ex_time, _ = parse_raw_plan(q.analyze_plans[0], analyze=True, parse=True) 31 | analyze_plan2, ex_time2, _ = parse_raw_plan(q2.analyze_plans[0], analyze=True, parse=True) 32 | analyze_plan.parse_lines_recursively() 33 | analyze_plan2.parse_lines_recursively() 34 | 35 | if analyze_plan.min_card() == 0: 36 | continue 37 | 38 | if ex_time < min_runtime: 39 | continue 40 | 41 | q_error = max(ex_time2 / ex_time, ex_time / ex_time2) 42 | q_errors.append(q_error) 43 | 44 | # statistics in seconds 45 | q_errors = np.array(q_errors) 46 | print(f"Q-Error/Deviation of both runs: " 47 | f"\n\tmedian: {np.median(q_errors):.2f}" 48 | f"\n\tmax: {np.max(q_errors):.2f}" 49 | f"\n\tmean: {np.mean(q_errors):.2f}") 50 | print(f"Parsed {len(q_errors)} plans ") 51 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/postgres/inflate_cardinality_errors.py: -------------------------------------------------------------------------------- 1 | def inflate_card_errors_pg(p, factor): 2 | # inflate the errors (both over- and underestimation) 3 | params = p.plan_parameters 4 | if params.act_card > params.est_card: 5 | q_err = params.act_card / params.est_card 6 | q_err = (q_err - 1) * factor + 1 7 | err_card = params.act_card / q_err 8 | 9 | else: 10 | q_err = params.est_card / params.act_card 11 | q_err = (q_err - 1) * factor + 1 12 | err_card = params.act_card * q_err 13 | 14 | if err_card < 1: 15 | err_card = 1 16 | err_card = float(int(err_card)) 17 | 18 | params.est_card = err_card 19 | params.act_card = err_card 20 | 21 | prod = 1 22 | for c in p.children: 23 | prod *= inflate_card_errors_pg(c, factor) 24 | 25 | params.est_children_card = prod 26 | params.act_children_card = prod 27 | 28 | return err_card 29 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/postgres/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def plan_statistics(plan_op, tables=None, filter_columns=None, operators=None, skip_columns=False, conv_to_dict=False): 5 | if tables is None: 6 | tables = set() 7 | if operators is None: 8 | operators = set() 9 | if filter_columns is None: 10 | filter_columns = set() 11 | 12 | params = plan_op.plan_parameters 13 | 14 | if conv_to_dict: 15 | params = vars(params) 16 | 17 | if 'table' in params: 18 | tables.add(params['table']) 19 | if 'op_name' in params: 20 | operators.add(params['op_name']) 21 | if 'filter_columns' in params and not skip_columns: 22 | list_columns(params['filter_columns'], filter_columns) 23 | 24 | for c in plan_op.children: 25 | plan_statistics(c, tables=tables, filter_columns=filter_columns, operators=operators, skip_columns=skip_columns, 26 | conv_to_dict=conv_to_dict) 27 | 28 | return tables, filter_columns, operators 29 | 30 | 31 | def child_prod(p, feature_name, default=1): 32 | child_feat = [c.plan_parameters.get(feature_name) for c in p.children 33 | if c.plan_parameters.get(feature_name) is not None] 34 | if len(child_feat) == 0: 35 | return default 36 | return math.prod(child_feat) 37 | 38 | 39 | def list_columns(n, columns): 40 | columns.add((n.column, n.operator)) 41 | for c in n.children: 42 | list_columns(c, columns) 43 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/run_workload.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 2 | from cross_db_benchmark.benchmark_tools.postgres.run_workload import run_pg_workload 3 | 4 | 5 | def run_workload(workload_path, database, db_name, database_conn_args, database_kwarg_dict, target_path, run_kwargs, 6 | repetitions_per_query, timeout_sec, mode, hints=None, with_indexes=False, cap_workload=None, explain_only: bool = False, 7 | min_runtime=100): 8 | if database == DatabaseSystem.POSTGRES: 9 | run_pg_workload(workload_path, database, db_name, database_conn_args, database_kwarg_dict, target_path, 10 | run_kwargs, repetitions_per_query, timeout_sec, random_hints=hints, with_indexes=with_indexes, 11 | cap_workload=cap_workload, min_runtime=min_runtime, mode=mode, explain_only=explain_only) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/benchmark_tools/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from types import SimpleNamespace 4 | 5 | 6 | def load_schema_json(dataset): 7 | schema_path = os.path.join('cross_db_benchmark/datasets/', dataset, 'schema.json') 8 | assert os.path.exists(schema_path), f"Could not find schema.json ({schema_path})" 9 | return load_json(schema_path) 10 | 11 | 12 | def load_column_statistics(dataset, namespace=True): 13 | path = os.path.join('cross_db_benchmark/datasets/', dataset, 'column_statistics.json') 14 | assert os.path.exists(path), f"Could not find file ({path})" 15 | return load_json(path, namespace=namespace) 16 | 17 | 18 | def load_string_statistics(dataset, namespace=True): 19 | path = os.path.join('cross_db_benchmark/datasets/', dataset, 'string_statistics.json') 20 | assert os.path.exists(path), f"Could not find file ({path})" 21 | return load_json(path, namespace=namespace) 22 | 23 | 24 | def load_json(path, namespace=True): 25 | with open(path) as json_file: 26 | if namespace: 27 | json_obj = json.load(json_file, object_hook=lambda d: SimpleNamespace(**d)) 28 | else: 29 | json_obj = json.load(json_file) 30 | return json_obj 31 | 32 | 33 | def load_schema_sql(dataset, sql_filename): 34 | sql_path = os.path.join('cross_db_benchmark/datasets/', dataset, 'schema_sql', sql_filename) 35 | assert os.path.exists(sql_path), f"Could not find schema.sql ({sql_path})" 36 | with open(sql_path, 'r') as file: 37 | data = file.read().replace('\n', '') 38 | return data 39 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/accidents/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "accidents", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "upravna_enota", 11 | "oseba", 12 | "nesreca" 13 | ], 14 | "relationships": [ 15 | [ 16 | "nesreca", 17 | [ 18 | "upravna_enota" 19 | ], 20 | "upravna_enota", 21 | [ 22 | "id_upravna_enota" 23 | ] 24 | ], 25 | [ 26 | "oseba", 27 | [ 28 | "upravna_enota" 29 | ], 30 | "upravna_enota", 31 | [ 32 | "id_upravna_enota" 33 | ] 34 | ] 35 | ] 36 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/accidents/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "nesreca"; 4 | 5 | CREATE TABLE "nesreca" ( 6 | "id_nesreca" char(6) NOT NULL, 7 | "klas_nesreca" char(1) NOT NULL, 8 | "upravna_enota" char(4) NOT NULL, 9 | "cas_nesreca" varchar(255) NOT NULL, 10 | "naselje_ali_izven" char(1) NOT NULL, 11 | "kategorija_cesta" char(1) DEFAULT NULL, 12 | "oznaka_cesta_ali_naselje" char(5) NOT NULL, 13 | "tekst_cesta_ali_naselje" varchar(25) NOT NULL, 14 | "oznaka_odsek_ali_ulica" char(5) NOT NULL, 15 | "tekst_odsek_ali_ulica" varchar(25) NOT NULL, 16 | "stacionazna_ali_hisna_st" varchar(9) DEFAULT NULL, 17 | "opis_prizorisce" char(1) NOT NULL, 18 | "vzrok_nesreca" char(2) NOT NULL, 19 | "tip_nesreca" char(2) NOT NULL, 20 | "vreme_nesreca" char(1) NOT NULL, 21 | "stanje_promet" char(1) NOT NULL, 22 | "stanje_vozisce" char(2) NOT NULL, 23 | "stanje_povrsina_vozisce" char(2) NOT NULL, 24 | "x" integer DEFAULT NULL, 25 | "y" integer DEFAULT NULL, 26 | "x_wgs84" double precision DEFAULT NULL, 27 | "y_wgs84" double precision DEFAULT NULL, 28 | PRIMARY KEY ("id_nesreca") 29 | ) ; 30 | 31 | DROP TABLE IF EXISTS "oseba"; 32 | 33 | CREATE TABLE "oseba" ( 34 | "id_nesreca" char(6) NOT NULL, 35 | "povzrocitelj_ali_udelezenec" char(1) DEFAULT NULL, 36 | "starost" integer DEFAULT NULL, 37 | "spol" char(1) NOT NULL, 38 | "upravna_enota" char(4) NOT NULL, 39 | "drzavljanstvo" char(3) NOT NULL, 40 | "poskodba" char(1) DEFAULT NULL, 41 | "vrsta_udelezenca" char(2) DEFAULT NULL, 42 | "varnostni_pas_ali_celada" char(1) DEFAULT NULL, 43 | "vozniski_staz_LL" integer DEFAULT NULL, 44 | "vozniski_staz_MM" integer DEFAULT NULL, 45 | "alkotest" decimal(3,2) DEFAULT NULL, 46 | "strokovni_pregled" decimal(3,2) DEFAULT NULL, 47 | "starost_d" char(1) DEFAULT NULL, 48 | "vozniski_staz_d" char(1) NOT NULL, 49 | "alkotest_d" char(1) NOT NULL, 50 | "strokovni_pregled_d" char(1) NOT NULL 51 | ) ; 52 | 53 | DROP TABLE IF EXISTS "upravna_enota"; 54 | 55 | CREATE TABLE "upravna_enota" ( 56 | "id_upravna_enota" char(4) NOT NULL, 57 | "ime_upravna_enota" varchar(255) NOT NULL, 58 | "st_prebivalcev" integer DEFAULT NULL, 59 | "povrsina" integer DEFAULT NULL, 60 | PRIMARY KEY ("id_upravna_enota") 61 | ) ; 62 | 63 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/accidents/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"upravna_enota": {"ime_upravna_enota": {"freq_str_words": ["Bistrica", "Ljubljana", "ob", "pri"]}}, "oseba": {"povzrocitelj_ali_udelezenec": {"freq_str_words": ["D", "N"]}, "poskodba": {"freq_str_words": ["L", "P", "H", "B"]}, "vrsta_udelezenca": {"freq_str_words": ["TV", "OA", "PT", "AV", "PE", "KO", "KM"]}, "varnostni_pas_ali_celada": {"freq_str_words": ["N", "0", "D"]}, "starost_d": {"freq_str_words": ["D", "C", "B", "E", "F", "G", "H"]}, "vozniski_staz_d": {"freq_str_words": ["B", "A", "D", "E", "C", "N", "F"]}, "alkotest_d": {"freq_str_words": ["B", "A", "N", "C", "D", "E"]}, "strokovni_pregled_d": {"freq_str_words": ["A", "N"]}}, "nesreca": {"klas_nesreca": {"freq_str_words": ["H", "B", "L", "P"]}, "naselje_ali_izven": {"freq_str_words": ["D", "N"]}, "kategorija_cesta": {"freq_str_words": ["V", "L", "M", "N", "R", "A"]}, "oznaka_cesta_ali_naselje": {"freq_str_words": ["00003", "25001", "00001", "00010", "00A10", "17042", "010-1", "010-8", "64033", "000A1", "03011", "18035"]}, "tekst_cesta_ali_naselje": {"freq_str_words": ["MEJA", "A-VI\u010c-ORMO\u017d-MEJA", "RH", "-", "LJUBLJANA", "NOVA", "GORICA", "KORENSKO", "SEDLO", "BREGANA", "\u0160ENTILJ", "DEKANI", "KOPER", "NA", "PO\u010cEHOVA", "LENDAVA", "DRAVOGRAD-DOBOVEC-MEJA", "MARIBOR", "PREDOR", "KARAVANKE-BREGANA", "CELJE", "VAS", "PRI", "KRANJ", "OBRE\u017dJE"]}, "oznaka_odsek_ali_ulica": {"freq_str_words": ["00000"]}, "tekst_odsek_ali_ulica": {"freq_str_words": ["NI", "ULIC", "ODSEKOV", "ULICA", "CESTA", "TRG", "NA", "VAS"]}, "opis_prizorisce": {"freq_str_words": ["C", "R", "P"]}, "vzrok_nesreca": {"freq_str_words": ["PR", "HI", "SV", "PD", "VR", "PV", "OS"]}, "tip_nesreca": {"freq_str_words": ["BT", "\u00c8T", "TO", "OS", "NT", "OP", "TV", "PR", "PP"]}, "vreme_nesreca": {"freq_str_words": ["O", "J", "M", "D", "S"]}, "stanje_promet": {"freq_str_words": ["R", "N", "G", "E"]}, "stanje_vozisce": {"freq_str_words": ["MO", "PN", "SP", "SU", "SL", "SN"]}, "stanje_povrsina_vozisce": {"freq_str_words": ["A", "M"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/airline/dataset_documentation/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Documentation 2 | 3 | Initially, it was taken from a relational fit dataset. 4 | 5 | https://relational.fit.cvut.cz/dataset/Airline 6 | 7 | However, there were only 400K flights included. Hence, we decided to download more flights from this webpage. 8 | 9 | https://www.transtats.bts.gov/DL_SelectFields.asp?Table_ID=236&DB_Short_Name=On-Time 10 | 11 | We extracted the csv files 12 | 13 | find . -name 'On_*.csv' -exec cp {} . \; 14 | 15 | ...and merged the csv files using awk 'FNR > 1' *.csv > merged.csv and then used the script.py in this folder. -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/baseball/table_lengths.json: -------------------------------------------------------------------------------- 1 | { 2 | "allstarfull": 48310, 3 | "appearances": 923060, 4 | "awardsmanagers": 1560, 5 | "awardsplayers": 57950, 6 | "awardssharemanagers": 3720, 7 | "awardsshareplayers": 62890, 8 | "batting": 923530, 9 | "battingpost": 97980, 10 | "els_teamnames": 3140, 11 | "fielding": 1379750, 12 | "fieldingof": 120270, 13 | "fieldingpost": 103460, 14 | "halloffame": 38830, 15 | "managers": 32920, 16 | "managershalf": 890, 17 | "pitching": 393610, 18 | "pitchingpost": 41970, 19 | "players": 164070, 20 | "salaries": 231110, 21 | "schools": 7490, 22 | "schoolsplayers": 56790, 23 | "seriespost": 2720, 24 | "teams": 27150, 25 | "teamsfranchises": 1200, 26 | "teamshalf": 520 27 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/basketball/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "basketball", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "players", 11 | "series_post", 12 | "awards_players", 13 | "draft", 14 | "awards_coaches", 15 | "player_allstar", 16 | "teams", 17 | "players_teams", 18 | "coaches" 19 | ], 20 | "relationships": [ 21 | [ 22 | "awards_coaches", 23 | [ 24 | "coachID", 25 | "year" 26 | ], 27 | "coaches", 28 | [ 29 | "coachID", 30 | "year" 31 | ] 32 | ], 33 | [ 34 | "awards_players", 35 | [ 36 | "playerID" 37 | ], 38 | "players", 39 | [ 40 | "playerID" 41 | ] 42 | ], 43 | [ 44 | "coaches", 45 | [ 46 | "tmID", 47 | "year" 48 | ], 49 | "teams", 50 | [ 51 | "tmID", 52 | "year" 53 | ] 54 | ], 55 | [ 56 | "draft", 57 | [ 58 | "tmID", 59 | "draftYear" 60 | ], 61 | "teams", 62 | [ 63 | "tmID", 64 | "year" 65 | ] 66 | ], 67 | [ 68 | "player_allstar", 69 | [ 70 | "playerID" 71 | ], 72 | "players", 73 | [ 74 | "playerID" 75 | ] 76 | ], 77 | [ 78 | "players_teams", 79 | [ 80 | "playerID" 81 | ], 82 | "players", 83 | [ 84 | "playerID" 85 | ] 86 | ], 87 | [ 88 | "players_teams", 89 | [ 90 | "tmID", 91 | "year" 92 | ], 93 | "teams", 94 | [ 95 | "tmID", 96 | "year" 97 | ] 98 | ], 99 | [ 100 | "series_post", 101 | [ 102 | "tmIDWinner", 103 | "year" 104 | ], 105 | "teams", 106 | [ 107 | "tmID", 108 | "year" 109 | ] 110 | ] 111 | ] 112 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/carcinogenesis/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "carcinogenesis", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "sbond_2", 11 | "canc", 12 | "atom", 13 | "sbond_7", 14 | "sbond_3", 15 | "sbond_1" 16 | ], 17 | "relationships": [ 18 | [ 19 | "atom", 20 | [ 21 | "drug" 22 | ], 23 | "canc", 24 | [ 25 | "drug_id" 26 | ] 27 | ], 28 | [ 29 | "sbond_1", 30 | [ 31 | "drug" 32 | ], 33 | "canc", 34 | [ 35 | "drug_id" 36 | ] 37 | ], 38 | [ 39 | "sbond_2", 40 | [ 41 | "drug" 42 | ], 43 | "canc", 44 | [ 45 | "drug_id" 46 | ] 47 | ], 48 | [ 49 | "sbond_3", 50 | [ 51 | "drug" 52 | ], 53 | "canc", 54 | [ 55 | "drug_id" 56 | ] 57 | ], 58 | [ 59 | "sbond_7", 60 | [ 61 | "drug" 62 | ], 63 | "canc", 64 | [ 65 | "drug_id" 66 | ] 67 | ] 68 | ] 69 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/carcinogenesis/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "atom"; 4 | 5 | CREATE TABLE "atom" ( 6 | "atomid" char(13) , 7 | "drug" char(10) DEFAULT NULL, 8 | "atomtype" char(100) DEFAULT NULL, 9 | "charge" char(100) DEFAULT NULL, 10 | "name" char(2) DEFAULT NULL, 11 | PRIMARY KEY ("atomid") 12 | ) ; 13 | 14 | DROP TABLE IF EXISTS "canc"; 15 | 16 | CREATE TABLE "canc" ( 17 | "drug_id" char(10) , 18 | "class" char(1) DEFAULT NULL, 19 | PRIMARY KEY ("drug_id") 20 | ) ; 21 | 22 | DROP TABLE IF EXISTS "sbond_1"; 23 | 24 | CREATE TABLE "sbond_1" ( 25 | "id" integer , 26 | "drug" char(10) DEFAULT NULL, 27 | "atomid" char(100) DEFAULT NULL, 28 | "atomid_2" char(100) DEFAULT NULL, 29 | PRIMARY KEY ("id") 30 | ) ; 31 | 32 | DROP TABLE IF EXISTS "sbond_2"; 33 | 34 | CREATE TABLE "sbond_2" ( 35 | "id" integer , 36 | "drug" char(10) DEFAULT NULL, 37 | "atomid" char(100) DEFAULT NULL, 38 | "atomid_2" char(100) DEFAULT NULL, 39 | PRIMARY KEY ("id") 40 | ) ; 41 | 42 | DROP TABLE IF EXISTS "sbond_3"; 43 | 44 | CREATE TABLE "sbond_3" ( 45 | "id" integer , 46 | "drug" char(8) DEFAULT NULL, 47 | "atomid" char(100) DEFAULT NULL, 48 | "atomid_2" char(100) DEFAULT NULL, 49 | PRIMARY KEY ("id") 50 | ) ; 51 | 52 | DROP TABLE IF EXISTS "sbond_7"; 53 | 54 | CREATE TABLE "sbond_7" ( 55 | "id" integer , 56 | "drug" char(9) DEFAULT NULL, 57 | "atomid" char(100) DEFAULT NULL, 58 | "atomid_2" char(100) DEFAULT NULL, 59 | PRIMARY KEY ("id") 60 | ) ; 61 | 62 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/carcinogenesis/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"sbond_2": {"drug": {"freq_str_words": ["d117", "d183", "d212", "d221", "d225", "d260", "d27", "d281", "d29", "d318", "d322", "d324", "d329", "d330", "d76", "d77"]}}, "canc": {}, "atom": {"drug": {"freq_str_words": ["d107", "d186", "d231", "d246", "d297"]}, "charge": {"freq_str_words": ["a0=-0_1355 str: 22 | if self._source_dataset is None: 23 | return self.db_name 24 | return self._source_dataset 25 | 26 | @property 27 | def data_folder(self) -> str: 28 | if self._data_folder is None: 29 | return self.db_name 30 | return self._data_folder 31 | 32 | 33 | # datasets that can be downloaded from osf and should be unzipped 34 | source_dataset_list = [ 35 | # original datasets 36 | SourceDataset('airline'), 37 | SourceDataset('imdb'), 38 | SourceDataset('ssb'), 39 | SourceDataset('tpc_h'), 40 | SourceDataset('walmart'), 41 | SourceDataset('financial'), 42 | SourceDataset('basketball'), 43 | SourceDataset('accidents'), 44 | SourceDataset('movielens'), 45 | SourceDataset('baseball'), 46 | SourceDataset('hepatitis'), 47 | SourceDataset('tournament'), 48 | SourceDataset('genome'), 49 | SourceDataset('credit'), 50 | SourceDataset('employee'), 51 | SourceDataset('carcinogenesis'), 52 | SourceDataset('consumer'), 53 | SourceDataset('geneea'), 54 | SourceDataset('seznam'), 55 | SourceDataset('fhnk'), 56 | ] 57 | 58 | database_list = [ 59 | # unscaled 60 | Database('airline', max_no_joins=5), 61 | Database('imdb'), 62 | Database('ssb', max_no_joins=3), 63 | #Database('tpc_h', max_no_joins=5), 64 | Database('tpc_h_pk', max_no_joins=5), # The initial TPC-H dataset has no primary keys due to bad featurization 65 | Database('walmart', max_no_joins=2), 66 | # scaled batch 1 67 | Database('financial', _data_folder='scaled_financial', scale=4), 68 | Database('basketball', _data_folder='scaled_basketball', scale=200), 69 | Database('accidents', _data_folder='accidents', scale=1, contain_unicode=True), 70 | Database('movielens', _data_folder='scaled_movielens', scale=8), 71 | Database('baseball', _data_folder='scaled_baseball', scale=10), 72 | # scaled batch 2 73 | Database('hepatitis', _data_folder='scaled_hepatitis', scale=2000), 74 | Database('tournament', _data_folder='scaled_tournament', scale=50), 75 | Database('credit', _data_folder='scaled_credit', scale=5), 76 | Database('employee', _data_folder='scaled_employee', scale=3), 77 | Database('consumer', _data_folder='scaled_consumer', scale=6), 78 | Database('geneea', _data_folder='scaled_geneea', scale=23, contain_unicode=True), 79 | Database('genome', _data_folder='scaled_genome', scale=6), 80 | Database('carcinogenesis', _data_folder='scaled_carcinogenesis', scale=674), 81 | Database('seznam', _data_folder='scaled_seznam', scale=2), 82 | Database('fhnk', _data_folder='scaled_fhnk', scale=2) 83 | ] 84 | 85 | ext_database_list = database_list + [Database('imdb_full', _data_folder='imdb')] 86 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/employee/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "employee", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "dept_emp", 11 | "dept_manager", 12 | "employees", 13 | "departments", 14 | "titles", 15 | "salaries" 16 | ], 17 | "relationships": [ 18 | [ 19 | "dept_emp", 20 | [ 21 | "emp_no" 22 | ], 23 | "employees", 24 | [ 25 | "emp_no" 26 | ] 27 | ], 28 | [ 29 | "dept_emp", 30 | [ 31 | "dept_no" 32 | ], 33 | "departments", 34 | [ 35 | "dept_no" 36 | ] 37 | ], 38 | [ 39 | "dept_manager", 40 | [ 41 | "emp_no" 42 | ], 43 | "employees", 44 | [ 45 | "emp_no" 46 | ] 47 | ], 48 | [ 49 | "salaries", 50 | [ 51 | "emp_no" 52 | ], 53 | "employees", 54 | [ 55 | "emp_no" 56 | ] 57 | ], 58 | [ 59 | "titles", 60 | [ 61 | "emp_no" 62 | ], 63 | "employees", 64 | [ 65 | "emp_no" 66 | ] 67 | ] 68 | ] 69 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/employee/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "departments"; 4 | 5 | CREATE TABLE "departments" ( 6 | "dept_no" char(6) , 7 | "dept_name" varchar(40) , 8 | PRIMARY KEY ("dept_no") 9 | ) ; 10 | 11 | DROP TABLE IF EXISTS "dept_emp"; 12 | 13 | CREATE TABLE "dept_emp" ( 14 | "emp_no" integer , 15 | "dept_no" char(6) , 16 | "from_date" varchar(255) , 17 | "to_date" varchar(255) , 18 | PRIMARY KEY ("emp_no","dept_no") 19 | ) ; 20 | 21 | DROP TABLE IF EXISTS "dept_manager"; 22 | 23 | CREATE TABLE "dept_manager" ( 24 | "dept_no" char(6) , 25 | "emp_no" integer , 26 | "from_date" varchar(255) , 27 | "to_date" varchar(255) , 28 | PRIMARY KEY ("emp_no","dept_no") 29 | ) ; 30 | 31 | DROP TABLE IF EXISTS "employees"; 32 | 33 | CREATE TABLE "employees" ( 34 | "emp_no" integer , 35 | "birth_date" varchar(255) , 36 | "first_name" varchar(14) , 37 | "last_name" varchar(16) , 38 | "gender" varchar(255) , 39 | "hire_date" varchar(255) , 40 | PRIMARY KEY ("emp_no") 41 | ) ; 42 | 43 | DROP TABLE IF EXISTS "salaries"; 44 | 45 | CREATE TABLE "salaries" ( 46 | "emp_no" integer , 47 | "salary" integer , 48 | "from_date" varchar(12) , 49 | "to_date" varchar(255) , 50 | PRIMARY KEY ("emp_no","from_date") 51 | ) ; 52 | 53 | DROP TABLE IF EXISTS "titles"; 54 | 55 | CREATE TABLE "titles" ( 56 | "emp_no" integer , 57 | "title" varchar(20) , 58 | "from_date" varchar(12) , 59 | "to_date" varchar(255) DEFAULT NULL, 60 | PRIMARY KEY ("emp_no","title","from_date") 61 | ) ; 62 | 63 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/employee/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"dept_emp": {"dept_no": {"freq_str_words": ["d005", "d007", "d004", "d003", "d008", "d006", "d009", "d001", "d002"]}, "to_date": {"freq_str_words": ["9999-01-01"]}}, "dept_manager": {"dept_no": {"freq_str_words": ["d001", "d002", "d003", "d004", "d005", "d006", "d007", "d008", "d009"]}, "from_date": {"freq_str_words": ["1985-01-01"]}, "to_date": {"freq_str_words": ["9999-01-01"]}}, "employees": {"gender": {"freq_str_words": ["M", "F"]}}, "departments": {}, "titles": {"title": {"freq_str_words": ["Senior", "Engineer", "Staff", "Assistant", "Technique", "Leader"]}, "to_date": {"freq_str_words": ["9999-01-01"]}}, "salaries": {"to_date": {"freq_str_words": ["9999-01-01"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/fhnk/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fhnk", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "pripady", 11 | "zup", 12 | "vykony" 13 | ], 14 | "relationships": [ 15 | [ 16 | "vykony", 17 | [ 18 | "Identifikace_pripadu" 19 | ], 20 | "pripady", 21 | [ 22 | "Identifikace_pripadu" 23 | ] 24 | ], 25 | [ 26 | "zup", 27 | [ 28 | "Identifikace_pripadu" 29 | ], 30 | "pripady", 31 | [ 32 | "Identifikace_pripadu" 33 | ] 34 | ] 35 | ] 36 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/fhnk/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "pripady"; 4 | 5 | CREATE TABLE "pripady" ( 6 | "Identifikace_pripadu" integer , 7 | "Identifikator_pacienta" integer , 8 | "Kod_zdravotni_pojistovny" integer , 9 | "Datum_prijeti" varchar(255) , 10 | "Datum_propusteni" varchar(255) , 11 | "Delka_hospitalizace" integer , 12 | "Vekovy_Interval_Pacienta" varchar(255) , 13 | "Pohlavi_pacienta" char(1) , 14 | "Zakladni_diagnoza" varchar(255) , 15 | "Seznam_vedlejsich_diagnoz" varchar(255) , 16 | "DRG_skupina" integer , 17 | "PSC" char(5) DEFAULT NULL, 18 | PRIMARY KEY ("Identifikace_pripadu") 19 | ) ; 20 | 21 | DROP TABLE IF EXISTS "vykony"; 22 | 23 | CREATE TABLE "vykony" ( 24 | "Identifikace_pripadu" integer , 25 | "Datum_provedeni_vykonu" varchar(12) , 26 | "Typ_polozky" integer , 27 | "Kod_polozky" integer , 28 | "Pocet" integer , 29 | "Body" integer , 30 | PRIMARY KEY ("Identifikace_pripadu","Datum_provedeni_vykonu","Kod_polozky") 31 | ) ; 32 | 33 | DROP TABLE IF EXISTS "zup"; 34 | 35 | CREATE TABLE "zup" ( 36 | "Identifikace_pripadu" integer , 37 | "Datum_provedeni_vykonu" varchar(12) , 38 | "Typ_polozky" integer DEFAULT NULL, 39 | "Kod_polozky" integer , 40 | "Pocet" decimal(10,2) DEFAULT NULL, 41 | "Cena" decimal(10,2) DEFAULT NULL, 42 | PRIMARY KEY ("Identifikace_pripadu","Datum_provedeni_vykonu","Kod_polozky") 43 | ) ; 44 | 45 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/fhnk/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"pripady": {"Vekovy_Interval_Pacienta": {"freq_str_words": ["50-60", "60-70", "80+", "0-10", "30-40", "40-50", "20-30", "10-20", "70-80"]}, "Pohlavi_pacienta": {"freq_str_words": ["M", "F"]}, "Zakladni_diagnoza": {"freq_str_words": ["Z380", "O800", "Z511", "G473", "Z510", "I7020", "J352", "Z508"]}, "Seznam_vedlejsich_diagnoz": {"freq_str_words": ["", "I10", "E118", "O990", "O718", "E780", "Z763", "I252", "E119", "Z290", "N390", "B961", "Z966", "E669", "N40", "Z955", "B962", "Z511", "D695", "E039", "J459", "E785", "D630", "U822", "I481", "Z921", "E789", "I480", "I501", "E038", "E86", "I259", "E782", "T810", "D62", "Z501", "I258", "G819", "Z867", "I251", "E790"]}, "PSC": {"freq_str_words": ["50346", "50009", "50327", "50011", "50303", "50002", "50006", "50351", "50012", "50003", "54101", "50315", "55101", "50601", "50341", "54701", "50008", "51601", "50801", "50401", "53002", "51721", "54401"]}}, "zup": {}, "vykony": {"Datum_provedeni_vykonu": {"freq_str_words": ["2013-12-30", "2014-01-01", "2013-12-29", "2013-12-27", "2013-12-28", "2013-12-20", "2013-12-23", "2014-01-02", "2013-12-25", "2014-01-03", "2014-01-04", "2014-01-05", "2014-01-06", "2014-01-07", "2014-01-08", "2014-01-09", "2014-01-10", "2014-01-11", "2014-01-12", "2014-01-13", "2014-01-14", "2014-01-15", "2014-01-16", "2014-01-17", "2014-01-18", "2014-01-19", "2014-01-20", "2014-01-21", "2014-01-22", "2014-01-23"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/financial/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "financial", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "card", 11 | "account", 12 | "order", 13 | "trans", 14 | "loan", 15 | "client", 16 | "disp", 17 | "district" 18 | ], 19 | "relationships": [ 20 | [ 21 | "account", 22 | [ 23 | "district_id" 24 | ], 25 | "district", 26 | [ 27 | "district_id" 28 | ] 29 | ], 30 | [ 31 | "card", 32 | [ 33 | "disp_id" 34 | ], 35 | "disp", 36 | [ 37 | "disp_id" 38 | ] 39 | ], 40 | [ 41 | "client", 42 | [ 43 | "district_id" 44 | ], 45 | "district", 46 | [ 47 | "district_id" 48 | ] 49 | ], 50 | [ 51 | "disp", 52 | [ 53 | "client_id" 54 | ], 55 | "client", 56 | [ 57 | "client_id" 58 | ] 59 | ], 60 | [ 61 | "loan", 62 | [ 63 | "account_id" 64 | ], 65 | "account", 66 | [ 67 | "account_id" 68 | ] 69 | ], 70 | [ 71 | "order", 72 | [ 73 | "account_id" 74 | ], 75 | "account", 76 | [ 77 | "account_id" 78 | ] 79 | ], 80 | [ 81 | "trans", 82 | [ 83 | "account_id" 84 | ], 85 | "account", 86 | [ 87 | "account_id" 88 | ] 89 | ] 90 | ] 91 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/financial/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "account"; 4 | 5 | CREATE TABLE "account" ( 6 | "account_id" integer DEFAULT 0, 7 | "district_id" integer DEFAULT 0, 8 | "frequency" varchar(18) , 9 | "date" varchar(255) , 10 | PRIMARY KEY ("account_id") 11 | ) ; 12 | 13 | DROP TABLE IF EXISTS "card"; 14 | 15 | CREATE TABLE "card" ( 16 | "card_id" integer DEFAULT 0, 17 | "disp_id" integer , 18 | "type" varchar(7) , 19 | "issued" varchar(255) , 20 | PRIMARY KEY ("card_id") 21 | ) ; 22 | 23 | DROP TABLE IF EXISTS "client"; 24 | 25 | CREATE TABLE "client" ( 26 | "client_id" integer , 27 | "gender" varchar(1) , 28 | "birth_date" varchar(255) , 29 | "district_id" integer , 30 | PRIMARY KEY ("client_id") 31 | ) ; 32 | 33 | DROP TABLE IF EXISTS "disp"; 34 | 35 | CREATE TABLE "disp" ( 36 | "disp_id" integer , 37 | "client_id" integer , 38 | "account_id" integer , 39 | "type" varchar(9) , 40 | PRIMARY KEY ("disp_id") 41 | ) ; 42 | 43 | DROP TABLE IF EXISTS "district"; 44 | 45 | CREATE TABLE "district" ( 46 | "district_id" integer DEFAULT 0, 47 | "A2" varchar(19) , 48 | "A3" varchar(15) , 49 | "A4" integer , 50 | "A5" integer , 51 | "A6" integer , 52 | "A7" integer , 53 | "A8" integer , 54 | "A9" integer , 55 | "A10" decimal(4,1) , 56 | "A11" integer , 57 | "A12" decimal(4,1) DEFAULT NULL, 58 | "A13" decimal(3,2) , 59 | "A14" integer , 60 | "A15" integer DEFAULT NULL, 61 | "A16" integer , 62 | PRIMARY KEY ("district_id") 63 | ) ; 64 | 65 | DROP TABLE IF EXISTS "loan"; 66 | 67 | CREATE TABLE "loan" ( 68 | "loan_id" integer DEFAULT 0, 69 | "account_id" integer , 70 | "date" varchar(255) , 71 | "amount" integer , 72 | "duration" integer , 73 | "payments" decimal(6,2) , 74 | "status" varchar(1) , 75 | PRIMARY KEY ("loan_id") 76 | ) ; 77 | 78 | DROP TABLE IF EXISTS "order"; 79 | 80 | CREATE TABLE "order" ( 81 | "order_id" integer DEFAULT 0, 82 | "account_id" integer , 83 | "bank_to" varchar(2) , 84 | "account_to" integer , 85 | "amount" decimal(6,1) , 86 | "k_symbol" varchar(8) , 87 | PRIMARY KEY ("order_id") 88 | ) ; 89 | 90 | DROP TABLE IF EXISTS "trans"; 91 | 92 | CREATE TABLE "trans" ( 93 | "trans_id" integer DEFAULT 0, 94 | "account_id" integer DEFAULT 0, 95 | "date" varchar(255) , 96 | "type" varchar(6) , 97 | "operation" varchar(14) DEFAULT NULL, 98 | "amount" integer , 99 | "balance" integer , 100 | "k_symbol" varchar(11) DEFAULT NULL, 101 | "bank" varchar(2) DEFAULT NULL, 102 | "account" integer DEFAULT NULL, 103 | PRIMARY KEY ("trans_id") 104 | ) ; 105 | 106 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/financial/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"card": {"type": {"freq_str_words": ["gold", "classic", "junior"]}}, "account": {"frequency": {"freq_str_words": ["POPLATEK", "MESICNE", "TYDNE", "PO", "OBRATU"]}}, "order": {"bank_to": {"freq_str_words": ["YZ", "ST", "QR", "WX", "CD", "AB", "UV", "GH", "IJ", "KL", "EF", "MN", "OP"]}, "k_symbol": {"freq_str_words": ["SIPO", "UVER", "POJISTNE", "LEASING"]}}, "trans": {"type": {"freq_str_words": ["PRIJEM", "VYDAJ", "VYBER"]}, "operation": {"freq_str_words": ["VKLAD", "PREVOD", "Z", "UCTU", "NA", "UCET", "VYBER", "KARTOU"]}, "k_symbol": {"freq_str_words": ["SIPO", "SLUZBY", "", "POJISTNE", "DUCHOD"]}, "bank": {"freq_str_words": ["AB", "YZ", "ST", "QR", "WX", "CD", "UV", "KL", "GH", "OP", "IJ", "EF", "MN"]}}, "loan": {"status": {"freq_str_words": ["A", "B", "D", "C"]}}, "client": {"gender": {"freq_str_words": ["F", "M"]}}, "disp": {"type": {"freq_str_words": ["OWNER", "DISPONENT"]}}, "district": {"A2": {"freq_str_words": ["Praha", "-", "Hradec", "Plzen", "mesto", "Usti", "nad", "Jicin", "Brno"]}, "A3": {"freq_str_words": ["central", "Bohemia", "south", "west", "north", "east", "Moravia"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/genome/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "genome", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "OBJ_CLASSES", 11 | "IMG_OBJ", 12 | "IMG_REL", 13 | "PRED_CLASSES", 14 | "ATT_CLASSES", 15 | "IMG_OBJ_ATT" 16 | ], 17 | "relationships": [ 18 | [ 19 | "IMG_OBJ", 20 | [ 21 | "OBJ_CLASS_ID" 22 | ], 23 | "OBJ_CLASSES", 24 | [ 25 | "OBJ_CLASS_ID" 26 | ] 27 | ], 28 | [ 29 | "IMG_OBJ_ATT", 30 | [ 31 | "ATT_CLASS_ID" 32 | ], 33 | "ATT_CLASSES", 34 | [ 35 | "ATT_CLASS_ID" 36 | ] 37 | ], 38 | [ 39 | "IMG_OBJ_ATT", 40 | [ 41 | "IMG_ID", 42 | "OBJ_SAMPLE_ID" 43 | ], 44 | "IMG_OBJ", 45 | [ 46 | "IMG_ID", 47 | "OBJ_SAMPLE_ID" 48 | ] 49 | ], 50 | [ 51 | "IMG_REL", 52 | [ 53 | "PRED_CLASS_ID" 54 | ], 55 | "PRED_CLASSES", 56 | [ 57 | "PRED_CLASS_ID" 58 | ] 59 | ], 60 | [ 61 | "IMG_REL", 62 | [ 63 | "IMG_ID", 64 | "OBJ1_SAMPLE_ID" 65 | ], 66 | "IMG_OBJ", 67 | [ 68 | "IMG_ID", 69 | "OBJ_SAMPLE_ID" 70 | ] 71 | ] 72 | ] 73 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/genome/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "ATT_CLASSES"; 4 | 5 | CREATE TABLE "ATT_CLASSES" ( 6 | "ATT_CLASS_ID" integer DEFAULT 0, 7 | "ATT_CLASS" char(50) , 8 | PRIMARY KEY ("ATT_CLASS_ID") 9 | ) ; 10 | 11 | DROP TABLE IF EXISTS "IMG_OBJ"; 12 | 13 | CREATE TABLE "IMG_OBJ" ( 14 | "IMG_ID" integer DEFAULT 0, 15 | "OBJ_SAMPLE_ID" integer DEFAULT 0, 16 | "OBJ_CLASS_ID" integer DEFAULT NULL, 17 | "X" integer DEFAULT NULL, 18 | "Y" integer DEFAULT NULL, 19 | "W" integer DEFAULT NULL, 20 | "H" integer DEFAULT NULL, 21 | PRIMARY KEY ("IMG_ID","OBJ_SAMPLE_ID") 22 | ) ; 23 | 24 | DROP TABLE IF EXISTS "IMG_OBJ_ATT"; 25 | 26 | CREATE TABLE "IMG_OBJ_ATT" ( 27 | "IMG_ID" integer DEFAULT 0, 28 | "ATT_CLASS_ID" integer DEFAULT 0, 29 | "OBJ_SAMPLE_ID" integer DEFAULT 0, 30 | PRIMARY KEY ("IMG_ID","ATT_CLASS_ID","OBJ_SAMPLE_ID") 31 | ) ; 32 | 33 | DROP TABLE IF EXISTS "IMG_REL"; 34 | 35 | CREATE TABLE "IMG_REL" ( 36 | "IMG_ID" integer DEFAULT 0, 37 | "PRED_CLASS_ID" integer DEFAULT 0, 38 | "OBJ1_SAMPLE_ID" integer DEFAULT 0, 39 | "OBJ2_SAMPLE_ID" integer DEFAULT 0, 40 | PRIMARY KEY ("IMG_ID","PRED_CLASS_ID","OBJ1_SAMPLE_ID","OBJ2_SAMPLE_ID") 41 | ) ; 42 | 43 | DROP TABLE IF EXISTS "OBJ_CLASSES"; 44 | 45 | CREATE TABLE "OBJ_CLASSES" ( 46 | "OBJ_CLASS_ID" integer DEFAULT 0, 47 | "OBJ_CLASS" char(50) , 48 | PRIMARY KEY ("OBJ_CLASS_ID") 49 | ) ; 50 | 51 | DROP TABLE IF EXISTS "PRED_CLASSES"; 52 | 53 | CREATE TABLE "PRED_CLASSES" ( 54 | "PRED_CLASS_ID" integer DEFAULT 0, 55 | "PRED_CLASS" char(100) , 56 | PRIMARY KEY ("PRED_CLASS_ID") 57 | ) ; 58 | 59 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/genome/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"OBJ_CLASSES": {}, "IMG_OBJ": {}, "IMG_REL": {}, "PRED_CLASSES": {"PRED_CLASS": {"freq_str_words": ["playing", "on", "looking", "a", "to", "of", "are", "driving", "side", "in", "leaning", "against", "lying", "front", "hanging", "over", "top", "covered", "sitting", "standing", "near", "walking", "by", "behind", "next", "with", "riding", "parked", "holding", "growing", "laying", "has", "at", "flying", "from", "inside"]}}, "ATT_CLASSES": {"ATT_CLASS": {"freq_str_words": ["s", "blue", "red", "and", "down", "in", "white", "dark", "light"]}}, "IMG_OBJ_ATT": {}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/hepatitis/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hepatitis", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "rel11", 11 | "dispat", 12 | "rel13", 13 | "indis", 14 | "inf", 15 | "Bio", 16 | "rel12" 17 | ], 18 | "relationships": [ 19 | [ 20 | "rel11", 21 | [ 22 | "b_id" 23 | ], 24 | "Bio", 25 | [ 26 | "b_id" 27 | ] 28 | ], 29 | [ 30 | "rel11", 31 | [ 32 | "m_id" 33 | ], 34 | "dispat", 35 | [ 36 | "m_id" 37 | ] 38 | ], 39 | [ 40 | "rel12", 41 | [ 42 | "m_id" 43 | ], 44 | "dispat", 45 | [ 46 | "m_id" 47 | ] 48 | ], 49 | [ 50 | "rel12", 51 | [ 52 | "in_id" 53 | ], 54 | "indis", 55 | [ 56 | "in_id" 57 | ] 58 | ], 59 | [ 60 | "rel13", 61 | [ 62 | "m_id" 63 | ], 64 | "dispat", 65 | [ 66 | "m_id" 67 | ] 68 | ], 69 | [ 70 | "rel13", 71 | [ 72 | "a_id" 73 | ], 74 | "inf", 75 | [ 76 | "a_id" 77 | ] 78 | ] 79 | ] 80 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/hepatitis/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "Bio"; 4 | 5 | CREATE TABLE "Bio" ( 6 | "fibros" varchar(45) , 7 | "activity" varchar(45) , 8 | "b_id" integer , 9 | PRIMARY KEY ("b_id") 10 | ) ; 11 | 12 | DROP TABLE IF EXISTS "dispat"; 13 | 14 | CREATE TABLE "dispat" ( 15 | "m_id" integer DEFAULT 0, 16 | "sex" varchar(45) DEFAULT NULL, 17 | "age" varchar(45) DEFAULT NULL, 18 | "Type" varchar(45) DEFAULT NULL, 19 | PRIMARY KEY ("m_id") 20 | ) ; 21 | 22 | DROP TABLE IF EXISTS "indis"; 23 | 24 | CREATE TABLE "indis" ( 25 | "got" varchar(10) DEFAULT NULL, 26 | "gpt" varchar(10) DEFAULT NULL, 27 | "alb" varchar(45) DEFAULT NULL, 28 | "tbil" varchar(45) DEFAULT NULL, 29 | "dbil" varchar(45) DEFAULT NULL, 30 | "che" varchar(45) DEFAULT NULL, 31 | "ttt" varchar(45) DEFAULT NULL, 32 | "ztt" varchar(45) DEFAULT NULL, 33 | "tcho" varchar(45) DEFAULT NULL, 34 | "tp" varchar(45) DEFAULT NULL, 35 | "in_id" integer , 36 | PRIMARY KEY ("in_id") 37 | ) ; 38 | 39 | DROP TABLE IF EXISTS "inf"; 40 | 41 | CREATE TABLE "inf" ( 42 | "dur" varchar(45) DEFAULT NULL, 43 | "a_id" integer DEFAULT 0, 44 | PRIMARY KEY ("a_id") 45 | ) ; 46 | 47 | DROP TABLE IF EXISTS "rel11"; 48 | 49 | CREATE TABLE "rel11" ( 50 | "b_id" integer DEFAULT 0, 51 | "m_id" integer DEFAULT 0, 52 | PRIMARY KEY ("b_id","m_id") 53 | ) ; 54 | 55 | DROP TABLE IF EXISTS "rel12"; 56 | 57 | CREATE TABLE "rel12" ( 58 | "in_id" integer DEFAULT 0, 59 | "m_id" integer DEFAULT 0, 60 | PRIMARY KEY ("in_id","m_id") 61 | ) ; 62 | 63 | DROP TABLE IF EXISTS "rel13"; 64 | 65 | CREATE TABLE "rel13" ( 66 | "a_id" integer DEFAULT 0, 67 | "m_id" integer DEFAULT 0, 68 | PRIMARY KEY ("a_id","m_id") 69 | ) ; 70 | 71 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/hepatitis/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"rel11": {}, "dispat": {}, "rel13": {}, "indis": {}, "inf": {}, "Bio": {}, "rel12": {}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/hockey/schema.json: -------------------------------------------------------------------------------- 1 | {"name": "hockey", "csv_kwargs": {"sep": "\t"}, "db_load_kwargs": {"postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;"}, "tables": ["Master", "Scoring", "SeriesPost", "TeamsSC", "GoaliesShootout", "CombinedShutouts", "AwardsPlayers", "TeamVsTeam", "TeamSplits", "Coaches", "Goalies", "GoaliesSC", "TeamsHalf", "ScoringShootout", "ScoringSC", "Teams", "ScoringSup", "AwardsCoaches", "TeamsPost"], "relationships": [["AwardsCoaches", ["coachID"], "Coaches", ["coachID"]], ["AwardsPlayers", ["playerID"], "Master", ["playerID"]], ["Coaches", ["year", "tmID"], "Teams", ["year", "tmID"]], ["CombinedShutouts", ["IDgoalie1"], "Master", ["playerID"]], ["Goalies", ["playerID"], "Master", ["playerID"]], ["Goalies", ["year", "tmID"], "Teams", ["year", "tmID"]], ["GoaliesSC", ["playerID"], "Master", ["playerID"]], ["GoaliesShootout", ["playerID"], "Master", ["playerID"]], ["Scoring", ["playerID"], "Master", ["playerID"]], ["ScoringSC", ["playerID"], "Master", ["playerID"]], ["ScoringShootout", ["playerID"], "Master", ["playerID"]], ["ScoringSup", ["playerID"], "Master", ["playerID"]], ["SeriesPost", ["year", "tmIDWinner"], "Teams", ["year", "tmID"]], ["TeamSplits", ["year", "tmID"], "Teams", ["year", "tmID"]], ["TeamVsTeam", ["year", "tmID"], "Teams", ["year", "tmID"]], ["TeamsHalf", ["tmID", "year"], "Teams", ["tmID", "year"]], ["TeamsPost", ["year", "tmID"], "Teams", ["year", "tmID"]], ["TeamsSC", ["year", "tmID"], "Teams", ["year", "tmID"]]]} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/imdb/dataset_documentation/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Documentation 2 | 3 | Schema was restricted to be acyclic such that Naru can be applied seamlessly. Also, we added the column names in the 4 | script.py -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/imdb/dataset_documentation/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | def extract_column_names(table_defs): 6 | column_names = dict() 7 | 8 | single_table_defs = table_defs.split("CREATE TABLE") 9 | for single_table in single_table_defs: 10 | alphanumeric_sequences = re.findall('\w+', single_table) 11 | if len(alphanumeric_sequences) > 0: 12 | table_name = alphanumeric_sequences[0] 13 | cols = [col.strip() for col in re.findall('\n\s+\w+', single_table)] 14 | if 'DROP' in cols: 15 | cols.remove('DROP') 16 | column_names[table_name] = cols 17 | 18 | return column_names 19 | 20 | 21 | imdb_no_header_path = '../../../../../data/datasets/imdb_no_header' 22 | imdb_path = '../../../../../data/datasets/imdb' 23 | sql_ddl_path = '../schema_sql/postgres.sql' 24 | assert os.path.exists(sql_ddl_path) 25 | assert os.path.exists(imdb_no_header_path) 26 | 27 | with open(sql_ddl_path, 'r') as file: 28 | table_defs = file.read() 29 | # This is a rather improvised function. It does not properly parse the sql but instead assumes that columns 30 | # start with a newline followed by whitespaces and table definitions start with CREATE TABLE ... 31 | column_names = extract_column_names(table_defs) 32 | 33 | print(column_names) 34 | 35 | for table in ["kind_type", "title", "cast_info", "company_name", "company_type", "info_type", "keyword", 36 | "movie_companies", "movie_info_idx", "movie_keyword", "movie_info", "person_info", "char_name", 37 | "aka_name", "name"]: 38 | print(f"Creating headers for {table}") 39 | with open(os.path.join(imdb_path, f'{table}.csv'), 'w') as outfile: 40 | with open(os.path.join(imdb_no_header_path, f'{table}.csv')) as infile: 41 | outfile.write(','.join(column_names[table]) + '\n') 42 | for line in infile: 43 | outfile.write(line) 44 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/imdb/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "imdb", 3 | "csv_kwargs": { 4 | "escapechar": "\\", 5 | "encoding": "utf-8", 6 | "quotechar": "\"", 7 | "on_bad_lines": "warn" 8 | }, 9 | "db_load_kwargs": { 10 | "postgres": "DELIMITER ',' QUOTE '\"' ESCAPE '\\' NULL '' CSV HEADER;" 11 | }, 12 | "tables": [ 13 | "title", 14 | "cast_info", 15 | "company_name", 16 | "company_type", 17 | "info_type", 18 | "keyword", 19 | "movie_companies", 20 | "movie_info_idx", 21 | "movie_keyword", 22 | "movie_info", 23 | "person_info", 24 | "kind_type", 25 | "char_name", 26 | "aka_name", 27 | "name" 28 | ], 29 | "relationships": [ 30 | [ 31 | "cast_info", 32 | "movie_id", 33 | "title", 34 | "id" 35 | ], 36 | [ 37 | "movie_companies", 38 | "company_id", 39 | "company_name", 40 | "id" 41 | ], 42 | [ 43 | "movie_companies", 44 | "company_type_id", 45 | "company_type", 46 | "id" 47 | ], 48 | [ 49 | "movie_info_idx", 50 | "info_type_id", 51 | "info_type", 52 | "id" 53 | ], 54 | [ 55 | "movie_keyword", 56 | "keyword_id", 57 | "keyword", 58 | "id" 59 | ], 60 | [ 61 | "movie_companies", 62 | "movie_id", 63 | "title", 64 | "id" 65 | ], 66 | [ 67 | "movie_info_idx", 68 | "movie_id", 69 | "title", 70 | "id" 71 | ], 72 | [ 73 | "cast_info", 74 | "person_role_id", 75 | "char_name", 76 | "id" 77 | ], 78 | [ 79 | "movie_keyword", 80 | "movie_id", 81 | "title", 82 | "id" 83 | ], 84 | [ 85 | "movie_info", 86 | "movie_id", 87 | "title", 88 | "id" 89 | ], 90 | [ 91 | "person_info", 92 | "person_id", 93 | "name", 94 | "id" 95 | ], 96 | [ 97 | "title", 98 | "kind_id", 99 | "kind_type", 100 | "id" 101 | ], 102 | [ 103 | "cast_info", 104 | "person_id", 105 | "aka_name", 106 | "id" 107 | ], 108 | [ 109 | "aka_name", 110 | "person_id", 111 | "name", 112 | "id" 113 | ] 114 | ] 115 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/imdb/table_lengths.json: -------------------------------------------------------------------------------- 1 | { 2 | "aka_name": 901343, 3 | "cast_info": 36244344, 4 | "char_name": 3140339, 5 | "company_name": 234997, 6 | "company_type": 4, 7 | "info_type": 113, 8 | "keyword": 134170, 9 | "kind_type": 7, 10 | "movie_companies": 2609129, 11 | "movie_info_idx": 1380035, 12 | "movie_keyword": 4523930, 13 | "name": 4167491, 14 | "title": 2528312, 15 | "movie_info": 14835720, 16 | "person_info": 2963664 17 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/imdb_full/dataset_documentation/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Documentation 2 | 3 | This dataset contains all tables of the imdb schema. -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/imdb_full/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "imdb", 3 | "csv_kwargs": { 4 | "escapechar": "\\", 5 | "encoding": "utf-8", 6 | "quotechar": "\"", 7 | "on_bad_lines": "warn" 8 | }, 9 | "db_load_kwargs": { 10 | "postgres": "DELIMITER ',' QUOTE '\"' ESCAPE '\\' NULL '' CSV HEADER;" 11 | }, 12 | "tables": [ 13 | "movie_link", 14 | "title", 15 | "cast_info", 16 | "company_name", 17 | "company_type", 18 | "info_type", 19 | "keyword", 20 | "movie_companies", 21 | "movie_info_idx", 22 | "movie_keyword", 23 | "complete_cast", 24 | "movie_info", 25 | "person_info", 26 | "kind_type", 27 | "link_type", 28 | "char_name", 29 | "comp_cast_type", 30 | "aka_name", 31 | "aka_title", 32 | "name", 33 | "role_type" 34 | ], 35 | "relationships": [ 36 | [ 37 | "cast_info", 38 | "movie_id", 39 | "title", 40 | "id" 41 | ], 42 | [ 43 | "movie_companies", 44 | "company_id", 45 | "company_name", 46 | "id" 47 | ], 48 | [ 49 | "movie_companies", 50 | "company_type_id", 51 | "company_type", 52 | "id" 53 | ], 54 | [ 55 | "movie_info_idx", 56 | "info_type_id", 57 | "info_type", 58 | "id" 59 | ], 60 | [ 61 | "movie_keyword", 62 | "keyword_id", 63 | "keyword", 64 | "id" 65 | ], 66 | [ 67 | "movie_companies", 68 | "movie_id", 69 | "title", 70 | "id" 71 | ], 72 | [ 73 | "movie_info_idx", 74 | "movie_id", 75 | "title", 76 | "id" 77 | ], 78 | [ 79 | "cast_info", 80 | "person_role_id", 81 | "char_name", 82 | "id" 83 | ], 84 | [ 85 | "movie_keyword", 86 | "movie_id", 87 | "title", 88 | "id" 89 | ], 90 | [ 91 | "movie_info", 92 | "movie_id", 93 | "title", 94 | "id" 95 | ], 96 | [ 97 | "person_info", 98 | "person_id", 99 | "name", 100 | "id" 101 | ], 102 | [ 103 | "title", 104 | "kind_id", 105 | "kind_type", 106 | "id" 107 | ], 108 | [ 109 | "cast_info", 110 | "person_id", 111 | "aka_name", 112 | "id" 113 | ], 114 | [ 115 | "aka_name", 116 | "person_id", 117 | "name", 118 | "id" 119 | ] 120 | ] 121 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/movielens/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "movielens", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "users", 11 | "movies2actors", 12 | "u2base", 13 | "movies2directors", 14 | "directors", 15 | "actors", 16 | "movies" 17 | ], 18 | "relationships": [ 19 | [ 20 | "movies2actors", 21 | [ 22 | "actorid" 23 | ], 24 | "actors", 25 | [ 26 | "actorid" 27 | ] 28 | ], 29 | [ 30 | "movies2actors", 31 | [ 32 | "movieid" 33 | ], 34 | "movies", 35 | [ 36 | "movieid" 37 | ] 38 | ], 39 | [ 40 | "movies2directors", 41 | [ 42 | "directorid" 43 | ], 44 | "directors", 45 | [ 46 | "directorid" 47 | ] 48 | ], 49 | [ 50 | "movies2directors", 51 | [ 52 | "movieid" 53 | ], 54 | "movies", 55 | [ 56 | "movieid" 57 | ] 58 | ], 59 | [ 60 | "u2base", 61 | [ 62 | "movieid" 63 | ], 64 | "movies", 65 | [ 66 | "movieid" 67 | ] 68 | ], 69 | [ 70 | "u2base", 71 | [ 72 | "userid" 73 | ], 74 | "users", 75 | [ 76 | "userid" 77 | ] 78 | ] 79 | ] 80 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/movielens/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "actors"; 4 | 5 | CREATE TABLE "actors" ( 6 | "actorid" integer , 7 | "a_gender" varchar(255) , 8 | "a_quality" integer , 9 | PRIMARY KEY ("actorid") 10 | ) ; 11 | 12 | DROP TABLE IF EXISTS "directors"; 13 | 14 | CREATE TABLE "directors" ( 15 | "directorid" integer , 16 | "d_quality" integer , 17 | "avg_revenue" integer , 18 | PRIMARY KEY ("directorid") 19 | ) ; 20 | 21 | DROP TABLE IF EXISTS "movies"; 22 | 23 | CREATE TABLE "movies" ( 24 | "movieid" integer DEFAULT 0, 25 | "year" integer , 26 | "isEnglish" varchar(255) , 27 | "country" varchar(50) , 28 | "runningtime" integer , 29 | PRIMARY KEY ("movieid") 30 | ) ; 31 | 32 | DROP TABLE IF EXISTS "movies2actors"; 33 | 34 | CREATE TABLE "movies2actors" ( 35 | "movieid" integer , 36 | "actorid" integer , 37 | "cast_num" integer , 38 | PRIMARY KEY ("movieid","actorid") 39 | ) ; 40 | 41 | DROP TABLE IF EXISTS "movies2directors"; 42 | 43 | CREATE TABLE "movies2directors" ( 44 | "movieid" integer , 45 | "directorid" integer , 46 | "genre" varchar(15) , 47 | PRIMARY KEY ("movieid","directorid") 48 | ) ; 49 | 50 | DROP TABLE IF EXISTS "u2base"; 51 | 52 | CREATE TABLE "u2base" ( 53 | "userid" integer DEFAULT 0, 54 | "movieid" integer , 55 | "rating" varchar(45) , 56 | PRIMARY KEY ("userid","movieid") 57 | ) ; 58 | 59 | DROP TABLE IF EXISTS "users"; 60 | 61 | CREATE TABLE "users" ( 62 | "userid" integer DEFAULT 0, 63 | "age" varchar(5) , 64 | "u_gender" varchar(5) , 65 | "occupation" varchar(45) , 66 | PRIMARY KEY ("userid") 67 | ) ; 68 | 69 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/movielens/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"users": {"u_gender": {"freq_str_words": ["F", "M"]}}, "movies2actors": {}, "u2base": {}, "movies2directors": {"genre": {"freq_str_words": ["Action", "Adventure", "Animation", "Comedy", "Crime", "Documentary", "Drama", "Horror", "Other"]}}, "directors": {}, "actors": {"a_gender": {"freq_str_words": ["M", "F"]}}, "movies": {"isEnglish": {"freq_str_words": ["T", "F"]}, "country": {"freq_str_words": ["other", "USA", "France", "UK"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/seznam/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "seznam", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "probehnuto", 11 | "dobito", 12 | "client", 13 | "probehnuto_mimo_penezenku" 14 | ], 15 | "relationships": [ 16 | [ 17 | "dobito", 18 | [ 19 | "client_id" 20 | ], 21 | "client", 22 | [ 23 | "client_id" 24 | ] 25 | ], 26 | [ 27 | "probehnuto", 28 | [ 29 | "client_id" 30 | ], 31 | "client", 32 | [ 33 | "client_id" 34 | ] 35 | ], 36 | [ 37 | "probehnuto_mimo_penezenku", 38 | [ 39 | "client_id" 40 | ], 41 | "client", 42 | [ 43 | "client_id" 44 | ] 45 | ] 46 | ] 47 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/seznam/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | 2 | 3 | DROP TABLE IF EXISTS "client"; 4 | 5 | CREATE TABLE "client" ( 6 | "client_id" integer , 7 | "kraj" varchar(255) DEFAULT NULL, 8 | "obor" varchar(255) DEFAULT NULL, 9 | PRIMARY KEY ("client_id") 10 | ) ; 11 | 12 | DROP TABLE IF EXISTS "dobito"; 13 | 14 | CREATE TABLE "dobito" ( 15 | "client_id" integer DEFAULT NULL, 16 | "month_year_datum_transakce" varchar(255) , 17 | "sluzba" varchar(255) , 18 | "kc_dobito" decimal(10,2) 19 | ) ; 20 | 21 | DROP TABLE IF EXISTS "probehnuto"; 22 | 23 | CREATE TABLE "probehnuto" ( 24 | "client_id" integer DEFAULT NULL, 25 | "month_year_datum_transakce" varchar(255) , 26 | "sluzba" varchar(255) DEFAULT NULL, 27 | "kc_proklikano" decimal(10,2) 28 | ) ; 29 | 30 | DROP TABLE IF EXISTS "probehnuto_mimo_penezenku"; 31 | 32 | CREATE TABLE "probehnuto_mimo_penezenku" ( 33 | "client_id" integer , 34 | "Month/Year" varchar(12) , 35 | "probehla_inzerce_mimo_penezenku" varchar(255) DEFAULT NULL, 36 | PRIMARY KEY ("client_id","Month/Year") 37 | ) ; 38 | 39 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/seznam/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"probehnuto": {"month_year_datum_transakce": {"freq_str_words": ["2015-10-01", "2015-01-01", "2015-02-01", "2015-08-01", "2015-05-01", "2015-09-01", "2015-07-01", "2015-06-01", "2015-04-01", "2015-03-01"]}, "sluzba": {"freq_str_words": ["h", "a"]}}, "dobito": {"month_year_datum_transakce": {"freq_str_words": ["2015-10-01", "2015-08-01", "2014-11-01", "2013-05-01", "2014-07-01", "2015-06-01", "2015-05-01", "2015-07-01", "2015-04-01", "2013-12-01", "2013-08-01", "2014-01-01", "2013-11-01", "2013-09-01", "2015-02-01", "2013-04-01", "2013-06-01", "2015-03-01", "2014-09-01", "2014-06-01", "2015-01-01", "2014-05-01", "2014-02-01", "2014-03-01", "2014-12-01", "2013-07-01", "2014-08-01", "2014-04-01", "2014-10-01", "2013-03-01", "2015-09-01", "2013-10-01"]}, "sluzba": {"freq_str_words": ["c", "a", "h"]}}, "client": {"kraj": {"freq_str_words": ["Vyso\u010dina", "Jihomoravsk\u00fd", "kraj", "Zl\u00ednsk\u00fd", "\u00dasteck\u00fd", "Kr\u00e1lov\u00e9hradeck\u00fd", "Praha", "Plze\u0148sk\u00fd", "Libereck\u00fd", "Olomouck\u00fd", "St\u0159edo\u010desk\u00fd", "Pardubick\u00fd", "Moravskoslezsk\u00fd", "Jiho\u010desk\u00fd", "Karlovarsk\u00fd"]}, "obor": {"freq_str_words": ["Vilma", "Leona", "Vladan", "Sonja", "Bohdana", "Anezka", "Veronika", "Radek", "Ozzy", "Alice", "Pink", "Herbert", "Matyas", "Gabriel", "Baltazar", "Tony", "Floyd", "Robin", "Otylie", "Zora", "Richard", "Dita", "Bozena", "Erika", "Miroslav", "Gabriela", "Josef", "Hermina", "Blazej", "Zdenek", "Hugo", "Andela", "Eduard"]}}, "probehnuto_mimo_penezenku": {"Month/Year": {"freq_str_words": ["2012-08-01", "2012-09-01", "2012-10-01", "2012-11-01", "2012-12-01", "2013-01-01", "2013-02-01", "2013-03-01", "2013-04-01", "2013-05-01", "2013-06-01", "2013-07-01", "2013-08-01", "2013-09-01", "2013-10-01", "2013-11-01", "2013-12-01", "2014-01-01", "2014-02-01", "2014-03-01"]}, "probehla_inzerce_mimo_penezenku": {"freq_str_words": ["ANO"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/ssb/dataset_documentation/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Documentation 2 | 3 | Standard SSB Dataset 4 | 5 | In the makefile change to linux 6 | 7 | ``` 8 | 9 | git clone https://github.com/gregrahn/ssb-kit.git 10 | cd ssb-kit/dbgen 11 | sed "s/MACHINE =MAC/MACHINE =LINUX/g" makefile -i 12 | sed "s/-O -DDBNAME/-DDBNAME/g" makefile -i 13 | make 14 | 15 | rm *.tbl 16 | rm *.csv 17 | 18 | # date does not work 19 | SSB_SCALE=2 20 | ./dbgen -s $SSB_SCALE -T lineorder 21 | ./dbgen -s $SSB_SCALE -T customer 22 | ./dbgen -s $SSB_SCALE -T part 23 | ./dbgen -s $SSB_SCALE -T supplier 24 | # stupid hack to circumvent IO error 25 | echo "tmp" > date.tbl 26 | ./dbgen -s $SSB_SCALE -T date 27 | 28 | for i in `ls *.tbl`; do 29 | sed 's/|$//' $i > ${i/tbl/csv} 30 | echo $i; 31 | done 32 | ``` 33 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/ssb/dataset_documentation/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | def extract_column_names(table_defs): 6 | column_names = dict() 7 | 8 | single_table_defs = table_defs.split("create table") 9 | for single_table in single_table_defs: 10 | alphanumeric_sequences = re.findall('\w+', single_table) 11 | if len(alphanumeric_sequences) > 0: 12 | table_name = alphanumeric_sequences[0] 13 | cols = [col.strip() for col in re.findall('\n\s+\w+', single_table)] 14 | if 'drop' in cols: 15 | cols.remove('drop') 16 | 17 | if 'primary' in cols: 18 | cols.remove('primary') 19 | 20 | column_names[table_name] = cols 21 | 22 | return column_names 23 | 24 | 25 | source_path = '../../../../../ssb-kit/dbgen' 26 | target = '../../../../../zero-shot-data/datasets/ssb' 27 | os.makedirs(target, exist_ok=True) 28 | sql_ddl_path = '../schema_sql/postgres.sql' 29 | assert os.path.exists(sql_ddl_path) 30 | assert os.path.exists(source_path) 31 | 32 | with open(sql_ddl_path, 'r') as file: 33 | table_defs = file.read() 34 | # This is a rather improvised function. It does not properly parse the sql but instead assumes that columns 35 | # start with a newline followed by whitespaces and table definitions start with CREATE TABLE ... 36 | column_names = extract_column_names(table_defs) 37 | 38 | print(column_names) 39 | 40 | for table in ["customer", "part", "supplier", "lineorder"]: 41 | print(f"Creating headers for {table}") 42 | with open(os.path.join(target, f'{table}.csv'), 'w') as outfile: 43 | with open(os.path.join(source_path, f'{table}.csv')) as infile: 44 | outfile.write('|'.join(column_names[table]) + '\n') 45 | for line in infile: 46 | outfile.write(line) 47 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/ssb/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ssb", 3 | "csv_kwargs": { 4 | "escapechar": "\\", 5 | "encoding": "utf-8", 6 | "quotechar": "\"", 7 | "on_bad_lines": "warn", 8 | "sep": "|" 9 | }, 10 | "db_load_kwargs": { 11 | "postgres": "DELIMITER '|' QUOTE '\"' ESCAPE '\\' NULL '' CSV HEADER;" 12 | }, 13 | "tables": [ 14 | "customer", 15 | "part", 16 | "supplier", 17 | "lineorder", 18 | "dim_date" 19 | ], 20 | "relationships": [ 21 | [ 22 | "lineorder", 23 | "lo_orderdate", 24 | "dim_date", 25 | "d_datekey" 26 | ], 27 | [ 28 | "lineorder", 29 | "lo_custkey", 30 | "customer", 31 | "c_custkey" 32 | ], 33 | [ 34 | "lineorder", 35 | "lo_partkey", 36 | "part", 37 | "p_partkey" 38 | ], 39 | [ 40 | "lineorder", 41 | "lo_suppkey", 42 | "supplier", 43 | "s_suppkey" 44 | ] 45 | ] 46 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/ssb/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | drop table if exists customer; 2 | create table customer 3 | ( 4 | c_custkey integer not null, 5 | c_name varchar(25) not null, 6 | c_address varchar(25) not null, 7 | c_city varchar(10) not null, 8 | c_nation varchar(15) not null, 9 | c_region varchar(12) not null, 10 | c_phone varchar(15) not null, 11 | c_mktsegment varchar(10) not null, 12 | primary key (c_custkey) 13 | ); 14 | 15 | drop table if exists part; 16 | create table part 17 | ( 18 | p_partkey integer not null, 19 | p_name varchar(22) not null, 20 | p_mfgr varchar(6) not null, 21 | p_category varchar(7) not null, 22 | p_brand1 varchar(9) not null, 23 | p_color varchar(11) not null, 24 | p_type varchar(25) not null, 25 | p_size integer not null, 26 | p_container varchar(10) not null, 27 | primary key (p_partkey) 28 | ); 29 | 30 | drop table if exists supplier; 31 | create table supplier 32 | ( 33 | s_suppkey integer not null, 34 | s_name varchar(25) not null, 35 | s_address varchar(25) not null, 36 | s_city varchar(10) not null, 37 | s_nation varchar(15) not null, 38 | s_region varchar(12) not null, 39 | s_phone varchar(15) not null, 40 | primary key (s_suppkey) 41 | ); 42 | 43 | drop table if exists lineorder; 44 | create table lineorder 45 | ( 46 | lo_orderkey BIGINT not null, 47 | lo_linenumber BIGINT not null, 48 | lo_custkey integer not null, 49 | lo_partkey integer not null, 50 | lo_suppkey integer not null, 51 | lo_orderdate integer not null, 52 | lo_orderpriority varchar(15) not null, 53 | lo_shippriority integer not null, 54 | lo_quantity integer not null, 55 | lo_extendedprice integer not null, 56 | lo_ordertotalprice integer not null, 57 | lo_discount integer not null, 58 | lo_revenue integer not null, 59 | lo_supplycost integer not null, 60 | lo_tax integer not null, 61 | lo_commitdate integer not null, 62 | lo_shipmode varchar(10) not null 63 | ); 64 | 65 | drop table if exists dim_date; 66 | create table dim_date ( 67 | d_datekey integer not null, 68 | d_date varchar(18) not null, 69 | d_dayofweek varchar(9) not null, 70 | d_month varchar(9) not null, 71 | d_year integer not null, 72 | d_yearmonthnum integer not null, 73 | d_yearmonth varchar(7) not null, 74 | d_daynuminweek integer not null, 75 | d_daynuminmonth integer not null, 76 | d_daynuminyear integer not null, 77 | d_monthnuminyear integer not null, 78 | d_weeknuminyear integer not null, 79 | d_sellingseason varchar(12) not null, 80 | d_lastdayinweekfl integer not null, 81 | d_lastdayinmonthfl integer not null, 82 | d_holidayfl integer not null, 83 | d_weekdayfl integer not null, 84 | primary key (d_datekey) 85 | ); 86 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tournament/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tournament", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "regular_season_compact_results", 11 | "tourney_detailed_results", 12 | "seasons", 13 | "target", 14 | "tourney_seeds", 15 | "regular_season_detailed_results", 16 | "tourney_slots", 17 | "tourney_compact_results", 18 | "teams" 19 | ], 20 | "relationships": [ 21 | [ 22 | "regular_season_compact_results", 23 | [ 24 | "season" 25 | ], 26 | "seasons", 27 | [ 28 | "season" 29 | ] 30 | ], 31 | [ 32 | "regular_season_detailed_results", 33 | [ 34 | "season" 35 | ], 36 | "seasons", 37 | [ 38 | "season" 39 | ] 40 | ], 41 | [ 42 | "regular_season_detailed_results", 43 | [ 44 | "wteam" 45 | ], 46 | "teams", 47 | [ 48 | "team_id" 49 | ] 50 | ], 51 | [ 52 | "target", 53 | [ 54 | "team_id1" 55 | ], 56 | "teams", 57 | [ 58 | "team_id" 59 | ] 60 | ], 61 | [ 62 | "tourney_compact_results", 63 | [ 64 | "wteam" 65 | ], 66 | "teams", 67 | [ 68 | "team_id" 69 | ] 70 | ], 71 | [ 72 | "tourney_detailed_results", 73 | [ 74 | "wteam" 75 | ], 76 | "teams", 77 | [ 78 | "team_id" 79 | ] 80 | ], 81 | [ 82 | "tourney_seeds", 83 | [ 84 | "team" 85 | ], 86 | "teams", 87 | [ 88 | "team_id" 89 | ] 90 | ], 91 | [ 92 | "tourney_slots", 93 | [ 94 | "season" 95 | ], 96 | "seasons", 97 | [ 98 | "season" 99 | ] 100 | ] 101 | ] 102 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tournament/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"regular_season_compact_results": {"wloc": {"freq_str_words": ["N", "H", "A"]}}, "tourney_detailed_results": {"wloc": {"freq_str_words": ["N"]}}, "seasons": {"dayzero": {"freq_str_words": ["00:00:00"]}, "regionW": {"freq_str_words": ["East", "Atlanta"]}, "regionX": {"freq_str_words": ["West", "Midwest", "Southeast", "South"]}, "regionY": {"freq_str_words": ["Midwest", "Southeast", "South"]}, "regionZ": {"freq_str_words": ["Southeast", "West", "South"]}}, "target": {}, "tourney_seeds": {"seed": {"freq_str_words": ["W11", "Z13", "X15", "X13", "Z15", "X12", "X07", "Y05", "X02", "Z06", "Z07", "X04", "Z05", "X09", "W05", "Y02", "Y10", "X08", "X05", "X10", "Z16", "W16", "Y15", "W15", "W14", "X14", "Z09", "Z10", "Z01", "X01", "Z02", "Z03", "W09", "X03", "Y08", "Y12", "X06", "Y13", "Y11", "Y03", "W04", "Y01", "W12", "Z04", "X11", "Z11", "Z08", "Y07", "W10", "Y14", "W13", "Y09", "Z12", "Z14", "W03", "Y04", "W07", "W08", "X16", "W06", "Y06", "Y16", "W02", "W01"]}}, "regular_season_detailed_results": {"wloc": {"freq_str_words": ["N", "H", "A"]}}, "tourney_slots": {"slot": {"freq_str_words": ["R1W1", "R1W2", "R1W3", "R1W4", "R1W5", "R1W6", "R1W7", "R1W8", "R1X1", "R1X2", "R1X3", "R1X4", "R1X5", "R1X6", "R1X7", "R1X8", "R1Y1", "R1Y2", "R1Y3", "R1Y4", "R1Y5", "R1Y6", "R1Y7", "R1Y8", "R1Z1", "R1Z2", "R1Z3", "R1Z4", "R1Z5", "R1Z6", "R1Z7", "R1Z8", "R2W1", "R2W2", "R2W3", "R2W4", "R2X1", "R2X2", "R2X3", "R2X4", "R2Y1", "R2Y2", "R2Y3", "R2Y4", "R2Z1", "R2Z2", "R2Z3", "R2Z4", "R3W1", "R3W2", "R3X1", "R3X2", "R3Y1", "R3Y2", "R3Z1", "R3Z2", "R4W1", "R4X1", "R4Y1", "R4Z1", "R5WX", "R5YZ", "R6CH"]}, "strongseed": {"freq_str_words": ["W01", "W02", "W03", "W04", "W05", "W06", "W07", "W08", "X01", "X02", "X03", "X04", "X05", "X06", "X07", "X08", "Y01", "Y02", "Y03", "Y04", "Y05", "Y06", "Y07", "Y08", "Z01", "Z02", "Z03", "Z04", "Z05", "Z06", "Z07", "Z08", "R1W1", "R1W2", "R1W3", "R1W4", "R1X1", "R1X2", "R1X3", "R1X4", "R1Y1", "R1Y2", "R1Y3", "R1Y4", "R1Z1", "R1Z2", "R1Z3", "R1Z4", "R2W1", "R2W2", "R2X1", "R2X2", "R2Y1", "R2Y2", "R2Z1", "R2Z2", "R3W1", "R3X1", "R3Y1", "R3Z1", "R4W1", "R4Y1", "R5WX"]}, "weakseed": {"freq_str_words": ["W16", "W15", "W14", "W13", "W12", "W11", "W10", "W09", "X16", "X15", "X14", "X13", "X12", "X11", "X10", "X09", "Y16", "Y15", "Y14", "Y13", "Y12", "Y11", "Y10", "Y09", "Z16", "Z15", "Z14", "Z13", "Z12", "Z11", "Z10", "Z09", "R1W8", "R1W7", "R1W6", "R1W5", "R1X8", "R1X7", "R1X6", "R1X5", "R1Y8", "R1Y7", "R1Y6", "R1Y5", "R1Z8", "R1Z7", "R1Z6", "R1Z5", "R2W4", "R2W3", "R2X4", "R2X3", "R2Y4", "R2Y3", "R2Z4", "R2Z3", "R3W2", "R3X2", "R3Y2", "R3Z2", "R4X1", "R4Z1", "R5YZ"]}}, "tourney_compact_results": {"wloc": {"freq_str_words": ["N"]}}, "teams": {"team_name": {"freq_str_words": ["Alabama", "A&M", "St", "Michigan", "CS", "E", "Illinois", "Kentucky", "Washington", "Carolina", "Florida", "Southern", "Tech", "Missouri", "N", "Dakota", "New", "North", "Texas", "San", "South", "Utah", "W"]}}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h/dataset_documentation/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Documentation 2 | 3 | Standard TPC-H Dataset 4 | 5 | ``` 6 | git clone git@github.com:electrum/tpch-dbgen.git 7 | make 8 | ./dbgen -s 1 9 | 10 | for i in `ls *.tbl`; do 11 | sed 's/|$//' $i > ${i/tbl/csv} 12 | echo $i; 13 | done 14 | ``` 15 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h/dataset_documentation/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | def extract_column_names(table_defs): 6 | column_names = dict() 7 | 8 | single_table_defs = table_defs.split("create table") 9 | for single_table in single_table_defs: 10 | alphanumeric_sequences = re.findall('\w+', single_table) 11 | if len(alphanumeric_sequences) > 0: 12 | table_name = alphanumeric_sequences[0] 13 | cols = [col.strip() for col in re.findall('\n\s+\w+', single_table)] 14 | if 'drop' in cols: 15 | cols.remove('drop') 16 | column_names[table_name] = cols 17 | 18 | return column_names 19 | 20 | 21 | source_path = '../../../../../tpch-dbgen' 22 | target = '../../../../../zero-shot-data/datasets/tpc_h' 23 | os.makedirs(target, exist_ok=True) 24 | sql_ddl_path = '../schema_sql/postgres.sql' 25 | assert os.path.exists(sql_ddl_path) 26 | assert os.path.exists(source_path) 27 | 28 | with open(sql_ddl_path, 'r') as file: 29 | table_defs = file.read() 30 | # This is a rather improvised function. It does not properly parse the sql but instead assumes that columns 31 | # start with a newline followed by whitespaces and table definitions start with CREATE TABLE ... 32 | column_names = extract_column_names(table_defs) 33 | 34 | print(column_names) 35 | 36 | for table in ["nation", "region", "part", "supplier", "partsupp", "customer", "orders", "lineitem"]: 37 | print(f"Creating headers for {table}") 38 | with open(os.path.join(target, f'{table}.csv'), 'w') as outfile: 39 | with open(os.path.join(source_path, f'{table}.csv')) as infile: 40 | outfile.write('|'.join(column_names[table]) + '\n') 41 | for line in infile: 42 | outfile.write(line) 43 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tpc_h", 3 | "csv_kwargs": { 4 | "escapechar": "\\", 5 | "encoding": "utf-8", 6 | "quotechar": "\"", 7 | "on_bad_lines": "warn", 8 | "sep": "|" 9 | }, 10 | "db_load_kwargs": { 11 | "postgres": "DELIMITER '|' QUOTE '\"' ESCAPE '\\' NULL '' CSV HEADER;" 12 | }, 13 | "tables": [ 14 | "nation", 15 | "region", 16 | "part", 17 | "supplier", 18 | "partsupp", 19 | "customer", 20 | "orders", 21 | "lineitem" 22 | ], 23 | "relationships": [ 24 | [ 25 | "lineitem", 26 | "l_orderkey", 27 | "orders", 28 | "o_orderkey" 29 | ], 30 | [ 31 | "orders", 32 | "o_custkey", 33 | "customer", 34 | "c_custkey" 35 | ], 36 | [ 37 | "lineitem", 38 | [ 39 | "l_partkey", 40 | "l_suppkey" 41 | ], 42 | "partsupp", 43 | [ 44 | "ps_partkey", 45 | "ps_suppkey" 46 | ] 47 | ], 48 | [ 49 | "partsupp", 50 | "ps_partkey", 51 | "part", 52 | "p_partkey" 53 | ], 54 | [ 55 | "partsupp", 56 | "ps_suppkey", 57 | "supplier", 58 | "s_suppkey" 59 | ], 60 | [ 61 | "supplier", 62 | "s_nationkey", 63 | "nation", 64 | "n_nationkey" 65 | ], 66 | [ 67 | "nation", 68 | "n_regionkey", 69 | "region", 70 | "r_regionkey" 71 | ] 72 | ] 73 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | drop table if exists nation; 2 | create table nation 3 | ( 4 | n_nationkey integer not null, 5 | n_name char(25) not null, 6 | n_regionkey integer not null, 7 | n_comment varchar(152) 8 | ); 9 | 10 | drop table if exists region; 11 | create table region 12 | ( 13 | r_regionkey integer not null, 14 | r_name char(25) not null, 15 | r_comment varchar(152) 16 | ); 17 | 18 | drop table if exists part; 19 | create table part 20 | ( 21 | p_partkey integer not null, 22 | p_name varchar(55) not null, 23 | p_mfgr char(25) not null, 24 | p_brand char(10) not null, 25 | p_type varchar(25) not null, 26 | p_size integer not null, 27 | p_container char(10) not null, 28 | p_retailprice decimal(15, 2) not null, 29 | p_comment varchar(23) not null 30 | ); 31 | 32 | drop table if exists supplier; 33 | create table supplier 34 | ( 35 | s_suppkey integer not null, 36 | s_name char(25) not null, 37 | s_address varchar(40) not null, 38 | s_nationkey integer not null, 39 | s_phone char(15) not null, 40 | s_acctbal decimal(15, 2) not null, 41 | s_comment varchar(101) not null 42 | ); 43 | 44 | drop table if exists partsupp; 45 | create table partsupp 46 | ( 47 | ps_partkey integer not null, 48 | ps_suppkey integer not null, 49 | ps_availqty integer not null, 50 | ps_supplycost decimal(15, 2) not null, 51 | ps_comment varchar(199) not null 52 | ); 53 | 54 | drop table if exists customer; 55 | create table customer 56 | ( 57 | c_custkey integer not null, 58 | c_name varchar(25) not null, 59 | c_address varchar(40) not null, 60 | c_nationkey integer not null, 61 | c_phone char(15) not null, 62 | c_acctbal decimal(15, 2) not null, 63 | c_mktsegment char(10) not null, 64 | c_comment varchar(117) not null 65 | ); 66 | 67 | drop table if exists orders; 68 | create table orders 69 | ( 70 | o_orderkey integer not null, 71 | o_custkey integer not null, 72 | o_orderstatus char(1) not null, 73 | o_totalprice decimal(15, 2) not null, 74 | o_orderdate date not null, 75 | o_orderpriority char(15) not null, 76 | o_clerk char(15) not null, 77 | o_shippriority integer not null, 78 | o_comment varchar(79) not null 79 | ); 80 | 81 | drop table if exists lineitem; 82 | create table lineitem 83 | ( 84 | l_orderkey integer not null, 85 | l_partkey integer not null, 86 | l_suppkey integer not null, 87 | l_linenumber integer not null, 88 | l_quantity decimal(15, 2) not null, 89 | l_extendedprice decimal(15, 2) not null, 90 | l_discount decimal(15, 2) not null, 91 | l_tax decimal(15, 2) not null, 92 | l_returnflag char(1) not null, 93 | l_linestatus char(1) not null, 94 | l_shipdate date not null, 95 | l_commitdate date not null, 96 | l_receiptdate date not null, 97 | l_shipinstruct char(25) not null, 98 | l_shipmode char(10) not null, 99 | l_comment varchar(44) not null 100 | ); -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h/table_lengths.json: -------------------------------------------------------------------------------- 1 | {"nation": 25, "region": 5, "part": 200000, "supplier": 10000, "partsupp": 800000, "customer": 150000, "orders": 1500000, "lineitem": 6001215} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h_pk/dataset_documentation/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Documentation 2 | 3 | Standard TPC-H Dataset 4 | 5 | ``` 6 | git clone git@github.com:electrum/tpch-dbgen.git 7 | make 8 | ./dbgen -s 1 9 | 10 | for i in `ls *.tbl`; do 11 | sed 's/|$//' $i > ${i/tbl/csv} 12 | echo $i; 13 | done 14 | ``` 15 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h_pk/dataset_documentation/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | def extract_column_names(table_defs): 6 | column_names = dict() 7 | 8 | single_table_defs = table_defs.split("create table") 9 | for single_table in single_table_defs: 10 | alphanumeric_sequences = re.findall('\w+', single_table) 11 | if len(alphanumeric_sequences) > 0: 12 | table_name = alphanumeric_sequences[0] 13 | cols = [col.strip() for col in re.findall('\n\s+\w+', single_table)] 14 | if 'drop' in cols: 15 | cols.remove('drop') 16 | column_names[table_name] = cols 17 | 18 | return column_names 19 | 20 | 21 | source_path = '../../../../../tpch-dbgen' 22 | target = '../../../../../zero-shot-data/datasets/tpc_h' 23 | os.makedirs(target, exist_ok=True) 24 | sql_ddl_path = '../schema_sql/postgres.sql' 25 | assert os.path.exists(sql_ddl_path) 26 | assert os.path.exists(source_path) 27 | 28 | with open(sql_ddl_path, 'r') as file: 29 | table_defs = file.read() 30 | # This is a rather improvised function. It does not properly parse the sql but instead assumes that columns 31 | # start with a newline followed by whitespaces and table definitions start with CREATE TABLE ... 32 | column_names = extract_column_names(table_defs) 33 | 34 | print(column_names) 35 | 36 | for table in ["nation", "region", "part", "supplier", "partsupp", "customer", "orders", "lineitem"]: 37 | print(f"Creating headers for {table}") 38 | with open(os.path.join(target, f'{table}.csv'), 'w') as outfile: 39 | with open(os.path.join(source_path, f'{table}.csv')) as infile: 40 | outfile.write('|'.join(column_names[table]) + '\n') 41 | for line in infile: 42 | outfile.write(line) 43 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h_pk/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tpc_h", 3 | "csv_kwargs": { 4 | "escapechar": "\\", 5 | "encoding": "utf-8", 6 | "quotechar": "\"", 7 | "on_bad_lines": "warn", 8 | "sep": "|" 9 | }, 10 | "db_load_kwargs": { 11 | "postgres": "DELIMITER '|' QUOTE '\"' ESCAPE '\\' NULL '' CSV HEADER;" 12 | }, 13 | "tables": [ 14 | "nation", 15 | "region", 16 | "part", 17 | "supplier", 18 | "partsupp", 19 | "customer", 20 | "orders", 21 | "lineitem" 22 | ], 23 | "relationships": [ 24 | [ 25 | "lineitem", 26 | "l_orderkey", 27 | "orders", 28 | "o_orderkey" 29 | ], 30 | [ 31 | "orders", 32 | "o_custkey", 33 | "customer", 34 | "c_custkey" 35 | ], 36 | [ 37 | "lineitem", 38 | [ 39 | "l_partkey", 40 | "l_suppkey" 41 | ], 42 | "partsupp", 43 | [ 44 | "ps_partkey", 45 | "ps_suppkey" 46 | ] 47 | ], 48 | [ 49 | "partsupp", 50 | "ps_partkey", 51 | "part", 52 | "p_partkey" 53 | ], 54 | [ 55 | "partsupp", 56 | "ps_suppkey", 57 | "supplier", 58 | "s_suppkey" 59 | ], 60 | [ 61 | "supplier", 62 | "s_nationkey", 63 | "nation", 64 | "n_nationkey" 65 | ], 66 | [ 67 | "nation", 68 | "n_regionkey", 69 | "region", 70 | "r_regionkey" 71 | ] 72 | ] 73 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h_pk/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | drop table if exists nation; 2 | create table nation 3 | ( 4 | n_nationkey integer not null, 5 | n_name char(25) not null, 6 | n_regionkey integer not null, 7 | n_comment varchar(152), 8 | PRIMARY KEY (n_nationkey) 9 | ); 10 | 11 | drop table if exists region; 12 | create table region 13 | ( 14 | r_regionkey integer not null, 15 | r_name char(25) not null, 16 | r_comment varchar(152), 17 | PRIMARY KEY (r_regionkey) 18 | ); 19 | 20 | drop table if exists part; 21 | create table part 22 | ( 23 | p_partkey integer not null, 24 | p_name varchar(55) not null, 25 | p_mfgr char(25) not null, 26 | p_brand char(10) not null, 27 | p_type varchar(25) not null, 28 | p_size integer not null, 29 | p_container char(10) not null, 30 | p_retailprice decimal(15, 2) not null, 31 | p_comment varchar(23) not null, 32 | PRIMARY KEY (p_partkey) 33 | ); 34 | 35 | drop table if exists supplier; 36 | create table supplier 37 | ( 38 | s_suppkey integer not null, 39 | s_name char(25) not null, 40 | s_address varchar(40) not null, 41 | s_nationkey integer not null, 42 | s_phone char(15) not null, 43 | s_acctbal decimal(15, 2) not null, 44 | s_comment varchar(101) not null, 45 | PRIMARY KEY (s_suppkey) 46 | ); 47 | 48 | drop table if exists partsupp; 49 | create table partsupp 50 | ( 51 | ps_partkey integer not null, 52 | ps_suppkey integer not null, 53 | ps_availqty integer not null, 54 | ps_supplycost decimal(15, 2) not null, 55 | ps_comment varchar(199) not null, 56 | PRIMARY KEY (ps_partkey, ps_suppkey) 57 | ); 58 | 59 | drop table if exists customer; 60 | create table customer 61 | ( 62 | c_custkey integer not null, 63 | c_name varchar(25) not null, 64 | c_address varchar(40) not null, 65 | c_nationkey integer not null, 66 | c_phone char(15) not null, 67 | c_acctbal decimal(15, 2) not null, 68 | c_mktsegment char(10) not null, 69 | c_comment varchar(117) not null, 70 | PRIMARY KEY (c_custkey) 71 | ); 72 | 73 | drop table if exists orders; 74 | create table orders 75 | ( 76 | o_orderkey integer not null, 77 | o_custkey integer not null, 78 | o_orderstatus char(1) not null, 79 | o_totalprice decimal(15, 2) not null, 80 | o_orderdate date not null, 81 | o_orderpriority char(15) not null, 82 | o_clerk char(15) not null, 83 | o_shippriority integer not null, 84 | o_comment varchar(79) not null, 85 | PRIMARY KEY (o_orderkey) 86 | ); 87 | 88 | drop table if exists lineitem; 89 | create table lineitem 90 | ( 91 | l_orderkey integer not null, 92 | l_partkey integer not null, 93 | l_suppkey integer not null, 94 | l_linenumber integer not null, 95 | l_quantity decimal(15, 2) not null, 96 | l_extendedprice decimal(15, 2) not null, 97 | l_discount decimal(15, 2) not null, 98 | l_tax decimal(15, 2) not null, 99 | l_returnflag char(1) not null, 100 | l_linestatus char(1) not null, 101 | l_shipdate date not null, 102 | l_commitdate date not null, 103 | l_receiptdate date not null, 104 | l_shipinstruct char(25) not null, 105 | l_shipmode char(10) not null, 106 | l_comment varchar(44) not null, 107 | PRIMARY KEY (l_orderkey, l_linenumber) 108 | ); -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/tpc_h_pk/table_lengths.json: -------------------------------------------------------------------------------- 1 | {"nation": 25, "region": 5, "part": 200000, "supplier": 10000, "partsupp": 800000, "customer": 150000, "orders": 1500000, "lineitem": 6001215} -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/walmart/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "walmart", 3 | "csv_kwargs": { 4 | "sep": "\t" 5 | }, 6 | "db_load_kwargs": { 7 | "postgres": "DELIMITER '\t' QUOTE '\"' ESCAPE '\\' NULL 'NULL' CSV HEADER;" 8 | }, 9 | "tables": [ 10 | "train", 11 | "station", 12 | "key" 13 | ], 14 | "relationships": [ 15 | [ 16 | "key", 17 | [ 18 | "station_nbr" 19 | ], 20 | "station", 21 | [ 22 | "station_nbr" 23 | ] 24 | ], 25 | [ 26 | "train", 27 | [ 28 | "store_nbr" 29 | ], 30 | "key", 31 | [ 32 | "store_nbr" 33 | ] 34 | ] 35 | ] 36 | } -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/walmart/schema_sql/postgres.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS "key"; 2 | 3 | CREATE TABLE "key" 4 | ( 5 | "store_nbr" integer, 6 | "station_nbr" integer DEFAULT NULL, 7 | PRIMARY KEY ("store_nbr") 8 | ); 9 | 10 | DROP TABLE IF EXISTS "station"; 11 | 12 | CREATE TABLE "station" 13 | ( 14 | "station_nbr" integer, 15 | PRIMARY KEY ("station_nbr") 16 | ); 17 | 18 | DROP TABLE IF EXISTS "train"; 19 | 20 | CREATE TABLE "train" 21 | ( 22 | "date" varchar(12), 23 | "store_nbr" integer, 24 | "item_nbr" integer, 25 | "units" integer DEFAULT NULL, 26 | PRIMARY KEY ("store_nbr", "date", "item_nbr") 27 | ); 28 | 29 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/datasets/walmart/string_statistics.json: -------------------------------------------------------------------------------- 1 | {"train": {}, "weather": {"codesum": {"freq_str_words": ["RA", "FZFG", "BR", "", "SN", "FG+", "FG", "UP", "HZ", "TSRA", "TS", "VCTS", "DZ"]}}, "station": {}, "key": {}} -------------------------------------------------------------------------------- /src/cross_db_benchmark/meta_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/lcm-eval/8ed11d4c47bae2cb7f0740f566170f3e736e8471/src/cross_db_benchmark/meta_tools/__init__.py -------------------------------------------------------------------------------- /src/cross_db_benchmark/meta_tools/dataset_stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cross_db_benchmark.benchmark_tools.utils import load_column_statistics, load_schema_json 4 | from cross_db_benchmark.datasets.datasets import database_list 5 | from cross_db_benchmark.meta_tools.scale_dataset import get_dataset_size 6 | from training.training.checkpoint import save_csv 7 | 8 | 9 | def generate_dataset_statistics(target, data_dir): 10 | dataset_stats = [] 11 | for db in database_list: 12 | dataset = db.db_name 13 | 14 | column_stats = load_column_statistics(dataset, namespace=False) 15 | schema = load_schema_json(dataset) 16 | 17 | size_gb = get_dataset_size(os.path.join(data_dir, db.source_dataset), schema) 18 | size_gb *= db.scale 19 | 20 | curr_stats = dict( 21 | dataset_name=db.db_name, 22 | no_tables=len(schema.tables), 23 | no_relationships=len(schema.relationships), 24 | no_columns=sum([len(column_stats[t]) for t in column_stats.keys()]), 25 | size_gb=size_gb 26 | ) 27 | dataset_stats.append(curr_stats) 28 | 29 | save_csv(dataset_stats, target) 30 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/meta_tools/download_relational_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from cross_db_benchmark.benchmark_tools.utils import load_schema_json 5 | 6 | 7 | def download_from_relational_fit(rel_fit_dataset_name, dataset_name, root_data_dir='../zero-shot-data/datasets'): 8 | schema = load_schema_json(dataset_name) 9 | data_dir = os.path.join(root_data_dir, dataset_name) 10 | os.makedirs(data_dir, exist_ok=True) 11 | 12 | for table in schema.tables: 13 | target_table_file = os.path.join(data_dir, f'{table}.csv') 14 | 15 | if not os.path.exists(target_table_file): 16 | # `{table}` 17 | download_cmd = f'echo "select * from \`{table}\`;" | mysql --host=relational.fit.cvut.cz --user=guest --password=relational {rel_fit_dataset_name} > {table}.csv' 18 | print(download_cmd) 19 | os.system(download_cmd) 20 | filesize = os.path.getsize(f"{table}.csv") 21 | if filesize == 0: 22 | print(f"Warning: file {table}.csv is empty") 23 | shutil.move(f'{table}.csv', target_table_file) 24 | else: 25 | print(f"Skipping download for {table}") 26 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/meta_tools/inflate_cardinality_errors.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 5 | from cross_db_benchmark.benchmark_tools.parse_run import dumper 6 | from cross_db_benchmark.benchmark_tools.postgres.inflate_cardinality_errors import inflate_card_errors_pg 7 | from cross_db_benchmark.benchmark_tools.utils import load_json 8 | 9 | 10 | def inflate_cardinality_errors(source_path, target_path, card_error_factor, database): 11 | assert os.path.exists(source_path) 12 | run_stats = load_json(source_path) 13 | 14 | if database == DatabaseSystem.POSTGRES: 15 | inflate_func = inflate_card_errors_pg 16 | else: 17 | raise NotImplementedError 18 | 19 | for p in run_stats.parsed_plans: 20 | inflate_func(p, card_error_factor) 21 | 22 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 23 | with open(target_path, 'w') as outfile: 24 | json.dump(run_stats, outfile, default=dumper) 25 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/meta_tools/replace_aliases.py: -------------------------------------------------------------------------------- 1 | def replace_workload_alias(dataset, source, target): 2 | with open(source, 'r') as file: 3 | src_workload = file.read() 4 | 5 | tables_aliases = raw_aliases(dataset) 6 | 7 | if len(tables_aliases) == 0: 8 | raise ValueError(f"No aliases defined for dataset {dataset}") 9 | 10 | for table, alias in tables_aliases: 11 | src_workload = src_workload.replace(f'{alias}.', f'{table}.') 12 | src_workload = src_workload.replace(f'{table} {alias}', f'{table}') 13 | 14 | with open(target, 'w') as file: 15 | file.write(src_workload) 16 | 17 | 18 | def raw_aliases(dataset): 19 | tables_aliases = [] 20 | if dataset == 'imdb': 21 | tables_aliases = [ 22 | ('title', 't'), 23 | ('movie_info_idx', 'mi_idx'), 24 | ('cast_info', 'ci'), 25 | ('movie_info', 'mi'), 26 | ('movie_keyword', 'mk'), 27 | ('movie_companies', 'mc'), 28 | ('company_name', 'cn'), 29 | ('role_type', 'rt'), 30 | ('movie_link', 'ml'), 31 | ] 32 | return tables_aliases 33 | 34 | 35 | def alias_dict(dataset): 36 | return {alias: full for full, alias in raw_aliases(dataset)} 37 | -------------------------------------------------------------------------------- /src/cross_db_benchmark/meta_tools/slice_no_tables.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | 5 | import numpy as np 6 | 7 | from cross_db_benchmark.benchmark_tools.parse_run import dumper 8 | from cross_db_benchmark.benchmark_tools.postgres.utils import plan_statistics 9 | from cross_db_benchmark.benchmark_tools.utils import load_json 10 | from training.training.checkpoint import save_csv 11 | 12 | 13 | def no_tables(p): 14 | tables, _, _ = plan_statistics(p, skip_columns=True, conv_to_dict=True) 15 | return len(tables) 16 | 17 | 18 | def slice_by_table_no(source_path, target_path, min_no_tables, max_no_tables, workload_slice_stats): 19 | assert os.path.exists(source_path) 20 | run_stats = load_json(source_path) 21 | 22 | no_table_stats = [no_tables(p) for p in run_stats.parsed_plans] 23 | no_tabs, counts = np.unique(no_table_stats, return_counts=True) 24 | for no_tab, count in zip(no_tabs, counts): 25 | print(f'No {no_tab} tables: {count}') 26 | 27 | prev_len = len(run_stats.parsed_plans) 28 | if min_no_tables is None: 29 | min_no_tables = 0 30 | if max_no_tables is None: 31 | max_no_tables = math.inf 32 | run_stats.parsed_plans = [p for p, no_tab in zip(run_stats.parsed_plans, no_table_stats) 33 | if min_no_tables <= no_tab <= max_no_tables] 34 | slice_idx = [i for i, no_tab in enumerate(no_table_stats) if min_no_tables <= no_tab <= max_no_tables] 35 | print(f'Reduced no of queries from {prev_len} to {len(run_stats.parsed_plans)}') 36 | 37 | if workload_slice_stats is not None: 38 | save_csv([dict(slice_idx=str(slice_idx))], workload_slice_stats) 39 | 40 | with open(target_path, 'w') as outfile: 41 | json.dump(run_stats, outfile, default=dumper) 42 | -------------------------------------------------------------------------------- /src/deprecated/dataset_tools.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from cross_db_benchmark.meta_tools.dataset_stats import generate_dataset_statistics 4 | from cross_db_benchmark.meta_tools.derive import derive_from_relational_fit 5 | from cross_db_benchmark.meta_tools.replace_aliases import replace_workload_alias 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dataset_name', default=None) 10 | parser.add_argument('--data_dir', default=None) 11 | parser.add_argument('--source', default=None) 12 | parser.add_argument('--target', default=None) 13 | parser.add_argument('--relational_fit_dataset_name', default=None) 14 | parser.add_argument('--replace_workload_alias', action='store_true') 15 | parser.add_argument('--generate_dataset_statistics', action='store_true') 16 | parser.add_argument('--derive_from_relational_fit', action='store_true') 17 | 18 | args = parser.parse_args() 19 | 20 | if args.derive_from_relational_fit: 21 | derive_from_relational_fit(args.relational_fit_dataset_name, args.dataset_name) 22 | 23 | if args.replace_workload_alias: 24 | replace_workload_alias(args.dataset_name, args.source, args.target) 25 | 26 | if args.generate_dataset_statistics: 27 | generate_dataset_statistics(args.target, args.data_dir) -------------------------------------------------------------------------------- /src/evaluation/plots/02_join_order_pg_act_cards.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "metadata": {}, 5 | "cell_type": "code", 6 | "source": [ 7 | "import matplotlib.pyplot as plt\n", 8 | "import seaborn as sns\n", 9 | "from classes.paths import LocalPaths" 10 | ], 11 | "id": "initial_id", 12 | "outputs": [], 13 | "execution_count": null 14 | }, 15 | { 16 | "metadata": {}, 17 | "cell_type": "code", 18 | "source": [ 19 | "path = LocalPaths().data / \"plots\" / \"optimizer_runtimes.pdf\"\n", 20 | "\n", 21 | "dark_palette = sns.color_palette(\"muted\", 9)\n", 22 | "light_palette = sns.color_palette(\"pastel\", 9)\n", 23 | "\n", 24 | "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(3, 1), sharey=True)\n", 25 | "x = ['Zero-Shot', 'Zero-Shot\\n(act. cards)', 'Sc. Postgres\\n(V13)', 'Sc. Postgres\\n(V13, act. cards)', 'Sc. Postgres\\n(V16)', 'Sc. Postgres\\n(V16, act. cards)', None, None]\n", 26 | "pg10_optimal = 445.65\n", 27 | "y = [530, 522, 518, 450, 562, 449]\n", 28 | "y = [a/pg10_optimal for a in y]\n", 29 | "\n", 30 | "# Define custom x-tick positions with extra space between the third and fourth bar\n", 31 | "custom_x_positions = [0, 1, 2, 3.5, 4.5, 5.5]\n", 32 | "plt.yticks(fontsize=13)\n", 33 | "\n", 34 | "# Set bar colors for specific bars\n", 35 | "ax.bar(custom_x_positions[0], y[2], color=dark_palette[7], edgecolor='black', label='Sc. PG10', hatch='//', width=1)\n", 36 | "ax.bar(custom_x_positions[1], y[4], color=dark_palette[8], edgecolor='black', label='Sc. PG16', hatch = '\\\\\\\\', width=1)\n", 37 | "ax.bar(custom_x_positions[2], y[0], color=dark_palette[6], edgecolor='black', label='Zero-Shot', hatch='o', width=1)\n", 38 | "ax.bar(custom_x_positions[3], y[3], color=light_palette[7], edgecolor='black', label='Sc. PG10 (act. card.)', hatch = '//', width=1)\n", 39 | "ax.bar(custom_x_positions[4], y[5], color=light_palette[8], edgecolor='black', label='Sc. PG16 (act. card.)', hatch = '\\\\\\\\', width=1)\n", 40 | "ax.bar(custom_x_positions[5], y[1], color=light_palette[6], edgecolor='black', label='Zero-Shot (act. card.)', hatch='xx', width=1)\n", 41 | "ax.set_ylabel('Relative\\nSlow-Down', fontsize=13)\n", 42 | "ax.set_xticklabels([], fontsize=13)\n", 43 | "ax.tick_params(axis='x', which='major', pad=-3)\n", 44 | "\n", 45 | "legend = ax.legend(fontsize=10,\n", 46 | " ncol=2,\n", 47 | " loc='center left',\n", 48 | " bbox_to_anchor=(-1.8, 0.45),\n", 49 | " labelspacing=0.2,\n", 50 | " edgecolor='white')\n", 51 | "\n", 52 | "ax.set_ylim(1, 1.28)\n", 53 | "ax.text(1.1, 0.95, 'Est. Cards.', ha='center', va='center', fontsize=12, backgroundcolor='gray', bbox=dict(facecolor='white', edgecolor='white', alpha=0))\n", 54 | "ax.text(4.6, 0.95, 'Act. Cards.', ha='center', va='center', fontsize=12, backgroundcolor='gray', bbox=dict(facecolor='white', edgecolor='white', alpha=0))\n", 55 | "\n", 56 | "plt.grid(True)\n", 57 | "plt.savefig(path, bbox_inches='tight')" 58 | ], 59 | "id": "2b1775c881fba282", 60 | "outputs": [], 61 | "execution_count": null 62 | } 63 | ], 64 | "metadata": { 65 | "kernelspec": { 66 | "display_name": "Python 3", 67 | "language": "python", 68 | "name": "python3" 69 | }, 70 | "language_info": { 71 | "codemirror_mode": { 72 | "name": "ipython", 73 | "version": 2 74 | }, 75 | "file_extension": ".py", 76 | "mimetype": "text/x-python", 77 | "name": "python", 78 | "nbconvert_exporter": "python", 79 | "pygments_lexer": "ipython2", 80 | "version": "2.7.6" 81 | } 82 | }, 83 | "nbformat": 4, 84 | "nbformat_minor": 5 85 | } 86 | -------------------------------------------------------------------------------- /src/evaluation/plots/05_analyze_joblight.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "id": "initial_id", 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "source": [ 10 | "from evaluation.workload_creation.create_evaluation_workloads import get_query_tables, catalan_number\n", 11 | "import sqlparse\n", 12 | "from classes.paths import LocalPaths\n", 13 | "\n", 14 | "path = LocalPaths().code / \"experiments\" / \"evaluation_workloads\" / \"imdb\" / \"job-light.sql\"\n", 15 | "with open(str(path), \"r\") as f:\n", 16 | " job_light = f.read()\n", 17 | "\n", 18 | "iterations = 0\n", 19 | "for index, query in enumerate(job_light.split(\";\")):\n", 20 | " parsed_query = sqlparse.parse(query)\n", 21 | " if parsed_query != \"\" and parsed_query != ():\n", 22 | " join_tables = get_query_tables(parsed_query[0])\n", 23 | " iterations += catalan_number(len(join_tables) - 1)\n", 24 | "print(f'{iterations} possible iterations found in total')" 25 | ], 26 | "outputs": [], 27 | "execution_count": null 28 | }, 29 | { 30 | "metadata": {}, 31 | "cell_type": "code", 32 | "source": [ 33 | "from cross_db_benchmark.benchmark_tools.utils import load_json\n", 34 | "import os\n", 35 | "directory = str(LocalPaths().parsed_plans / \"imdb\" / \"join_order_full\")\n", 36 | "\n", 37 | "runtime = 0\n", 38 | "for file in os.listdir(directory):\n", 39 | " if file.endswith(\".json\"):\n", 40 | " path = os.path.join(directory, file)\n", 41 | " print(path)\n", 42 | " wl = load_json(path)\n", 43 | " plans = wl.parsed_plans\n", 44 | " \n", 45 | " for plan in plans:\n", 46 | " runtime += plan.plan_runtime\n", 47 | " \n", 48 | "print(f'Runtime: {runtime}')" 49 | ], 50 | "id": "a2707d6b173d72be", 51 | "outputs": [], 52 | "execution_count": null 53 | } 54 | ], 55 | "metadata": { 56 | "kernelspec": { 57 | "display_name": "Python 3", 58 | "language": "python", 59 | "name": "python3" 60 | }, 61 | "language_info": { 62 | "codemirror_mode": { 63 | "name": "ipython", 64 | "version": 2 65 | }, 66 | "file_extension": ".py", 67 | "mimetype": "text/x-python", 68 | "name": "python", 69 | "nbconvert_exporter": "python", 70 | "pygments_lexer": "ipython2", 71 | "version": "2.7.6" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 5 76 | } 77 | -------------------------------------------------------------------------------- /src/evaluation/workload_creation/create_retraining_workloads.py: -------------------------------------------------------------------------------- 1 | import json 2 | from cross_db_benchmark.benchmark_tools.generate_workload import generate_workload 3 | from classes.paths import LocalPaths 4 | 5 | if __name__ == '__main__': 6 | # Read out column statistics to later filter out non-numerical columns. 7 | imdb_path = LocalPaths().dataset_path / "imdb" / "column_statistics.json" 8 | with open(imdb_path, 'r') as f: 9 | imdb_schema = json.load(f) 10 | print(imdb_schema) 11 | for workload in ["imdb"]: 12 | # Some queries will have timeouts, so generating more and delete again 13 | index_target_path = LocalPaths().workloads / 'retraining' / workload / 'index_retraining.sql' 14 | seq_target_path = LocalPaths().workloads / 'retraining' / workload / 'seq_retraining.sql' 15 | index_queries = [] 16 | seq_queries = [] 17 | cap = 1000 18 | 19 | queries = generate_workload(dataset=workload, 20 | target_path=index_target_path, 21 | num_queries=2000, 22 | min_no_predicates=1, 23 | max_no_predicates=2, 24 | max_no_aggregates=0, 25 | max_no_group_by=0, 26 | max_no_joins=0, 27 | max_cols_per_agg=1, 28 | seed=1, 29 | force=True) 30 | 31 | failing_queries = [] 32 | for query in queries: 33 | try: 34 | table = query.split('FROM')[1].split(' ')[1].replace('"', '') 35 | filter_column = query.split('WHERE')[1].split(' ')[1].replace('"', '') 36 | filter_column = filter_column.split('.')[1] 37 | 38 | datatype = imdb_schema[table][filter_column]['datatype'] 39 | if datatype in ['int', 'float']: 40 | index_name = f"idx_{table}_{filter_column}" 41 | seq_queries.append(f"/*+SeqScan({table})*/ " + query) 42 | index_queries.append(f"/*+IndexScan({table} {index_name})*/ " + query) 43 | 44 | except IndexError as e: 45 | print("Erroneous query: ", {query}) 46 | 47 | with open(index_target_path, "w") as text_file: 48 | text_file.write('\n'.join(index_queries[0:cap])) 49 | 50 | with open(seq_target_path, "w") as text_file: 51 | text_file.write('\n'.join(seq_queries[0:cap])) 52 | 53 | print(f'Generated {len(seq_queries[0:cap]) + len(index_queries[0:cap])} queries for {workload}') 54 | -------------------------------------------------------------------------------- /src/experiments/data/statistics/dataset_stats.csv: -------------------------------------------------------------------------------- 1 | dataset_name,no_tables,no_relationships,no_columns,size_gb 2 | airline,16,27,113,3.278603225015104 3 | imdb,15,14,82,3.5750512555241585 4 | ssb,4,3,41,1.1496225325390697 5 | tpc_h,8,7,61,1.0170345567166805 6 | walmart,3,2,7,0.08105922862887383 7 | financial,8,7,55,0.25412818044424057 8 | basketball,9,8,195,0.9521948173642159 9 | accidents,3,2,43,0.11672541964799166 10 | movielens,7,6,24,0.13860997557640076 11 | baseball,25,24,353,0.31771184876561165 12 | hepatitis,7,6,26,0.37561357021331787 13 | tournament,9,8,106,0.4658178426325321 14 | credit,8,7,73,0.4229858284816146 15 | employee,6,5,24,0.3953401604667306 16 | consumer,3,2,23,0.47641230560839176 17 | geneea,19,18,128,0.4906308250501752 18 | genome,6,5,20,0.42964373901486397 19 | carcinogenesis,6,5,23,0.49948817677795887 20 | seznam,4,3,14,0.13277561776340008 21 | fhnk,3,2,24,0.13427039235830307 22 | -------------------------------------------------------------------------------- /src/experiments/evaluation_workloads/res/fkindexes.sql: -------------------------------------------------------------------------------- 1 | create index company_id_movie_companies on movie_companies(company_id); 2 | create index company_type_id_movie_companies on movie_companies(company_type_id); 3 | create index info_type_id_movie_info_idx on movie_info_idx(info_type_id); 4 | create index info_type_id_movie_info on movie_info(info_type_id); 5 | create index info_type_id_person_info on person_info(info_type_id); 6 | create index keyword_id_movie_keyword on movie_keyword(keyword_id); 7 | create index kind_id_aka_title on aka_title(kind_id); 8 | create index kind_id_title on title(kind_id); 9 | create index linked_movie_id_movie_link on movie_link(linked_movie_id); 10 | create index link_type_id_movie_link on movie_link(link_type_id); 11 | create index movie_id_aka_title on aka_title(movie_id); 12 | create index movie_id_cast_info on cast_info(movie_id); 13 | create index movie_id_complete_cast on complete_cast(movie_id); 14 | create index movie_id_movie_companies on movie_companies(movie_id); 15 | create index movie_id_movie_info_idx on movie_info_idx(movie_id); 16 | create index movie_id_movie_keyword on movie_keyword(movie_id); 17 | create index movie_id_movie_link on movie_link(movie_id); 18 | create index movie_id_movie_info on movie_info(movie_id); 19 | create index person_id_aka_name on aka_name(person_id); 20 | create index person_id_cast_info on cast_info(person_id); 21 | create index person_id_person_info on person_info(person_id); 22 | create index person_role_id_cast_info on cast_info(person_role_id); 23 | create index role_id_cast_info on cast_info(role_id); 24 | vacuum analyze; -------------------------------------------------------------------------------- /src/experiments/setup/postgres/tune_hyperparameters.py: -------------------------------------------------------------------------------- 1 | from experiments.setup.utils import strip_commands 2 | 3 | def gen_tune_commands(study_name=None, 4 | workload_runs=None, 5 | statistics_file='../zero-shot-data/runs/parsed_plans/statistics_workload_10k_s0_c8220.json', 6 | n_trials=10, 7 | n_workers=16, 8 | db_host='c05.lab', 9 | db_user='postgres', 10 | db_password='postgres', 11 | cardinalities='actual', 12 | database=None, 13 | max_epoch_tuples=10000, 14 | ): 15 | workload_runs = ' '.join(workload_runs) 16 | assert study_name is not None 17 | exp_commands = [f"""python3 tune.py 18 | --workload_runs {workload_runs} 19 | --statistics_file {statistics_file} 20 | --target ../zero-shot-data/tuning/{study_name} 21 | [device_placeholder] 22 | --database {str(database)} 23 | --num_workers {n_workers} 24 | --db_user {db_user} 25 | --db_password {db_password} 26 | --db_host {db_host} 27 | --study_name {study_name} 28 | --cardinalities {cardinalities} 29 | --max_epoch_tuples {max_epoch_tuples} 30 | --n_trials 1 31 | --setup distributed 32 | --seed {i} 33 | """ for i in range(n_trials)] 34 | 35 | exp_commands = strip_commands(exp_commands) 36 | return exp_commands 37 | -------------------------------------------------------------------------------- /src/experiments/setup/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def strip_single_command(cmd): 5 | cmd = cmd.replace('\n', ' ') 6 | regex = re.compile(r"\s+", re.IGNORECASE) 7 | cmd = regex.sub(" ", cmd) 8 | return cmd 9 | 10 | 11 | def strip_commands(exp_commands): 12 | exp_commands = [strip_single_command(cmd) for cmd in exp_commands] 13 | return exp_commands 14 | -------------------------------------------------------------------------------- /src/gather_feature_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | 5 | from training.preprocessing.feature_statistics import gather_feature_statistics 6 | 7 | 8 | def get_all_json_files(directory): 9 | # Use os.path.join to make the path OS-independent 10 | if os.path.isfile(directory) and directory.endswith('.json'): 11 | json_files = [directory] 12 | else: 13 | path = os.path.join(directory, '**', '*.json') 14 | # Use glob.glob with the recursive flag set to True to find all .json files in the directory and its subdirectories 15 | json_files = glob.glob(path, recursive=True) 16 | json_files = [file for file in json_files if "feature_statistics" not in file] 17 | print(json_files) 18 | return json_files 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--database', default=None) 23 | parser.add_argument('--workload', default=None, nargs="*") 24 | parser.add_argument('--target', default=None) 25 | args = parser.parse_args() 26 | json_files = [] 27 | for workload in args.workload: 28 | json_files += get_all_json_files(f'../data/runs/json/{args.database}/{workload}') 29 | print(len(json_files)) 30 | gather_feature_statistics(json_files, args.target) 31 | -------------------------------------------------------------------------------- /src/models/dace/dace_model.py: -------------------------------------------------------------------------------- 1 | import loralib as lora 2 | import torch 3 | import torch.nn as nn 4 | 5 | from classes.classes import DACEModelConfig 6 | from training import losses 7 | 8 | 9 | class DACELora(nn.Module): 10 | """# create DACE model with lora""" 11 | def __init__(self, config: DACEModelConfig): 12 | super(DACELora, self).__init__() 13 | self.label_norm = None 14 | self.device = config.device 15 | self.config = config 16 | self.loss_fxn = losses.__dict__[config.loss_class_name](self, **config.loss_class_kwargs) 17 | 18 | self.transformer_encoder = nn.TransformerEncoder( 19 | nn.TransformerEncoderLayer( 20 | d_model=config.node_length, 21 | dim_feedforward=config.hidden_dim, 22 | nhead=1, 23 | batch_first=True, 24 | activation=config.transformer_activation, 25 | dropout=config.transformer_dropout), 26 | num_layers=1) 27 | 28 | self.node_length = config.node_length 29 | if config.mlp_activation == "ReLU": 30 | self.mlp_activation = nn.ReLU() 31 | elif config.mlp_activation == "GELU": 32 | self.mlp_activation = nn.GELU() 33 | elif config.mlp_activation == "LeakyReLU": 34 | self.mlp_activation = nn.LeakyReLU() 35 | self.mlp_hidden_dims = [128, 64, 1] 36 | 37 | self.mlp = nn.Sequential( 38 | *[lora.Linear(self.node_length, self.mlp_hidden_dims[0], r=16), 39 | nn.Dropout(config.mlp_dropout), 40 | self.mlp_activation, 41 | lora.Linear(self.mlp_hidden_dims[0], self.mlp_hidden_dims[1], r=8), 42 | nn.Dropout(config.mlp_dropout), 43 | self.mlp_activation, 44 | lora.Linear(self.mlp_hidden_dims[1], config.output_dim, r=4)]) 45 | 46 | self.sigmoid = nn.Sigmoid() 47 | 48 | def forward_batch(self, x, attn_mask=None) -> torch.Tensor: 49 | # change x shape to (batch, seq_len, input_size) from (batch, len) 50 | # one node is 18 bits 51 | x = x.view(x.shape[0], -1, self.node_length) 52 | out = self.transformer_encoder(x, mask=attn_mask) 53 | out = self.mlp(out) 54 | out = self.sigmoid(out).squeeze(dim=2) 55 | return out 56 | 57 | def forward(self, x, attn_mask=None): 58 | seq_encodings, attention_masks, loss_masks, real_run_times = x 59 | self.loss_fxn.loss_masks = loss_masks 60 | self.loss_fxn.real_run_times = real_run_times 61 | preds = self.forward_batch(seq_encodings, attention_masks) 62 | self.loss_fxn.preds = preds # we append the full prediction to the loss function 63 | predicted_runtimes = preds[:, 0] 64 | predicted_runtimes = predicted_runtimes * self.config.max_runtime / 1000 65 | return predicted_runtimes 66 | -------------------------------------------------------------------------------- /src/models/dace/dace_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def load_json(path): 7 | with open(path) as json_file: 8 | json_obj = json.load(json_file) 9 | return json_obj 10 | 11 | 12 | def set_seed(seed): 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | np.random.seed(seed) 16 | -------------------------------------------------------------------------------- /src/models/workload_driven/model/tree_lstm.py: -------------------------------------------------------------------------------- 1 | import dgl.function as fn 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class LstmConv(nn.Module): 7 | 8 | def __init__(self, hidden_dim=4): 9 | super().__init__() 10 | self.message_dim = hidden_dim // 2 11 | self.lstm_input_dim = hidden_dim 12 | # input size, hidden size 13 | self.lstm = nn.LSTMCell(self.lstm_input_dim, self.message_dim) 14 | 15 | def forward(self, graph=None, etypes=None, in_node_types=None, out_node_types=None, feat_dict=None): 16 | with graph.local_scope(): 17 | graph.ndata['h'] = feat_dict 18 | 19 | x_t = graph.ndata['h'] 20 | if len(etypes) > 0: 21 | # G_t and R_t are left and right halfs of hidden feature vectors 22 | graph.multi_update_all({etype: (fn.copy_u('h', 'm'), fn.mean('m', 'ft')) for etype in etypes}, 23 | cross_reducer='mean') 24 | # nodes without children have initialization zero which is correct 25 | rst = graph.ndata['ft'] 26 | 27 | else: 28 | rst = {n_type: torch.zeros_like(feat_dict[n_type]) for n_type in out_node_types} 29 | 30 | assert len(out_node_types) == 1 31 | out_type = list(out_node_types)[0] 32 | assert rst[out_type].shape[1] == self.message_dim * 2 33 | # slice in half to obtain G_t and R_t 34 | G_t = rst[out_type][:, :self.message_dim] 35 | R_t = rst[out_type][:, self.message_dim:] 36 | x_t = x_t[out_type] 37 | 38 | assert G_t.size(1) == R_t.size(1) == self.message_dim 39 | assert x_t.size(1) == self.lstm_input_dim 40 | G_t_2, R_t_2 = self.lstm(x_t, (G_t, R_t)) 41 | new_hidden = torch.cat((G_t_2, R_t_2), dim=1) 42 | 43 | out_dict = {out_type: new_hidden} 44 | 45 | return out_dict 46 | -------------------------------------------------------------------------------- /src/models/workload_driven/preprocessing/word_embeddings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import os 4 | import time 5 | 6 | from gensim.models import Word2Vec 7 | 8 | from cross_db_benchmark.benchmark_tools.utils import load_json 9 | 10 | 11 | def compute_word_embeddings(source, target): 12 | logging.basicConfig(format="%(levelname)s - %(asctime)s: %(message)s", datefmt='%H:%M:%S', level=logging.INFO) 13 | 14 | os.makedirs(os.path.dirname(target), exist_ok=True) 15 | sentences = load_json(source) 16 | print(f"Constructing word embeddings for {len(sentences)} sentences") 17 | 18 | # hyperparameters taken from Sun et al. 19 | cores = multiprocessing.cpu_count() 20 | w2v_model = Word2Vec(min_count=5, 21 | window=5, 22 | vector_size=500, 23 | alpha=0.03, 24 | min_alpha=0.0007, 25 | negative=20, 26 | workers=cores - 2) 27 | 28 | t = time.perf_counter() 29 | w2v_model.build_vocab(sentences, progress_per=10000) 30 | print(f"Time to build vocab: {(time.perf_counter() - t) / 60:.2f} mins") 31 | 32 | t = time.perf_counter() 33 | w2v_model.train(sentences, total_examples=w2v_model.corpus_count, epochs=30, report_delay=1) 34 | print(f"Time to train the model: {(time.perf_counter() - t) / 60:.2f} mins") 35 | 36 | # w2v_model.wv['Carrier_EV'] 37 | w2v_model.wv.save(target) 38 | print('model saved') 39 | -------------------------------------------------------------------------------- /src/models/workload_driven/tests/mscn/test_message_passing.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import numpy as np 3 | 4 | from models.workload_driven.dataset.plan_tree_batching import tensorize_feats 5 | from models.workload_driven.model.mscn_model import MSCNModel 6 | 7 | 8 | def test_plan_message(): 9 | # averaging of nodes 10 | m = MSCNModel(hidden_dim=2, input_dim_table=2, input_dim_pred=2, input_dim_join=2, 11 | input_dim_agg=2, loss_class_name='QLoss', loss_class_kwargs=dict(), device='cpu') 12 | 13 | num_nodes_dict = { 14 | 'plan': 2, 15 | 'table': 4, 16 | 'pred': 4, 17 | 'agg': 4, 18 | 'join': 4, 19 | } 20 | features = dict() 21 | for i, k in enumerate(m.pool_node_types): 22 | if k == 'plan': 23 | continue 24 | features[k] = [[1 + i, i], [i, 1 + i], [1 + i, i], [i, 1 + i]] 25 | 26 | edge_dict = { 27 | etype: [(0, 0), (1, 0), (2, 1), (3, 1)] 28 | for etype in [('table', 'table_plan', 'plan'), ('pred', 'pred_plan', 'plan'), ('agg', 'agg_plan', 'plan'), 29 | ('join', 'join_plan', 'plan')] 30 | } 31 | 32 | graph = dgl.heterograph(edge_dict, num_nodes_dict=num_nodes_dict) 33 | features = tensorize_feats(features) 34 | 35 | out = m.pool_set_features(graph, features) 36 | exp_out = np.array([[0.5000, 0.5000, 1.5000, 1.5000, 2.5000, 2.5000, 3.5000, 3.5000], 37 | [0.5000, 0.5000, 1.5000, 1.5000, 2.5000, 2.5000, 3.5000, 3.5000]]) 38 | assert np.allclose(exp_out, out.numpy()) 39 | -------------------------------------------------------------------------------- /src/models/zeroshot/message_aggregators/aggregator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.zeroshot.utils.fc_out_model import FcOutModel 4 | 5 | 6 | class MessageAggregator(FcOutModel): 7 | """ 8 | Abstract message aggregator class. Combines child messages (either using MSCN or GAT) and afterward combines the 9 | hidden state with the aggregated messages using an MLP. 10 | """ 11 | 12 | def __init__(self, test=False, **fc_out_kwargs): 13 | super().__init__(**fc_out_kwargs) 14 | self.test = test 15 | 16 | def forward(self, graph=None, etypes=None, out_node_types=None, feat_dict=None): 17 | raise NotImplementedError 18 | 19 | def combine(self, feat, out_node_types, rst): 20 | out_dict = dict() 21 | for out_type in out_node_types: 22 | if out_type in feat and out_type in rst: 23 | assert feat[out_type].shape[0] == rst[out_type].shape[0] 24 | # send through fully connected layers 25 | if not self.test: 26 | out_dict[out_type] = self.fcout(torch.cat([feat[out_type], rst[out_type]], dim=1)) 27 | # simply add in debug mode for testing 28 | else: 29 | out_dict[out_type] = feat[out_type] + rst[out_type] 30 | return out_dict 31 | -------------------------------------------------------------------------------- /src/models/zeroshot/message_aggregators/message_aggregators.py: -------------------------------------------------------------------------------- 1 | from models.zeroshot.message_aggregators.gat import GATConv 2 | from models.zeroshot.message_aggregators.mscn import MscnConv 3 | 4 | # different models supported for message aggregation 5 | MscnConv 6 | GATConv -------------------------------------------------------------------------------- /src/models/zeroshot/message_aggregators/mscn.py: -------------------------------------------------------------------------------- 1 | import dgl.function as fn 2 | 3 | from models.zeroshot.message_aggregators.aggregator import MessageAggregator 4 | 5 | 6 | class MscnConv(MessageAggregator): 7 | """ 8 | A message aggregator that sums up child messages and afterward combines them with the current hidden state of the 9 | parent node using an MLP 10 | """ 11 | 12 | def __init__(self, hidden_dim=0, **kwargs): 13 | super().__init__(input_dim=2 * hidden_dim, output_dim=hidden_dim, **kwargs) 14 | 15 | def forward(self, graph=None, etypes=None, in_node_types=None, out_node_types=None, feat_dict=None): 16 | if len(etypes) == 0: 17 | return feat_dict 18 | 19 | with graph.local_scope(): 20 | graph.ndata['h'] = feat_dict 21 | 22 | # message passing 23 | graph.multi_update_all({etype: (fn.copy_u('h', 'm'), fn.sum('m', 'ft')) for etype in etypes}, 24 | cross_reducer='sum') 25 | 26 | feat = graph.ndata['h'] 27 | rst = graph.ndata['ft'] 28 | 29 | out_dict = self.combine(feat, out_node_types, rst) 30 | return out_dict 31 | -------------------------------------------------------------------------------- /src/models/zeroshot/message_aggregators/pooling.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import dgl.function as fn 4 | from torch import nn 5 | 6 | 7 | class PoolingType(Enum): 8 | MIN = 'min' 9 | MAX = 'max' 10 | MEAN = 'mean' 11 | SUM = 'sum' 12 | 13 | def __str__(self): 14 | return self.value 15 | 16 | 17 | class PoolingConv(nn.Module): 18 | """ 19 | A pooling convolution used to combine multiple messages during message passing. 20 | """ 21 | 22 | def __init__(self, pooling_type=None): 23 | super().__init__() 24 | self.pooling_type = pooling_type 25 | 26 | # the different pooling operations which are supported are mapped to DGL functions 27 | dgl_mapping = { 28 | PoolingType.MIN: fn.min, 29 | PoolingType.MAX: fn.max, 30 | PoolingType.MEAN: fn.mean, 31 | PoolingType.SUM: fn.sum, 32 | } 33 | self.pooling_fn = dgl_mapping[self.pooling_type] 34 | 35 | def forward(self, graph=None, etypes=None, in_node_types=None, out_node_types=None, feat_dict=None): 36 | with graph.local_scope(): 37 | graph.ndata['h'] = feat_dict 38 | 39 | # message passing 40 | graph.multi_update_all({etype: (fn.copy_u('h', 'm'), self.pooling_fn('m', 'ft')) for etype in etypes}, 41 | cross_reducer=str(self.pooling_type)) 42 | 43 | feat = graph.ndata['h'] 44 | rst = graph.ndata['ft'] 45 | 46 | # simply add to avoid overwriting old values 47 | out_dict = self.combine(feat, out_node_types, rst) 48 | return out_dict 49 | 50 | def combine(self, feat, out_node_types, rst): 51 | out_dict = dict() 52 | for out_type in out_node_types: 53 | if out_type in feat and out_type in rst: 54 | out_dict[out_type] = feat[out_type] + rst[out_type] 55 | return out_dict 56 | -------------------------------------------------------------------------------- /src/models/zeroshot/specific_models/model.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 2 | from models.zeroshot.specific_models.postgres_zero_shot import PostgresZeroShotModel 3 | 4 | # dictionary with tailored model for each database system (we learn one model per system that generalizes across 5 | # databases (i.e., datasets) but on the same database system) 6 | zero_shot_models = { 7 | DatabaseSystem.POSTGRES: PostgresZeroShotModel, 8 | } 9 | -------------------------------------------------------------------------------- /src/models/zeroshot/specific_models/postgres_zero_shot.py: -------------------------------------------------------------------------------- 1 | from classes.classes import ZeroShotModelConfig 2 | from models.zeroshot.zero_shot_model import ZeroShotModel 3 | 4 | 5 | class PostgresZeroShotModel(ZeroShotModel): 6 | """ 7 | Zero-shot cost estimation model for postgres. 8 | """ 9 | def __init__(self, model_config: ZeroShotModelConfig, feature_statistics: dict, **zero_shot_kwargs): 10 | plan_featurization, encoders = None, None 11 | if model_config.featurization is not None: 12 | plan_featurization = model_config.featurization 13 | 14 | # define the MLPs for the different node types in the graph representation of queries 15 | encoders = [ 16 | ('column', plan_featurization.COLUMN_FEATURES), 17 | ('table', plan_featurization.TABLE_FEATURES), 18 | ('output_column', plan_featurization.OUTPUT_COLUMN_FEATURES), 19 | ('filter_column', plan_featurization.FILTER_FEATURES + plan_featurization.COLUMN_FEATURES), 20 | ('plan', plan_featurization.PLAN_FEATURES), 21 | ('logical_pred', plan_featurization.FILTER_FEATURES), 22 | ] 23 | 24 | # define messages passing which is peculiar for postgres 25 | prepasses = [dict(model_name='column_output_column', e_name='col_output_col')] 26 | tree_model_types = ['column_output_column'] 27 | 28 | super().__init__(plan_featurization=plan_featurization, 29 | encoders=encoders, 30 | prepasses=prepasses, 31 | add_tree_model_types=tree_model_types, 32 | model_config=model_config, 33 | feature_statistics=feature_statistics) -------------------------------------------------------------------------------- /src/models/zeroshot/utils/activations.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import LeakyReLU, CELU, SELU 3 | from torch.nn import functional as F 4 | 5 | LeakyReLU 6 | CELU 7 | SELU 8 | 9 | 10 | class GELU(nn.Module): 11 | def forward(self, input): 12 | return F.gelu(input) 13 | -------------------------------------------------------------------------------- /src/models/zeroshot/utils/embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class EmbeddingInitializer(nn.Module): 6 | """ 7 | Wrapper to generate a learnable embedding. Used whenever a categorical variable should be represented in zero-shot 8 | models. 9 | """ 10 | 11 | def __init__(self, num_embeddings, max_emb_dim, p_dropout, minimize_emb_dim=True, drop_whole_embeddings=False, 12 | one_hot=False): 13 | """ 14 | :param minimize_emb_dim: 15 | Whether to set embedding_dim = max_emb_dim or to make embedding_dim smaller is num_embeddings is small 16 | :param drop_whole_embeddings: 17 | If True, dropout pretends the embedding was a missing value. If false, dropout sets embed features to 0 18 | :param drop_whole_embeddings: 19 | If True, one-hot encode variables whose cardinality is < max_emb_dim. Also, set reqiures_grad = False 20 | """ 21 | super().__init__() 22 | self.p_dropout = p_dropout 23 | self.drop_whole_embeddings = drop_whole_embeddings 24 | if minimize_emb_dim: 25 | self.emb_dim = min(max_emb_dim, num_embeddings) # Don't use a crazy huge embedding if not needed 26 | else: 27 | self.emb_dim = max_emb_dim 28 | # Note: if you change the name of self.embed, or initialize an embedding elsewhere in a model, 29 | # the function get_optim_no_wd_on_embeddings will not work properly 30 | self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=self.emb_dim) 31 | self.embed.weight.data.clamp_(-2, 2) # fastai uses truncated normal init 32 | 33 | if one_hot: 34 | if num_embeddings <= max_emb_dim: 35 | torch.eye(self.emb_dim, out=self.embed.weight.data) 36 | self.do = nn.Dropout(p=p_dropout) 37 | 38 | def forward(self, input): 39 | if self.drop_whole_embeddings and self.training: 40 | mask = torch.zeros_like(input).bernoulli_(1 - self.p_dropout) 41 | input = input * mask 42 | out = self.embed(input) 43 | if not self.drop_whole_embeddings: 44 | out = self.do(out) 45 | return out 46 | -------------------------------------------------------------------------------- /src/models/zeroshot/utils/node_type_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from training.preprocessing.feature_statistics import FeatureType 8 | from models.zeroshot.utils.embeddings import EmbeddingInitializer 9 | from models.zeroshot.utils.fc_out_model import FcOutModel 10 | 11 | 12 | class NodeTypeEncoder(FcOutModel): 13 | """ 14 | Model to encode one type of nodes in the graph (with particular features) 15 | """ 16 | 17 | def __init__(self, features: List[str], feature_statistics: dict, max_emb_dim=32, drop_whole_embeddings=False, 18 | one_hot_embeddings=True, **kwargs): 19 | 20 | for f in features: 21 | if f not in feature_statistics: 22 | raise ValueError(f"Did not find {f} in feature statistics") 23 | 24 | self.features = features 25 | self.feature_types = [FeatureType[feature_statistics[feat]['type']] for feat in features] 26 | self.feature_idxs = [] 27 | 28 | # initialize embeddings and input dimension 29 | 30 | self.input_dim = 0 31 | self.input_feature_idx = 0 32 | embeddings = dict() 33 | for i, (feat, type) in enumerate(zip(self.features, self.feature_types)): 34 | if type == FeatureType.numeric: 35 | # a single value is encoded here 36 | self.feature_idxs.append(np.arange(self.input_feature_idx, self.input_feature_idx + 1)) 37 | self.input_feature_idx += 1 38 | 39 | self.input_dim += 1 40 | elif type == FeatureType.categorical: 41 | # similarly, a single value is encoded here 42 | self.feature_idxs.append(np.arange(self.input_feature_idx, self.input_feature_idx + 1)) 43 | self.input_feature_idx += 1 44 | 45 | embd = EmbeddingInitializer(num_embeddings=feature_statistics[feat]['no_vals'], 46 | max_emb_dim=max_emb_dim, 47 | p_dropout=kwargs['p_dropout'], 48 | drop_whole_embeddings=drop_whole_embeddings, 49 | one_hot=one_hot_embeddings) 50 | embeddings[feat] = embd 51 | self.input_dim += embd.emb_dim 52 | else: 53 | raise NotImplementedError 54 | super().__init__(input_dim=self.input_dim, **kwargs) 55 | 56 | self.embeddings = nn.ModuleDict(embeddings) 57 | 58 | def forward(self, input): 59 | if self.no_input_required: 60 | return self.replacement_param.repeat(input.shape[0], 1) 61 | 62 | assert input.shape[1] == self.input_feature_idx 63 | encoded_input = [] 64 | for feat, feat_type, feat_idxs in zip(self.features, self.feature_types, self.feature_idxs): 65 | feat_data = input[:, feat_idxs] 66 | 67 | if feat_type == FeatureType.numeric: 68 | encoded_input.append(feat_data) 69 | elif feat_type == FeatureType.categorical: 70 | feat_data = torch.reshape(feat_data, (-1,)) 71 | embd_data = self.embeddings[feat](feat_data.long()) 72 | encoded_input.append(embd_data) 73 | else: 74 | raise NotImplementedError 75 | 76 | input_enc = torch.cat(encoded_input, dim=1) 77 | 78 | return self.fcout(input_enc) 79 | -------------------------------------------------------------------------------- /src/parse_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import multiprocessing as mp 4 | import os 5 | import time 6 | from os.path import relpath 7 | 8 | import pandas as pd 9 | 10 | from cross_db_benchmark.benchmark_tools.database import DatabaseSystem 11 | from cross_db_benchmark.benchmark_tools.parse_run import parse_run 12 | 13 | 14 | def compute(input): 15 | source, target, d, wl, parse_baseline, cap_queries, max_query_ms, include_zero_card, explain_only = input 16 | if parse_baseline: 17 | parse_join_conds = False 18 | else: 19 | parse_join_conds = True 20 | 21 | no_plans, stats = parse_run(source, target, DatabaseSystem.POSTGRES, min_query_ms=0, cap_queries=cap_queries, 22 | parse_baseline=parse_baseline, parse_join_conds=parse_join_conds, max_query_ms=max_query_ms, 23 | include_zero_card=include_zero_card, explain_only=explain_only) 24 | return dict(dataset=d, workload=wl, no_plans=no_plans, **stats) 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--raw_dir', default=None) 30 | parser.add_argument('--parsed_plan_dir', default=None) 31 | parser.add_argument('--parsed_plan_dir_baseline', default=None) 32 | parser.add_argument('--combine', default=None) 33 | parser.add_argument('--target_stats_path', default=None) 34 | parser.add_argument('--workload_prefix', default='') 35 | parser.add_argument('--workloads', nargs='+', default=None) 36 | parser.add_argument('--min_query_ms', default=100, type=int) 37 | parser.add_argument('--max_query_ms', default=30000, type=int) # 30s 38 | parser.add_argument('--cap_queries', default=5000, type=int) 39 | parser.add_argument('--include_zero_card', action='store_true') 40 | parser.add_argument('--explain_only', action='store_true') 41 | parser.add_argument('--database', default=DatabaseSystem.POSTGRES, type=DatabaseSystem, 42 | choices=list(DatabaseSystem)) 43 | args = parser.parse_args() 44 | 45 | cap_queries = args.cap_queries 46 | if cap_queries == 'None': 47 | cap_queries = None 48 | 49 | explain_only = False 50 | if args.explain_only: 51 | explain_only = True 52 | setups = [] 53 | for path, subdirs, files in os.walk(args.raw_dir): 54 | for workload_name in files: 55 | wl_path = relpath(path, args.raw_dir) 56 | source = os.path.join(args.raw_dir, wl_path, workload_name) #, path, workload_name) 57 | target = os.path.join(args.parsed_plan_dir, wl_path, workload_name) #, path, workload_name) 58 | 59 | setups.append((source, target, "postgres", workload_name, False, cap_queries, args.max_query_ms, args.include_zero_card, explain_only)) 60 | 61 | target = os.path.join(args.parsed_plan_dir_baseline, wl_path, workload_name) 62 | setups.append((source, target, "postgres", workload_name, True, cap_queries, args.max_query_ms, args.include_zero_card, explain_only)) 63 | 64 | 65 | start_t = time.perf_counter() 66 | proc = multiprocessing.cpu_count() - 2 67 | p = mp.Pool(initargs=('arg',), processes=proc) 68 | eval = p.map(compute, setups) 69 | 70 | eval = pd.DataFrame(eval) 71 | if args.target_stats_path is not None: 72 | eval.to_csv(args.target_stats_path, index=False) 73 | 74 | print() 75 | print(eval[['dataset', 'workload', 'no_plans']].to_string(index=False)) 76 | 77 | print() 78 | print(eval[['workload', 'no_plans']].groupby('workload').sum().to_string()) 79 | 80 | print() 81 | print(f"Total plans parsed in {time.perf_counter() - start_t:.2f} secs: {eval.no_plans.sum()}") 82 | -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | from dgl.dataloading import DataLoader 5 | from torch import nn 6 | from wandb.apis.public import Run 7 | 8 | from classes.classes import ModelConfig 9 | from classes.workload_runs import WorkloadRuns 10 | from training.training.checkpoint import save_csv 11 | from training.training.metrics import Metric 12 | from training.training.train import validate_model 13 | from training.training.utils import find_early_stopping_metric 14 | 15 | 16 | def predict_model(mode: str, 17 | config: ModelConfig, 18 | test_loaders: List[DataLoader], 19 | workload_runs: WorkloadRuns, 20 | model_dir: Path, 21 | metrics: List[Metric], 22 | model: nn.Module, 23 | epoch: int, 24 | run: Run): 25 | 26 | if test_loaders is not None: 27 | if not (model_dir is None or config.name.NAME is None): 28 | for test_path, test_loader in zip(workload_runs.target_test_csv_paths, test_loaders): 29 | print(f"Starting validation for {test_path}") 30 | test_stats = dict() 31 | 32 | # In case of retraining, do not load the totally best model but the latest one 33 | if mode != "retrain": 34 | early_stop_m = find_early_stopping_metric(metrics) 35 | model.load_state_dict(early_stop_m.best_model) 36 | 37 | validate_model(config=config, 38 | val_loader=test_loader, 39 | model=model, 40 | epoch=epoch, 41 | epoch_stats=test_stats, 42 | metrics=metrics, 43 | custom_batch_to=config.batch_to_func, 44 | log_all_queries=True, 45 | run=run, 46 | model_dir=model_dir, 47 | target_path=test_path) 48 | 49 | save_csv([test_stats], str(test_path) + "_test_stats.csv") 50 | else: 51 | print("Skipping saving the test stats") 52 | -------------------------------------------------------------------------------- /src/scripts/exp_runner/exp_osf_upload.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from octopus.step import Rsync, KillAllScreens 4 | from classes.classes import TrainingServers 5 | from classes.paths import LocalPaths, ClusterPaths 6 | 7 | from scripts.exp_runner.exp_runner import ExpRunner 8 | 9 | 10 | # This script is used to upload the submission data to the OSF project 11 | 12 | if __name__ == '__main__': 13 | setup_steps = [Rsync(src=[str(LocalPaths().workloads)], dest=[ClusterPaths().data], update=True, put=True), 14 | Rsync(src=[str(LocalPaths().runs)], dest=[ClusterPaths().data], update=True, put=True)] 15 | 16 | purge_steps = [KillAllScreens()] 17 | pickup_steps = [Rsync(src=[ClusterPaths().evaluation], dest=[LocalPaths().data], put=False, update=True)] 18 | 19 | node = TrainingServers().NODE03 20 | runner = ExpRunner(replicate=False, 21 | node_names=[node["hostname"]], 22 | root_path=ClusterPaths().root, 23 | python_version=node["python"]) 24 | 25 | username = os.getenv('OSF_USERNAME') 26 | project = os.getenv('OSF_SUBMISSION_PROJECT') 27 | password = os.getenv('OSF_PASSWORD') 28 | 29 | source_paths = [ClusterPaths().evaluation, ClusterPaths().models, ClusterPaths().runs] 30 | commands = [] 31 | for source in source_paths: 32 | target = str(source).replace(str(ClusterPaths().data), "") 33 | commands.append(f"OSF_PASSWORD={password} osf -p {project} -u {username} upload -r {source}/ {target} ") 34 | 35 | runner.run_exp(node_names=[node["hostname"]], 36 | commands=commands, 37 | set_up_steps=setup_steps, 38 | purge_steps=purge_steps, 39 | pickup_steps=pickup_steps, 40 | screens_per_node=1, 41 | offset=0) 42 | -------------------------------------------------------------------------------- /src/scripts/exp_runner/exp_remove_workload.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from typing import List 3 | 4 | from octopus.step import Step, Remove 5 | 6 | from cross_db_benchmark.datasets.datasets import Database 7 | from classes.classes import MODEL_CONFIGS, TrainingServers 8 | from classes.paths import LocalPaths, ClusterPaths 9 | from scripts.exp_runner.exp_runner import ExpRunner 10 | 11 | if __name__ == '__main__': 12 | """ This script removes a certain workload from localhost and cloudlab machines""" 13 | 14 | node = TrainingServers.NODE00 15 | 16 | runner = ExpRunner(replicate=True, 17 | node_names=[node['hostname']], 18 | python_version=node['python'], 19 | root_path=ClusterPaths().root) 20 | 21 | databases = [Database("tpc-h"), Database("imdb"), Database("baseball")] 22 | study = "agg_range_filter" 23 | 24 | for paths in [LocalPaths(), ClusterPaths()]: 25 | paths_to_remove = [] 26 | for database in databases: 27 | paths_to_remove += [ 28 | paths.augmented_plans_baseline / database.db_name / study, 29 | paths.parsed_plans / database.db_name / study, 30 | paths.parsed_plans_baseline / database.db_name / study, 31 | paths.json / database.db_name / study, 32 | paths.raw / database.db_name / study] 33 | 34 | for model in MODEL_CONFIGS: 35 | paths_to_remove.append(paths.evaluation / model.name.NAME / database.db_name / study) 36 | print(paths_to_remove) 37 | 38 | if isinstance(paths, LocalPaths): 39 | print("Removing from localhost") 40 | for path in paths_to_remove: 41 | shutil.rmtree(path, ignore_errors=True) 42 | 43 | if isinstance(paths, ClusterPaths): 44 | print("Removing from cluster") 45 | setup_steps: List[Step] = [] 46 | for path in paths_to_remove: 47 | setup_steps.append(Remove(files=str(path), directory=True)) 48 | #setup_steps: List[Step] = [Remove(files=" ".join([str(p) for p in paths_to_remove]), directory=True)] 49 | runner.run_exp(node_names=[node['hostname']], 50 | commands=[], 51 | set_up_steps=setup_steps, 52 | purge_steps=[], 53 | pickup_steps=[]) 54 | -------------------------------------------------------------------------------- /src/scripts/exp_runner/exp_setup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from octopus.step import Cmd, KillAllScreens, SetupVenv, LocalCmd, \ 4 | Step, Rsync 5 | 6 | from classes.paths import LocalPaths, CloudlabPaths 7 | from scripts.exp_runner.exp_runner import ExpRunner 8 | 9 | if __name__ == '__main__': 10 | """ 11 | This script sets-up various machines at cloudlab. 12 | Currently this needs to be executed twice as the target machines needs to be rebooted in between. 13 | """ 14 | # Read nodenames from file 15 | with open(LocalPaths().node_list, 'r') as f: 16 | node_names = f.read().splitlines() 17 | 18 | runner = ExpRunner(replicate=True, node_names=node_names, python_version="3.9") 19 | postgres_version = "16" 20 | 21 | commands = [ 22 | f'python3.9 setup.py ' 23 | f'--data_dir ~/cost-eval/data/datasets/ ' 24 | f'--osf_username {runner.osf_username} ' 25 | f'--osf_password {runner.osf_password} ' 26 | f'--osf_project {runner.osf_project} ' 27 | f'--database_conn {runner.database_conn}' 28 | ] 29 | 30 | setup_steps: List[Step] = [ 31 | # Download server keys 32 | LocalCmd(cmd=' && '.join([f'ssh-keyscan -H {node} ' f'>> {LocalPaths().known_hosts}' for node in node_names])), 33 | # Rsync the current repository 34 | Rsync(src=[str(LocalPaths().code)], dest=[CloudlabPaths().root], update=True, put=True), 35 | Rsync(src=[str(LocalPaths().requirements)], dest=[CloudlabPaths().root], update=True, put=True), 36 | # Resize the disk if it is a cloudlab instance 37 | Cmd(cmd=f'{CloudlabPaths().code}/scripts/postgres_installation/resize_partition.sh'), 38 | # Continue resize the disk if it is a cloudlab instance 39 | Cmd(cmd=f'{CloudlabPaths().code}/scripts/postgres_installation/resize_partition_cont.sh'), 40 | # Install tools 41 | Cmd(cmd=f'{CloudlabPaths().code}/scripts/postgres_installation/install_tools.sh'), 42 | # Install postgres 43 | Cmd(cmd=f'{CloudlabPaths().code}/scripts/postgres_installation/install_postgres_{postgres_version}.sh'), 44 | # Setup virtual environment 45 | SetupVenv(use_requirements_txt=True, 46 | requirements_txt_filename=f'{CloudlabPaths().root}/requirements/requirements_cloudlab.txt', 47 | force=False, 48 | python_cmd=f'python{runner.python_version}', 49 | python_version=runner.python_version) 50 | ] 51 | 52 | purge_steps: List[Step] = [KillAllScreens(), 53 | Cmd(cmd='sudo service postgresql restart')] 54 | 55 | runner.run_exp(node_names=node_names, 56 | commands=commands, 57 | set_up_steps=setup_steps, 58 | purge_steps=purge_steps, 59 | pickup_steps=[]) 60 | -------------------------------------------------------------------------------- /src/scripts/misc/parse_cloudlab_manifest.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from xml.dom.minidom import parse 4 | 5 | 6 | def main(rspec_path: str, output_path: str, keypath: str, identifier: str = "node"): 7 | if not os.path.exists(rspec_path): 8 | raise FileNotFoundError("No rspec found at " + rspec_path) 9 | 10 | # parse rspec document 11 | rspec_document = parse(rspec_path) 12 | rspec_nodes = rspec_document.getElementsByTagName("node") 13 | 14 | # Obtain relevant node information from rspec nodes 15 | nodes_list = [] 16 | for node in rspec_nodes: 17 | node_info = dict() 18 | node_info["hostname"] = node.getElementsByTagName("services")[0].childNodes[1]._attrs['hostname']._value 19 | node_info["id"] = node.getAttributeNode("client_id").value 20 | node_info.update(node.getElementsByTagName("services")[0].getElementsByTagName("login")[0].attributes.items()) 21 | node_info["shortname"] = identifier + node_info["id"].split("_")[1] 22 | print("Found node with settings: " + str(node_info)) 23 | nodes_list.append(node_info) 24 | 25 | generate_sshconf(nodes_list, output_path, keypath) 26 | generate_hostname_list(nodes_list, output_path) 27 | 28 | 29 | def generate_hostname_list(nodes: list, output: str): 30 | path = output + "/hostnames" 31 | open(path, 'w').close() 32 | with open(path, 'a') as file: 33 | for node_info in nodes: 34 | file.write(node_info["hostname"] + "\n") 35 | 36 | 37 | def generate_sshconf(nodes: list, output: str, keypath: str): 38 | path = output + "/sshconfig" 39 | open(path, 'w').close() 40 | # write contents to file 41 | with open(path, 'a') as file: 42 | for node_info in nodes: 43 | file.write("Host " + node_info["shortname"] + "\n") 44 | file.write("Hostname " + node_info["hostname"] + "\n") 45 | file.write("User " + node_info["username"] + "\n") 46 | file.write("IdentityFile " + keypath + "\n") 47 | file.write("\n") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser(description='Converting Cloud-Lab rspec to ssh-config file') 52 | parser.add_argument("--input", required=True, type=str, help="Path to manifest file") 53 | parser.add_argument("--output", required=True, type=str, help="Output directory where to store ssh-files") 54 | parser.add_argument("--identifier", required=True, type=str, 55 | help="SSH-identifier to name the cluster and the instances. " 56 | "Example: 'node' will give: node0, node1, ... as instance names") 57 | parser.add_argument("--keypath", required=True, type=str, help="Path to ssh-key") 58 | args = parser.parse_args() 59 | main(args.input, args.output, args.keypath, args.identifier) 60 | print("Successful!") -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_postgres.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_DONE ]] 7 | then 8 | echo "skip installation" 9 | else 10 | 11 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 12 | echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main" | sudo tee /etc/apt/sources.list.d/postgresql-pgdg.list > /dev/null 13 | sudo apt update 14 | sudo apt install -y postgresql-14 15 | sudo apt install -y postgresql-server-dev-14 16 | sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'bM2YGRAX*bG_QAilUid§2iD';" 17 | sudo service postgresql restart 18 | sudo apt install gcc 19 | sudo apt install make 20 | wget https://github.com/ossc-db/pg_hint_plan/archive/refs/tags/REL14_1_4_0.tar.gz 21 | tar xzvf REL14_1_4_0.tar.gz 22 | cd pg_hint_plan-REL14_1_4_0 23 | make 24 | sudo make install 25 | cd .. 26 | rm REL14_1_4_0.tar.gz 27 | sudo service postgresql restart 28 | sudo cp cost-eval/src/conf/modified-postgresql14.conf /etc/postgresql/14/main/postgresql.conf 29 | sudo cpcost-eval/src/conf/pg_hba.conf /etc/postgresql/14/main/pg_hba.conf 30 | sudo service postgresql restart 31 | touch FLAG_INSTALL_DONE 32 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_postgres_10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_DONE ]] 7 | then 8 | echo "skip installation" 9 | else 10 | 11 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 12 | echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main" | sudo tee /etc/apt/sources.list.d/postgresql-pgdg.list > /dev/null 13 | sudo apt update 14 | sudo apt install -y postgresql-10 15 | sudo apt install -y postgresql-server-dev-10 16 | sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'bM2YGRAX*bG_QAilUid§2iD';" 17 | sudo service postgresql restart 18 | sudo apt install gcc 19 | sudo apt install make 20 | wget https://github.com/ossc-db/pg_hint_plan/archive/refs/tags/REL10_1_3_7.tar.gz 21 | tar xzvf REL10_1_3_7.tar.gz 22 | cd pg_hint_plan-REL10_1_3_7 23 | make 24 | sudo make install 25 | cd .. 26 | rm REL10_1_3_7.tar.gz 27 | sudo service postgresql restart 28 | sudo cp cost-eval/src/conf/postgres/modified-postgresql10.conf /etc/postgresql/10/main/postgresql.conf 29 | sudo cp cost-eval/src/conf/postgres/pg_hba.conf /etc/postgresql/10/main/pg_hba.conf 30 | sudo service postgresql restart 31 | touch FLAG_INSTALL_DONE 32 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_postgres_11.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_DONE ]] 7 | then 8 | echo "skip installation" 9 | else 10 | 11 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 12 | echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main" | sudo tee /etc/apt/sources.list.d/postgresql-pgdg.list > /dev/null 13 | sudo apt update 14 | sudo apt install -y postgresql-11 15 | sudo apt install -y postgresql-server-dev-11 16 | sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'bM2YGRAX*bG_QAilUid§2iD';" 17 | sudo service postgresql restart 18 | sudo apt install gcc 19 | sudo apt install make 20 | wget https://github.com/ossc-db/pg_hint_plan/archive/refs/tags/REL11_1_3_9.tar.gz 21 | tar xzvf REL11_1_3_9.tar.gz 22 | cd pg_hint_plan-REL11_1_3_9 23 | make 24 | sudo make install 25 | cd .. 26 | rm REL11_1_3_9.tar.gz 27 | sudo service postgresql restart 28 | sudo cp cost-eval/src/conf/postgres/modified-postgresql11.conf /etc/postgresql/11/main/postgresql.conf 29 | sudo cp cost-eval/src/conf/postgres/pg_hba.conf /etc/postgresql/11/main/pg_hba.conf 30 | sudo service postgresql restart 31 | touch FLAG_INSTALL_DONE 32 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_postgres_12.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_DONE ]] 7 | then 8 | echo "skip installation" 9 | else 10 | 11 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 12 | echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main" | sudo tee /etc/apt/sources.list.d/postgresql-pgdg.list > /dev/null 13 | sudo apt update 14 | sudo apt install -y postgresql-12 15 | sudo apt install -y postgresql-server-dev-12 16 | sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'bM2YGRAX*bG_QAilUid§2iD';" 17 | sudo service postgresql restart 18 | sudo apt install gcc 19 | sudo apt install make 20 | wget https://github.com/ossc-db/pg_hint_plan/archive/refs/tags/REL12_1_3_9.tar.gz 21 | tar xzvf REL12_1_3_9.tar.gz 22 | cd pg_hint_plan-REL12_1_3_9 23 | make 24 | sudo make install 25 | cd .. 26 | rm REL12_1_3_9.tar.gz 27 | sudo service postgresql restart 28 | sudo cp cost-eval/src/conf/postgres/modified-postgresql12.conf /etc/postgresql/12/main/postgresql.conf 29 | sudo cp cost-eval/src/conf/postgres/pg_hba.conf /etc/postgresql/12/main/pg_hba.conf 30 | sudo service postgresql restart 31 | touch FLAG_INSTALL_DONE 32 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_postgres_13.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_DONE ]] 7 | then 8 | echo "skip installation" 9 | else 10 | 11 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 12 | echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main" | sudo tee /etc/apt/sources.list.d/postgresql-pgdg.list > /dev/null 13 | sudo apt update 14 | sudo apt install -y postgresql-13 15 | sudo apt install -y postgresql-server-dev-13 16 | sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'bM2YGRAX*bG_QAilUid§2iD';" 17 | sudo service postgresql restart 18 | sudo apt install gcc 19 | sudo apt install make 20 | wget https://github.com/ossc-db/pg_hint_plan/archive/refs/tags/REL13_1_3_9.tar.gz 21 | tar xzvf REL13_1_3_9.tar.gz 22 | cd pg_hint_plan-REL13_1_3_9 23 | make 24 | sudo make install 25 | cd .. 26 | rm REL13_1_3_9.tar.gz 27 | sudo service postgresql restart 28 | sudo cp cost-eval/src/conf/postgres/modified-postgresql14.conf /etc/postgresql/14/main/postgresql.conf 29 | sudo cp cost-eval/src/conf/postgres/pg_hba.conf /etc/postgresql/14/main/pg_hba.conf 30 | sudo service postgresql restart 31 | touch FLAG_INSTALL_DONE 32 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_postgres_16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_DONE ]] 7 | then 8 | echo "skip installation" 9 | else 10 | 11 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 12 | echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main" | sudo tee /etc/apt/sources.list.d/postgresql-pgdg.list > /dev/null 13 | sudo apt update 14 | sudo apt install -y postgresql-16 15 | sudo apt install -y postgresql-server-dev-16 16 | sudo -u postgres psql -c "ALTER USER postgres WITH PASSWORD 'bM2YGRAX*bG_QAilUid§2iD';" 17 | sudo service postgresql restart 18 | sudo apt install gcc 19 | sudo apt install make 20 | wget https://github.com/ossc-db/pg_hint_plan/archive/refs/tags/REL16_1_6_1.tar.gz 21 | tar xzvf REL16_1_6_1.tar.gz 22 | cd pg_hint_plan-REL16_1_6_1 23 | make 24 | sudo make install 25 | cd .. 26 | rm REL16_1_6_1.tar.gz 27 | sudo service postgresql restart 28 | sudo cp cost-eval/src/conf/postgres/modified-postgresql16.conf /etc/postgresql/16/main/postgresql.conf 29 | sudo cp cost-eval/src/conf/postgres/pg_hba.conf /etc/postgresql/16/main/pg_hba.conf 30 | sudo service postgresql restart 31 | touch FLAG_INSTALL_DONE 32 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/install_tools.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | if [[ -e FLAG_INSTALL_TOOLS_DONE ]] 7 | then 8 | echo "skip tools installation" 9 | else 10 | sudo add-apt-repository ppa:deadsnakes/ppa -y 11 | sudo apt update 12 | sudo apt install python3.9 -y 13 | sudo apt install python3.9-dev -y 14 | sudo apt install python3.9-distutils -y 15 | sudo apt install python3.9-venv -y 16 | ##sudo apt install python3.9-pip -y 17 | sudo apt install mysql-client -y 18 | sudo apt install htop 19 | touch FLAG_INSTALL_TOOLS_DONE 20 | fi -------------------------------------------------------------------------------- /src/scripts/postgres_installation/resize_partition.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | # resize partitions to 900GB 5 | if [[ -e RESIZE_DONE ]] 6 | then 7 | echo "skip disk resize" 8 | else 9 | sudo timedatectl set-timezone Europe/Berlin 10 | sudo swapoff --all 11 | sudo sysctl vm.swappiness=0 12 | echo -e "vm.swappiness = 0" | sudo tee -a /etc/sysctl.conf 13 | sudo sed -i 's/.*swap.*//g' /etc/fstab 14 | sudo update-grub 15 | sudo update-initramfs -u 16 | sudo parted << EOF 17 | 18 | rm 3 19 | EOF 20 | DISK=/dev/$(lsblk | tail -n +2 | grep -Po "^[^└├\s]*" | head -n 1 ) 21 | DEVICE=$(df -h / | grep -Po '/dev/([^\s]*)') 22 | printf "yes\n400000M\n" | sudo parted $DISK ---pretend-input-tty resizepart 1 400000M 23 | touch RESIZE_DONE 24 | sudo reboot 25 | fi 26 | -------------------------------------------------------------------------------- /src/scripts/postgres_installation/resize_partition_cont.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | # resize partitions to 900GB 5 | if [[ -e FLAG_DISK_DONE ]] 6 | then 7 | echo "skip update_fstab" 8 | else 9 | DISK=/dev/$(lsblk | tail -n +2 | grep -Po "^[^└├\s]*" | head -n 1 ) 10 | DEVICE=$(df -h / | grep -Po '/dev/([^\s]*)') 11 | sudo resize2fs $DEVICE 12 | FREE=$(df -k . |awk '{print $4}' | tail -n 1) 13 | if [[ "$FREE" -gt 100000000 ]]; # at least 100G must be free 14 | then 15 | echo "disk resize ok ($FREE)" 16 | else 17 | echo "disk resize failed - only $FREE b available" 18 | exit 19 | fi 20 | touch FLAG_DISK_DONE 21 | fi 22 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/lcm-eval/8ed11d4c47bae2cb7f0740f566170f3e736e8471/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from models.zeroshot.specific_models.postgres_zero_shot import PostgresZeroShotModel 5 | 6 | 7 | def message_passing(g, model_class=PostgresZeroShotModel): 8 | # define a message passing model 9 | fc_out_kwargs = dict(p_dropout=0.0, activation_class_name='LeakyReLU', activation_class_kwargs={}, 10 | norm_class_name='Identity', norm_class_kwargs={}, residual=False, dropout=True, 11 | activation=True, inplace=True) 12 | final_mlp_kwargs = dict(width_factor=1, n_layers=2) 13 | tree_layer_kwargs = dict(width_factor=1, n_layers=2, test=True) 14 | final_mlp_kwargs.update(**fc_out_kwargs) 15 | tree_layer_kwargs.update(**fc_out_kwargs) 16 | m = model_class(device='cpu', hidden_dim=6, final_mlp_kwargs=final_mlp_kwargs, 17 | tree_layer_name='MscnConv', 18 | tree_layer_kwargs=tree_layer_kwargs, test=True) 19 | # initialize hidden states with one hot encodings 20 | no_nodes_per_type = [] 21 | hidden_dict = dict() 22 | for i, ntype in enumerate(g.ntypes): 23 | hidden = np.zeros((g.number_of_nodes(ntype=ntype), len(g.ntypes)), dtype=np.float32) 24 | hidden[:, i] = 1 25 | no_nodes_per_type.append(g.number_of_nodes(ntype=ntype)) 26 | hidden_dict[ntype] = torch.from_numpy(hidden) 27 | model_out = m.message_passing(g, hidden_dict).numpy() 28 | return no_nodes_per_type, model_out 29 | -------------------------------------------------------------------------------- /src/tests/workload_parsing/test_filter_parsing.py: -------------------------------------------------------------------------------- 1 | from cross_db_benchmark.benchmark_tools.generate_workload import Operator, LogicalOperator 2 | from cross_db_benchmark.benchmark_tools.postgres.parse_filter import parse_filter 3 | from cross_db_benchmark.benchmark_tools.postgres.utils import list_columns 4 | 5 | 6 | def test_nested_condition(): 7 | filter_cond = "((company_type_id >= 2) AND ((note)::text ~~ '%(2009)%'::text) AND (company_id >= 420) AND (company_id <= 1665) AND ((movie_id <= 1034200) OR ((movie_id <= 1793763) AND (movie_id >= 1728789) AND (movie_id <= 1786561))))" 8 | parse_tree = parse_filter(filter_cond) 9 | columns = set() 10 | list_columns(parse_tree, columns) 11 | assert columns == {(('company_type_id',), Operator.GEQ), (('movie_id',), Operator.LEQ), (('note',), Operator.LIKE), 12 | (('company_id',), Operator.LEQ), (('movie_id',), Operator.GEQ), (('company_id',), Operator.GEQ), 13 | (None, LogicalOperator.AND), (None, LogicalOperator.OR)} 14 | 15 | 16 | def test_in_conditions(): 17 | filter_cond = '(((name)::text ~~ \'%Michael%\'::text) AND ((name_pcode_cf)::text ~~ \'%A5362%\'::text) AND (((imdb_index)::text = ANY (\'{IV,II,III,I}\'::text[])) OR ((surname_pcode)::text = \'R5\'::text)))' 18 | parse_tree = parse_filter(filter_cond) 19 | columns = set() 20 | list_columns(parse_tree, columns) 21 | assert columns == {(('name_pcode_cf',), Operator.LIKE), (('surname_pcode',), Operator.EQ), 22 | (('imdb_index',), Operator.IN), (('name',), Operator.LIKE), 23 | (None, LogicalOperator.AND), (None, LogicalOperator.OR)} 24 | -------------------------------------------------------------------------------- /src/training/batch_to_funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from training.training.utils import recursive_to 5 | 6 | 7 | def dace_batch_to(batch, device, label_norm): 8 | seq_encodings, attention_masks, loss_masks, run_times, labels, sample_idxs = batch 9 | recursive_to(seq_encodings, device) 10 | recursive_to(attention_masks, device) 11 | recursive_to(run_times, device) 12 | recursive_to(loss_masks, device) 13 | recursive_to(labels, device) 14 | return (seq_encodings, attention_masks, loss_masks, run_times), labels, sample_idxs 15 | 16 | 17 | def simple_batch_to(batch, device, label_norm): 18 | query_plans, labels, sample_idxs = batch 19 | if label_norm is not None: 20 | labels = label_norm.transform(np.asarray(labels).reshape(-1, 1)) 21 | labels = labels.reshape(-1) 22 | labels = torch.as_tensor(labels, device=device, dtype=torch.float) 23 | recursive_to(query_plans, device) 24 | recursive_to(labels, device) 25 | return query_plans, labels, sample_idxs 26 | 27 | 28 | def batch_to(batch, device, label_norm): 29 | graph, features, label, sample_idxs = batch 30 | 31 | # normalize the labels for training 32 | if label_norm is not None: 33 | label = label_norm.transform(np.asarray(label).reshape(-1, 1)) 34 | label = label.reshape(-1) 35 | 36 | label = torch.as_tensor(label, device=device, dtype=torch.float) 37 | recursive_to(features, device) 38 | recursive_to(label, device) 39 | # recursive_to(graph, device) 40 | graph = graph.to(device, non_blocking=True) 41 | return (graph, features), label, sample_idxs 42 | -------------------------------------------------------------------------------- /src/training/dataset/plan_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class PlanDataset(Dataset): 7 | def __init__(self, plans, idxs): 8 | self.plans = plans 9 | self.idxs = [int(i) for i in idxs] 10 | assert len(self.plans) == len(self.idxs) 11 | 12 | def __len__(self): 13 | return len(self.plans) 14 | 15 | def __getitem__(self, i: int): 16 | return self.idxs[i], self.plans[i] 17 | 18 | def split(self, ratio: float) -> Tuple: 19 | split_idx = int(len(self) * ratio) 20 | return PlanDataset(self.plans[:split_idx], self.idxs[:split_idx]), PlanDataset(self.plans[split_idx:], self.idxs[split_idx:]) 21 | -------------------------------------------------------------------------------- /src/training/preprocessing/feature_statistics.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | from enum import Enum 5 | 6 | import numpy as np 7 | from sklearn.preprocessing import RobustScaler 8 | from tqdm import tqdm 9 | 10 | 11 | def gather_values_recursively(json_dict, value_dict=None): 12 | if value_dict is None: 13 | value_dict = collections.defaultdict(list) 14 | 15 | if isinstance(json_dict, dict): 16 | for k, v in json_dict.items(): 17 | if not (isinstance(v, list) or isinstance(v, tuple) or isinstance(v, dict)): 18 | value_dict[k].append(v) 19 | elif (isinstance(v, list) or isinstance(v, tuple)) and len(v) > 0 and \ 20 | (isinstance(v[0], int) or isinstance(v[0], float) or isinstance(v[0], str)): 21 | for v_e in v: 22 | value_dict[k].append(v_e) 23 | else: 24 | gather_values_recursively(v, value_dict=value_dict) 25 | elif isinstance(json_dict, tuple) or isinstance(json_dict, list): 26 | for e in json_dict: 27 | gather_values_recursively(e, value_dict=value_dict) 28 | 29 | return value_dict 30 | 31 | 32 | class FeatureType(Enum): 33 | numeric = 'numeric' 34 | categorical = 'categorical' 35 | 36 | def __str__(self): 37 | return self.value 38 | 39 | 40 | def gather_feature_statistics(workload_run_paths, target): 41 | """ 42 | Traverses a JSON object and gathers metadata for each key. Depending on whether the values of the key are 43 | categorical or numerical, different statistics are collected. This is later on used to automate the feature 44 | extraction during the training (e.g., how to consistently map a categorical value to an index). 45 | """ 46 | 47 | run_stats = [] 48 | for source in tqdm(workload_run_paths): 49 | assert os.path.exists(source), f"{source} does not exist" 50 | try: 51 | with open(source) as json_file: 52 | run_stats.append(json.load(json_file)) 53 | except: 54 | raise ValueError(f"Could not read {source}") 55 | value_dict = gather_values_recursively(run_stats) 56 | 57 | print("Saving") 58 | # save unique values for categorical features and scale and center of RobustScaler for numerical ones 59 | statistics_dict = dict() 60 | for k, values in value_dict.items(): 61 | values = [v for v in values if v is not None] 62 | if len(values) == 0: 63 | continue 64 | 65 | if all([isinstance(v, int) or isinstance(v, float) or v is None for v in values]): 66 | scaler = RobustScaler() 67 | np_values = np.array(values, dtype=np.float32).reshape(-1, 1) 68 | scaler.fit(np_values) 69 | 70 | statistics_dict[k] = dict(max=float(np_values.max()), 71 | scale=scaler.scale_.item(), 72 | center=scaler.center_.item(), 73 | type=str(FeatureType.numeric)) 74 | else: 75 | unique_values = set(values) 76 | statistics_dict[k] = dict(value_dict={v: id for id, v in enumerate(unique_values)}, 77 | no_vals=len(unique_values), 78 | type=str(FeatureType.categorical)) 79 | 80 | # save as json 81 | os.makedirs(os.path.dirname(target), exist_ok=True) 82 | with open(target, 'w') as outfile: 83 | json.dump(statistics_dict, outfile) 84 | -------------------------------------------------------------------------------- /src/training/training/metrics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from sklearn.metrics import mean_squared_error 5 | 6 | 7 | class Metric: 8 | """ 9 | Abstract class defining a metric used to evaluate the zero-shot cost model performance (e.g., Q-error) 10 | """ 11 | 12 | def __init__(self, metric_prefix='val_', metric_name='metric', maximize=True, early_stopping_metric=False): 13 | self.maximize = maximize 14 | self.default_value = -np.inf 15 | if not self.maximize: 16 | self.default_value = np.inf 17 | self.best_seen_value = self.default_value 18 | self.last_seen_value = self.default_value 19 | self.metric_name = metric_prefix + metric_name 20 | self.best_model = None 21 | self.early_stopping_metric = early_stopping_metric 22 | 23 | def evaluate(self, model=None, metrics_dict=None, **kwargs): 24 | metric = self.default_value 25 | try: 26 | metric = self.evaluate_metric(**kwargs) 27 | except ValueError as e: 28 | print(f"Observed ValueError for {self.metric_name} in calculation: {e}") 29 | self.last_seen_value = metric 30 | 31 | metrics_dict[self.metric_name] = metric 32 | print(f"{self.metric_name:<30}: {metric:<10.4f} [best: {self.best_seen_value:.4f}]") 33 | 34 | best_seen = False 35 | if (self.maximize and metric > self.best_seen_value) or (not self.maximize and metric < self.best_seen_value): 36 | self.best_seen_value = metric 37 | best_seen = True 38 | if model is not None: 39 | self.best_model = copy.deepcopy(model.state_dict()) 40 | return best_seen 41 | 42 | 43 | class MAPE(Metric): 44 | def __init__(self, **kwargs): 45 | super().__init__(metric_name='mape', maximize=False, **kwargs) 46 | 47 | def evaluate_metric(self, labels=None, preds=None, probs=None): 48 | mape = np.mean(np.abs((labels - preds) / labels)) 49 | return mape 50 | 51 | def evaluate_metric(self, labels=None, preds=None): 52 | raise NotImplementedError 53 | 54 | 55 | class RMSE(Metric): 56 | def __init__(self, **kwargs): 57 | super().__init__(metric_name='mse', maximize=False, **kwargs) 58 | 59 | def evaluate_metric(self, labels=None, preds=None, probs=None): 60 | val_mse = np.sqrt(mean_squared_error(labels, preds)) 61 | return val_mse 62 | 63 | 64 | class MAPE(Metric): 65 | def __init__(self, **kwargs): 66 | super().__init__(metric_name='mape', maximize=False, **kwargs) 67 | 68 | def evaluate_metric(self, labels=None, preds=None, probs=None): 69 | mape = np.mean(np.abs((labels - preds) / labels)) 70 | return mape 71 | 72 | 73 | class QError(Metric): 74 | def __init__(self, percentile=50, min_val=0.1, **kwargs): 75 | super().__init__(metric_name=f'median_q_error_{percentile}', maximize=False, **kwargs) 76 | self.percentile = percentile 77 | self.min_val = min_val 78 | 79 | def evaluate_metric(self, labels=None, preds=None, probs=None): 80 | #if not np.all(labels >= self.min_val): 81 | # print("WARNING: some labels are smaller than min_val") 82 | preds = np.abs(preds) 83 | # preds = np.clip(preds, self.min_val, np.inf) 84 | 85 | q_errors = np.maximum(labels / preds, preds / labels) 86 | q_errors = np.nan_to_num(q_errors, nan=np.inf) 87 | median_q = np.percentile(q_errors, self.percentile) 88 | return median_q 89 | -------------------------------------------------------------------------------- /src/training/training/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import hashlib 3 | import json 4 | import os 5 | 6 | import dgl 7 | import torch 8 | from dgl import DGLHeteroGraph 9 | 10 | from cross_db_benchmark.benchmark_tools.postgres.json_plan import OperatorTree 11 | 12 | 13 | def recursive_to(iterable, device): 14 | if isinstance(iterable, (dgl.DGLGraph, DGLHeteroGraph)): 15 | iterable.to(device, non_blocking=True) 16 | if isinstance(iterable, torch.Tensor): 17 | iterable.data = iterable.data.to(device, non_blocking=True) 18 | elif isinstance(iterable, collections.abc.Mapping): 19 | for v in iterable.values(): 20 | recursive_to(v, device) 21 | elif isinstance(iterable, OperatorTree): 22 | iterable.encoded_features = iterable.encoded_features.to(device, non_blocking=True) 23 | for c in iterable.children: 24 | recursive_to(c, device) 25 | elif isinstance(iterable, (list, tuple)): 26 | for v in iterable: 27 | recursive_to(v, device) 28 | 29 | 30 | def flatten_dict(d, parent_key='', sep='_'): 31 | """ 32 | https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys 33 | """ 34 | items = [] 35 | for k, v in d.items(): 36 | new_key = parent_key + sep + k if parent_key else k 37 | if isinstance(v, collections.MutableMapping): 38 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 39 | else: 40 | items.append((new_key, v)) 41 | return dict(items) 42 | 43 | 44 | def dict_hash(dictionary): 45 | """MD5 hash of a dictionary.""" 46 | dhash = hashlib.md5() 47 | # We need to sort arguments so {'a': 1, 'b': 2} is the same as {'b': 2, 'a': 1} 48 | encoded = json.dumps(dictionary, sort_keys=True).encode() 49 | dhash.update(encoded) 50 | return dhash.hexdigest() 51 | 52 | 53 | def find_early_stopping_metric(metrics): 54 | potential_metrics = [m for m in metrics if m.early_stopping_metric] 55 | assert len(potential_metrics) == 1 56 | early_stopping_metric = potential_metrics[0] 57 | return early_stopping_metric 58 | 59 | def save_config(save_dict, target_dir, json_name): 60 | os.makedirs(target_dir, exist_ok=True) 61 | target_params_path = os.path.join(target_dir, json_name) 62 | print(f"Saving best params to {target_params_path}") 63 | 64 | with open(target_params_path, 'w') as f: 65 | json.dump(save_dict, f) 66 | --------------------------------------------------------------------------------